diff --git a/cli/Cargo.lock b/cli/Cargo.lock index 45d0f478db1..198430cee92 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -146,9 +146,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.2.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "cache-padded" @@ -230,6 +230,7 @@ dependencies = [ "async-trait", "atty", "base64", + "bytes", "cfg-if", "chrono", "clap", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index c7b6c3746a6..97affef0a64 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -17,7 +17,7 @@ clap = { version = "3.0", features = ["derive", "env"] } open = { version = "2.1.0" } reqwest = { version = "0.11.9", default-features = false, features = ["json", "stream", "native-tls"] } tokio = { version = "1.24.2", features = ["full"] } -tokio-util = { version = "0.7", features = ["compat"] } +tokio-util = { version = "0.7", features = ["compat", "codec"] } flate2 = { version = "1.0.22" } zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] } regex = { version = "1.5.5" } @@ -54,6 +54,7 @@ thiserror = "1.0" cfg-if = "1.0.0" pin-project = "1.0" console = "0.15" +bytes = "1.4" [build-dependencies] serde = { version = "1.0" } diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 7c3771fe644..8a390930f9b 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -190,7 +190,7 @@ pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Resul let auth = Auth::new(&ctx.paths, ctx.log.clone()); let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths); dt.rename_tunnel(&rename_args.name).await?; - ctx.log.result(&format!( + ctx.log.result(format!( "Successfully renamed this gateway to {}", &rename_args.name )); @@ -287,7 +287,7 @@ pub async fn prune(ctx: CommandContext) -> Result { .filter(|s| s.get_running_pid().is_none()) .try_for_each(|s| { ctx.log - .result(&format!("Deleted {}", s.server_dir.display())); + .result(format!("Deleted {}", s.server_dir.display())); s.delete() }) .map_err(AnyError::from)?; diff --git a/cli/src/commands/update.rs b/cli/src/commands/update.rs index 80a57b12bb1..0d7321a814f 100644 --- a/cli/src/commands/update.rs +++ b/cli/src/commands/update.rs @@ -3,6 +3,8 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +use std::sync::Arc; + use indicatif::ProgressBar; use crate::{ @@ -17,7 +19,7 @@ use super::{args::StandaloneUpdateArgs, CommandContext}; pub async fn update(ctx: CommandContext, args: StandaloneUpdateArgs) -> Result { let update_service = UpdateService::new( ctx.log.clone(), - ReqwestSimpleHttp::with_client(ctx.http.clone()), + Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())), ); let update_service = SelfUpdate::new(&update_service)?; diff --git a/cli/src/commands/version.rs b/cli/src/commands/version.rs index c0d44fa1438..e80fa481c7b 100644 --- a/cli/src/commands/version.rs +++ b/cli/src/commands/version.rs @@ -58,5 +58,5 @@ pub async fn show(ctx: CommandContext) -> Result { } fn print_now_using(log: &log::Logger, version: &RequestedVersion, path: &Path) { - log.result(&format!("Now using {} from {}", version, path.display())); + log.result(format!("Now using {} from {}", version, path.display())); } diff --git a/cli/src/json_rpc.rs b/cli/src/json_rpc.rs index 083c4316542..57baac01c5e 100644 --- a/cli/src/json_rpc.rs +++ b/cli/src/json_rpc.rs @@ -50,7 +50,7 @@ pub async fn start_json_rpc( mut msg_rx: impl Receivable>, mut shutdown_rx: Barrier, ) -> io::Result> { - let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); + let (write_tx, mut write_rx) = mpsc::channel::>(8); let mut read = BufReader::new(read); let mut read_buf = String::new(); @@ -84,7 +84,18 @@ pub async fn start_json_rpc( let write_tx = write_tx.clone(); tokio::spawn(async move { if let Some(v) = fut.await { - write_tx.send(v).ok(); + let _ = write_tx.send(v).await; + } + }); + }, + MaybeSync::Stream((dto, fut)) => { + if let Some(dto) = dto { + dispatcher.register_stream(write_tx.clone(), dto).await; + } + let write_tx = write_tx.clone(); + tokio::spawn(async move { + if let Some(v) = fut.await { + let _ = write_tx.send(v).await; } }); } diff --git a/cli/src/log.rs b/cli/src/log.rs index 7162432062d..1bce766a871 100644 --- a/cli/src/log.rs +++ b/cli/src/log.rs @@ -27,21 +27,19 @@ pub fn next_counter() -> u32 { // Log level #[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)] +#[derive(Default)] pub enum Level { Trace = 0, Debug, - Info, + #[default] + Info, Warn, Error, Critical, Off, } -impl Default for Level { - fn default() -> Self { - Level::Info - } -} + impl fmt::Display for Level { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/cli/src/msgpack_rpc.rs b/cli/src/msgpack_rpc.rs index de46e738da8..18d5d0b9d2b 100644 --- a/cli/src/msgpack_rpc.rs +++ b/cli/src/msgpack_rpc.rs @@ -8,6 +8,7 @@ use tokio::{ pin, sync::mpsc, }; +use tokio_util::codec::Decoder; use crate::{ rpc::{self, MaybeSync, Serialization}, @@ -38,7 +39,6 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder { rpc::RpcBuilder::new(MsgPackSerializer {}) } -#[allow(clippy::read_zero_byte_vec)] // false positive pub async fn start_msgpack_rpc( dispatcher: rpc::RpcDispatcher, read: impl AsyncRead + Unpin, @@ -46,34 +46,45 @@ pub async fn start_msgpack_rpc( mut msg_rx: impl Receivable>, mut shutdown_rx: Barrier, ) -> io::Result> { - let (write_tx, mut write_rx) = mpsc::unbounded_channel::>(); + let (write_tx, mut write_rx) = mpsc::channel::>(8); let mut read = BufReader::new(read); - let mut decode_buf = vec![]; + let mut decoder = U32PrefixedCodec {}; + let mut decoder_buf = bytes::BytesMut::new(); let shutdown_fut = shutdown_rx.wait(); pin!(shutdown_fut); loop { tokio::select! { - u = read.read_u32() => { - let msg_length = u? as usize; - decode_buf.resize(msg_length, 0); - tokio::select! { - r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) { + r = read.read_buf(&mut decoder_buf) => { + r?; + + while let Some(frame) = decoder.decode(&mut decoder_buf)? { + match dispatcher.dispatch(&frame) { MaybeSync::Sync(Some(v)) => { - write_tx.send(v).ok(); + let _ = write_tx.send(v).await; }, MaybeSync::Sync(None) => continue, MaybeSync::Future(fut) => { let write_tx = write_tx.clone(); tokio::spawn(async move { if let Some(v) = fut.await { - write_tx.send(v).ok(); + let _ = write_tx.send(v).await; } }); } - }, - r = &mut shutdown_fut => return Ok(r.ok()), + MaybeSync::Stream((stream, fut)) => { + if let Some(stream) = stream { + dispatcher.register_stream(write_tx.clone(), stream).await; + } + let write_tx = write_tx.clone(); + tokio::spawn(async move { + if let Some(v) = fut.await { + let _ = write_tx.send(v).await; + } + }); + } + } }; }, Some(m) = write_rx.recv() => { @@ -88,3 +99,33 @@ pub async fn start_msgpack_rpc( write.flush().await?; } } + +/// Reader that reads length-prefixed msgpack messages in a cancellation-safe +/// way using Tokio's codecs. +pub struct U32PrefixedCodec {} + +const U32_SIZE: usize = 4; + +impl tokio_util::codec::Decoder for U32PrefixedCodec { + type Item = Vec; + type Error = io::Error; + + fn decode(&mut self, src: &mut bytes::BytesMut) -> Result, Self::Error> { + if src.len() < 4 { + src.reserve(U32_SIZE - src.len()); + return Ok(None); + } + + let mut be_bytes = [0; U32_SIZE]; + be_bytes.copy_from_slice(&src[..U32_SIZE]); + let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize); + if src.len() < required_len { + src.reserve(required_len - src.len()); + return Ok(None); + } + + let msg = src[U32_SIZE..].to_vec(); + src.resize(0, 0); + Ok(Some(msg)) + } +} diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index b5e5b53ee69..f3c68321590 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -15,17 +15,26 @@ use std::{ use crate::log; use futures::{future::BoxFuture, Future, FutureExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::sync::{mpsc, oneshot}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf}, + sync::{mpsc, oneshot}, +}; use crate::util::errors::AnyError; pub type SyncMethod = Arc, &[u8]) -> Option>>; pub type AsyncMethod = Arc, &[u8]) -> BoxFuture<'static, Option>>>; +pub type Duplex = Arc< + dyn Send + + Sync + + Fn(Option, &[u8]) -> (Option, BoxFuture<'static, Option>>), +>; pub enum Method { Sync(SyncMethod), Async(AsyncMethod), + Duplex(Duplex), } /// Serialization is given to the RpcBuilder and defines how data gets serialized @@ -81,6 +90,12 @@ pub struct RpcMethodBuilder { calls: Arc>>, } +#[derive(Serialize)] +struct DuplexStreamStarted { + pub for_request_id: u32, + pub stream_id: u32, +} + impl RpcMethodBuilder { /// Registers a synchronous rpc call that returns its result directly. pub fn register_sync(&mut self, method_name: &'static str, callback: F) @@ -179,14 +194,105 @@ impl RpcMethodBuilder { ); } + /// Registers an async rpc call that returns a Future containing a duplex + /// stream that should be handled by the client. + pub fn register_duplex(&mut self, method_name: &'static str, callback: F) + where + P: DeserializeOwned + Send + 'static, + R: Serialize + Send + Sync + 'static, + Fut: Future> + Send, + F: (Fn(DuplexStream, P, Arc) -> Fut) + Clone + Send + Sync + 'static, + { + let serial = self.serializer.clone(); + let context = self.context.clone(); + self.methods.insert( + method_name, + Method::Duplex(Arc::new(move |id, body| { + let param = match serial.deserialize::>(body) { + Ok(p) => p, + Err(err) => { + return ( + None, + future::ready(id.map(|id| { + serial.serialize(&ErrorResponse { + id, + error: ResponseError { + code: 0, + message: format!("{:?}", err), + }, + }) + })) + .boxed(), + ); + } + }; + + let callback = callback.clone(); + let serial = serial.clone(); + let context = context.clone(); + let stream_id = next_message_id(); + let (client, server) = tokio::io::duplex(8192); + + let fut = async move { + match callback(server, param.params, context).await { + Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })), + Err(err) => id.map(|id| { + serial.serialize(&ErrorResponse { + id, + error: ResponseError { + code: -1, + message: format!("{:?}", err), + }, + }) + }), + } + }; + + ( + Some(StreamDto { + req_id: id.unwrap_or(0), + stream_id, + duplex: client, + }), + fut.boxed(), + ) + })), + ); + } + /// Builds into a usable, sync rpc dispatcher. - pub fn build(self, log: log::Logger) -> RpcDispatcher { + pub fn build(mut self, log: log::Logger) -> RpcDispatcher { + let streams: Arc>>> = + Arc::new(tokio::sync::Mutex::new(HashMap::new())); + + let s1 = streams.clone(); + self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| { + let s1 = s1.clone(); + async move { + s1.lock().await.remove(&m.stream); + Ok(()) + } + }); + + let s2 = streams.clone(); + self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| { + let s2 = s2.clone(); + async move { + let mut lock = s2.lock().await; + if let Some(stream) = lock.get_mut(&m.stream) { + let _ = stream.write_all(&m.segment).await; + } + Ok(()) + } + }); + RpcDispatcher { log, context: self.context, calls: self.calls, serializer: self.serializer, methods: Arc::new(self.methods), + streams, } } } @@ -281,6 +387,7 @@ pub struct RpcDispatcher { serializer: Arc, methods: Arc>, calls: Arc>>, + streams: Arc>>>, } static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); @@ -310,6 +417,7 @@ impl RpcDispatcher { match method { Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)), Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)), + Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)), None => MaybeSync::Sync(id.map(|id| { self.serializer.serialize(&ErrorResponse { id, @@ -333,11 +441,91 @@ impl RpcDispatcher { } } + /// Registers a stream call returned from dispatch(). + pub async fn register_stream( + &self, + write_tx: mpsc::Sender> + Send>, + dto: StreamDto, + ) { + let stream_id = dto.stream_id; + let for_request_id = dto.req_id; + let (mut read, write) = tokio::io::split(dto.duplex); + let serial = self.serializer.clone(); + + self.streams.lock().await.insert(dto.stream_id, write); + + tokio::spawn(async move { + let r = write_tx + .send( + serial + .serialize(&FullRequest { + id: None, + method: METHOD_STREAM_STARTED, + params: DuplexStreamStarted { + stream_id, + for_request_id, + }, + }) + .into(), + ) + .await; + + if r.is_err() { + return; + } + + let mut buf = Vec::with_capacity(4096); + loop { + match read.read_buf(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => { + let r = write_tx + .send( + serial + .serialize(&FullRequest { + id: None, + method: METHOD_STREAM_DATA, + params: StreamDataParams { + segment: &buf[..n], + stream: stream_id, + }, + }) + .into(), + ) + .await; + + if r.is_err() { + return; + } + + buf.truncate(0); + } + } + } + + let _ = write_tx + .send( + serial + .serialize(&FullRequest { + id: None, + method: METHOD_STREAM_ENDED, + params: StreamEndedParams { stream: stream_id }, + }) + .into(), + ) + .await; + }); + } + pub fn context(&self) -> Arc { self.context.clone() } } +const METHOD_STREAM_STARTED: &str = "stream_started"; +const METHOD_STREAM_DATA: &str = "stream_data"; +const METHOD_STREAM_ENDED: &str = "stream_ended"; + trait AssertIsSync: Sync {} impl AssertIsSync for RpcDispatcher {} @@ -349,6 +537,25 @@ struct PartialIncoming { pub error: Option, } +#[derive(Deserialize)] +struct StreamDataIncomingParams { + #[serde(with = "serde_bytes")] + pub segment: Vec, + pub stream: u32, +} + +#[derive(Serialize, Deserialize)] +struct StreamDataParams<'a> { + #[serde(with = "serde_bytes")] + pub segment: &'a [u8], + pub stream: u32, +} + +#[derive(Serialize, Deserialize)] +struct StreamEndedParams { + pub stream: u32, +} + #[derive(Serialize)] pub struct FullRequest, P> { pub id: Option, @@ -384,7 +591,14 @@ enum Outcome { Error(ResponseError), } +pub struct StreamDto { + stream_id: u32, + req_id: u32, + duplex: DuplexStream, +} + pub enum MaybeSync { + Stream((Option, BoxFuture<'static, Option>>)), Future(BoxFuture<'static, Option>>), Sync(Option>), } diff --git a/cli/src/self_update.rs b/cli/src/self_update.rs index 62228a5b3d1..33201a345e3 100644 --- a/cli/src/self_update.rs +++ b/cli/src/self_update.rs @@ -86,8 +86,8 @@ impl<'a> SelfUpdate<'a> { // Try to rename the old CLI to the tempdir, where it can get cleaned up by the // OS later. However, this can fail if the tempdir is on a different drive // than the installation dir. In this case just rename it to ".old". - if fs::rename(&target_path, &tempdir.path().join("old-code-cli")).is_err() { - fs::rename(&target_path, &target_path.with_extension(".old")) + if fs::rename(&target_path, tempdir.path().join("old-code-cli")).is_err() { + fs::rename(&target_path, target_path.with_extension(".old")) .map_err(|e| wrap(e, "failed to rename old CLI"))?; } @@ -132,7 +132,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res let archive_file = unzipped_files[0] .as_ref() .map_err(|e| wrap(e, "error listing update files"))?; - fs::copy(&archive_file.path(), staging_path) + fs::copy(archive_file.path(), staging_path) .map_err(|e| wrap(e, "error copying to staging file"))?; Ok(()) } @@ -140,7 +140,7 @@ fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Res #[cfg(target_os = "windows")] fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> { let permissions = from.metadata()?.permissions(); - fs::set_permissions(&to, permissions)?; + fs::set_permissions(to, permissions)?; Ok(()) } diff --git a/cli/src/tunnels/code_server.rs b/cli/src/tunnels/code_server.rs index 1f9eb4ea85c..677bbfc2546 100644 --- a/cli/src/tunnels/code_server.rs +++ b/cli/src/tunnels/code_server.rs @@ -16,7 +16,7 @@ use crate::util::command::{capture_command, kill_tree}; use crate::util::errors::{ wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError, }; -use crate::util::http::{self, SimpleHttp}; +use crate::util::http::{self, BoxedHttp}; use crate::util::io::SilentCopyProgress; use crate::util::machine::process_exists; use crate::{debug, info, log, span, spanf, trace, warning}; @@ -176,7 +176,7 @@ impl ServerParamsRaw { pub async fn resolve( self, log: &log::Logger, - http: impl SimpleHttp + Send + Sync + 'static, + http: BoxedHttp, ) -> Result { Ok(ResolvedServerParams { release: self.get_or_fetch_commit_id(log, http).await?, @@ -187,7 +187,7 @@ impl ServerParamsRaw { async fn get_or_fetch_commit_id( &self, log: &log::Logger, - http: impl SimpleHttp + Send + Sync + 'static, + http: BoxedHttp, ) -> Result { let target = match self.headless { true => TargetKind::Server, @@ -287,7 +287,7 @@ async fn install_server_if_needed( log: &log::Logger, paths: &ServerPaths, release: &Release, - http: impl SimpleHttp + Send + Sync + 'static, + http: BoxedHttp, existing_archive_path: Option, ) -> Result<(), AnyError> { if paths.executable.exists() { @@ -321,7 +321,7 @@ async fn download_server( path: &Path, release: &Release, log: &log::Logger, - http: impl SimpleHttp + Send + Sync + 'static, + http: BoxedHttp, ) -> Result { let response = UpdateService::new(log.clone(), http) .get_download_stream(release) @@ -403,20 +403,20 @@ async fn do_extension_install_on_running_server( } } -pub struct ServerBuilder<'a, Http: SimpleHttp + Send + Sync + Clone> { +pub struct ServerBuilder<'a> { logger: &'a log::Logger, server_params: &'a ResolvedServerParams, last_used: LastUsedServers<'a>, server_paths: ServerPaths, - http: Http, + http: BoxedHttp, } -impl<'a, Http: SimpleHttp + Send + Sync + Clone + 'static> ServerBuilder<'a, Http> { +impl<'a> ServerBuilder<'a> { pub fn new( logger: &'a log::Logger, server_params: &'a ResolvedServerParams, launcher_paths: &'a LauncherPaths, - http: Http, + http: BoxedHttp, ) -> Self { Self { logger, diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index 8c2d1a76d1a..bf8ce00380f 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -5,6 +5,7 @@ use crate::async_pipe::get_socket_rw_stream; use crate::constants::CONTROL_PORT; use crate::log; +use crate::msgpack_rpc::U32PrefixedCodec; use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization}; use crate::self_update::SelfUpdate; use crate::state::LauncherPaths; @@ -12,7 +13,8 @@ use crate::tunnels::protocol::HttpRequestParams; use crate::tunnels::socket_signal::CloseReason; use crate::update_service::{Platform, UpdateService}; use crate::util::errors::{ - wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError, + wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError, + NoAttachedServerError, }; use crate::util::http::{ DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp, @@ -24,11 +26,14 @@ use crate::util::sync::{new_barrier, Barrier}; use opentelemetry::trace::SpanKind; use opentelemetry::KeyValue; use std::collections::HashMap; +use std::process::Stdio; +use tokio::pin; +use tokio_util::codec::Decoder; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Instant; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream}; use tokio::sync::{mpsc, Mutex}; use super::code_server::{ @@ -40,8 +45,8 @@ use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, - ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, - VersionParams, + ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, + UpdateResult, VersionParams, }; use super::server_bridge::ServerBridge; use super::server_multiplexer::ServerMultiplexer; @@ -73,7 +78,7 @@ struct HandlerContext { /// install platform for the VS Code server platform: Platform, /// http client to make download/update requests - http: FallbackSimpleHttp, + http: Arc, /// requests being served by the client http_requests: HttpRequestsMap, } @@ -196,7 +201,7 @@ pub async fn serve( ], ); cx.span().end(); - }); + }); } } } @@ -247,7 +252,10 @@ async fn process_socket( server_bridges: server_bridges.clone(), port_forwarding, platform, - http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated), + http: Arc::new(FallbackSimpleHttp::new( + ReqwestSimpleHttp::new(), + http_delegated, + )), http_requests: http_requests.clone(), }); @@ -276,6 +284,9 @@ async fn process_socket( rpc.register_async("unforward", |p: UnforwardParams, c| async move { handle_unforward(&c.log, &c.port_forwarding, p).await }); + rpc.register_duplex("spawn", |stream, p: SpawnParams, c| async move { + handle_spawn(&c.log, stream, p).await + }); rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| { if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) { req.initial_response(p.status_code, p.headers); @@ -393,20 +404,20 @@ async fn handle_socket_read( rx_counter: Arc, rpc: &RpcDispatcher, ) -> Result<(), std::io::Error> { - let mut socket_reader = BufReader::new(readhalf); - let mut decode_buf = vec![]; + let mut readhalf = BufReader::new(readhalf); + let mut decoder = U32PrefixedCodec {}; + let mut decoder_buf = bytes::BytesMut::new(); loop { - let read = read_next( - &mut socket_reader, - &rx_counter, - &mut closer, - &mut decode_buf, - ) - .await; + let read_len = tokio::select! { + r = readhalf.read_buf(&mut decoder_buf) => r, + _ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")), + }?; - match read { - Ok(len) => match rpc.dispatch(&decode_buf[..len]) { + rx_counter.fetch_add(read_len, Ordering::Relaxed); + + while let Some(frame) = decoder.decode(&mut decoder_buf)? { + match rpc.dispatch(&frame) { MaybeSync::Sync(Some(v)) => { if socket_tx.send(SocketSignal::Send(v)).await.is_err() { return Ok(()); @@ -421,34 +432,22 @@ async fn handle_socket_read( } }); } - }, - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()), - Err(e) => return Err(e), + MaybeSync::Stream((stream, fut)) => { + if let Some(stream) = stream { + rpc.register_stream(socket_tx.clone(), stream).await; + } + let socket_tx = socket_tx.clone(); + tokio::spawn(async move { + if let Some(v) = fut.await { + socket_tx.send(SocketSignal::Send(v)).await.ok(); + } + }); + } + } } } } -/// Reads and handles the next data packet. Returns the next packet to dispatch, -/// or an error (including EOF). -async fn read_next( - socket_reader: &mut BufReader, - rx_counter: &Arc, - closer: &mut Barrier<()>, - decode_buf: &mut Vec, -) -> Result { - let msg_length = tokio::select! { - u = socket_reader.read_u32() => u? as usize, - _ = closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")), - }; - decode_buf.resize(msg_length, 0); - rx_counter.fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed); - - tokio::select! { - r = socket_reader.read_exact(decode_buf) => r, - _ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")), - } -} - #[derive(Clone)] struct ServerOutputSink { tx: mpsc::Sender, @@ -487,7 +486,9 @@ async fn handle_serve( }; let resolved = if params.use_local_download { - params_raw.resolve(&c.log, c.http.delegated()).await + params_raw + .resolve(&c.log, Arc::new(c.http.delegated())) + .await } else { params_raw.resolve(&c.log, c.http.clone()).await }?; @@ -518,7 +519,7 @@ async fn handle_serve( &install_log, &resolved, &c.launcher_paths, - c.http.delegated(), + Arc::new(c.http.delegated()), ); do_setup!(sb) } else { @@ -606,7 +607,7 @@ fn handle_prune(paths: &LauncherPaths) -> Result, AnyError> { } async fn handle_update( - http: &FallbackSimpleHttp, + http: &Arc, log: &log::Logger, did_update: &AtomicBool, params: &UpdateParams, @@ -732,3 +733,83 @@ async fn handle_call_server_http( .to_vec(), }) } + +async fn handle_spawn( + log: &log::Logger, + mut duplex: DuplexStream, + params: SpawnParams, +) -> Result { + debug!( + log, + "requested to spawn {} with args {:?}", params.command, params.args + ); + + let mut p = tokio::process::Command::new(¶ms.command) + .args(¶ms.args) + .envs(¶ms.env) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(CodeError::ProcessSpawnFailed)?; + + let mut stdout = p.stdout.take().unwrap(); + let mut stderr = p.stderr.take().unwrap(); + let mut stdin = p.stdin.take().unwrap(); + let (tx, mut rx) = mpsc::channel(4); + + macro_rules! copy_stream_to { + ($target:expr) => { + let tx = tx.clone(); + tokio::spawn(async move { + let mut buf = vec![0; 4096]; + loop { + let n = match $target.read(&mut buf).await { + Ok(0) | Err(_) => return, + Ok(n) => n, + }; + if !tx.send(buf[..n].to_vec()).await.is_ok() { + return; + } + } + }); + }; + } + + copy_stream_to!(stdout); + copy_stream_to!(stderr); + + let mut stdin_buf = vec![0; 4096]; + let closed = p.wait(); + pin!(closed); + + loop { + tokio::select! { + Ok(n) = duplex.read(&mut stdin_buf) => { + let _ = stdin.write_all(&stdin_buf[..n]).await; + }, + Some(m) = rx.recv() => { + let _ = duplex.write_all(&m).await; + }, + r = &mut closed => { + let r = match r { + Ok(e) => SpawnResult { + message: e.to_string(), + exit_code: e.code().unwrap_or(-1), + }, + Err(e) => SpawnResult { + message: e.to_string(), + exit_code: -1, + }, + }; + + debug!( + log, + "spawned command {} exited with code {}", params.command, r.exit_code + ); + + return Ok(r) + }, + } + } +} diff --git a/cli/src/tunnels/paths.rs b/cli/src/tunnels/paths.rs index 3c47b2575d7..cdf6cef6f51 100644 --- a/cli/src/tunnels/paths.rs +++ b/cli/src/tunnels/paths.rs @@ -68,7 +68,7 @@ impl ServerPaths { // VS Code Server pid pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> { - write(&self.pidfile, &format!("{}", pid)).map_err(|e| { + write(&self.pidfile, format!("{}", pid)).map_err(|e| { wrap( e, format!("error writing process id into {}", self.pidfile.display()), diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index ef033415d93..89f9c3acb28 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -158,6 +158,20 @@ impl Default for VersionParams { } } +#[derive(Deserialize)] +pub struct SpawnParams { + pub command: String, + pub args: Vec, + #[serde(default)] + pub env: HashMap, +} + +#[derive(Serialize)] +pub struct SpawnResult { + pub message: String, + pub exit_code: i32, +} + pub mod singleton { use crate::log; use serde::{Deserialize, Serialize}; diff --git a/cli/src/tunnels/service_windows.rs b/cli/src/tunnels/service_windows.rs index d230d2e454f..e557499364c 100644 --- a/cli/src/tunnels/service_windows.rs +++ b/cli/src/tunnels/service_windows.rs @@ -59,7 +59,7 @@ impl CliServiceManager for WindowsService { }; for arg in args { - add_arg(*arg); + add_arg(arg); } add_arg("--log-to-file"); diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs index c625593a21a..a3d3b08a5d4 100644 --- a/cli/src/tunnels/socket_signal.rs +++ b/cli/src/tunnels/socket_signal.rs @@ -22,6 +22,12 @@ pub enum SocketSignal { CloseWith(CloseReason), } +impl From> for SocketSignal { + fn from(v: Vec) -> Self { + SocketSignal::Send(v) + } +} + impl SocketSignal { pub fn from_message(msg: &T) -> Self where diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs index b6250c8247f..69d7eb94389 100644 --- a/cli/src/tunnels/wsl_server.rs +++ b/cli/src/tunnels/wsl_server.rs @@ -3,6 +3,8 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +use std::sync::Arc; + use tokio::sync::mpsc; use crate::{ @@ -139,7 +141,12 @@ async fn handle_serve( }, }; - let sb = ServerBuilder::new(&c.log, &resolved, &c.launcher_paths, c.http.clone()); + let sb = ServerBuilder::new( + &c.log, + &resolved, + &c.launcher_paths, + Arc::new(c.http.clone()), + ); let code_server = match sb.get_running().await? { Some(AnyCodeServer::Socket(s)) => s, Some(_) => return Err(MismatchedLaunchModeError().into()), diff --git a/cli/src/update_service.rs b/cli/src/update_service.rs index 9dcfc0f5107..e56d3781804 100644 --- a/cli/src/update_service.rs +++ b/cli/src/update_service.rs @@ -3,7 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::path::Path; +use std::{fmt, path::Path}; use serde::Deserialize; @@ -11,19 +11,20 @@ use crate::{ constants::VSCODE_CLI_UPDATE_ENDPOINT, debug, log, options, spanf, util::{ - errors::{AnyError, UnsupportedPlatformError, UpdatesNotConfigured, WrappedError}, - http::{SimpleHttp, SimpleResponse}, + errors::{AnyError, CodeError, UpdatesNotConfigured, WrappedError}, + http::{BoxedHttp, SimpleResponse}, io::ReportCopyProgress, }, }; /// Implementation of the VS Code Update service for use in the CLI. pub struct UpdateService { - client: Box, + client: BoxedHttp, log: log::Logger, } /// Describes a specific release, can be created manually or returned from the update service. +#[derive(Clone, Eq, PartialEq)] pub struct Release { pub name: String, pub platform: Platform, @@ -53,11 +54,8 @@ fn quality_download_segment(quality: options::Quality) -> &'static str { } impl UpdateService { - pub fn new(log: log::Logger, http: impl SimpleHttp + Send + Sync + 'static) -> Self { - UpdateService { - client: Box::new(http), - log, - } + pub fn new(log: log::Logger, http: BoxedHttp) -> Self { + UpdateService { client: http, log } } pub async fn get_release_by_semver_version( @@ -71,7 +69,7 @@ impl UpdateService { VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?; let download_segment = target .download_segment(platform) - .ok_or(UnsupportedPlatformError())?; + .ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?; let download_url = format!( "{}/api/versions/{}/{}/{}", update_endpoint, @@ -113,7 +111,7 @@ impl UpdateService { VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?; let download_segment = target .download_segment(platform) - .ok_or(UnsupportedPlatformError())?; + .ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?; let download_url = format!( "{}/api/latest/{}/{}", update_endpoint, @@ -150,7 +148,7 @@ impl UpdateService { let download_segment = release .target .download_segment(release.platform) - .ok_or(UnsupportedPlatformError())?; + .ok_or_else(|| CodeError::UnsupportedPlatform(release.platform.to_string()))?; let download_url = format!( "{}/commit:{}/{}/{}", @@ -208,7 +206,7 @@ impl TargetKind { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Platform { LinuxAlpineX64, LinuxAlpineARM64, @@ -306,3 +304,20 @@ impl Platform { } } } + +impl fmt::Display for Platform { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + Platform::LinuxAlpineARM64 => "LinuxAlpineARM64", + Platform::LinuxAlpineX64 => "LinuxAlpineX64", + Platform::LinuxX64 => "LinuxX64", + Platform::LinuxARM64 => "LinuxARM64", + Platform::LinuxARM32 => "LinuxARM32", + Platform::DarwinX64 => "DarwinX64", + Platform::DarwinARM64 => "DarwinARM64", + Platform::WindowsX64 => "WindowsX64", + Platform::WindowsX86 => "WindowsX86", + Platform::WindowsARM64 => "WindowsARM64", + }) + } +} diff --git a/cli/src/util/command.rs b/cli/src/util/command.rs index c0434b10647..ad1f3a1d13e 100644 --- a/cli/src/util/command.rs +++ b/cli/src/util/command.rs @@ -2,29 +2,47 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use super::errors::{wrap, AnyError, CommandFailed, WrappedError}; -use std::{borrow::Cow, ffi::OsStr, process::Stdio}; +use super::errors::CodeError; +use std::{ + borrow::Cow, + ffi::OsStr, + process::{Output, Stdio}, +}; use tokio::process::Command; pub async fn capture_command_and_check_status( command_str: impl AsRef, args: &[impl AsRef], -) -> Result { +) -> Result { let output = capture_command(&command_str, args).await?; + check_output_status(output, || { + format!( + "{} {}", + command_str.as_ref().to_string_lossy(), + args.iter() + .map(|a| a.as_ref().to_string_lossy()) + .collect::>>() + .join(" ") + ) + }) +} + +pub fn check_output_status( + output: Output, + cmd_str: impl FnOnce() -> String, +) -> Result { if !output.status.success() { - return Err(CommandFailed { - command: format!( - "{} {}", - command_str.as_ref().to_string_lossy(), - args.iter() - .map(|a| a.as_ref().to_string_lossy()) - .collect::>>() - .join(" ") - ), - output, - } - .into()); + return Err(CodeError::CommandFailed { + command: cmd_str(), + code: output.status.code().unwrap_or(-1), + output: String::from_utf8_lossy(if output.stderr.is_empty() { + &output.stdout + } else { + &output.stderr + }) + .into(), + }); } Ok(output) @@ -33,7 +51,7 @@ pub async fn capture_command_and_check_status( pub async fn capture_command( command_str: A, args: I, -) -> Result +) -> Result where A: AsRef, I: IntoIterator, @@ -45,27 +63,23 @@ where .stdout(Stdio::piped()) .output() .await - .map_err(|e| { - wrap( - e, - format!( - "failed to execute command '{}'", - command_str.as_ref().to_string_lossy() - ), - ) + .map_err(|e| CodeError::CommandFailed { + command: command_str.as_ref().to_string_lossy().to_string(), + code: -1, + output: e.to_string(), }) } /// Kills and processes and all of its children. #[cfg(target_os = "windows")] -pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> { +pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> { capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?; Ok(()) } /// Kills and processes and all of its children. #[cfg(not(target_os = "windows"))] -pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> { +pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> { use futures::future::join_all; use tokio::io::{AsyncBufReadExt, BufReader}; @@ -82,7 +96,11 @@ pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> { .stdin(Stdio::null()) .stdout(Stdio::piped()) .spawn() - .map_err(|e| wrap(e, "error enumerating process tree"))?; + .map_err(|e| CodeError::CommandFailed { + command: format!("pgrep -P {}", parent_id), + code: -1, + output: e.to_string(), + })?; let mut kill_futures = vec![tokio::spawn( async move { kill_single_pid(parent_id).await }, diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index 0ad128d7cf0..b5ca1066046 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -258,18 +258,6 @@ impl std::fmt::Display for RefreshTokenNotAvailableError { } } -#[derive(Debug)] -pub struct UnsupportedPlatformError(); - -impl std::fmt::Display for UnsupportedPlatformError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "This operation is not supported on your current platform" - ) - } -} - #[derive(Debug)] pub struct NoInstallInUserProvidedPath(pub String); @@ -419,28 +407,6 @@ impl std::fmt::Display for OAuthError { } } -#[derive(Debug)] -pub struct CommandFailed { - pub output: std::process::Output, - pub command: String, -} - -impl std::fmt::Display for CommandFailed { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "Failed to run command \"{}\" (code {}): {}", - self.command, - self.output.status, - String::from_utf8_lossy(if self.output.stderr.is_empty() { - &self.output.stdout - } else { - &self.output.stderr - }) - ) - } -} - // Makes an "AnyError" enum that contains any of the given errors, in the form // `enum AnyError { FooError(FooError) }` (when given `makeAnyError!(FooError)`). // Useful to easily deal with application error types without making tons of "From" @@ -500,6 +466,20 @@ pub enum CodeError { #[cfg(windows)] #[error("could not get windows app lock: {0:?}")] AppLockFailed(std::io::Error), + #[error("failed to run command \"{command}\" (code {code}): {output}")] + CommandFailed { + command: String, + code: i32, + output: String, + }, + + #[error("platform not currently supported: {0}")] + UnsupportedPlatform(String), + #[error("This machine not meet {name}'s prerequisites, expected either...: {bullets}")] + PrerequisitesFailed { name: &'static str, bullets: String }, + + #[error("failed to spawn process: {0:?}")] + ProcessSpawnFailed(std::io::Error) } makeAnyError!( @@ -518,7 +498,6 @@ makeAnyError!( ExtensionInstallFailed, MismatchedLaunchModeError, NoAttachedServerError, - UnsupportedPlatformError, RefreshTokenNotAvailableError, NoInstallInUserProvidedPath, UserCancelledInstallation, @@ -530,7 +509,6 @@ makeAnyError!( UpdatesNotConfigured, CorruptDownload, MissingHomeDirectory, - CommandFailed, OAuthError, InvalidRpcDataError, CodeError diff --git a/cli/src/util/http.rs b/cli/src/util/http.rs index 16681c07596..953dba678c3 100644 --- a/cli/src/util/http.rs +++ b/cli/src/util/http.rs @@ -16,7 +16,7 @@ use hyper::{ HeaderMap, StatusCode, }; use serde::de::DeserializeOwned; -use std::{io, pin::Pin, str::FromStr, task::Poll}; +use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll}; use tokio::{ fs, io::{AsyncRead, AsyncReadExt}, @@ -116,6 +116,8 @@ pub trait SimpleHttp { ) -> Result; } +pub type BoxedHttp = Arc; + // Implementation of SimpleHttp that uses a reqwest client. #[derive(Clone)] pub struct ReqwestSimpleHttp { @@ -324,7 +326,6 @@ impl AsyncRead for DelegatedReader { /// Simple http implementation that falls back to delegated http if /// making a direct reqwest fails. -#[derive(Clone)] pub struct FallbackSimpleHttp { native: ReqwestSimpleHttp, delegated: DelegatedSimpleHttp, diff --git a/cli/src/util/prereqs.rs b/cli/src/util/prereqs.rs index 5e5c6db7c15..d8cbd1b91dd 100644 --- a/cli/src/util/prereqs.rs +++ b/cli/src/util/prereqs.rs @@ -7,13 +7,12 @@ use std::cmp::Ordering; use super::command::capture_command; use crate::constants::QUALITYLESS_SERVER_NAME; use crate::update_service::Platform; -use crate::util::errors::SetupError; use lazy_static::lazy_static; use regex::bytes::Regex as BinRegex; use regex::Regex; use tokio::fs; -use super::errors::AnyError; +use super::errors::CodeError; lazy_static! { static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap(); @@ -41,19 +40,18 @@ impl PreReqChecker { } #[cfg(not(target_os = "linux"))] - pub async fn verify(&self) -> Result { - use crate::constants::QUALITYLESS_PRODUCT_NAME; + pub async fn verify(&self) -> Result { Platform::env_default().ok_or_else(|| { - SetupError(format!( - "{} is not supported on this platform", - QUALITYLESS_PRODUCT_NAME + CodeError::UnsupportedPlatform(format!( + "{} {}", + std::env::consts::OS, + std::env::consts::ARCH )) - .into() }) } #[cfg(target_os = "linux")] - pub async fn verify(&self) -> Result { + pub async fn verify(&self) -> Result { let (is_nixos, gnu_a, gnu_b, or_musl) = tokio::join!( check_is_nixos(), check_glibc_version(), @@ -96,10 +94,10 @@ impl PreReqChecker { .collect::>() .join("\n"); - Err(AnyError::from(SetupError(format!( - "This machine not meet {}'s prerequisites, expected either...\n{}", - QUALITYLESS_SERVER_NAME, bullets, - )))) + Err(CodeError::PrerequisitesFailed { + bullets, + name: QUALITYLESS_SERVER_NAME, + }) } } diff --git a/cli/src/util/sync.rs b/cli/src/util/sync.rs index 8b653cd2d53..2b506bd54e3 100644 --- a/cli/src/util/sync.rs +++ b/cli/src/util/sync.rs @@ -4,9 +4,11 @@ *--------------------------------------------------------------------------------------------*/ use async_trait::async_trait; use std::{marker::PhantomData, sync::Arc}; -use tokio::sync::{ - broadcast, mpsc, - watch::{self, error::RecvError}, +use tokio::{ + sync::{ + broadcast, mpsc, + watch::{self, error::RecvError}, + }, }; #[derive(Clone)]