From 2c2ead679bfdfea755d6d383bab628fe1fb912bf Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Tue, 17 Jan 2023 16:05:25 -0800 Subject: [PATCH] cli: initial wsl control server Adds an stdin/out json rpc server for wsl. Exposes a singular install_local command to install+boot the vscode server on a port from a local archive. Also refines the common rpc layer some more. I'm decently happy with it now. --- cli/Cargo.lock | 8 +- cli/Cargo.toml | 2 +- cli/src/bin/code/main.rs | 5 +- cli/src/commands.rs | 1 + cli/src/commands/args.rs | 16 ++++ cli/src/commands/internal_wsl.rs | 32 +++++++ cli/src/commands/tunnels.rs | 58 +++--------- cli/src/json_rpc.rs | 85 +++++++++++++++++ cli/src/lib.rs | 1 + cli/src/log.rs | 12 +++ cli/src/rpc.rs | 146 ++++++++++++++++++++--------- cli/src/tunnels.rs | 3 + cli/src/tunnels/code_server.rs | 51 ++++++++-- cli/src/tunnels/control_server.rs | 35 ++++--- cli/src/tunnels/protocol.rs | 15 +++ cli/src/tunnels/service.rs | 3 +- cli/src/tunnels/service_linux.rs | 10 +- cli/src/tunnels/service_macos.rs | 10 +- cli/src/tunnels/service_windows.rs | 3 +- cli/src/tunnels/shutdown_signal.rs | 63 +++++++++++++ cli/src/tunnels/wsl_server.rs | 132 ++++++++++++++++++++++++++ 21 files changed, 554 insertions(+), 137 deletions(-) create mode 100644 cli/src/commands/internal_wsl.rs create mode 100644 cli/src/json_rpc.rs create mode 100644 cli/src/tunnels/shutdown_signal.rs create mode 100644 cli/src/tunnels/wsl_server.rs diff --git a/cli/Cargo.lock b/cli/Cargo.lock index beb9c7d59ac..fb39d8bd026 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -1197,9 +1197,9 @@ dependencies = [ [[package]] name = "ntapi" -version = "0.3.7" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c28774a7fd2fbb4f0babd8237ce554b73af68021b5f695a3cebd6c59bac0980f" +checksum = "bc51db7b362b205941f71232e56c625156eb9a929f8cf74a428fd5bc094a4afc" dependencies = [ "winapi", ] @@ -2135,9 +2135,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.23.13" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3977ec2e0520829be45c8a2df70db2bf364714d8a748316a10c3c35d4d2b01c9" +checksum = "975fe381e0ecba475d4acff52466906d95b153a40324956552e027b2a9eaa89e" dependencies = [ "cfg-if", "core-foundation-sys", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 8c7de8ced10..c3d5d492b94 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -22,7 +22,7 @@ flate2 = { version = "1.0.22" } zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] } regex = { version = "1.5.5" } lazy_static = { version = "1.4.0" } -sysinfo = { version = "0.23.5" } +sysinfo = { version = "0.27.7" } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } rmp-serde = "1.0" diff --git a/cli/src/bin/code/main.rs b/cli/src/bin/code/main.rs index e9e440bafd9..6bc991b4f5a 100644 --- a/cli/src/bin/code/main.rs +++ b/cli/src/bin/code/main.rs @@ -8,7 +8,7 @@ use std::process::Command; use clap::Parser; use cli::{ - commands::{args, tunnels, update, version, CommandContext}, + commands::{args, internal_wsl, tunnels, update, version, CommandContext}, constants::get_default_user_agent, desktop, log as own_log, state::LauncherPaths, @@ -58,6 +58,9 @@ async fn main() -> Result<(), std::convert::Infallible> { .. }) => match cmd { args::StandaloneCommands::Update(args) => update::update(context, args).await, + args::StandaloneCommands::Wsl(args) => match args.command { + args::WslCommands::Serve => internal_wsl::serve(context).await, + }, }, args::AnyCli::Standalone(args::StandaloneCli { core: c, .. }) | args::AnyCli::Integrated(args::IntegratedCli { core: c, .. }) => match c.subcommand { diff --git a/cli/src/commands.rs b/cli/src/commands.rs index 754729f2c04..32b1ac3592b 100644 --- a/cli/src/commands.rs +++ b/cli/src/commands.rs @@ -9,4 +9,5 @@ pub mod args; pub mod tunnels; pub mod update; pub mod version; +pub mod internal_wsl; pub use context::CommandContext; diff --git a/cli/src/commands/args.rs b/cli/src/commands/args.rs index 8bffd551e32..96f27ce7b72 100644 --- a/cli/src/commands/args.rs +++ b/cli/src/commands/args.rs @@ -146,6 +146,22 @@ impl<'a> From<&'a CliCore> for CodeServerArgs { pub enum StandaloneCommands { /// Updates the CLI. Update(StandaloneUpdateArgs), + + /// Internal commands for WSL serving. + #[clap(hide = true)] + Wsl(WslArgs), +} + +#[derive(Args, Debug, Clone)] +pub struct WslArgs { + #[clap(subcommand)] + pub command: WslCommands, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum WslCommands { + /// Runs the WSL server on stdin/out + Serve, } #[derive(Args, Debug, Clone)] diff --git a/cli/src/commands/internal_wsl.rs b/cli/src/commands/internal_wsl.rs new file mode 100644 index 00000000000..9912b59428b --- /dev/null +++ b/cli/src/commands/internal_wsl.rs @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use crate::{ + tunnels::{serve_wsl, shutdown_signal::ShutdownSignal}, + util::{errors::AnyError, prereqs::PreReqChecker}, +}; + +use super::CommandContext; + +pub async fn serve(ctx: CommandContext) -> Result { + let signal = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); + let platform = spanf!( + ctx.log, + ctx.log.span("prereq"), + PreReqChecker::new().verify() + )?; + + serve_wsl( + ctx.log, + ctx.paths, + (&ctx.args).into(), + platform, + ctx.http, + signal, + ) + .await?; + + Ok(0) +} diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 5b0bc0d38fc..c59006ca50a 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -5,11 +5,9 @@ use async_trait::async_trait; use sha2::{Digest, Sha256}; -use std::fmt; use std::str::FromStr; -use sysinfo::{Pid, SystemExt}; +use sysinfo::Pid; use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; use super::{ args::{ @@ -20,6 +18,7 @@ use super::{ }; use crate::tunnels::dev_tunnels::ActiveTunnel; +use crate::tunnels::shutdown_signal::ShutdownSignal; use crate::{ auth::Auth, log::{self, Logger}, @@ -93,22 +92,6 @@ impl ServiceContainer for TunnelServiceContainer { Ok(()) } } -/// Describes the signal to manully stop the server -pub enum ShutdownSignal { - CtrlC, - ParentProcessKilled, - ServiceStopped, -} - -impl fmt::Display for ShutdownSignal { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ShutdownSignal::CtrlC => write!(f, "Ctrl-C received"), - ShutdownSignal::ParentProcessKilled => write!(f, "Parent process no longer exists"), - ShutdownSignal::ServiceStopped => write!(f, "Service stopped"), - } - } -} pub async fn service( ctx: CommandContext, @@ -269,32 +252,17 @@ async fn serve_with_csa( let shutdown_tx = if let Some(tx) = shutdown_rx { tx - } else { - let (tx, rx) = mpsc::unbounded_channel::(); - if let Some(process_id) = gateway_args.parent_process_id { - match Pid::from_str(&process_id) { - Ok(pid) => { - let tx = tx.clone(); - info!(log, "checking for parent process {}", process_id); - tokio::spawn(async move { - let mut s = sysinfo::System::new(); - while s.refresh_process(pid) { - sleep(Duration::from_millis(2000)).await; - } - tx.send(ShutdownSignal::ParentProcessKilled).ok(); - }); - } - Err(_) => { - info!(log, "invalid parent process id: {}", process_id); - } - } - } - tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - tx.send(ShutdownSignal::CtrlC).ok(); - }); - rx - }; + } else if let Some(pid) = gateway_args + .parent_process_id + .and_then(|p| Pid::from_str(&p).ok()) + { + ShutdownSignal::create_rx(&[ + ShutdownSignal::CtrlC, + ShutdownSignal::ParentProcessKilled(pid), + ]) + } else { + ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]) + }; let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?; r.tunnel.close().await.ok(); diff --git a/cli/src/json_rpc.rs b/cli/src/json_rpc.rs new file mode 100644 index 00000000000..38088e21589 --- /dev/null +++ b/cli/src/json_rpc.rs @@ -0,0 +1,85 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use tokio::{ + io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, + sync::mpsc, +}; + +use crate::{ + rpc::{self, MaybeSync, Serialization}, + util::errors::InvalidRpcDataError, +}; +use std::io; + +#[derive(Clone)] +pub struct JsonRpcSerializer {} + +impl Serialization for JsonRpcSerializer { + fn serialize(&self, value: impl serde::Serialize) -> Vec { + let mut v = serde_json::to_vec(&value).unwrap(); + v.push(b'\n'); + v + } + + fn deserialize( + &self, + b: &[u8], + ) -> Result { + serde_json::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into()) + } +} + +/// Creates a new RPC Builder that serializes to JSON. +pub fn new_json_rpc() -> rpc::RpcBuilder { + rpc::RpcBuilder::new(JsonRpcSerializer {}) +} + +pub async fn start_json_rpc( + dispatcher: rpc::RpcDispatcher, + read: impl AsyncRead + Unpin, + mut write: impl AsyncWrite + Unpin, + mut msg_rx: mpsc::UnboundedReceiver>, + mut shutdown_rx: mpsc::UnboundedReceiver, +) -> io::Result> { + let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); + let mut read = BufReader::new(read); + + let mut read_buf = String::new(); + + loop { + tokio::select! { + r = shutdown_rx.recv() => return Ok(r), + Some(w) = write_rx.recv() => { + write.write_all(&w).await?; + }, + Some(w) = msg_rx.recv() => { + write.write_all(&w).await?; + }, + n = read.read_line(&mut read_buf) => { + let r = match n { + Ok(0) => return Ok(None), + Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()), + Err(e) => return Err(e) + }; + + match r { + MaybeSync::Sync(Some(v)) => { + write_tx.send(v).ok(); + }, + MaybeSync::Sync(None) => continue, + MaybeSync::Future(fut) => { + let write_tx = write_tx.clone(); + tokio::spawn(async move { + if let Some(v) = fut.await { + write_tx.send(v).ok(); + } + }); + } + } + } + } + } +} diff --git a/cli/src/lib.rs b/cli/src/lib.rs index 3186cf6a6bf..ff3d3662853 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -19,3 +19,4 @@ pub mod update_service; pub mod util; mod rpc; +mod json_rpc; diff --git a/cli/src/log.rs b/cli/src/log.rs index b4d291c5708..a008a1ba06d 100644 --- a/cli/src/log.rs +++ b/cli/src/log.rs @@ -135,6 +135,7 @@ impl Clone for Box { } } +/// The basic log sink that writes output to stdout, with colors when relevant. #[derive(Clone)] pub struct StdioLogSink { level: Level, @@ -247,6 +248,17 @@ impl Logger { } } + /// Creates a new logger with the sink replace with the given sink. + pub fn with_sink(&self, sink: T) -> Logger + where + T: LogSink + 'static, + { + Logger { + sink: vec![Box::new(sink)], + ..self.clone() + } + } + pub fn get_download_logger<'a>(&'a self, prefix: &'static str) -> DownloadLogger<'a> { DownloadLogger { prefix, diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index 1584b62a7af..2249e5b7cf1 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -5,18 +5,17 @@ use std::{ collections::HashMap, + future, sync::{ atomic::{AtomicU32, Ordering}, Arc, Mutex, }, }; -use futures::{ - future::{self, BoxFuture}, - Future, FutureExt, -}; +use crate::log; +use futures::{future::BoxFuture, Future, FutureExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use crate::util::errors::AnyError; @@ -38,22 +37,51 @@ pub trait Serialization: Send + Sync + 'static { /// RPC is a basic, transport-agnostic builder for RPC methods. You can /// register methods to it, then call `.build()` to get a "dispatcher" type. -pub struct RpcBuilder { - context: Arc, +pub struct RpcBuilder { serializer: Arc, methods: HashMap<&'static str, Method>, + calls: Arc>>, } -impl RpcBuilder { +impl RpcBuilder { /// Creates a new empty RPC builder. - pub fn new(serializer: S, context: C) -> Self { + pub fn new(serializer: S) -> Self { Self { - context: Arc::new(context), serializer: Arc::new(serializer), methods: HashMap::new(), + calls: Arc::new(std::sync::Mutex::new(HashMap::new())), } } + /// Creates a caller that will be connected to any eventual dispatchers, + /// and that sends data to the "tx" channel. + pub fn get_caller(&mut self, sender: mpsc::UnboundedSender>) -> RpcCaller { + RpcCaller { + serializer: self.serializer.clone(), + calls: self.calls.clone(), + sender, + } + } + + /// Gets a method builder. + pub fn methods(self, context: C) -> RpcMethodBuilder { + RpcMethodBuilder { + context: Arc::new(context), + serializer: self.serializer, + methods: self.methods, + calls: self.calls, + } + } +} + +pub struct RpcMethodBuilder { + context: Arc, + serializer: Arc, + methods: HashMap<&'static str, Method>, + calls: Arc>>, +} + +impl RpcMethodBuilder { /// Registers a synchronous rpc call that returns its result directly. pub fn register_sync(&mut self, method_name: &'static str, callback: F) where @@ -152,10 +180,13 @@ impl RpcBuilder { } /// Builds into a usable, sync rpc dispatcher. - pub fn build(self) -> RpcDispatcher { + pub fn build(self, log: log::Logger) -> RpcDispatcher { RpcDispatcher { - i: Arc::new(self), - calls: Arc::new(std::sync::Mutex::new(HashMap::new())), + log, + context: self.context, + calls: self.calls, + serializer: self.serializer, + methods: Arc::new(self.methods), } } } @@ -163,37 +194,38 @@ impl RpcBuilder { type DispatchMethod = Box; /// Dispatcher returned from a Builder that provides a transport-agnostic way to -/// deserialize and handle RPC calls. This structure may get more advanced as +/// deserialize and dispatch RPC calls. This structure may get more advanced as /// time goes on... -pub struct RpcDispatcher { - i: Arc>, +#[derive(Clone)] +pub struct RpcCaller { + serializer: Arc, calls: Arc>>, + sender: mpsc::UnboundedSender>, } -impl Clone for RpcDispatcher { - fn clone(&self) -> Self { - RpcDispatcher { - i: self.i.clone(), - calls: self.calls.clone(), - } +impl RpcCaller { + /// Enqueues an outbound call. Returns whether the message was enqueued. + pub fn notify(&self, method: M, params: A) -> bool + where + M: Into, + A: Serialize, + { + let body = self.serializer.serialize(&FullRequest { + id: None, + method: method.into(), + params, + }); + + self.sender.send(body).is_ok() } -} -static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); -fn next_message_id() -> u32 { - MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst) -} -trait AssertIsSync: Sync {} -impl AssertIsSync for RpcDispatcher {} - -impl RpcDispatcher { - /// Enqueues an outbound call, returning the bytes that should be sent to make it run. + /// Enqueues an outbound call, returning its result. #[allow(dead_code)] pub async fn call( &self, method: M, params: A, - ) -> (Vec, oneshot::Receiver>) + ) -> oneshot::Receiver> where M: Into, A: Serialize, @@ -201,13 +233,18 @@ impl RpcDispatcher { { let (tx, rx) = oneshot::channel(); let id = next_message_id(); - let body = self.i.serializer.serialize(&FullRequest { + let body = self.serializer.serialize(&FullRequest { id: Some(id), method: method.into(), params, }); - let serializer = self.i.serializer.clone(); + if self.sender.send(body).is_err() { + drop(tx); + return rx; + } + + let serializer = self.serializer.clone(); self.calls.lock().unwrap().insert( id, Box::new(move |body| { @@ -226,9 +263,28 @@ impl RpcDispatcher { }), ); - (body, rx) + rx } +} +/// Dispatcher returned from a Builder that provides a transport-agnostic way to +/// deserialize and handle RPC calls. This structure may get more advanced as +/// time goes on... +#[derive(Clone)] +pub struct RpcDispatcher { + log: log::Logger, + context: Arc, + serializer: Arc, + methods: Arc>, + calls: Arc>>, +} + +static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); +fn next_message_id() -> u32 { + MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst) +} + +impl RpcDispatcher { /// Runs the incoming request, returning the result of the call synchronously /// or in a future. (The caller can then decide whether to run the future /// sequentially in its receive loop, or not.) @@ -236,19 +292,22 @@ impl RpcDispatcher { /// The future or return result will be optional bytes that should be sent /// back to the socket. pub fn dispatch(&self, body: &[u8]) -> MaybeSync { - let partial = match self.i.serializer.deserialize::(body) { + let partial = match self.serializer.deserialize::(body) { Ok(b) => b, - Err(_err) => return MaybeSync::Sync(None), + Err(_err) => { + warning!(self.log, "Failed to deserialize request, hex: {:X?}", body); + return MaybeSync::Sync(None); + } }; let id = partial.id; if let Some(method_name) = partial.method { - let method = self.i.methods.get(method_name.as_str()); + let method = self.methods.get(method_name.as_str()); match method { Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)), Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)), None => MaybeSync::Sync(id.map(|id| { - self.i.serializer.serialize(&ErrorResponse { + self.serializer.serialize(&ErrorResponse { id, error: ResponseError { code: -1, @@ -273,10 +332,13 @@ impl RpcDispatcher { } pub fn context(&self) -> Arc { - self.i.context.clone() + self.context.clone() } } +trait AssertIsSync: Sync {} +impl AssertIsSync for RpcDispatcher {} + /// Approximate shape that is used to determine what kind of data is incoming. #[derive(Deserialize)] struct PartialIncoming { @@ -287,7 +349,7 @@ struct PartialIncoming { } #[derive(Serialize)] -struct FullRequest

{ +pub struct FullRequest

{ pub id: Option, pub method: String, pub params: P, diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index 011127e02b2..0f112c74734 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -7,7 +7,9 @@ pub mod code_server; pub mod dev_tunnels; pub mod legal; pub mod paths; +pub mod shutdown_signal; +mod wsl_server; mod control_server; mod name_generator; mod port_forwarder; @@ -25,6 +27,7 @@ mod service_windows; mod socket_signal; pub use control_server::serve; +pub use wsl_server::serve_wsl; pub use service::{ create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, }; diff --git a/cli/src/tunnels/code_server.rs b/cli/src/tunnels/code_server.rs index cf0beeec2bc..357f04750ba 100644 --- a/cli/src/tunnels/code_server.rs +++ b/cli/src/tunnels/code_server.rs @@ -286,6 +286,7 @@ async fn install_server_if_needed( paths: &ServerPaths, release: &Release, http: impl SimpleHttp + Send + Sync + 'static, + existing_archive_path: Option, ) -> Result<(), AnyError> { if paths.executable.exists() { info!( @@ -296,11 +297,14 @@ async fn install_server_if_needed( return Ok(()); } - let tar_file_path = spanf!( - log, - log.span("server.download"), - download_server(&paths.server_dir, release, log, http) - )?; + let tar_file_path = match existing_archive_path { + Some(p) => p, + None => spanf!( + log, + log.span("server.download"), + download_server(&paths.server_dir, release, log, http) + )?, + }; span!( log, @@ -471,7 +475,7 @@ impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Htt } /// Ensures the server is set up in the configured directory. - pub async fn setup(&self) -> Result<(), AnyError> { + pub async fn setup(&self, existing_archive_path: Option) -> Result<(), AnyError> { debug!( self.logger, "Installing and setting up {}...", QUALITYLESS_SERVER_NAME @@ -482,6 +486,7 @@ impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Htt &self.server_paths, &self.server_params.release, self.http.clone(), + existing_archive_path, ) .await?; debug!(self.logger, "Server setup complete"); @@ -499,6 +504,40 @@ impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Htt Ok(()) } + pub async fn listen_on_port(&self, port: u16) -> Result { + let mut cmd = self.get_base_command(); + cmd.arg("--start-server") + .arg("--enable-remote-auto-shutdown") + .arg(format!("--port={}", port)); + + let child = self.spawn_server_process(cmd)?; + let log_file = self.get_logfile()?; + let plog = self.logger.prefixed(&log::new_code_server_prefix()); + + let (mut origin, listen_rx) = + monitor_server::(child, Some(log_file), plog, false); + + let port = match timeout(Duration::from_secs(8), listen_rx).await { + Err(e) => { + origin.kill().await; + Err(wrap(e, "timed out looking for port")) + } + Ok(Err(e)) => { + origin.kill().await; + Err(wrap(e, "server exited without writing port")) + } + Ok(Ok(p)) => Ok(p), + }?; + + info!(self.logger, "Server started"); + + Ok(PortCodeServer { + commit_id: self.server_params.release.commit.to_owned(), + port, + origin: Arc::new(origin), + }) + } + pub async fn listen_on_default_socket(&self) -> Result { let requested_file = if cfg!(target_os = "windows") { PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4())) diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index 046f4f685ba..cf14b49f5d6 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -2,7 +2,6 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use crate::commands::tunnels::ShutdownSignal; use crate::constants::{ CONTROL_PORT, EDITOR_WEB_URL, PROTOCOL_VERSION, QUALITYLESS_SERVER_NAME, VSCODE_CLI_VERSION, }; @@ -49,6 +48,7 @@ use super::protocol::{ VersionParams, }; use super::server_bridge::{get_socket_rw_stream, ServerBridge}; +use super::shutdown_signal::ShutdownSignal; use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal}; type ServerBridgeListLock = Arc>>>; @@ -297,22 +297,19 @@ async fn process_socket( let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); let server_bridges = Arc::new(std::sync::Mutex::new(Some(vec![]))); let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone()); - let mut rpc = RpcBuilder::new( - MsgPackSerializer {}, - HandlerContext { - did_update: Arc::new(AtomicBool::new(false)), - socket_tx: socket_tx.clone(), - log: log.clone(), - launcher_paths, - code_server_args, - code_server: Arc::new(Mutex::new(None)), - server_bridges: server_bridges.clone(), - port_forwarding, - platform, - http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated), - http_requests: http_requests.clone(), - }, - ); + let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext { + did_update: Arc::new(AtomicBool::new(false)), + socket_tx: socket_tx.clone(), + log: log.clone(), + launcher_paths, + code_server_args, + code_server: Arc::new(Mutex::new(None)), + server_bridges: server_bridges.clone(), + port_forwarding, + platform, + http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated), + http_requests: http_requests.clone(), + }); rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {})); rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname()); @@ -363,7 +360,7 @@ async fn process_socket( let rx_counter = rx_counter.clone(); let socket_tx = socket_tx.clone(); let exit_barrier = exit_barrier.clone(); - let rpc = rpc.build(); + let rpc = rpc.build(log.clone()); tokio::spawn(async move { send_version(&socket_tx).await; @@ -579,7 +576,7 @@ async fn handle_serve( Some(AnyCodeServer::Socket(s)) => s, Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())), None => { - $sb.setup().await?; + $sb.setup(None).await?; $sb.listen_on_default_socket().await? } } diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index dc3d06879bb..cc271323791 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -54,6 +54,21 @@ pub struct ForwardResult { pub uri: String, } +/// The `install_local` method in the wsl control server +#[derive(Deserialize, Debug)] +pub struct InstallFromLocalFolderParams { + pub commit_id: String, + pub quality: Quality, + pub archive_path: String, + #[serde(default)] + pub extensions: Vec, +} + +#[derive(Serialize, Debug)] +pub struct InstallPortServerResult { + pub port: u16, +} + #[derive(Deserialize, Debug)] pub struct ServeParams { pub socket_id: u16, diff --git a/cli/src/tunnels/service.rs b/cli/src/tunnels/service.rs index 63e3003d9af..a0a2129bef9 100644 --- a/cli/src/tunnels/service.rs +++ b/cli/src/tunnels/service.rs @@ -8,12 +8,13 @@ use std::path::{Path, PathBuf}; use async_trait::async_trait; use tokio::sync::mpsc; -use crate::commands::tunnels::ShutdownSignal; use crate::log; use crate::state::LauncherPaths; use crate::util::errors::{wrap, AnyError}; use crate::util::io::{tailf, TailEvent}; +use super::shutdown_signal::ShutdownSignal; + pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log"; #[async_trait] diff --git a/cli/src/tunnels/service_linux.rs b/cli/src/tunnels/service_linux.rs index 6ba66eb7392..022d4cee409 100644 --- a/cli/src/tunnels/service_linux.rs +++ b/cli/src/tunnels/service_linux.rs @@ -10,12 +10,11 @@ use std::{ process::Command, }; +use super::shutdown_signal::ShutdownSignal; use async_trait::async_trait; -use tokio::sync::mpsc; use zbus::{dbus_proxy, zvariant, Connection}; use crate::{ - commands::tunnels::ShutdownSignal, constants::{APPLICATION_NAME, PRODUCT_NAME_LONG}, log, state::LauncherPaths, @@ -120,12 +119,7 @@ impl ServiceManager for SystemdService { launcher_paths: crate::state::LauncherPaths, mut handle: impl 'static + super::ServiceContainer, ) -> Result<(), crate::util::errors::AnyError> { - let (tx, rx) = mpsc::unbounded_channel::(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - tx.send(ShutdownSignal::CtrlC).ok(); - }); - + let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); handle.run_service(self.log, launcher_paths, rx).await } diff --git a/cli/src/tunnels/service_macos.rs b/cli/src/tunnels/service_macos.rs index 47833dc31fa..6c83af0f039 100644 --- a/cli/src/tunnels/service_macos.rs +++ b/cli/src/tunnels/service_macos.rs @@ -9,11 +9,10 @@ use std::{ path::{Path, PathBuf}, }; +use super::shutdown_signal::ShutdownSignal; use async_trait::async_trait; -use tokio::sync::mpsc; use crate::{ - commands::tunnels::ShutdownSignal, constants::APPLICATION_NAME, log, state::LauncherPaths, @@ -74,12 +73,7 @@ impl ServiceManager for LaunchdService { launcher_paths: crate::state::LauncherPaths, mut handle: impl 'static + super::ServiceContainer, ) -> Result<(), crate::util::errors::AnyError> { - let (tx, rx) = mpsc::unbounded_channel::(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - tx.send(ShutdownSignal::CtrlC).ok(); - }); - + let rx = ShutdownSignal::create_rx(&[ShutdownSignal::CtrlC]); handle.run_service(self.log, launcher_paths, rx).await } diff --git a/cli/src/tunnels/service_windows.rs b/cli/src/tunnels/service_windows.rs index 8c292db6f8d..c893a3d0b79 100644 --- a/cli/src/tunnels/service_windows.rs +++ b/cli/src/tunnels/service_windows.rs @@ -20,9 +20,8 @@ use windows_service::{ }; use crate::{ - commands::tunnels::ShutdownSignal, constants::QUALITYLESS_PRODUCT_NAME, - util::errors::{wrap, wrapdbg, AnyError, WindowsNeedsElevation}, + util::errors::{wrap, wrapdbg, AnyError, WindowsNeedsElevation}, tunnels::shutdown_signal::ShutdownSignal, }; use crate::{ log::{self, FileLogSink}, diff --git a/cli/src/tunnels/shutdown_signal.rs b/cli/src/tunnels/shutdown_signal.rs new file mode 100644 index 00000000000..9e185770058 --- /dev/null +++ b/cli/src/tunnels/shutdown_signal.rs @@ -0,0 +1,63 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use std::{fmt, time::Duration}; + +use sysinfo::{Pid, SystemExt}; +use tokio::{sync::mpsc, time::sleep}; + +/// Describes the signal to manully stop the server +pub enum ShutdownSignal { + CtrlC, + ParentProcessKilled(Pid), + ServiceStopped, +} + +impl fmt::Display for ShutdownSignal { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ShutdownSignal::CtrlC => write!(f, "Ctrl-C received"), + ShutdownSignal::ParentProcessKilled(p) => { + write!(f, "Parent process {} no longer exists", p) + } + ShutdownSignal::ServiceStopped => write!(f, "Service stopped"), + } + } +} + +impl ShutdownSignal { + /// Creates a receiver channel sent to once any of the signals are received. + /// Note: does not handle ServiceStopped + pub fn create_rx(signals: &[ShutdownSignal]) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded_channel(); + for signal in signals { + let tx = tx.clone(); + match signal { + ShutdownSignal::CtrlC => { + let ctrl_c = tokio::signal::ctrl_c(); + tokio::spawn(async move { + ctrl_c.await.ok(); + tx.send(ShutdownSignal::CtrlC).ok(); + }); + } + ShutdownSignal::ParentProcessKilled(pid) => { + let pid = *pid; + let tx = tx.clone(); + tokio::spawn(async move { + let mut s = sysinfo::System::new(); + while s.refresh_process(pid) { + sleep(Duration::from_millis(2000)).await; + } + tx.send(ShutdownSignal::ParentProcessKilled(pid)).ok(); + }); + } + ShutdownSignal::ServiceStopped => { + unreachable!("Cannot use ServiceStopped in ShutdownSignal::create_rx"); + } + } + } + rx + } +} diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs new file mode 100644 index 00000000000..b2d2f35a319 --- /dev/null +++ b/cli/src/tunnels/wsl_server.rs @@ -0,0 +1,132 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use std::path::PathBuf; + +use tokio::sync::mpsc; + +use crate::{ + json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer}, + log, + rpc::RpcCaller, + state::LauncherPaths, + tunnels::code_server::ServerBuilder, + update_service::{Platform, Release, TargetKind}, + util::{ + errors::{wrap, AnyError, MismatchedLaunchModeError}, + http::ReqwestSimpleHttp, + }, +}; + +use super::{ + code_server::{AnyCodeServer, CodeServerArgs, ResolvedServerParams}, + protocol::{InstallFromLocalFolderParams, InstallPortServerResult}, + shutdown_signal::ShutdownSignal, +}; + +struct HandlerContext { + log: log::Logger, + code_server_args: CodeServerArgs, + launcher_paths: LauncherPaths, + platform: Platform, + http: ReqwestSimpleHttp, +} + +#[derive(Clone)] +struct JsonRpcLogSink(RpcCaller); + +impl JsonRpcLogSink { + fn write_json(&self, level: String, message: &str) { + self.0.notify( + "log", + serde_json::json!({ + "level": level, + "message": message, + }), + ); + } +} + +impl log::LogSink for JsonRpcLogSink { + fn write_log(&self, level: log::Level, _prefix: &str, message: &str) { + self.write_json(level.to_string(), message); + } + + fn write_result(&self, message: &str) { + self.write_json("result".to_string(), message); + } +} + +pub async fn serve_wsl( + log: log::Logger, + launcher_paths: LauncherPaths, + code_server_args: CodeServerArgs, + platform: Platform, + http: reqwest::Client, + shutdown_rx: mpsc::UnboundedReceiver, +) -> Result { + let (caller_tx, caller_rx) = mpsc::unbounded_channel(); + let mut rpc = new_json_rpc(); + let caller = rpc.get_caller(caller_tx); + + let log = log.with_sink(JsonRpcLogSink(caller)); + let mut rpc = rpc.methods(HandlerContext { + log: log.clone(), + code_server_args, + launcher_paths, + platform, + http: ReqwestSimpleHttp::with_client(http), + }); + + rpc.register_async( + "install_local", + move |params: InstallFromLocalFolderParams, c| async move { install_local(&c, params).await }, + ); + + start_json_rpc( + rpc.build(log), + tokio::io::stdin(), + tokio::io::stdout(), + caller_rx, + shutdown_rx, + ) + .await + .map_err(|e| wrap(e, "error handling server stdio"))?; + + Ok(0) +} + +async fn install_local( + c: &HandlerContext, + params: InstallFromLocalFolderParams, +) -> Result { + // fill params.extensions into code_server_args.install_extensions + let mut csa = c.code_server_args.clone(); + csa.install_extensions.extend(params.extensions.into_iter()); + + let resolved = ResolvedServerParams { + code_server_args: csa, + release: Release { + name: String::new(), + commit: params.commit_id, + platform: c.platform, + target: TargetKind::Server, + quality: params.quality, + }, + }; + + let sb = ServerBuilder::new(&c.log, &resolved, &c.launcher_paths, c.http.clone()); + + let s = match sb.get_running().await? { + Some(AnyCodeServer::Port(s)) => s, + Some(_) => return Err(MismatchedLaunchModeError().into()), + None => { + sb.setup(Some(PathBuf::from(params.archive_path))).await?; + sb.listen_on_port(0).await? + } + }; + + Ok(InstallPortServerResult { port: s.port }) +}