cli: add streams to rpc, generic 'spawn' command (#179732)

* cli: apply improvements from integrated wsl branch

* cli: add streams to rpc, generic 'spawn' command

For the "exec server" concept, fyi @aeschli.

* update clippy and apply fixes

* fix unused imports :(
This commit is contained in:
Connor Peet
2023-04-12 08:51:29 -07:00
committed by GitHub
parent bb7570f4f8
commit 2d8ff25c85
23 changed files with 572 additions and 184 deletions

View File

@@ -15,17 +15,26 @@ use std::{
use crate::log;
use futures::{future::BoxFuture, Future, FutureExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
sync::{mpsc, oneshot},
};
use crate::util::errors::AnyError;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
pub type AsyncMethod =
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
pub type Duplex = Arc<
dyn Send
+ Sync
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
>;
pub enum Method {
Sync(SyncMethod),
Async(AsyncMethod),
Duplex(Duplex),
}
/// Serialization is given to the RpcBuilder and defines how data gets serialized
@@ -81,6 +90,12 @@ pub struct RpcMethodBuilder<S, C> {
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
}
#[derive(Serialize)]
struct DuplexStreamStarted {
pub for_request_id: u32,
pub stream_id: u32,
}
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
/// Registers a synchronous rpc call that returns its result directly.
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
@@ -179,14 +194,105 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
);
}
/// Registers an async rpc call that returns a Future containing a duplex
/// stream that should be handled by the client.
pub fn register_duplex<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
where
P: DeserializeOwned + Send + 'static,
R: Serialize + Send + Sync + 'static,
Fut: Future<Output = Result<R, AnyError>> + Send,
F: (Fn(DuplexStream, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
{
let serial = self.serializer.clone();
let context = self.context.clone();
self.methods.insert(
method_name,
Method::Duplex(Arc::new(move |id, body| {
let param = match serial.deserialize::<RequestParams<P>>(body) {
Ok(p) => p,
Err(err) => {
return (
None,
future::ready(id.map(|id| {
serial.serialize(&ErrorResponse {
id,
error: ResponseError {
code: 0,
message: format!("{:?}", err),
},
})
}))
.boxed(),
);
}
};
let callback = callback.clone();
let serial = serial.clone();
let context = context.clone();
let stream_id = next_message_id();
let (client, server) = tokio::io::duplex(8192);
let fut = async move {
match callback(server, param.params, context).await {
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
Err(err) => id.map(|id| {
serial.serialize(&ErrorResponse {
id,
error: ResponseError {
code: -1,
message: format!("{:?}", err),
},
})
}),
}
};
(
Some(StreamDto {
req_id: id.unwrap_or(0),
stream_id,
duplex: client,
}),
fut.boxed(),
)
})),
);
}
/// Builds into a usable, sync rpc dispatcher.
pub fn build(self, log: log::Logger) -> RpcDispatcher<S, C> {
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let s1 = streams.clone();
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
let s1 = s1.clone();
async move {
s1.lock().await.remove(&m.stream);
Ok(())
}
});
let s2 = streams.clone();
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
let s2 = s2.clone();
async move {
let mut lock = s2.lock().await;
if let Some(stream) = lock.get_mut(&m.stream) {
let _ = stream.write_all(&m.segment).await;
}
Ok(())
}
});
RpcDispatcher {
log,
context: self.context,
calls: self.calls,
serializer: self.serializer,
methods: Arc::new(self.methods),
streams,
}
}
}
@@ -281,6 +387,7 @@ pub struct RpcDispatcher<S, C> {
serializer: Arc<S>,
methods: Arc<HashMap<&'static str, Method>>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
}
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
@@ -310,6 +417,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
match method {
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
None => MaybeSync::Sync(id.map(|id| {
self.serializer.serialize(&ErrorResponse {
id,
@@ -333,11 +441,91 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
}
}
/// Registers a stream call returned from dispatch().
pub async fn register_stream(
&self,
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
dto: StreamDto,
) {
let stream_id = dto.stream_id;
let for_request_id = dto.req_id;
let (mut read, write) = tokio::io::split(dto.duplex);
let serial = self.serializer.clone();
self.streams.lock().await.insert(dto.stream_id, write);
tokio::spawn(async move {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_STARTED,
params: DuplexStreamStarted {
stream_id,
for_request_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
let mut buf = Vec::with_capacity(4096);
loop {
match read.read_buf(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let r = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_DATA,
params: StreamDataParams {
segment: &buf[..n],
stream: stream_id,
},
})
.into(),
)
.await;
if r.is_err() {
return;
}
buf.truncate(0);
}
}
}
let _ = write_tx
.send(
serial
.serialize(&FullRequest {
id: None,
method: METHOD_STREAM_ENDED,
params: StreamEndedParams { stream: stream_id },
})
.into(),
)
.await;
});
}
pub fn context(&self) -> Arc<C> {
self.context.clone()
}
}
const METHOD_STREAM_STARTED: &str = "stream_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";
trait AssertIsSync: Sync {}
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
@@ -349,6 +537,25 @@ struct PartialIncoming {
pub error: Option<ResponseError>,
}
#[derive(Deserialize)]
struct StreamDataIncomingParams {
#[serde(with = "serde_bytes")]
pub segment: Vec<u8>,
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamDataParams<'a> {
#[serde(with = "serde_bytes")]
pub segment: &'a [u8],
pub stream: u32,
}
#[derive(Serialize, Deserialize)]
struct StreamEndedParams {
pub stream: u32,
}
#[derive(Serialize)]
pub struct FullRequest<M: AsRef<str>, P> {
pub id: Option<u32>,
@@ -384,7 +591,14 @@ enum Outcome {
Error(ResponseError),
}
pub struct StreamDto {
stream_id: u32,
req_id: u32,
duplex: DuplexStream,
}
pub enum MaybeSync {
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
Future(BoxFuture<'static, Option<Vec<u8>>>),
Sync(Option<Vec<u8>>),
}