diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 97affef0a64..6b5c8d07c3f 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -55,6 +55,7 @@ cfg-if = "1.0.0" pin-project = "1.0" console = "0.15" bytes = "1.4" +tar = { version = "0.4" } [build-dependencies] serde = { version = "1.0" } @@ -68,7 +69,6 @@ winapi = "0.3.9" core-foundation = "0.9.3" [target.'cfg(target_os = "linux")'.dependencies] -tar = { version = "0.4" } zbus = { version = "3.4", default-features = false, features = ["tokio"] } [patch.crates-io] diff --git a/cli/src/download_cache.rs b/cli/src/download_cache.rs new file mode 100644 index 00000000000..869fcf62357 --- /dev/null +++ b/cli/src/download_cache.rs @@ -0,0 +1,119 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use std::{ + fs::create_dir_all, + path::{Path, PathBuf}, +}; + +use futures::Future; +use tokio::fs::remove_dir_all; + +use crate::{ + state::PersistedState, + util::errors::{wrap, AnyError, WrappedError}, +}; + +const KEEP_LRU: usize = 5; +const STAGING_SUFFIX: &str = ".staging"; + +#[derive(Clone)] +pub struct DownloadCache { + path: PathBuf, + state: PersistedState>, +} + +impl DownloadCache { + pub fn new(path: PathBuf) -> DownloadCache { + DownloadCache { + state: PersistedState::new(path.join("lru.json")), + path, + } + } + + /// Gets the download cache path. Names of cache entries can be formed by + /// joining them to the path. + pub fn path(&self) -> &Path { + &self.path + } + + /// Gets whether a cache exists with the name already. Marks it as recently + /// used if it does exist. + pub fn exists(&self, name: &str) -> Option { + let p = self.path.join(name); + if !p.exists() { + return None; + } + + let _ = self.touch(name.to_string()); + Some(p) + } + + /// Removes the item from the cache, if it exists + pub fn delete(&self, name: &str) -> Result<(), WrappedError> { + let f = self.path.join(name); + if f.exists() { + std::fs::remove_dir_all(f).map_err(|e| wrap(e, "error removing cached folder"))?; + } + + self.state.update(|l| { + l.retain(|n| n != name); + }) + } + + /// Calls the function to create the cached folder if it doesn't exist, + /// returning the path where the folder is. Note that the path passed to + /// the `do_create` method is a staging path and will not be the same as the + /// final returned path. + pub async fn create( + &self, + name: impl AsRef, + do_create: F, + ) -> Result + where + F: FnOnce(PathBuf) -> T, + T: Future> + Send, + { + let name = name.as_ref(); + let target_dir = self.path.join(name); + if target_dir.exists() { + return Ok(target_dir); + } + + let temp_dir = self.path.join(format!("{}{}", name, STAGING_SUFFIX)); + let _ = remove_dir_all(&temp_dir).await; // cleanup any existing + + create_dir_all(&temp_dir).map_err(|e| wrap(e, "error creating server directory"))?; + do_create(temp_dir.clone()).await?; + + let _ = self.touch(name.to_string()); + std::fs::rename(&temp_dir, &target_dir) + .map_err(|e| wrap(e, "error renaming downloaded server"))?; + + Ok(target_dir) + } + + fn touch(&self, name: String) -> Result<(), AnyError> { + self.state.update(|l| { + if let Some(index) = l.iter().position(|s| s == &name) { + l.remove(index); + } + l.insert(0, name); + + if l.len() <= KEEP_LRU { + return; + } + + if let Some(f) = l.last() { + let f = self.path.join(f); + if !f.exists() || std::fs::remove_dir_all(f).is_ok() { + l.pop(); + } + } + })?; + + Ok(()) + } +} diff --git a/cli/src/lib.rs b/cli/src/lib.rs index 0fe65c9d8c8..93777852df2 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -18,6 +18,7 @@ pub mod tunnels; pub mod update_service; pub mod util; +mod download_cache; mod async_pipe; mod json_rpc; mod msgpack_rpc; diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index f3c68321590..28dfc0efb47 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -93,7 +93,7 @@ pub struct RpcMethodBuilder { #[derive(Serialize)] struct DuplexStreamStarted { pub for_request_id: u32, - pub stream_id: u32, + pub stream_ids: Vec, } impl RpcMethodBuilder { @@ -196,12 +196,16 @@ impl RpcMethodBuilder { /// Registers an async rpc call that returns a Future containing a duplex /// stream that should be handled by the client. - pub fn register_duplex(&mut self, method_name: &'static str, callback: F) - where + pub fn register_duplex( + &mut self, + method_name: &'static str, + streams: usize, + callback: F, + ) where P: DeserializeOwned + Send + 'static, R: Serialize + Send + Sync + 'static, Fut: Future> + Send, - F: (Fn(DuplexStream, P, Arc) -> Fut) + Clone + Send + Sync + 'static, + F: (Fn(Vec, P, Arc) -> Fut) + Clone + Send + Sync + 'static, { let serial = self.serializer.clone(); let context = self.context.clone(); @@ -230,11 +234,21 @@ impl RpcMethodBuilder { let callback = callback.clone(); let serial = serial.clone(); let context = context.clone(); - let stream_id = next_message_id(); - let (client, server) = tokio::io::duplex(8192); + + let mut dto = StreamDto { + req_id: id.unwrap_or(0), + streams: Vec::with_capacity(streams), + }; + let mut servers = Vec::with_capacity(streams); + + for _ in 0..streams { + let (client, server) = tokio::io::duplex(8192); + servers.push(server); + dto.streams.push((next_message_id(), client)); + } let fut = async move { - match callback(server, param.params, context).await { + match callback(servers, param.params, context).await { Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })), Err(err) => id.map(|id| { serial.serialize(&ErrorResponse { @@ -248,14 +262,7 @@ impl RpcMethodBuilder { } }; - ( - Some(StreamDto { - req_id: id.unwrap_or(0), - stream_id, - duplex: client, - }), - fut.boxed(), - ) + (Some(dto), fut.boxed()) })), ); } @@ -447,74 +454,73 @@ impl RpcDispatcher { write_tx: mpsc::Sender> + Send>, dto: StreamDto, ) { - let stream_id = dto.stream_id; - let for_request_id = dto.req_id; - let (mut read, write) = tokio::io::split(dto.duplex); - let serial = self.serializer.clone(); + let r = write_tx + .send( + self.serializer + .serialize(&FullRequest { + id: None, + method: METHOD_STREAMS_STARTED, + params: DuplexStreamStarted { + stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(), + for_request_id: dto.req_id, + }, + }) + .into(), + ) + .await; - self.streams.lock().await.insert(dto.stream_id, write); + if r.is_err() { + return; + } - tokio::spawn(async move { - let r = write_tx - .send( - serial - .serialize(&FullRequest { - id: None, - method: METHOD_STREAM_STARTED, - params: DuplexStreamStarted { - stream_id, - for_request_id, - }, - }) - .into(), - ) - .await; + let mut streams_map = self.streams.lock().await; + for (stream_id, duplex) in dto.streams { + let (mut read, write) = tokio::io::split(duplex); + streams_map.insert(stream_id, write); - if r.is_err() { - return; - } + let write_tx = write_tx.clone(); + let serial = self.serializer.clone(); + tokio::spawn(async move { + let mut buf = vec![0; 4096]; + loop { + match read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => { + let r = write_tx + .send( + serial + .serialize(&FullRequest { + id: None, + method: METHOD_STREAM_DATA, + params: StreamDataParams { + segment: &buf[..n], + stream: stream_id, + }, + }) + .into(), + ) + .await; - let mut buf = Vec::with_capacity(4096); - loop { - match read.read_buf(&mut buf).await { - Ok(0) | Err(_) => break, - Ok(n) => { - let r = write_tx - .send( - serial - .serialize(&FullRequest { - id: None, - method: METHOD_STREAM_DATA, - params: StreamDataParams { - segment: &buf[..n], - stream: stream_id, - }, - }) - .into(), - ) - .await; - - if r.is_err() { - return; + if r.is_err() { + return; + } } - - buf.truncate(0); } } - } - let _ = write_tx - .send( - serial - .serialize(&FullRequest { - id: None, - method: METHOD_STREAM_ENDED, - params: StreamEndedParams { stream: stream_id }, - }) - .into(), - ) - .await; - }); + let _ = write_tx + .send( + serial + .serialize(&FullRequest { + id: None, + method: METHOD_STREAM_ENDED, + params: StreamEndedParams { stream: stream_id }, + }) + .into(), + ) + .await; + }); + } } pub fn context(&self) -> Arc { @@ -522,7 +528,7 @@ impl RpcDispatcher { } } -const METHOD_STREAM_STARTED: &str = "stream_started"; +const METHOD_STREAMS_STARTED: &str = "streams_started"; const METHOD_STREAM_DATA: &str = "stream_data"; const METHOD_STREAM_ENDED: &str = "stream_ended"; @@ -592,9 +598,8 @@ enum Outcome { } pub struct StreamDto { - stream_id: u32, req_id: u32, - duplex: DuplexStream, + streams: Vec<(u32, DuplexStream)>, } pub enum MaybeSync { diff --git a/cli/src/self_update.rs b/cli/src/self_update.rs index 33201a345e3..2e95719a3b9 100644 --- a/cli/src/self_update.rs +++ b/cli/src/self_update.rs @@ -65,8 +65,8 @@ impl<'a> SelfUpdate<'a> { ) -> Result<(), AnyError> { // 1. Download the archive into a temporary directory let tempdir = tempdir().map_err(|e| wrap(e, "Failed to create temp dir"))?; - let archive_path = tempdir.path().join("archive"); let stream = self.update_service.get_download_stream(release).await?; + let archive_path = tempdir.path().join(stream.url_path_basename().unwrap()); http::download_into_file(&archive_path, progress, stream).await?; // 2. Unzip the archive and get the binary diff --git a/cli/src/state.rs b/cli/src/state.rs index 296d1b535d8..3f6ae4f227c 100644 --- a/cli/src/state.rs +++ b/cli/src/state.rs @@ -15,6 +15,7 @@ use serde::{de::DeserializeOwned, Serialize}; use crate::{ constants::VSCODE_CLI_QUALITY, + download_cache::DownloadCache, util::errors::{wrap, AnyError, NoHomeForLauncherError, WrappedError}, }; @@ -22,6 +23,8 @@ const HOME_DIR_ALTS: [&str; 2] = ["$HOME", "~"]; #[derive(Clone)] pub struct LauncherPaths { + pub server_cache: DownloadCache, + pub cli_cache: DownloadCache, root: PathBuf, } @@ -95,14 +98,10 @@ where } /// Mutates persisted state. - pub fn update_with( - &self, - v: V, - mutator: fn(v: V, state: &mut T) -> R, - ) -> Result { + pub fn update(&self, mutator: impl FnOnce(&mut T) -> R) -> Result { let mut container = self.container.lock().unwrap(); let mut state = container.load_or_get(); - let r = mutator(v, &mut state); + let r = mutator(&mut state); container.save(state).map(|_| r) } } @@ -132,7 +131,15 @@ impl LauncherPaths { } pub fn new_without_replacements(root: PathBuf) -> LauncherPaths { - LauncherPaths { root } + // cleanup folders that existed before the new LRU strategy: + let _ = std::fs::remove_dir_all(root.join("server-insiders")); + let _ = std::fs::remove_dir_all(root.join("server-stable")); + + LauncherPaths { + server_cache: DownloadCache::new(root.join("servers")), + cli_cache: DownloadCache::new(root.join("cli")), + root, + } } /// Root directory for the server launcher diff --git a/cli/src/tunnels/code_server.rs b/cli/src/tunnels/code_server.rs index 677bbfc2546..1246e1c9441 100644 --- a/cli/src/tunnels/code_server.rs +++ b/cli/src/tunnels/code_server.rs @@ -2,31 +2,31 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use super::paths::{InstalledServer, LastUsedServers, ServerPaths}; +use super::paths::{InstalledServer, ServerPaths}; use crate::async_pipe::get_socket_name; use crate::constants::{ APPLICATION_NAME, EDITOR_WEB_URL, QUALITYLESS_PRODUCT_NAME, QUALITYLESS_SERVER_NAME, }; +use crate::download_cache::DownloadCache; use crate::options::{Quality, TelemetryLevel}; use crate::state::LauncherPaths; +use crate::tunnels::paths::{get_server_folder_name, SERVER_FOLDER_NAME}; use crate::update_service::{ unzip_downloaded_release, Platform, Release, TargetKind, UpdateService, }; use crate::util::command::{capture_command, kill_tree}; -use crate::util::errors::{ - wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError, -}; +use crate::util::errors::{wrap, AnyError, CodeError, ExtensionInstallFailed, WrappedError}; use crate::util::http::{self, BoxedHttp}; use crate::util::io::SilentCopyProgress; use crate::util::machine::process_exists; -use crate::{debug, info, log, span, spanf, trace, warning}; +use crate::{debug, info, log, spanf, trace, warning}; use lazy_static::lazy_static; use opentelemetry::KeyValue; use regex::Regex; use serde::Deserialize; use std::fs; use std::fs::File; -use std::io::{ErrorKind, Write}; +use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; @@ -42,8 +42,6 @@ lazy_static! { static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap(); } -const MAX_RETAINED_SERVERS: usize = 5; - #[derive(Clone, Debug, Default)] pub struct CodeServerArgs { pub host: Option, @@ -276,102 +274,6 @@ impl CodeServerOrigin { } } -async fn check_and_create_dir(path: &Path) -> Result<(), WrappedError> { - tokio::fs::create_dir_all(path) - .await - .map_err(|e| wrap(e, "error creating server directory"))?; - Ok(()) -} - -async fn install_server_if_needed( - log: &log::Logger, - paths: &ServerPaths, - release: &Release, - http: BoxedHttp, - existing_archive_path: Option, -) -> Result<(), AnyError> { - if paths.executable.exists() { - info!( - log, - "Found existing installation at {}", - paths.server_dir.display() - ); - return Ok(()); - } - - let tar_file_path = match existing_archive_path { - Some(p) => p, - None => spanf!( - log, - log.span("server.download"), - download_server(&paths.server_dir, release, log, http) - )?, - }; - - span!( - log, - log.span("server.extract"), - install_server(&tar_file_path, paths, log) - )?; - - Ok(()) -} - -async fn download_server( - path: &Path, - release: &Release, - log: &log::Logger, - http: BoxedHttp, -) -> Result { - let response = UpdateService::new(log.clone(), http) - .get_download_stream(release) - .await?; - - let mut save_path = path.to_owned(); - save_path.push("archive"); - - info!( - log, - "Downloading {} server -> {}", - QUALITYLESS_PRODUCT_NAME, - save_path.display() - ); - - http::download_into_file( - &save_path, - log.get_download_logger("server download progress:"), - response, - ) - .await?; - - Ok(save_path) -} - -fn install_server( - compressed_file: &Path, - paths: &ServerPaths, - log: &log::Logger, -) -> Result<(), AnyError> { - info!(log, "Setting up server..."); - - unzip_downloaded_release(compressed_file, &paths.server_dir, SilentCopyProgress())?; - - match fs::remove_file(compressed_file) { - Ok(()) => {} - Err(e) => { - if e.kind() != ErrorKind::NotFound { - return Err(AnyError::from(wrap(e, "error removing downloaded file"))); - } - } - } - - if !paths.executable.exists() { - return Err(AnyError::from(MissingEntrypointError())); - } - - Ok(()) -} - /// Ensures the given list of extensions are installed on the running server. async fn do_extension_install_on_running_server( start_script_path: &Path, @@ -406,7 +308,7 @@ async fn do_extension_install_on_running_server( pub struct ServerBuilder<'a> { logger: &'a log::Logger, server_params: &'a ResolvedServerParams, - last_used: LastUsedServers<'a>, + launcher_paths: &'a LauncherPaths, server_paths: ServerPaths, http: BoxedHttp, } @@ -421,7 +323,7 @@ impl<'a> ServerBuilder<'a> { Self { logger, server_params, - last_used: LastUsedServers::new(launcher_paths), + launcher_paths, server_paths: server_params .as_installed_server() .server_paths(launcher_paths), @@ -477,31 +379,54 @@ impl<'a> ServerBuilder<'a> { } /// Ensures the server is set up in the configured directory. - pub async fn setup(&self, existing_archive_path: Option) -> Result<(), AnyError> { + pub async fn setup(&self) -> Result<(), AnyError> { debug!( self.logger, "Installing and setting up {}...", QUALITYLESS_SERVER_NAME ); - check_and_create_dir(&self.server_paths.server_dir).await?; - install_server_if_needed( - self.logger, - &self.server_paths, - &self.server_params.release, - self.http.clone(), - existing_archive_path, - ) - .await?; - debug!(self.logger, "Server setup complete"); - match self.last_used.add(self.server_params.as_installed_server()) { - Err(e) => warning!(self.logger, "Error adding server to last used: {}", e), - Ok(count) if count > MAX_RETAINED_SERVERS => { - if let Err(e) = self.last_used.trim(self.logger, MAX_RETAINED_SERVERS) { - warning!(self.logger, "Error trimming old servers: {}", e); - } - } - Ok(_) => {} - } + let update_service = UpdateService::new(self.logger.clone(), self.http.clone()); + let name = get_server_folder_name( + self.server_params.release.quality, + &self.server_params.release.commit, + ); + + self.launcher_paths + .server_cache + .create(name, |target_dir| async move { + let tmpdir = + tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?; + + let response = update_service + .get_download_stream(&self.server_params.release) + .await?; + let archive_path = tmpdir.path().join(response.url_path_basename().unwrap()); + + info!( + self.logger, + "Downloading {} server -> {}", + QUALITYLESS_PRODUCT_NAME, + archive_path.display() + ); + + http::download_into_file( + &archive_path, + self.logger.get_download_logger("server download progress:"), + response, + ) + .await?; + + unzip_downloaded_release( + &archive_path, + &target_dir.join(SERVER_FOLDER_NAME), + SilentCopyProgress(), + )?; + + Ok(()) + }) + .await?; + + debug!(self.logger, "Server setup complete"); Ok(()) } @@ -836,3 +761,39 @@ pub fn print_listening(log: &log::Logger, tunnel_name: &str) { let message = &format!("\nOpen this link in your browser {}\n", addr); log.result(message); } + +pub async fn download_cli_into_cache( + cache: &DownloadCache, + release: &Release, + update_service: &UpdateService, +) -> Result { + let cache_name = format!( + "{}-{}-{}", + release.quality, release.commit, release.platform + ); + let cli_dir = cache + .create(&cache_name, |target_dir| async move { + let tmpdir = + tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?; + let response = update_service.get_download_stream(release).await?; + + let name = response.url_path_basename().unwrap(); + let archive_path = tmpdir.path().join(name); + http::download_into_file(&archive_path, SilentCopyProgress(), response).await?; + unzip_downloaded_release(&archive_path, &target_dir, SilentCopyProgress())?; + Ok(()) + }) + .await?; + + let cli = std::fs::read_dir(cli_dir) + .map_err(|_| CodeError::CorruptDownload("could not read cli folder contents"))? + .next(); + + match cli { + Some(Ok(cli)) => Ok(cli.path()), + _ => { + let _ = cache.delete(&cache_name); + Err(CodeError::CorruptDownload("cli directory is empty").into()) + } + } +} diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index bf8ce00380f..ef1a8d03284 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -3,7 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ use crate::async_pipe::get_socket_rw_stream; -use crate::constants::CONTROL_PORT; +use crate::constants::{CONTROL_PORT, PRODUCT_NAME_LONG}; use crate::log; use crate::msgpack_rpc::U32PrefixedCodec; use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization}; @@ -11,7 +11,7 @@ use crate::self_update::SelfUpdate; use crate::state::LauncherPaths; use crate::tunnels::protocol::HttpRequestParams; use crate::tunnels::socket_signal::CloseReason; -use crate::update_service::{Platform, UpdateService}; +use crate::update_service::{Platform, Release, TargetKind, UpdateService}; use crate::util::errors::{ wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError, @@ -23,6 +23,8 @@ use crate::util::io::SilentCopyProgress; use crate::util::is_integrated_cli; use crate::util::sync::{new_barrier, Barrier}; +use futures::stream::FuturesUnordered; +use futures::FutureExt; use opentelemetry::trace::SpanKind; use opentelemetry::KeyValue; use std::collections::HashMap; @@ -37,16 +39,17 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, D use tokio::sync::{mpsc, Mutex}; use super::code_server::{ - AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, SocketCodeServer, + download_cli_into_cache, AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, + SocketCodeServer, }; use super::dev_tunnels::ActiveTunnel; use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ - CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, ForwardParams, - ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, - ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, - UpdateResult, VersionParams, + AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, + ForwardParams, ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, + ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, + UnforwardParams, UpdateParams, UpdateResult, VersionParams, }; use super::server_bridge::ServerBridge; use super::server_multiplexer::ServerMultiplexer; @@ -284,8 +287,18 @@ async fn process_socket( rpc.register_async("unforward", |p: UnforwardParams, c| async move { handle_unforward(&c.log, &c.port_forwarding, p).await }); - rpc.register_duplex("spawn", |stream, p: SpawnParams, c| async move { - handle_spawn(&c.log, stream, p).await + rpc.register_async("acquire_cli", |p: AcquireCliParams, c| async move { + handle_acquire_cli(&c.launcher_paths, &c.http, &c.log, p).await + }); + rpc.register_duplex("spawn", 3, |mut streams, p: SpawnParams, c| async move { + handle_spawn( + &c.log, + p, + Some(streams.remove(0)), + Some(streams.remove(0)), + Some(streams.remove(0)), + ) + .await }); rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| { if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) { @@ -507,7 +520,7 @@ async fn handle_serve( Some(AnyCodeServer::Socket(s)) => s, Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())), None => { - $sb.setup(None).await?; + $sb.setup().await?; $sb.listen_on_default_socket().await? } } @@ -734,82 +747,106 @@ async fn handle_call_server_http( }) } -async fn handle_spawn( +async fn handle_acquire_cli( + paths: &LauncherPaths, + http: &Arc, log: &log::Logger, - mut duplex: DuplexStream, - params: SpawnParams, + params: AcquireCliParams, ) -> Result { + let update_service = UpdateService::new(log.clone(), http.clone()); + + let release = match params.commit_id { + Some(commit) => Release { + name: format!("{} CLI", PRODUCT_NAME_LONG), + commit, + platform: params.platform, + quality: params.quality, + target: TargetKind::Cli, + }, + None => { + update_service + .get_latest_commit(params.platform, TargetKind::Cli, params.quality) + .await? + } + }; + + let cli = download_cli_into_cache(&paths.cli_cache, &release, &update_service).await?; + let file = tokio::fs::File::open(cli) + .await + .map_err(|e| wrap(e, "error opening cli file"))?; + + handle_spawn::<_, DuplexStream>(log, params.spawn, Some(file), None, None).await +} + +async fn handle_spawn( + log: &log::Logger, + params: SpawnParams, + stdin: Option, + stdout: Option, + stderr: Option, +) -> Result +where + Stdin: AsyncRead + Unpin + Send, + StdoutAndErr: AsyncWrite + Unpin + Send, +{ debug!( log, "requested to spawn {} with args {:?}", params.command, params.args ); - let mut p = tokio::process::Command::new(¶ms.command) - .args(¶ms.args) - .envs(¶ms.env) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .map_err(CodeError::ProcessSpawnFailed)?; - - let mut stdout = p.stdout.take().unwrap(); - let mut stderr = p.stderr.take().unwrap(); - let mut stdin = p.stdin.take().unwrap(); - let (tx, mut rx) = mpsc::channel(4); - - macro_rules! copy_stream_to { - ($target:expr) => { - let tx = tx.clone(); - tokio::spawn(async move { - let mut buf = vec![0; 4096]; - loop { - let n = match $target.read(&mut buf).await { - Ok(0) | Err(_) => return, - Ok(n) => n, - }; - if !tx.send(buf[..n].to_vec()).await.is_ok() { - return; - } - } - }); + macro_rules! pipe_if_some { + ($e: expr) => { + if $e.is_some() { + Stdio::piped() + } else { + Stdio::null() + } }; } - copy_stream_to!(stdout); - copy_stream_to!(stderr); + let mut p = tokio::process::Command::new(¶ms.command) + .args(¶ms.args) + .envs(¶ms.env) + .stdin(pipe_if_some!(stdin)) + .stdout(pipe_if_some!(stdout)) + .stderr(pipe_if_some!(stderr)) + .spawn() + .map_err(CodeError::ProcessSpawnFailed)?; + + let futs = FuturesUnordered::new(); + if let (Some(mut a), Some(mut b)) = (p.stdout.take(), stdout) { + futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed()); + } + if let (Some(mut a), Some(mut b)) = (p.stderr.take(), stderr) { + futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed()); + } + if let (Some(mut b), Some(mut a)) = (p.stdin.take(), stdin) { + futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed()); + } - let mut stdin_buf = vec![0; 4096]; let closed = p.wait(); pin!(closed); - loop { - tokio::select! { - Ok(n) = duplex.read(&mut stdin_buf) => { - let _ = stdin.write_all(&stdin_buf[..n]).await; - }, - Some(m) = rx.recv() => { - let _ = duplex.write_all(&m).await; - }, - r = &mut closed => { - let r = match r { - Ok(e) => SpawnResult { - message: e.to_string(), - exit_code: e.code().unwrap_or(-1), - }, - Err(e) => SpawnResult { - message: e.to_string(), - exit_code: -1, - }, - }; + let r = tokio::select! { + _ = futures::future::join_all(futs) => closed.await, + r = &mut closed => r + }; - debug!( - log, - "spawned command {} exited with code {}", params.command, r.exit_code - ); + let r = match r { + Ok(e) => SpawnResult { + message: e.to_string(), + exit_code: e.code().unwrap_or(-1), + }, + Err(e) => SpawnResult { + message: e.to_string(), + exit_code: -1, + }, + }; - return Ok(r) - }, - } - } + debug!( + log, + "spawned command {} exited with code {}", params.command, r.exit_code + ); + + Ok(r) } diff --git a/cli/src/tunnels/paths.rs b/cli/src/tunnels/paths.rs index cdf6cef6f51..fa06db5dd7a 100644 --- a/cli/src/tunnels/paths.rs +++ b/cli/src/tunnels/paths.rs @@ -11,19 +11,15 @@ use std::{ use serde::{Deserialize, Serialize}; use crate::{ - log, options, - state::{LauncherPaths, PersistedState}, + options::{self, Quality}, + state::LauncherPaths, util::{ errors::{wrap, AnyError, WrappedError}, machine, }, }; -const INSIDERS_INSTALL_FOLDER: &str = "server-insiders"; -const STABLE_INSTALL_FOLDER: &str = "server-stable"; -const EXPLORATION_INSTALL_FOLDER: &str = "server-exploration"; -const PIDFILE_SUFFIX: &str = ".pid"; -const LOGFILE_SUFFIX: &str = ".log"; +pub const SERVER_FOLDER_NAME: &str = "server"; pub struct ServerPaths { // Directory into which the server is downloaded @@ -93,76 +89,27 @@ pub struct InstalledServer { impl InstalledServer { /// Gets path information about where a specific server should be stored. pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths { - let base_folder = self.get_install_folder(p); - let server_dir = base_folder.join("bin").join(&self.commit); + let server_dir = self.get_install_folder(p); ServerPaths { executable: server_dir + .join(SERVER_FOLDER_NAME) .join("bin") .join(self.quality.server_entrypoint()), + logfile: server_dir.join("log.txt"), + pidfile: server_dir.join("pid.txt"), server_dir, - logfile: base_folder.join(format!(".{}{}", self.commit, LOGFILE_SUFFIX)), - pidfile: base_folder.join(format!(".{}{}", self.commit, PIDFILE_SUFFIX)), } } fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf { - let name = match self.quality { - options::Quality::Insiders => INSIDERS_INSTALL_FOLDER, - options::Quality::Exploration => EXPLORATION_INSTALL_FOLDER, - options::Quality::Stable => STABLE_INSTALL_FOLDER, - }; - - p.root().join(if !self.headless { - format!("{}-web", name) + p.server_cache.path().join(if !self.headless { + format!("{}-web", get_server_folder_name(self.quality, &self.commit)) } else { - name.to_string() + get_server_folder_name(self.quality, &self.commit) }) } } -pub struct LastUsedServers<'a> { - state: PersistedState>, - paths: &'a LauncherPaths, -} - -impl<'a> LastUsedServers<'a> { - pub fn new(paths: &'a LauncherPaths) -> LastUsedServers { - LastUsedServers { - state: PersistedState::new(paths.root().join("last-used-servers.json")), - paths, - } - } - - /// Adds a server as having been used most recently. Returns the number of retained server. - pub fn add(&self, server: InstalledServer) -> Result { - self.state.update_with(server, |server, l| { - if let Some(index) = l.iter().position(|s| s == &server) { - l.remove(index); - } - l.insert(0, server); - l.len() - }) - } - - /// Trims so that at most `max_servers` are saved on disk. - pub fn trim(&self, log: &log::Logger, max_servers: usize) -> Result<(), WrappedError> { - let mut servers = self.state.load(); - while servers.len() > max_servers { - let server = servers.pop().unwrap(); - debug!( - log, - "Removing old server {}/{}", - server.quality.get_machine_name(), - server.commit - ); - let server_paths = server.server_paths(self.paths); - server_paths.delete()?; - } - self.state.save(servers)?; - Ok(()) - } -} - /// Prunes servers not currently running, and returns the deleted servers. pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result, AnyError> { get_all_servers(launcher_paths) @@ -177,40 +124,31 @@ pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result Vec { let mut servers: Vec = vec![]; - let mut server = InstalledServer { - commit: "".to_owned(), - headless: false, - quality: options::Quality::Stable, - }; + if let Ok(children) = read_dir(lp.server_cache.path()) { + for child in children.flatten() { + let fname = child.file_name(); + let fname = fname.to_string_lossy(); + let (quality, commit) = match fname.split_once('-') { + Some(r) => r, + None => continue, + }; - add_server_paths_in_folder(lp, &server, &mut servers); + let quality = match options::Quality::try_from(quality) { + Ok(q) => q, + Err(_) => continue, + }; - server.headless = true; - add_server_paths_in_folder(lp, &server, &mut servers); - - server.headless = false; - server.quality = options::Quality::Insiders; - add_server_paths_in_folder(lp, &server, &mut servers); - - server.headless = true; - add_server_paths_in_folder(lp, &server, &mut servers); + servers.push(InstalledServer { + quality, + commit: commit.to_string(), + headless: true, + }); + } + } servers } -fn add_server_paths_in_folder( - lp: &LauncherPaths, - server: &InstalledServer, - servers: &mut Vec, -) { - let dir = server.get_install_folder(lp).join("bin"); - if let Ok(children) = read_dir(dir) { - for bin in children.flatten() { - servers.push(InstalledServer { - quality: server.quality, - headless: server.headless, - commit: bin.file_name().to_string_lossy().into(), - }); - } - } +pub fn get_server_folder_name(quality: Quality, commit: &str) -> String { + format!("{}-{}", quality, commit) } diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index 89f9c3acb28..2093b1d1896 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -7,6 +7,7 @@ use std::collections::HashMap; use crate::{ constants::{PROTOCOL_VERSION, VSCODE_CLI_VERSION}, options::Quality, + update_service::Platform, }; use serde::{Deserialize, Serialize}; @@ -166,6 +167,15 @@ pub struct SpawnParams { pub env: HashMap, } +#[derive(Deserialize)] +pub struct AcquireCliParams { + pub platform: Platform, + pub quality: Quality, + pub commit_id: Option, + #[serde(flatten)] + pub spawn: SpawnParams, +} + #[derive(Serialize)] pub struct SpawnResult { pub message: String, diff --git a/cli/src/tunnels/wsl_server.rs b/cli/src/tunnels/wsl_server.rs index 69d7eb94389..3eafae92f3a 100644 --- a/cli/src/tunnels/wsl_server.rs +++ b/cli/src/tunnels/wsl_server.rs @@ -151,7 +151,7 @@ async fn handle_serve( Some(AnyCodeServer::Socket(s)) => s, Some(_) => return Err(MismatchedLaunchModeError().into()), None => { - sb.setup(Some(params.archive_path.into())).await?; + sb.setup().await?; sb.listen_on_default_socket().await? } }; diff --git a/cli/src/update_service.rs b/cli/src/update_service.rs index e56d3781804..b03d8ea5963 100644 --- a/cli/src/update_service.rs +++ b/cli/src/update_service.rs @@ -3,9 +3,9 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use std::{fmt, path::Path}; +use std::{ffi::OsStr, fmt, path::Path}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{ constants::VSCODE_CLI_UPDATE_ENDPOINT, @@ -14,6 +14,7 @@ use crate::{ errors::{AnyError, CodeError, UpdatesNotConfigured, WrappedError}, http::{BoxedHttp, SimpleResponse}, io::ReportCopyProgress, + tar, zipper, }, }; @@ -175,14 +176,9 @@ pub fn unzip_downloaded_release( where T: ReportCopyProgress, { - #[cfg(any(target_os = "windows", target_os = "macos"))] - { - use crate::util::zipper; + if compressed_file.extension() == Some(OsStr::new("zip")) { zipper::unzip_file(compressed_file, target_dir, reporter) - } - #[cfg(target_os = "linux")] - { - use crate::util::tar; + } else { tar::decompress_tarball(compressed_file, target_dir, reporter) } } @@ -206,7 +202,7 @@ impl TargetKind { } } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize)] pub enum Platform { LinuxAlpineX64, LinuxAlpineARM64, diff --git a/cli/src/util.rs b/cli/src/util.rs index 1478394d696..f45f74de13c 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -17,9 +17,5 @@ pub mod sync; pub use is_integrated::*; pub mod app_lock; pub mod file_lock; - -#[cfg(target_os = "linux")] pub mod tar; - -#[cfg(any(target_os = "windows", target_os = "macos"))] pub mod zipper; diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index b5ca1066046..9ab421d3301 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -477,9 +477,11 @@ pub enum CodeError { UnsupportedPlatform(String), #[error("This machine not meet {name}'s prerequisites, expected either...: {bullets}")] PrerequisitesFailed { name: &'static str, bullets: String }, - #[error("failed to spawn process: {0:?}")] - ProcessSpawnFailed(std::io::Error) + ProcessSpawnFailed(std::io::Error), + + #[error("download appears corrupted, please retry ({0})")] + CorruptDownload(&'static str), } makeAnyError!( diff --git a/cli/src/util/http.rs b/cli/src/util/http.rs index 953dba678c3..e49120578a7 100644 --- a/cli/src/util/http.rs +++ b/cli/src/util/http.rs @@ -59,14 +59,23 @@ pub struct SimpleResponse { pub status_code: StatusCode, pub headers: HeaderMap, pub read: Pin>, - pub url: String, + pub url: Option, } impl SimpleResponse { - pub fn generic_error(url: String) -> Self { + pub fn url_path_basename(&self) -> Option { + self.url.as_ref().and_then(|u| { + u.path_segments() + .and_then(|s| s.last().map(|s| s.to_owned())) + }) + } +} + +impl SimpleResponse { + pub fn generic_error(url: &str) -> Self { let (_, rx) = mpsc::unbounded_channel(); SimpleResponse { - url, + url: url::Url::parse(url).ok(), status_code: StatusCode::INTERNAL_SERVER_ERROR, headers: HeaderMap::new(), read: Box::pin(DelegatedReader::new(rx)), @@ -79,7 +88,10 @@ impl SimpleResponse { self.read.read_to_string(&mut body).await.ok(); StatusError { - url: self.url, + url: self + .url + .map(|u| u.to_string()) + .unwrap_or_else(|| "".to_owned()), status_code: self.status_code.as_u16(), body, } @@ -97,7 +109,7 @@ impl SimpleResponse { .map_err(|e| wrap(e, "error reading response"))?; let t = serde_json::from_slice(&buf) - .map_err(|e| wrap(e, format!("error decoding json from {}", self.url)))?; + .map_err(|e| wrap(e, format!("error decoding json from {:?}", self.url)))?; Ok(t) } @@ -161,7 +173,7 @@ impl SimpleHttp for ReqwestSimpleHttp { Ok(SimpleResponse { status_code: res.status(), headers: res.headers().clone(), - url, + url: Some(res.url().clone()), read: Box::pin( res.bytes_stream() .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) @@ -250,7 +262,7 @@ impl SimpleHttp for DelegatedSimpleHttp { .await; if sent.is_err() { - return Ok(SimpleResponse::generic_error(url)); // sender shut down + return Ok(SimpleResponse::generic_error(&url)); // sender shut down } match rx.recv().await { @@ -275,16 +287,16 @@ impl SimpleHttp for DelegatedSimpleHttp { } Ok(SimpleResponse { - url, + url: url::Url::parse(&url).ok(), status_code: StatusCode::from_u16(status_code) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), headers: headers_map, read: Box::pin(DelegatedReader::new(rx)), }) } - Some(DelegatedHttpEvent::End) => Ok(SimpleResponse::generic_error(url)), + Some(DelegatedHttpEvent::End) => Ok(SimpleResponse::generic_error(&url)), Some(_) => panic!("expected initresponse as first message from delegated http"), - None => Ok(SimpleResponse::generic_error(url)), // sender shut down + None => Ok(SimpleResponse::generic_error(&url)), // sender shut down } } }