cli: add streams to rpc, generic 'spawn' command (#179732)

* cli: apply improvements from integrated wsl branch

* cli: add streams to rpc, generic 'spawn' command

For the "exec server" concept, fyi @aeschli.

* update clippy and apply fixes

* fix unused imports :(
This commit is contained in:
Connor Peet
2023-04-12 08:51:29 -07:00
committed by GitHub
parent bb7570f4f8
commit 2d8ff25c85
23 changed files with 572 additions and 184 deletions

View File

@@ -8,6 +8,7 @@ use tokio::{
pin,
sync::mpsc,
};
use tokio_util::codec::Decoder;
use crate::{
rpc::{self, MaybeSync, Serialization},
@@ -38,7 +39,6 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
rpc::RpcBuilder::new(MsgPackSerializer {})
}
#[allow(clippy::read_zero_byte_vec)] // false positive
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
read: impl AsyncRead + Unpin,
@@ -46,34 +46,45 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
let mut read = BufReader::new(read);
let mut decode_buf = vec![];
let mut decoder = U32PrefixedCodec {};
let mut decoder_buf = bytes::BytesMut::new();
let shutdown_fut = shutdown_rx.wait();
pin!(shutdown_fut);
loop {
tokio::select! {
u = read.read_u32() => {
let msg_length = u? as usize;
decode_buf.resize(msg_length, 0);
tokio::select! {
r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) {
r = read.read_buf(&mut decoder_buf) => {
r?;
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
match dispatcher.dispatch(&frame) {
MaybeSync::Sync(Some(v)) => {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
},
MaybeSync::Sync(None) => continue,
MaybeSync::Future(fut) => {
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
}
});
}
},
r = &mut shutdown_fut => return Ok(r.ok()),
MaybeSync::Stream((stream, fut)) => {
if let Some(stream) = stream {
dispatcher.register_stream(write_tx.clone(), stream).await;
}
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
let _ = write_tx.send(v).await;
}
});
}
}
};
},
Some(m) = write_rx.recv() => {
@@ -88,3 +99,33 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
write.flush().await?;
}
}
/// Reader that reads length-prefixed msgpack messages in a cancellation-safe
/// way using Tokio's codecs.
pub struct U32PrefixedCodec {}
const U32_SIZE: usize = 4;
impl tokio_util::codec::Decoder for U32PrefixedCodec {
type Item = Vec<u8>;
type Error = io::Error;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
src.reserve(U32_SIZE - src.len());
return Ok(None);
}
let mut be_bytes = [0; U32_SIZE];
be_bytes.copy_from_slice(&src[..U32_SIZE]);
let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize);
if src.len() < required_len {
src.reserve(required_len - src.len());
return Ok(None);
}
let msg = src[U32_SIZE..].to_vec();
src.resize(0, 0);
Ok(Some(msg))
}
}