diff --git a/cli/src/json_rpc.rs b/cli/src/json_rpc.rs index 38088e21589..9cc5ad1ade1 100644 --- a/cli/src/json_rpc.rs +++ b/cli/src/json_rpc.rs @@ -33,10 +33,12 @@ impl Serialization for JsonRpcSerializer { } /// Creates a new RPC Builder that serializes to JSON. +#[allow(dead_code)] pub fn new_json_rpc() -> rpc::RpcBuilder { rpc::RpcBuilder::new(JsonRpcSerializer {}) } +#[allow(dead_code)] pub async fn start_json_rpc( dispatcher: rpc::RpcDispatcher, read: impl AsyncRead + Unpin, diff --git a/cli/src/lib.rs b/cli/src/lib.rs index ff3d3662853..fd0917843ba 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -20,3 +20,4 @@ pub mod util; mod rpc; mod json_rpc; +mod msgpack_rpc; diff --git a/cli/src/msgpack_rpc.rs b/cli/src/msgpack_rpc.rs new file mode 100644 index 00000000000..535fb31bbeb --- /dev/null +++ b/cli/src/msgpack_rpc.rs @@ -0,0 +1,80 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, + sync::mpsc, +}; + +use crate::{ + rpc::{self, MaybeSync, Serialization}, + util::errors::{AnyError, InvalidRpcDataError}, +}; +use std::io; + +#[derive(Copy, Clone)] +pub struct MsgPackSerializer {} + +impl Serialization for MsgPackSerializer { + fn serialize(&self, value: impl serde::Serialize) -> Vec { + rmp_serde::to_vec_named(&value).expect("expected to serialize") + } + + fn deserialize(&self, b: &[u8]) -> Result { + rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into()) + } +} + +pub type MsgPackCaller = rpc::RpcCaller; + +/// Creates a new RPC Builder that serializes to JSON. +pub fn new_msgpack_rpc() -> rpc::RpcBuilder { + rpc::RpcBuilder::new(MsgPackSerializer {}) +} + +pub async fn start_msgpack_rpc( + dispatcher: rpc::RpcDispatcher, + read: impl AsyncRead + Unpin, + mut write: impl AsyncWrite + Unpin, + mut msg_rx: mpsc::UnboundedReceiver>, + mut shutdown_rx: mpsc::UnboundedReceiver, +) -> io::Result> { + let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); + let mut read = BufReader::new(read); + let mut decode_buf = vec![]; + + loop { + tokio::select! { + u = read.read_u32() => { + let msg_length = u? as usize; + decode_buf.resize(msg_length, 0); + tokio::select! { + r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) { + MaybeSync::Sync(Some(v)) => { + write_tx.send(v).ok(); + }, + MaybeSync::Sync(None) => continue, + MaybeSync::Future(fut) => { + let write_tx = write_tx.clone(); + tokio::spawn(async move { + if let Some(v) = fut.await { + write_tx.send(v).ok(); + } + }); + } + }, + r = shutdown_rx.recv() => return Ok(r), + }; + }, + Some(m) = write_rx.recv() => { + write.write_all(&m).await?; + }, + Some(m) = msg_rx.recv() => { + write.write_all(&m).await?; + }, + r = shutdown_rx.recv() => return Ok(r), + } + } +} diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index 0f112c74734..683df012248 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -9,7 +9,6 @@ pub mod legal; pub mod paths; pub mod shutdown_signal; -mod wsl_server; mod control_server; mod name_generator; mod port_forwarder; @@ -17,6 +16,7 @@ mod protocol; #[cfg_attr(unix, path = "tunnels/server_bridge_unix.rs")] #[cfg_attr(windows, path = "tunnels/server_bridge_windows.rs")] mod server_bridge; +mod server_multiplexer; mod service; #[cfg(target_os = "linux")] mod service_linux; @@ -25,9 +25,10 @@ mod service_macos; #[cfg(target_os = "windows")] mod service_windows; mod socket_signal; +mod wsl_server; pub use control_server::serve; -pub use wsl_server::serve_wsl; pub use service::{ create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, }; +pub use wsl_server::serve_wsl; diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index cf14b49f5d6..c6fa03f5e56 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -48,20 +48,15 @@ use super::protocol::{ VersionParams, }; use super::server_bridge::{get_socket_rw_stream, ServerBridge}; +use super::server_multiplexer::ServerMultiplexer; use super::shutdown_signal::ShutdownSignal; -use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal}; +use super::socket_signal::{ + ClientMessageDecoder, ServerMessageDestination, ServerMessageSink, SocketSignal, +}; -type ServerBridgeListLock = Arc>>>; type HttpRequestsMap = Arc>>; type CodeServerCell = Arc>>; -struct ServerBridgeRec { - id: u16, - // bridge is removed when there's a write loop currently active - bridge: Option, - write_queue: Vec>, -} - struct HandlerContext { /// Log handle for the server log: log::Logger, @@ -74,7 +69,7 @@ struct HandlerContext { /// Connected VS Code Server code_server: CodeServerCell, /// Potentially many "websocket" connections to client - server_bridges: ServerBridgeListLock, + server_bridges: ServerMultiplexer, // the cli arguments used to start the code server code_server_args: CodeServerArgs, /// port forwarding functionality @@ -96,28 +91,7 @@ pub fn next_message_id() -> u32 { impl HandlerContext { async fn dispose(&self) { - let bridges = { - let mut lock = self.server_bridges.lock().unwrap(); - lock.take() - }; - - let bridges = match bridges { - Some(b) => b, - None => return, - }; - - for rec in bridges { - if let Some(b) = rec.bridge { - if let Err(e) = b.close().await { - warning!( - self.log, - "Could not properly dispose of connection context: {}", - e - ) - } - } - } - + self.server_bridges.dispose().await; info!(self.log, "Disposed of connection to running server."); } } @@ -295,7 +269,7 @@ async fn process_socket( let (socket_tx, mut socket_rx) = mpsc::channel(4); let rx_counter = Arc::new(AtomicUsize::new(0)); let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); - let server_bridges = Arc::new(std::sync::Mutex::new(Some(vec![]))); + let server_bridges = ServerMultiplexer::new(); let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone()); let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext { did_update: Arc::new(AtomicBool::new(false)), @@ -426,13 +400,6 @@ async fn process_socket( debug!(log, "Closing connection: {}", reason.0); break; } - SocketSignal::CloseServerBridge(id) => { - let mut lock = server_bridges.lock().unwrap(); - match &mut *lock { - Some(bridges) => bridges.retain(|sb| sb.id != id), - None => {} - } - } } } } @@ -618,37 +585,34 @@ async fn attach_server_bridge( log: &log::Logger, code_server: SocketCodeServer, socket_tx: mpsc::Sender, - server_bridges: ServerBridgeListLock, + multiplexer: ServerMultiplexer, socket_id: u16, compress: bool, ) -> Result { let (server_messages, decoder) = if compress { ( - ServerMessageSink::new_compressed(socket_tx), + ServerMessageSink::new_compressed( + multiplexer.clone(), + socket_id, + ServerMessageDestination::Channel(socket_tx), + ), ClientMessageDecoder::new_compressed(), ) } else { ( - ServerMessageSink::new_plain(socket_tx), + ServerMessageSink::new_plain( + multiplexer.clone(), + socket_id, + ServerMessageDestination::Channel(socket_tx), + ), ClientMessageDecoder::new_plain(), ) }; - let attached_fut = - ServerBridge::new(&code_server.socket, socket_id, server_messages, decoder).await; - + let attached_fut = ServerBridge::new(&code_server.socket, server_messages, decoder).await; match attached_fut { Ok(a) => { - let mut lock = server_bridges.lock().unwrap(); - let bridge_rec = ServerBridgeRec { - id: socket_id, - bridge: Some(a), - write_queue: vec![], - }; - match &mut *lock { - Some(server_bridges) => (*server_bridges).push(bridge_rec), - None => *lock = Some(vec![bridge_rec]), - } + multiplexer.register(socket_id, a); trace!(log, "Attached to server"); Ok(socket_id) } @@ -660,71 +624,14 @@ async fn attach_server_bridge( /// to ensure message order is preserved exactly, which is necessary for compression. fn handle_server_message( log: &log::Logger, - bridges_lock: &ServerBridgeListLock, + multiplexer: &ServerMultiplexer, params: ServerMessageParams, ) -> Result { - let mut lock = bridges_lock.lock().unwrap(); - - match &mut *lock { - Some(server_bridges) => match server_bridges.iter_mut().find(|b| b.id == params.i) { - Some(sb) => { - sb.write_queue.push(params.body); - if let Some(bridge) = sb.bridge.take() { - let bridges_lock = bridges_lock.clone(); - let log = log.clone(); - tokio::spawn(start_bridge_write_loop(log, sb.id, bridge, bridges_lock)); - } - } - None => return Err(AnyError::from(NoAttachedServerError())), - }, - None => return Err(AnyError::from(NoAttachedServerError())), + if multiplexer.write_message(log, params.i, params.body) { + Ok(EmptyObject {}) + } else { + Err(AnyError::from(NoAttachedServerError())) } - - Ok(EmptyObject {}) -} - -/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and -/// runs until there's no more items in the 'write queue'. At that point, if the -/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll -/// return the ServerBridge so that the next handle_server_message call starts -/// the loop again. Otherwise, it'll close the bridge. -async fn start_bridge_write_loop( - log: log::Logger, - id: u16, - mut bridge: ServerBridge, - bridges_lock: ServerBridgeListLock, -) { - let mut items_vec = vec![]; - loop { - { - let mut lock = bridges_lock.lock().unwrap(); - let server_bridges = match &mut *lock { - Some(sb) => sb, - None => break, - }; - - let bridge_rec = match server_bridges.iter_mut().find(|b| id == b.id) { - Some(b) => b, - None => break, - }; - - if bridge_rec.write_queue.is_empty() { - bridge_rec.bridge = Some(bridge); - return; - } - - std::mem::swap(&mut bridge_rec.write_queue, &mut items_vec); - } - - for item in items_vec.drain(..) { - if let Err(e) = bridge.write(item).await { - warning!(log, "Error writing to server: {:?}", e); - break; - } - } - } - - bridge.close().await.ok(); // got here from `break` above, meaning our record got cleared. Close the bridge if so } fn handle_prune(paths: &LauncherPaths) -> Result, AnyError> { diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index cc271323791..b10bba7a32a 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -57,16 +57,9 @@ pub struct ForwardResult { /// The `install_local` method in the wsl control server #[derive(Deserialize, Debug)] pub struct InstallFromLocalFolderParams { - pub commit_id: String, - pub quality: Quality, pub archive_path: String, - #[serde(default)] - pub extensions: Vec, -} - -#[derive(Serialize, Debug)] -pub struct InstallPortServerResult { - pub port: u16, + #[serde(flatten)] + pub inner: ServeParams, } #[derive(Deserialize, Debug)] diff --git a/cli/src/tunnels/server_bridge_unix.rs b/cli/src/tunnels/server_bridge_unix.rs index 9f06223ccbb..c7be34cf5d0 100644 --- a/cli/src/tunnels/server_bridge_unix.rs +++ b/cli/src/tunnels/server_bridge_unix.rs @@ -37,7 +37,6 @@ const BUFFER_SIZE: usize = 65536; impl ServerBridge { pub async fn new( path: &Path, - index: u16, mut target: ServerMessageSink, decoder: ClientMessageDecoder, ) -> Result { @@ -50,11 +49,10 @@ impl ServerBridge { match read.read(&mut read_buf).await { Err(_) => return, Ok(0) => { - let _ = target.closed_server_bridge(index).await; return; // EOF } Ok(s) => { - let send = target.server_message(index, &read_buf[..s]).await; + let send = target.server_message(&read_buf[..s]).await; if send.is_err() { return; } diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs index c7ac242fa6c..ca604468518 100644 --- a/cli/src/tunnels/server_bridge_windows.rs +++ b/cli/src/tunnels/server_bridge_windows.rs @@ -49,7 +49,6 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result Result { @@ -88,7 +87,7 @@ impl ServerBridge { match client.try_read(&mut read_buf) { Ok(0) => return, // EOF Ok(s) => { - let send = target.server_message(index, &read_buf[..s]).await; + let send = target.server_message(&read_buf[..s]).await; if send.is_err() { return; } diff --git a/cli/src/tunnels/server_multiplexer.rs b/cli/src/tunnels/server_multiplexer.rs new file mode 100644 index 00000000000..34782ff375b --- /dev/null +++ b/cli/src/tunnels/server_multiplexer.rs @@ -0,0 +1,145 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use std::sync::Arc; + +use futures::future::join_all; + +use crate::log; + +use super::server_bridge::ServerBridge; + +type Inner = Arc>>>; + +struct ServerBridgeRec { + id: u16, + // bridge is removed when there's a write loop currently active + bridge: Option, + write_queue: Vec>, +} + +/// The ServerMultiplexer manages multiple server bridges and allows writing +/// to them in a thread-safe way. It is copy, sync, and clone. +#[derive(Clone)] +pub struct ServerMultiplexer { + inner: Inner, +} + +impl ServerMultiplexer { + pub fn new() -> Self { + Self { + inner: Arc::new(std::sync::Mutex::new(Some(Vec::new()))), + } + } + + /// Adds a new bridge to the multiplexer. + pub fn register(&self, id: u16, bridge: ServerBridge) { + let bridge_rec = ServerBridgeRec { + id, + bridge: Some(bridge), + write_queue: vec![], + }; + + let mut lock = self.inner.lock().unwrap(); + match &mut *lock { + Some(server_bridges) => (*server_bridges).push(bridge_rec), + None => *lock = Some(vec![bridge_rec]), + } + } + + /// Removes a server bridge by ID. + pub fn remove(&self, id: u16) { + let mut lock = self.inner.lock().unwrap(); + if let Some(bridges) = &mut *lock { + bridges.retain(|sb| sb.id != id); + } + } + + /// Handle an incoming server message. This is synchronous and uses a 'write loop' + /// to ensure message order is preserved exactly, which is necessary for compression. + /// Returns false if there was no server with the given bridge_id. + pub fn write_message(&self, log: &log::Logger, bridge_id: u16, message: Vec) -> bool { + let mut lock = self.inner.lock().unwrap(); + + let bridges = match &mut *lock { + Some(sb) => sb, + None => return false, + }; + + let record = match bridges.iter_mut().find(|b| b.id == bridge_id) { + Some(sb) => sb, + None => return false, + }; + + record.write_queue.push(message); + if let Some(bridge) = record.bridge.take() { + let bridges_lock = self.inner.clone(); + let log = log.clone(); + tokio::spawn(write_loop(log, record.id, bridge, bridges_lock)); + } + + true + } + + /// Disposes all running server bridges. + pub async fn dispose(&self) { + let bridges = { + let mut lock = self.inner.lock().unwrap(); + lock.take() + }; + + let bridges = match bridges { + Some(b) => b, + None => return, + }; + + join_all( + bridges + .into_iter() + .filter_map(|b| b.bridge) + .map(|b| b.close()), + ) + .await; + } +} + +/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and +/// runs until there's no more items in the 'write queue'. At that point, if the +/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll +/// return the ServerBridge so that the next handle_server_message call starts +/// the loop again. Otherwise, it'll close the bridge. +async fn write_loop(log: log::Logger, id: u16, mut bridge: ServerBridge, bridges_lock: Inner) { + let mut items_vec = vec![]; + loop { + { + let mut lock = bridges_lock.lock().unwrap(); + let server_bridges = match &mut *lock { + Some(sb) => sb, + None => break, + }; + + let bridge_rec = match server_bridges.iter_mut().find(|b| id == b.id) { + Some(b) => b, + None => break, + }; + + if bridge_rec.write_queue.is_empty() { + bridge_rec.bridge = Some(bridge); + return; + } + + std::mem::swap(&mut bridge_rec.write_queue, &mut items_vec); + } + + for item in items_vec.drain(..) { + if let Err(e) = bridge.write(item).await { + warning!(log, "Error writing to server: {:?}", e); + break; + } + } + } + + bridge.close().await.ok(); // got here from `break` above, meaning our record got cleared. Close the bridge if so +} diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs index 95ed0bc3e0e..0445db96e8f 100644 --- a/cli/src/tunnels/socket_signal.rs +++ b/cli/src/tunnels/socket_signal.rs @@ -6,7 +6,12 @@ use serde::Serialize; use tokio::sync::mpsc; -use super::protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest}; +use crate::msgpack_rpc::MsgPackCaller; + +use super::{ + protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest}, + server_multiplexer::ServerMultiplexer, +}; pub struct CloseReason(pub String); @@ -15,8 +20,6 @@ pub enum SocketSignal { Send(Vec), /// Closes the socket (e.g. as a result of an error) CloseWith(CloseReason), - /// Disposes ServerBridge corresponding to an ID - CloseServerBridge(u16), } impl SocketSignal { @@ -28,20 +31,43 @@ impl SocketSignal { } } +/// todo@connor4312: cleanup once everything is moved to rpc standard interfaces +pub enum ServerMessageDestination { + Channel(mpsc::Sender), + Rpc(MsgPackCaller), +} + /// Struct that handling sending or closing a connected server socket. pub struct ServerMessageSink { - tx: mpsc::Sender, + id: u16, + tx: Option, + multiplexer: ServerMultiplexer, flate: Option>, } impl ServerMessageSink { - pub fn new_plain(tx: mpsc::Sender) -> Self { - Self { tx, flate: None } + pub fn new_plain( + multiplexer: ServerMultiplexer, + id: u16, + tx: ServerMessageDestination, + ) -> Self { + Self { + tx: Some(tx), + id, + multiplexer, + flate: None, + } } - pub fn new_compressed(tx: mpsc::Sender) -> Self { + pub fn new_compressed( + multiplexer: ServerMultiplexer, + id: u16, + tx: ServerMessageDestination, + ) -> Self { Self { - tx, + tx: Some(tx), + id, + multiplexer, flate: Some(FlateStream::new(CompressFlateAlgorithm( flate2::Compress::new(flate2::Compression::new(2), false), ))), @@ -50,18 +76,30 @@ impl ServerMessageSink { pub async fn server_message( &mut self, - i: u16, body: &[u8], ) -> Result<(), mpsc::error::SendError> { - let msg = { - let body = self.get_server_msg_content(body); - SocketSignal::from_message(&ToClientRequest { - id: None, - params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }), - }) + let id = self.id; + let mut tx = self.tx.take().unwrap(); + let body = self.get_server_msg_content(body); + let msg = RefServerMessageParams { i: id, body }; + + let r = match &mut tx { + ServerMessageDestination::Channel(tx) => { + tx.send(SocketSignal::from_message(&ToClientRequest { + id: None, + params: ClientRequestMethod::servermsg(msg), + })) + .await + } + ServerMessageDestination::Rpc(caller) => { + caller.notify("servermsg", msg); + Ok(()) + } }; - self.tx.send(msg).await + drop(body); + self.tx = Some(tx); + r } pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] { @@ -73,13 +111,11 @@ impl ServerMessageSink { body } +} - #[allow(dead_code)] - pub async fn closed_server_bridge( - &mut self, - i: u16, - ) -> Result<(), mpsc::error::SendError> { - self.tx.send(SocketSignal::CloseServerBridge(i)).await +impl Drop for ServerMessageSink { + fn drop(&mut self) { + self.multiplexer.remove(self.id); } } @@ -228,7 +264,11 @@ mod tests { #[test] fn test_round_trips_compression() { let (tx, _) = mpsc::channel(1); - let mut sink = ServerMessageSink::new_compressed(tx); + let mut sink = ServerMessageSink::new_compressed( + ServerMultiplexer::new(), + 0, + ServerMessageDestination::Channel(tx), + ); let mut decompress = ClientMessageDecoder::new_compressed(); // 3000 and 30000 test resizing the buffer diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs index b2d2f35a319..22dab1a7fc0 100644 --- a/cli/src/tunnels/wsl_server.rs +++ b/cli/src/tunnels/wsl_server.rs @@ -3,27 +3,29 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::path::PathBuf; - use tokio::sync::mpsc; use crate::{ - json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer}, log, - rpc::RpcCaller, + msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCaller}, state::LauncherPaths, tunnels::code_server::ServerBuilder, update_service::{Platform, Release, TargetKind}, util::{ - errors::{wrap, AnyError, MismatchedLaunchModeError}, + errors::{ + wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError, + }, http::ReqwestSimpleHttp, }, }; use super::{ code_server::{AnyCodeServer, CodeServerArgs, ResolvedServerParams}, - protocol::{InstallFromLocalFolderParams, InstallPortServerResult}, + protocol::{EmptyObject, InstallFromLocalFolderParams, ServerMessageParams}, + server_bridge::ServerBridge, + server_multiplexer::ServerMultiplexer, shutdown_signal::ShutdownSignal, + socket_signal::{ClientMessageDecoder, ServerMessageDestination, ServerMessageSink}, }; struct HandlerContext { @@ -32,12 +34,14 @@ struct HandlerContext { launcher_paths: LauncherPaths, platform: Platform, http: ReqwestSimpleHttp, + caller: MsgPackCaller, + multiplexer: ServerMultiplexer, } #[derive(Clone)] -struct JsonRpcLogSink(RpcCaller); +struct RpcLogSink(MsgPackCaller); -impl JsonRpcLogSink { +impl RpcLogSink { fn write_json(&self, level: String, message: &str) { self.0.notify( "log", @@ -49,7 +53,7 @@ impl JsonRpcLogSink { } } -impl log::LogSink for JsonRpcLogSink { +impl log::LogSink for RpcLogSink { fn write_log(&self, level: log::Level, _prefix: &str, message: &str) { self.write_json(level.to_string(), message); } @@ -68,24 +72,33 @@ pub async fn serve_wsl( shutdown_rx: mpsc::UnboundedReceiver, ) -> Result { let (caller_tx, caller_rx) = mpsc::unbounded_channel(); - let mut rpc = new_json_rpc(); + let mut rpc = new_msgpack_rpc(); let caller = rpc.get_caller(caller_tx); - let log = log.with_sink(JsonRpcLogSink(caller)); + let log = log.with_sink(RpcLogSink(caller.clone())); let mut rpc = rpc.methods(HandlerContext { log: log.clone(), + caller, code_server_args, launcher_paths, platform, + multiplexer: ServerMultiplexer::new(), http: ReqwestSimpleHttp::with_client(http), }); rpc.register_async( - "install_local", - move |params: InstallFromLocalFolderParams, c| async move { install_local(&c, params).await }, + "serve", + move |m: InstallFromLocalFolderParams, c| async move { handle_serve(&c, m).await }, ); + rpc.register_sync("servermsg", move |m: ServerMessageParams, c| { + if c.multiplexer.write_message(&c.log, m.i, m.body) { + Ok(EmptyObject {}) + } else { + Err(NoAttachedServerError().into()) + } + }); - start_json_rpc( + start_msgpack_rpc( rpc.build(log), tokio::io::stdin(), tokio::io::stdout(), @@ -98,35 +111,51 @@ pub async fn serve_wsl( Ok(0) } -async fn install_local( +async fn handle_serve( c: &HandlerContext, params: InstallFromLocalFolderParams, -) -> Result { +) -> Result { // fill params.extensions into code_server_args.install_extensions let mut csa = c.code_server_args.clone(); - csa.install_extensions.extend(params.extensions.into_iter()); + csa.install_extensions + .extend(params.inner.extensions.into_iter()); let resolved = ResolvedServerParams { code_server_args: csa, release: Release { name: String::new(), - commit: params.commit_id, + commit: params + .inner + .commit_id + .ok_or_else(|| InvalidRpcDataError("commit_id is required".to_string()))?, platform: c.platform, target: TargetKind::Server, - quality: params.quality, + quality: params.inner.quality, }, }; let sb = ServerBuilder::new(&c.log, &resolved, &c.launcher_paths, c.http.clone()); - - let s = match sb.get_running().await? { - Some(AnyCodeServer::Port(s)) => s, + let code_server = match sb.get_running().await? { + Some(AnyCodeServer::Socket(s)) => s, Some(_) => return Err(MismatchedLaunchModeError().into()), None => { - sb.setup(Some(PathBuf::from(params.archive_path))).await?; - sb.listen_on_port(0).await? + sb.setup(None).await?; + sb.listen_on_default_socket().await? } }; - Ok(InstallPortServerResult { port: s.port }) + let bridge = ServerBridge::new( + &code_server.socket, + ServerMessageSink::new_plain( + c.multiplexer.clone(), + params.inner.socket_id, + ServerMessageDestination::Rpc(c.caller.clone()), + ), + ClientMessageDecoder::new_plain(), + ) + .await?; + + c.multiplexer.register(params.inner.socket_id, bridge); + trace!(c.log, "Attached to server"); + Ok(EmptyObject {}) }