mirror of
https://github.com/microsoft/vscode.git
synced 2026-04-20 00:28:52 +01:00
cli: add streams to rpc, generic 'spawn' command (#179732)
* cli: apply improvements from integrated wsl branch * cli: add streams to rpc, generic 'spawn' command For the "exec server" concept, fyi @aeschli. * update clippy and apply fixes * fix unused imports :(
This commit is contained in:
218
cli/src/rpc.rs
218
cli/src/rpc.rs
@@ -15,17 +15,26 @@ use std::{
|
||||
use crate::log;
|
||||
use futures::{future::BoxFuture, Future, FutureExt};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
|
||||
use crate::util::errors::AnyError;
|
||||
|
||||
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
|
||||
pub type AsyncMethod =
|
||||
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
|
||||
pub type Duplex = Arc<
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
|
||||
>;
|
||||
|
||||
pub enum Method {
|
||||
Sync(SyncMethod),
|
||||
Async(AsyncMethod),
|
||||
Duplex(Duplex),
|
||||
}
|
||||
|
||||
/// Serialization is given to the RpcBuilder and defines how data gets serialized
|
||||
@@ -81,6 +90,12 @@ pub struct RpcMethodBuilder<S, C> {
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DuplexStreamStarted {
|
||||
pub for_request_id: u32,
|
||||
pub stream_id: u32,
|
||||
}
|
||||
|
||||
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
/// Registers a synchronous rpc call that returns its result directly.
|
||||
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
|
||||
@@ -179,14 +194,105 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
);
|
||||
}
|
||||
|
||||
/// Registers an async rpc call that returns a Future containing a duplex
|
||||
/// stream that should be handled by the client.
|
||||
pub fn register_duplex<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
|
||||
where
|
||||
P: DeserializeOwned + Send + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<R, AnyError>> + Send,
|
||||
F: (Fn(DuplexStream, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
method_name,
|
||||
Method::Duplex(Arc::new(move |id, body| {
|
||||
let param = match serial.deserialize::<RequestParams<P>>(body) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return (
|
||||
None,
|
||||
future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}))
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let callback = callback.clone();
|
||||
let serial = serial.clone();
|
||||
let context = context.clone();
|
||||
let stream_id = next_message_id();
|
||||
let (client, server) = tokio::io::duplex(8192);
|
||||
|
||||
let fut = async move {
|
||||
match callback(server, param.params, context).await {
|
||||
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
Some(StreamDto {
|
||||
req_id: id.unwrap_or(0),
|
||||
stream_id,
|
||||
duplex: client,
|
||||
}),
|
||||
fut.boxed(),
|
||||
)
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/// Builds into a usable, sync rpc dispatcher.
|
||||
pub fn build(self, log: log::Logger) -> RpcDispatcher<S, C> {
|
||||
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
|
||||
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
|
||||
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let s1 = streams.clone();
|
||||
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
|
||||
let s1 = s1.clone();
|
||||
async move {
|
||||
s1.lock().await.remove(&m.stream);
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
let s2 = streams.clone();
|
||||
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
|
||||
let s2 = s2.clone();
|
||||
async move {
|
||||
let mut lock = s2.lock().await;
|
||||
if let Some(stream) = lock.get_mut(&m.stream) {
|
||||
let _ = stream.write_all(&m.segment).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
RpcDispatcher {
|
||||
log,
|
||||
context: self.context,
|
||||
calls: self.calls,
|
||||
serializer: self.serializer,
|
||||
methods: Arc::new(self.methods),
|
||||
streams,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -281,6 +387,7 @@ pub struct RpcDispatcher<S, C> {
|
||||
serializer: Arc<S>,
|
||||
methods: Arc<HashMap<&'static str, Method>>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
|
||||
}
|
||||
|
||||
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
@@ -310,6 +417,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
||||
match method {
|
||||
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
|
||||
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
|
||||
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
|
||||
None => MaybeSync::Sync(id.map(|id| {
|
||||
self.serializer.serialize(&ErrorResponse {
|
||||
id,
|
||||
@@ -333,11 +441,91 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a stream call returned from dispatch().
|
||||
pub async fn register_stream(
|
||||
&self,
|
||||
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
|
||||
dto: StreamDto,
|
||||
) {
|
||||
let stream_id = dto.stream_id;
|
||||
let for_request_id = dto.req_id;
|
||||
let (mut read, write) = tokio::io::split(dto.duplex);
|
||||
let serial = self.serializer.clone();
|
||||
|
||||
self.streams.lock().await.insert(dto.stream_id, write);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let r = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_STARTED,
|
||||
params: DuplexStreamStarted {
|
||||
stream_id,
|
||||
for_request_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut buf = Vec::with_capacity(4096);
|
||||
loop {
|
||||
match read.read_buf(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => {
|
||||
let r = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_DATA,
|
||||
params: StreamDataParams {
|
||||
segment: &buf[..n],
|
||||
stream: stream_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
buf.truncate(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_ENDED,
|
||||
params: StreamEndedParams { stream: stream_id },
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
pub fn context(&self) -> Arc<C> {
|
||||
self.context.clone()
|
||||
}
|
||||
}
|
||||
|
||||
const METHOD_STREAM_STARTED: &str = "stream_started";
|
||||
const METHOD_STREAM_DATA: &str = "stream_data";
|
||||
const METHOD_STREAM_ENDED: &str = "stream_ended";
|
||||
|
||||
trait AssertIsSync: Sync {}
|
||||
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
|
||||
|
||||
@@ -349,6 +537,25 @@ struct PartialIncoming {
|
||||
pub error: Option<ResponseError>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamDataIncomingParams {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: Vec<u8>,
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamDataParams<'a> {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: &'a [u8],
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamEndedParams {
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct FullRequest<M: AsRef<str>, P> {
|
||||
pub id: Option<u32>,
|
||||
@@ -384,7 +591,14 @@ enum Outcome {
|
||||
Error(ResponseError),
|
||||
}
|
||||
|
||||
pub struct StreamDto {
|
||||
stream_id: u32,
|
||||
req_id: u32,
|
||||
duplex: DuplexStream,
|
||||
}
|
||||
|
||||
pub enum MaybeSync {
|
||||
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
|
||||
Future(BoxFuture<'static, Option<Vec<u8>>>),
|
||||
Sync(Option<Vec<u8>>),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user