From b784bcdd32cd3b467a9ce05c5402ede34efa5644 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Fri, 23 Sep 2022 14:17:01 -0700 Subject: [PATCH] cli: use hard tabs to align with vscode style --- .vscode/settings.json | 1 - cli/build.rs | 61 +- cli/rustfmt.toml | 1 + cli/src/auth.rs | 856 +++++++-------- cli/src/bin/code-tunnel/main.rs | 150 +-- cli/src/bin/code/legacy_args.rs | 402 +++---- cli/src/bin/code/main.rs | 240 ++-- cli/src/commands/args.rs | 748 ++++++------- cli/src/commands/context.rs | 8 +- cli/src/commands/output.rs | 168 +-- cli/src/commands/tunnels.rs | 374 +++---- cli/src/commands/version.rs | 80 +- cli/src/constants.rs | 20 +- cli/src/desktop/version_manager.rs | 736 ++++++------- cli/src/lib.rs | 2 +- cli/src/log.rs | 420 +++---- cli/src/options.rs | 144 +-- cli/src/state.rs | 194 ++-- cli/src/tunnels.rs | 2 +- cli/src/tunnels/code_server.rs | 1064 +++++++++--------- cli/src/tunnels/control_server.rs | 1072 +++++++++--------- cli/src/tunnels/dev_tunnels.rs | 1269 +++++++++++----------- cli/src/tunnels/legal.rs | 64 +- cli/src/tunnels/name_generator.rs | 402 +++---- cli/src/tunnels/paths.rs | 306 +++--- cli/src/tunnels/port_forwarder.rs | 172 +-- cli/src/tunnels/protocol.rs | 108 +- cli/src/tunnels/server_bridge_unix.rs | 102 +- cli/src/tunnels/server_bridge_windows.rs | 200 ++-- cli/src/tunnels/service.rs | 72 +- cli/src/tunnels/service_windows.rs | 408 +++---- cli/src/update.rs | 148 +-- cli/src/update_service.rs | 416 +++---- cli/src/util/command.rs | 96 +- cli/src/util/errors.rs | 366 +++---- cli/src/util/http.rs | 34 +- cli/src/util/input.rs | 72 +- cli/src/util/io.rs | 58 +- cli/src/util/machine.rs | 90 +- cli/src/util/prereqs.rs | 394 +++---- cli/src/util/sync.rs | 100 +- cli/src/util/tar.rs | 60 +- cli/src/util/zipper.rs | 206 ++-- 43 files changed, 5942 insertions(+), 5944 deletions(-) create mode 100644 cli/rustfmt.toml diff --git a/.vscode/settings.json b/.vscode/settings.json index 184042b4f2e..33056829dde 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -90,7 +90,6 @@ "[rust]": { "editor.defaultFormatter": "rust-lang.rust-analyzer", "editor.formatOnSave": true, - "editor.insertSpaces": true }, "typescript.tsc.autoDetect": "off", "testing.autoRun.mode": "rerun", diff --git a/cli/build.rs b/cli/build.rs index 278b60fe1c4..797e612ec4b 100644 --- a/cli/build.rs +++ b/cli/build.rs @@ -3,57 +3,56 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ - const FILE_HEADER: &str = "/*---------------------------------------------------------------------------------------------\n * Copyright (c) Microsoft Corporation. All rights reserved.\n * Licensed under the MIT License. See License.txt in the project root for license information.\n *--------------------------------------------------------------------------------------------*/"; use std::{env, fs, io, path::PathBuf, process}; fn main() { - let files = enumerate_source_files().expect("expected to enumerate files"); - ensure_file_headers(&files).expect("expected to ensure file headers"); + let files = enumerate_source_files().expect("expected to enumerate files"); + ensure_file_headers(&files).expect("expected to ensure file headers"); } fn ensure_file_headers(files: &[PathBuf]) -> Result<(), io::Error> { - let mut ok = true; + let mut ok = true; let crlf_header_str = str::replace(FILE_HEADER, "\n", "\r\n"); let crlf_header = crlf_header_str.as_bytes(); let lf_header = FILE_HEADER.as_bytes(); - for file in files { - let contents = fs::read(file)?; + for file in files { + let contents = fs::read(file)?; - if !(contents.starts_with(lf_header) || contents.starts_with(crlf_header)) { - eprintln!("File missing copyright header: {}", file.display()); - ok = false; - } - } + if !(contents.starts_with(lf_header) || contents.starts_with(crlf_header)) { + eprintln!("File missing copyright header: {}", file.display()); + ok = false; + } + } - if !ok { - process::exit(1); - } + if !ok { + process::exit(1); + } - Ok(()) + Ok(()) } /// Gets all "rs" files in the source directory fn enumerate_source_files() -> Result, io::Error> { - let mut files = vec![]; - let mut queue = vec![]; + let mut files = vec![]; + let mut queue = vec![]; - let current_dir = env::current_dir()?.join("src"); - queue.push(current_dir); + let current_dir = env::current_dir()?.join("src"); + queue.push(current_dir); - while !queue.is_empty() { - for entry in fs::read_dir(queue.pop().unwrap())? { - let entry = entry?; - let ftype = entry.file_type()?; - if ftype.is_dir() { - queue.push(entry.path()); - } else if ftype.is_file() && entry.file_name().to_string_lossy().ends_with(".rs") { - files.push(entry.path()); - } - } - } + while !queue.is_empty() { + for entry in fs::read_dir(queue.pop().unwrap())? { + let entry = entry?; + let ftype = entry.file_type()?; + if ftype.is_dir() { + queue.push(entry.path()); + } else if ftype.is_file() && entry.file_name().to_string_lossy().ends_with(".rs") { + files.push(entry.path()); + } + } + } - Ok(files) + Ok(files) } diff --git a/cli/rustfmt.toml b/cli/rustfmt.toml new file mode 100644 index 00000000000..218e203215e --- /dev/null +++ b/cli/rustfmt.toml @@ -0,0 +1 @@ +hard_tabs = true diff --git a/cli/src/auth.rs b/cli/src/auth.rs index a9b1e7bee08..55bc56c4817 100644 --- a/cli/src/auth.rs +++ b/cli/src/auth.rs @@ -4,15 +4,15 @@ *--------------------------------------------------------------------------------------------*/ use crate::{ - constants::get_default_user_agent, - info, log, - state::{LauncherPaths, PersistedState}, - trace, - util::{ - errors::{wrap, AnyError, RefreshTokenNotAvailableError, StatusError, WrappedError}, - input::prompt_options, - }, - warning, + constants::get_default_user_agent, + info, log, + state::{LauncherPaths, PersistedState}, + trace, + util::{ + errors::{wrap, AnyError, RefreshTokenNotAvailableError, StatusError, WrappedError}, + input::prompt_options, + }, + warning, }; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; @@ -21,160 +21,160 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{cell::Cell, fmt::Display, path::PathBuf, sync::Arc}; use tokio::time::sleep; use tunnels::{ - contracts::PROD_FIRST_PARTY_APP_ID, - management::{Authorization, AuthorizationProvider, HttpError}, + contracts::PROD_FIRST_PARTY_APP_ID, + management::{Authorization, AuthorizationProvider, HttpError}, }; #[derive(Deserialize)] struct DeviceCodeResponse { - device_code: String, - user_code: String, - message: Option, - verification_uri: String, - expires_in: i64, + device_code: String, + user_code: String, + message: Option, + verification_uri: String, + expires_in: i64, } #[derive(Deserialize)] struct AuthenticationResponse { - access_token: String, - refresh_token: Option, - expires_in: Option, + access_token: String, + refresh_token: Option, + expires_in: Option, } #[derive(clap::ArgEnum, Serialize, Deserialize, Debug, Clone, Copy)] pub enum AuthProvider { - Microsoft, - Github, + Microsoft, + Github, } impl Display for AuthProvider { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AuthProvider::Microsoft => write!(f, "Microsoft Account"), - AuthProvider::Github => write!(f, "Github Account"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthProvider::Microsoft => write!(f, "Microsoft Account"), + AuthProvider::Github => write!(f, "Github Account"), + } + } } impl AuthProvider { - pub fn client_id(&self) -> &'static str { - match self { - AuthProvider::Microsoft => "aebc6443-996d-45c2-90f0-388ff96faa56", - AuthProvider::Github => "01ab8ac9400c4e429b23", - } - } + pub fn client_id(&self) -> &'static str { + match self { + AuthProvider::Microsoft => "aebc6443-996d-45c2-90f0-388ff96faa56", + AuthProvider::Github => "01ab8ac9400c4e429b23", + } + } - pub fn code_uri(&self) -> &'static str { - match self { - AuthProvider::Microsoft => { - "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode" - } - AuthProvider::Github => "https://github.com/login/device/code", - } - } + pub fn code_uri(&self) -> &'static str { + match self { + AuthProvider::Microsoft => { + "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode" + } + AuthProvider::Github => "https://github.com/login/device/code", + } + } - pub fn grant_uri(&self) -> &'static str { - match self { - AuthProvider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token", - AuthProvider::Github => "https://github.com/login/oauth/access_token", - } - } + pub fn grant_uri(&self) -> &'static str { + match self { + AuthProvider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token", + AuthProvider::Github => "https://github.com/login/oauth/access_token", + } + } - pub fn get_default_scopes(&self) -> String { - match self { - AuthProvider::Microsoft => format!( - "{}/.default+offline_access+profile+openid", - PROD_FIRST_PARTY_APP_ID - ), - AuthProvider::Github => "read:user+read:org".to_string(), - } - } + pub fn get_default_scopes(&self) -> String { + match self { + AuthProvider::Microsoft => format!( + "{}/.default+offline_access+profile+openid", + PROD_FIRST_PARTY_APP_ID + ), + AuthProvider::Github => "read:user+read:org".to_string(), + } + } } #[derive(Serialize, Deserialize, Debug, Clone)] pub struct StoredCredential { - #[serde(rename = "p")] - provider: AuthProvider, - #[serde(rename = "a")] - access_token: String, - #[serde(rename = "r")] - refresh_token: Option, - #[serde(rename = "e")] - expires_at: Option>, + #[serde(rename = "p")] + provider: AuthProvider, + #[serde(rename = "a")] + access_token: String, + #[serde(rename = "r")] + refresh_token: Option, + #[serde(rename = "e")] + expires_at: Option>, } impl StoredCredential { - pub async fn is_expired(&self, client: &reqwest::Client) -> bool { - match self.provider { - AuthProvider::Microsoft => self - .expires_at - .map(|e| Utc::now() + chrono::Duration::minutes(5) > e) - .unwrap_or(false), + pub async fn is_expired(&self, client: &reqwest::Client) -> bool { + match self.provider { + AuthProvider::Microsoft => self + .expires_at + .map(|e| Utc::now() + chrono::Duration::minutes(5) > e) + .unwrap_or(false), - // Make an auth request to Github. Mark the credential as expired - // only on a verifiable 4xx code. We don't error on any failed - // request since then a drop in connection could "require" a refresh - AuthProvider::Github => client - .get("https://api.github.com/user") - .header("Authorization", format!("token {}", self.access_token)) - .header("User-Agent", get_default_user_agent()) - .send() - .await - .map(|r| r.status().is_client_error()) - .unwrap_or(false), - } - } + // Make an auth request to Github. Mark the credential as expired + // only on a verifiable 4xx code. We don't error on any failed + // request since then a drop in connection could "require" a refresh + AuthProvider::Github => client + .get("https://api.github.com/user") + .header("Authorization", format!("token {}", self.access_token)) + .header("User-Agent", get_default_user_agent()) + .send() + .await + .map(|r| r.status().is_client_error()) + .unwrap_or(false), + } + } - fn from_response(auth: AuthenticationResponse, provider: AuthProvider) -> Self { - StoredCredential { - provider, - access_token: auth.access_token, - refresh_token: auth.refresh_token, - expires_at: auth.expires_in.map(|e| Utc::now() + Duration::seconds(e)), - } - } + fn from_response(auth: AuthenticationResponse, provider: AuthProvider) -> Self { + StoredCredential { + provider, + access_token: auth.access_token, + refresh_token: auth.refresh_token, + expires_at: auth.expires_in.map(|e| Utc::now() + Duration::seconds(e)), + } + } } struct StorageWithLastRead { - storage: Box, - last_read: Cell, WrappedError>>, + storage: Box, + last_read: Cell, WrappedError>>, } #[derive(Clone)] pub struct Auth { - client: reqwest::Client, - log: log::Logger, - file_storage_path: PathBuf, - storage: Arc>>, + client: reqwest::Client, + log: log::Logger, + file_storage_path: PathBuf, + storage: Arc>>, } trait StorageImplementation: Send + Sync { - fn read(&mut self) -> Result, WrappedError>; - fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>; - fn clear(&mut self) -> Result<(), WrappedError>; + fn read(&mut self) -> Result, WrappedError>; + fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>; + fn clear(&mut self) -> Result<(), WrappedError>; } // unseal decrypts and deserializes the value fn seal(value: &T) -> String where - T: Serialize + ?Sized, + T: Serialize + ?Sized, { - let dec = serde_json::to_string(value).expect("expected to serialize"); - encrypt(&dec) + let dec = serde_json::to_string(value).expect("expected to serialize"); + encrypt(&dec) } // unseal decrypts and deserializes the value fn unseal(value: &str) -> Option where - T: DeserializeOwned, + T: DeserializeOwned, { - // small back-compat for old unencrypted values - if let Ok(v) = serde_json::from_str::(value) { - return Some(v); - } + // small back-compat for old unencrypted values + if let Ok(v) = serde_json::from_str::(value) { + return Some(v); + } - let dec = decrypt(value)?; - serde_json::from_str::(&dec).ok() + let dec = decrypt(value)?; + serde_json::from_str::(&dec).ok() } #[cfg(target_os = "windows")] @@ -186,409 +186,409 @@ const CONTINUE_MARKER: &str = ""; #[derive(Default)] struct KeyringStorage { - // keywring storage can be split into multiple entries due to entry length limits - // on Windows https://github.com/microsoft/vscode-cli/issues/358 - entries: Vec, + // keywring storage can be split into multiple entries due to entry length limits + // on Windows https://github.com/microsoft/vscode-cli/issues/358 + entries: Vec, } macro_rules! get_next_entry { - ($self: expr, $i: expr) => { - match $self.entries.get($i) { - Some(e) => e, - None => { - let e = keyring::Entry::new("vscode-cli", &format!("vscode-cli-{}", $i)); - $self.entries.push(e); - $self.entries.last().unwrap() - } - } - }; + ($self: expr, $i: expr) => { + match $self.entries.get($i) { + Some(e) => e, + None => { + let e = keyring::Entry::new("vscode-cli", &format!("vscode-cli-{}", $i)); + $self.entries.push(e); + $self.entries.last().unwrap() + } + } + }; } impl StorageImplementation for KeyringStorage { - fn read(&mut self) -> Result, WrappedError> { - let mut str = String::new(); + fn read(&mut self) -> Result, WrappedError> { + let mut str = String::new(); - for i in 0.. { - let entry = get_next_entry!(self, i); - let next_chunk = match entry.get_password() { - Ok(value) => value, - Err(keyring::Error::NoEntry) => return Ok(None), // missing entries? - Err(e) => return Err(wrap(e, "error reading keyring")), - }; + for i in 0.. { + let entry = get_next_entry!(self, i); + let next_chunk = match entry.get_password() { + Ok(value) => value, + Err(keyring::Error::NoEntry) => return Ok(None), // missing entries? + Err(e) => return Err(wrap(e, "error reading keyring")), + }; - if next_chunk.ends_with(CONTINUE_MARKER) { - str.push_str(&next_chunk[..next_chunk.len() - CONTINUE_MARKER.len()]); - } else { - str.push_str(&next_chunk); - break; - } - } + if next_chunk.ends_with(CONTINUE_MARKER) { + str.push_str(&next_chunk[..next_chunk.len() - CONTINUE_MARKER.len()]); + } else { + str.push_str(&next_chunk); + break; + } + } - Ok(unseal(&str)) - } + Ok(unseal(&str)) + } - fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> { - let sealed = seal(&value); - let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len(); + fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> { + let sealed = seal(&value); + let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len(); - for i in (0..sealed.len()).step_by(step_size) { - let entry = get_next_entry!(self, i / step_size); + for i in (0..sealed.len()).step_by(step_size) { + let entry = get_next_entry!(self, i / step_size); - let cutoff = i + step_size; - let stored = if cutoff <= sealed.len() { - let mut part = sealed[i..cutoff].to_string(); - part.push_str(CONTINUE_MARKER); - entry.set_password(&part) - } else { - entry.set_password(&sealed[i..]) - }; + let cutoff = i + step_size; + let stored = if cutoff <= sealed.len() { + let mut part = sealed[i..cutoff].to_string(); + part.push_str(CONTINUE_MARKER); + entry.set_password(&part) + } else { + entry.set_password(&sealed[i..]) + }; - if let Err(e) = stored { - return Err(wrap(e, "error updating keyring")); - } - } + if let Err(e) = stored { + return Err(wrap(e, "error updating keyring")); + } + } - Ok(()) - } + Ok(()) + } - fn clear(&mut self) -> Result<(), WrappedError> { - self.read().ok(); // make sure component parts are available - for entry in self.entries.iter() { - entry - .delete_password() - .map_err(|e| wrap(e, "error updating keyring"))?; - } - self.entries.clear(); + fn clear(&mut self) -> Result<(), WrappedError> { + self.read().ok(); // make sure component parts are available + for entry in self.entries.iter() { + entry + .delete_password() + .map_err(|e| wrap(e, "error updating keyring"))?; + } + self.entries.clear(); - Ok(()) - } + Ok(()) + } } struct FileStorage(PersistedState>); impl StorageImplementation for FileStorage { - fn read(&mut self) -> Result, WrappedError> { - Ok(self.0.load().and_then(|s| unseal(&s))) - } + fn read(&mut self) -> Result, WrappedError> { + Ok(self.0.load().and_then(|s| unseal(&s))) + } - fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> { - self.0.save(Some(seal(&value))) - } + fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> { + self.0.save(Some(seal(&value))) + } - fn clear(&mut self) -> Result<(), WrappedError> { - self.0.save(None) - } + fn clear(&mut self) -> Result<(), WrappedError> { + self.0.save(None) + } } impl Auth { - pub fn new(paths: &LauncherPaths, log: log::Logger) -> Auth { - Auth { - log, - client: reqwest::Client::new(), - file_storage_path: paths.root().join("token.json"), - storage: Arc::new(std::sync::Mutex::new(None)), - } - } + pub fn new(paths: &LauncherPaths, log: log::Logger) -> Auth { + Auth { + log, + client: reqwest::Client::new(), + file_storage_path: paths.root().join("token.json"), + storage: Arc::new(std::sync::Mutex::new(None)), + } + } - fn with_storage(&self, op: F) -> T - where - F: FnOnce(&mut StorageWithLastRead) -> T, - { - let mut opt = self.storage.lock().unwrap(); - if let Some(s) = opt.as_mut() { - return op(s); - } + fn with_storage(&self, op: F) -> T + where + F: FnOnce(&mut StorageWithLastRead) -> T, + { + let mut opt = self.storage.lock().unwrap(); + if let Some(s) = opt.as_mut() { + return op(s); + } - let mut keyring_storage = KeyringStorage::default(); - let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone())); + let mut keyring_storage = KeyringStorage::default(); + let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone())); - let keyring_storage_result = match std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN") { - Ok(_) => Err(wrap("", "user prefers file storage")), - _ => keyring_storage.read(), - }; + let keyring_storage_result = match std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN") { + Ok(_) => Err(wrap("", "user prefers file storage")), + _ => keyring_storage.read(), + }; - let mut storage = match keyring_storage_result { - Ok(v) => StorageWithLastRead { - last_read: Cell::new(Ok(v)), - storage: Box::new(keyring_storage), - }, - Err(_) => StorageWithLastRead { - last_read: Cell::new(file_storage.read()), - storage: Box::new(file_storage), - }, - }; + let mut storage = match keyring_storage_result { + Ok(v) => StorageWithLastRead { + last_read: Cell::new(Ok(v)), + storage: Box::new(keyring_storage), + }, + Err(_) => StorageWithLastRead { + last_read: Cell::new(file_storage.read()), + storage: Box::new(file_storage), + }, + }; - let out = op(&mut storage); - *opt = Some(storage); - out - } + let out = op(&mut storage); + *opt = Some(storage); + out + } - /// Gets a tunnel Authentication for use in the tunnel management API. - pub async fn get_tunnel_authentication(&self) -> Result { - let cred = self.get_credential().await?; - let auth = match cred.provider { - AuthProvider::Microsoft => Authorization::Bearer(cred.access_token), - AuthProvider::Github => Authorization::Github(format!( - "client_id={} {}", - cred.provider.client_id(), - cred.access_token - )), - }; + /// Gets a tunnel Authentication for use in the tunnel management API. + pub async fn get_tunnel_authentication(&self) -> Result { + let cred = self.get_credential().await?; + let auth = match cred.provider { + AuthProvider::Microsoft => Authorization::Bearer(cred.access_token), + AuthProvider::Github => Authorization::Github(format!( + "client_id={} {}", + cred.provider.client_id(), + cred.access_token + )), + }; - Ok(auth) - } + Ok(auth) + } - /// Reads the current details from the keyring. - pub fn get_current_credential(&self) -> Result, WrappedError> { - self.with_storage(|storage| { - let value = storage.last_read.replace(Ok(None)); - storage.last_read.set(value.clone()); - value - }) - } + /// Reads the current details from the keyring. + pub fn get_current_credential(&self) -> Result, WrappedError> { + self.with_storage(|storage| { + let value = storage.last_read.replace(Ok(None)); + storage.last_read.set(value.clone()); + value + }) + } - /// Clears login info from the keyring. - pub fn clear_credentials(&self) -> Result<(), WrappedError> { - self.with_storage(|storage| { - storage.storage.clear()?; - storage.last_read.set(Ok(None)); - Ok(()) - }) - } + /// Clears login info from the keyring. + pub fn clear_credentials(&self) -> Result<(), WrappedError> { + self.with_storage(|storage| { + storage.storage.clear()?; + storage.last_read.set(Ok(None)); + Ok(()) + }) + } - /// Runs the login flow, optionally pre-filling a provider and/or access token. - pub async fn login( - &self, - provider: Option, - access_token: Option, - ) -> Result { - let provider = match provider { - Some(p) => p, - None => self.prompt_for_provider().await?, - }; + /// Runs the login flow, optionally pre-filling a provider and/or access token. + pub async fn login( + &self, + provider: Option, + access_token: Option, + ) -> Result { + let provider = match provider { + Some(p) => p, + None => self.prompt_for_provider().await?, + }; - let credentials = match access_token { - Some(t) => StoredCredential { - provider, - access_token: t, - refresh_token: None, - expires_at: None, - }, - None => self.do_device_code_flow_with_provider(provider).await?, - }; + let credentials = match access_token { + Some(t) => StoredCredential { + provider, + access_token: t, + refresh_token: None, + expires_at: None, + }, + None => self.do_device_code_flow_with_provider(provider).await?, + }; - self.store_credentials(credentials.clone()); - Ok(credentials) - } + self.store_credentials(credentials.clone()); + Ok(credentials) + } - /// Gets the currently stored credentials, or asks the user to log in. - pub async fn get_credential(&self) -> Result { - let entry = match self.get_current_credential() { - Ok(Some(old_creds)) => { - trace!(self.log, "Found token in keyring"); - match self.get_refreshed_token(&old_creds).await { - Ok(Some(new_creds)) => { - self.store_credentials(new_creds.clone()); - new_creds - } - Ok(None) => old_creds, - Err(e) => { - info!(self.log, "error refreshing token: {}", e); - let new_creds = self - .do_device_code_flow_with_provider(old_creds.provider) - .await?; - self.store_credentials(new_creds.clone()); - new_creds - } - } - } + /// Gets the currently stored credentials, or asks the user to log in. + pub async fn get_credential(&self) -> Result { + let entry = match self.get_current_credential() { + Ok(Some(old_creds)) => { + trace!(self.log, "Found token in keyring"); + match self.get_refreshed_token(&old_creds).await { + Ok(Some(new_creds)) => { + self.store_credentials(new_creds.clone()); + new_creds + } + Ok(None) => old_creds, + Err(e) => { + info!(self.log, "error refreshing token: {}", e); + let new_creds = self + .do_device_code_flow_with_provider(old_creds.provider) + .await?; + self.store_credentials(new_creds.clone()); + new_creds + } + } + } - Ok(None) => { - trace!(self.log, "No token in keyring, getting a new one"); - let creds = self.do_device_code_flow().await?; - self.store_credentials(creds.clone()); - creds - } + Ok(None) => { + trace!(self.log, "No token in keyring, getting a new one"); + let creds = self.do_device_code_flow().await?; + self.store_credentials(creds.clone()); + creds + } - Err(e) => { - warning!( - self.log, - "Error reading token from keyring, getting a new one: {}", - e - ); - let creds = self.do_device_code_flow().await?; - self.store_credentials(creds.clone()); - creds - } - }; + Err(e) => { + warning!( + self.log, + "Error reading token from keyring, getting a new one: {}", + e + ); + let creds = self.do_device_code_flow().await?; + self.store_credentials(creds.clone()); + creds + } + }; - Ok(entry) - } + Ok(entry) + } - /// Stores credentials, logging a warning if it fails. - fn store_credentials(&self, creds: StoredCredential) { - self.with_storage(|storage| { - if let Err(e) = storage.storage.store(creds.clone()) { - warning!( - self.log, - "Failed to update keyring with new credentials: {}", - e - ); - } - storage.last_read.set(Ok(Some(creds))); - }) - } + /// Stores credentials, logging a warning if it fails. + fn store_credentials(&self, creds: StoredCredential) { + self.with_storage(|storage| { + if let Err(e) = storage.storage.store(creds.clone()) { + warning!( + self.log, + "Failed to update keyring with new credentials: {}", + e + ); + } + storage.last_read.set(Ok(Some(creds))); + }) + } - /// Refreshes the token in the credentials if necessary. Returns None if - /// the token is up to date, or Some new token otherwise. - async fn get_refreshed_token( - &self, - creds: &StoredCredential, - ) -> Result, AnyError> { - if !creds.is_expired(&self.client).await { - return Ok(None); - } + /// Refreshes the token in the credentials if necessary. Returns None if + /// the token is up to date, or Some new token otherwise. + async fn get_refreshed_token( + &self, + creds: &StoredCredential, + ) -> Result, AnyError> { + if !creds.is_expired(&self.client).await { + return Ok(None); + } - let refresh_token = match &creds.refresh_token { - Some(t) => t, - None => return Err(AnyError::from(RefreshTokenNotAvailableError())), - }; + let refresh_token = match &creds.refresh_token { + Some(t) => t, + None => return Err(AnyError::from(RefreshTokenNotAvailableError())), + }; - self.do_grant( - creds.provider, - format!( - "client_id={}&grant_type=refresh_token&refresh_token={}", - creds.provider.client_id(), - refresh_token - ), - ) - .await - .map(Some) - } + self.do_grant( + creds.provider, + format!( + "client_id={}&grant_type=refresh_token&refresh_token={}", + creds.provider.client_id(), + refresh_token + ), + ) + .await + .map(Some) + } - /// Does a "grant token" request. - async fn do_grant( - &self, - provider: AuthProvider, - body: String, - ) -> Result { - let response = self - .client - .post(provider.grant_uri()) - .body(body) - .header("Accept", "application/json") - .send() - .await?; + /// Does a "grant token" request. + async fn do_grant( + &self, + provider: AuthProvider, + body: String, + ) -> Result { + let response = self + .client + .post(provider.grant_uri()) + .body(body) + .header("Accept", "application/json") + .send() + .await?; - if !response.status().is_success() { - return Err(StatusError::from_res(response).await?.into()); - } + if !response.status().is_success() { + return Err(StatusError::from_res(response).await?.into()); + } - let body = response.json::().await?; - Ok(StoredCredential::from_response(body, provider)) - } + let body = response.json::().await?; + Ok(StoredCredential::from_response(body, provider)) + } - /// Implements the device code flow, returning the credentials upon success. - async fn do_device_code_flow(&self) -> Result { - let provider = self.prompt_for_provider().await?; - self.do_device_code_flow_with_provider(provider).await - } + /// Implements the device code flow, returning the credentials upon success. + async fn do_device_code_flow(&self) -> Result { + let provider = self.prompt_for_provider().await?; + self.do_device_code_flow_with_provider(provider).await + } - async fn prompt_for_provider(&self) -> Result { - if std::env::var("VSCODE_CLI_ALLOW_MS_AUTH").is_err() { - return Ok(AuthProvider::Github); - } + async fn prompt_for_provider(&self) -> Result { + if std::env::var("VSCODE_CLI_ALLOW_MS_AUTH").is_err() { + return Ok(AuthProvider::Github); + } - let provider = prompt_options( - "How would you like to log in to VS Code?", - &[AuthProvider::Microsoft, AuthProvider::Github], - )?; + let provider = prompt_options( + "How would you like to log in to VS Code?", + &[AuthProvider::Microsoft, AuthProvider::Github], + )?; - Ok(provider) - } + Ok(provider) + } - async fn do_device_code_flow_with_provider( - &self, - provider: AuthProvider, - ) -> Result { - loop { - let init_code = self - .client - .post(provider.code_uri()) - .header("Accept", "application/json") - .body(format!( - "client_id={}&scope={}", - provider.client_id(), - provider.get_default_scopes(), - )) - .send() - .await?; + async fn do_device_code_flow_with_provider( + &self, + provider: AuthProvider, + ) -> Result { + loop { + let init_code = self + .client + .post(provider.code_uri()) + .header("Accept", "application/json") + .body(format!( + "client_id={}&scope={}", + provider.client_id(), + provider.get_default_scopes(), + )) + .send() + .await?; - if !init_code.status().is_success() { - return Err(StatusError::from_res(init_code).await?.into()); - } + if !init_code.status().is_success() { + return Err(StatusError::from_res(init_code).await?.into()); + } - let init_code_json = init_code.json::().await?; - let expires_at = Utc::now() + chrono::Duration::seconds(init_code_json.expires_in); + let init_code_json = init_code.json::().await?; + let expires_at = Utc::now() + chrono::Duration::seconds(init_code_json.expires_in); - match &init_code_json.message { - Some(m) => self.log.result(m), - None => self.log.result(&format!( - "To grant access to the server, please log into {} and use code {}", - init_code_json.verification_uri, init_code_json.user_code - )), - }; + match &init_code_json.message { + Some(m) => self.log.result(m), + None => self.log.result(&format!( + "To grant access to the server, please log into {} and use code {}", + init_code_json.verification_uri, init_code_json.user_code + )), + }; - let body = format!( + let body = format!( "client_id={}&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code={}", provider.client_id(), init_code_json.device_code ); - while Utc::now() < expires_at { - sleep(std::time::Duration::from_secs(5)).await; + while Utc::now() < expires_at { + sleep(std::time::Duration::from_secs(5)).await; - match self.do_grant(provider, body.clone()).await { - Ok(creds) => return Ok(creds), - Err(e) => { - trace!(self.log, "refresh poll failed, retrying: {}", e); - } - } - } - } - } + match self.do_grant(provider, body.clone()).await { + Ok(creds) => return Ok(creds), + Err(e) => { + trace!(self.log, "refresh poll failed, retrying: {}", e); + } + } + } + } + } } #[async_trait] impl AuthorizationProvider for Auth { - async fn get_authorization(&self) -> Result { - self.get_tunnel_authentication() - .await - .map_err(|e| HttpError::AuthorizationError(e.to_string())) - } + async fn get_authorization(&self) -> Result { + self.get_tunnel_authentication() + .await + .map_err(|e| HttpError::AuthorizationError(e.to_string())) + } } lazy_static::lazy_static! { - static ref HOSTNAME: Vec = gethostname().to_string_lossy().bytes().collect(); + static ref HOSTNAME: Vec = gethostname().to_string_lossy().bytes().collect(); } #[cfg(feature = "vscode-encrypt")] fn encrypt(value: &str) -> String { - vscode_encrypt::encrypt(&HOSTNAME, value.as_bytes()).expect("expected to encrypt") + vscode_encrypt::encrypt(&HOSTNAME, value.as_bytes()).expect("expected to encrypt") } #[cfg(feature = "vscode-encrypt")] fn decrypt(value: &str) -> Option { - let b = vscode_encrypt::decrypt(&HOSTNAME, value).ok()?; - String::from_utf8(b).ok() + let b = vscode_encrypt::decrypt(&HOSTNAME, value).ok()?; + String::from_utf8(b).ok() } #[cfg(not(feature = "vscode-encrypt"))] fn encrypt(value: &str) -> String { - value.to_owned() + value.to_owned() } #[cfg(not(feature = "vscode-encrypt"))] fn decrypt(value: &str) -> Option { - Some(value.to_owned()) + Some(value.to_owned()) } diff --git a/cli/src/bin/code-tunnel/main.rs b/cli/src/bin/code-tunnel/main.rs index 05fdbcb8bb3..5e4a574be2d 100644 --- a/cli/src/bin/code-tunnel/main.rs +++ b/cli/src/bin/code-tunnel/main.rs @@ -5,9 +5,9 @@ use clap::Parser; use cli::{ - commands::{args, tunnels, CommandContext}, - constants, log as own_log, - state::LauncherPaths, + commands::{args, tunnels, CommandContext}, + constants, log as own_log, + state::LauncherPaths, }; use opentelemetry::sdk::trace::TracerProvider as SdkTracerProvider; use opentelemetry::trace::TracerProvider; @@ -21,11 +21,11 @@ use log::{Level, Metadata, Record}; version = match constants::VSCODE_CLI_VERSION { Some(v) => v, None => "dev" }, )] pub struct TunnelCli { - #[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))] - pub global_options: args::GlobalOptions, + #[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))] + pub global_options: args::GlobalOptions, - #[clap(flatten, next_help_heading = Some("TUNNEL OPTIONS"))] - pub tunnel_options: args::TunnelArgs, + #[clap(flatten, next_help_heading = Some("TUNNEL OPTIONS"))] + pub tunnel_options: args::TunnelArgs, } /// Entrypoint for a standalone "code-tunnel" subcommand. This is a temporary @@ -33,56 +33,56 @@ pub struct TunnelCli { /// code in here is duplicated from `src/bin/code/main.rs` #[tokio::main] async fn main() -> Result<(), std::convert::Infallible> { - let parsed = TunnelCli::parse(); - let context = CommandContext { - http: reqwest::Client::new(), - paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(), - log: own_log::Logger::new( - SdkTracerProvider::builder().build().tracer("codecli"), - if parsed.global_options.verbose { - own_log::Level::Trace - } else { - parsed.global_options.log.unwrap_or(own_log::Level::Info) - }, - ), - args: args::Cli { - global_options: parsed.global_options, - subcommand: Some(args::Commands::Tunnel(parsed.tunnel_options.clone())), - ..Default::default() - }, - }; + let parsed = TunnelCli::parse(); + let context = CommandContext { + http: reqwest::Client::new(), + paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(), + log: own_log::Logger::new( + SdkTracerProvider::builder().build().tracer("codecli"), + if parsed.global_options.verbose { + own_log::Level::Trace + } else { + parsed.global_options.log.unwrap_or(own_log::Level::Info) + }, + ), + args: args::Cli { + global_options: parsed.global_options, + subcommand: Some(args::Commands::Tunnel(parsed.tunnel_options.clone())), + ..Default::default() + }, + }; - log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone())))) - .map(|()| log::set_max_level(log::LevelFilter::Debug)) - .expect("expected to make logger"); + log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone())))) + .map(|()| log::set_max_level(log::LevelFilter::Debug)) + .expect("expected to make logger"); - let result = match parsed.tunnel_options.subcommand { - Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await, - Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await, - Some(args::TunnelSubcommand::Rename(rename_args)) => { - tunnels::rename(context, rename_args).await - } - Some(args::TunnelSubcommand::User(user_command)) => { - tunnels::user(context, user_command).await - } - Some(args::TunnelSubcommand::Service(service_args)) => { - tunnels::service(context, service_args).await - } - None => tunnels::serve(context, parsed.tunnel_options.serve_args).await, - }; + let result = match parsed.tunnel_options.subcommand { + Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await, + Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await, + Some(args::TunnelSubcommand::Rename(rename_args)) => { + tunnels::rename(context, rename_args).await + } + Some(args::TunnelSubcommand::User(user_command)) => { + tunnels::user(context, user_command).await + } + Some(args::TunnelSubcommand::Service(service_args)) => { + tunnels::service(context, service_args).await + } + None => tunnels::serve(context, parsed.tunnel_options.serve_args).await, + }; - match result { - Err(e) => print_and_exit(e), - Ok(code) => std::process::exit(code), - } + match result { + Err(e) => print_and_exit(e), + Ok(code) => std::process::exit(code), + } } fn print_and_exit(err: E) -> ! where - E: std::fmt::Display, + E: std::fmt::Display, { - own_log::emit(own_log::Level::Error, "", &format!("{}", err)); - std::process::exit(1); + own_log::emit(own_log::Level::Error, "", &format!("{}", err)); + std::process::exit(1); } /// Logger that uses the common rust "log" crate and directs back to one of @@ -90,34 +90,34 @@ where struct RustyLogger(own_log::Logger); impl log::Log for RustyLogger { - fn enabled(&self, metadata: &Metadata) -> bool { - metadata.level() <= Level::Debug - } + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= Level::Debug + } - fn log(&self, record: &Record) { - if !self.enabled(record.metadata()) { - return; - } + fn log(&self, record: &Record) { + if !self.enabled(record.metadata()) { + return; + } - // exclude noisy log modules: - let src = match record.module_path() { - Some("russh::cipher") => return, - Some("russh::negotiation") => return, - Some(s) => s, - None => "", - }; + // exclude noisy log modules: + let src = match record.module_path() { + Some("russh::cipher") => return, + Some("russh::negotiation") => return, + Some(s) => s, + None => "", + }; - self.0.emit( - match record.level() { - log::Level::Debug => own_log::Level::Debug, - log::Level::Error => own_log::Level::Error, - log::Level::Info => own_log::Level::Info, - log::Level::Trace => own_log::Level::Trace, - log::Level::Warn => own_log::Level::Warn, - }, - &format!("[{}] {}", src, record.args()), - ); - } + self.0.emit( + match record.level() { + log::Level::Debug => own_log::Level::Debug, + log::Level::Error => own_log::Level::Error, + log::Level::Info => own_log::Level::Info, + log::Level::Trace => own_log::Level::Trace, + log::Level::Warn => own_log::Level::Warn, + }, + &format!("[{}] {}", src, record.args()), + ); + } - fn flush(&self) {} + fn flush(&self) {} } diff --git a/cli/src/bin/code/legacy_args.rs b/cli/src/bin/code/legacy_args.rs index 0cc5be3d4c6..54dbb9ac32d 100644 --- a/cli/src/bin/code/legacy_args.rs +++ b/cli/src/bin/code/legacy_args.rs @@ -6,229 +6,229 @@ use std::collections::HashMap; use cli::commands::args::{ - Cli, Commands, DesktopCodeOptions, ExtensionArgs, ExtensionSubcommand, InstallExtensionArgs, - ListExtensionArgs, UninstallExtensionArgs, + Cli, Commands, DesktopCodeOptions, ExtensionArgs, ExtensionSubcommand, InstallExtensionArgs, + ListExtensionArgs, UninstallExtensionArgs, }; /// Tries to parse the argv using the legacy CLI interface, looking for its /// flags and generating a CLI with subcommands if those don't exist. pub fn try_parse_legacy( - iter: impl IntoIterator>, + iter: impl IntoIterator>, ) -> Option { - let raw = clap_lex::RawArgs::new(iter); - let mut cursor = raw.cursor(); - raw.next(&mut cursor); // Skip the bin + let raw = clap_lex::RawArgs::new(iter); + let mut cursor = raw.cursor(); + raw.next(&mut cursor); // Skip the bin - // First make a hashmap of all flags and capture positional arguments. - let mut args: HashMap> = HashMap::new(); - let mut last_arg = None; - while let Some(arg) = raw.next(&mut cursor) { - if let Some((long, value)) = arg.to_long() { - if let Ok(long) = long { - last_arg = Some(long.to_string()); - match args.get_mut(long) { - Some(prev) => { - if let Some(v) = value { - prev.push(v.to_str_lossy().to_string()); - } - } - None => { - if let Some(v) = value { - args.insert(long.to_string(), vec![v.to_str_lossy().to_string()]); - } else { - args.insert(long.to_string(), vec![]); - } - } - } - } - } else if let Ok(value) = arg.to_value() { - if let Some(last_arg) = &last_arg { - args.get_mut(last_arg) - .expect("expected to have last arg") - .push(value.to_string()); - } - } - } + // First make a hashmap of all flags and capture positional arguments. + let mut args: HashMap> = HashMap::new(); + let mut last_arg = None; + while let Some(arg) = raw.next(&mut cursor) { + if let Some((long, value)) = arg.to_long() { + if let Ok(long) = long { + last_arg = Some(long.to_string()); + match args.get_mut(long) { + Some(prev) => { + if let Some(v) = value { + prev.push(v.to_str_lossy().to_string()); + } + } + None => { + if let Some(v) = value { + args.insert(long.to_string(), vec![v.to_str_lossy().to_string()]); + } else { + args.insert(long.to_string(), vec![]); + } + } + } + } + } else if let Ok(value) = arg.to_value() { + if let Some(last_arg) = &last_arg { + args.get_mut(last_arg) + .expect("expected to have last arg") + .push(value.to_string()); + } + } + } - let get_first_arg_value = - |key: &str| args.get(key).and_then(|v| v.first()).map(|s| s.to_string()); - let desktop_code_options = DesktopCodeOptions { - extensions_dir: get_first_arg_value("extensions-dir"), - user_data_dir: get_first_arg_value("user-data-dir"), - use_version: None, - }; + let get_first_arg_value = + |key: &str| args.get(key).and_then(|v| v.first()).map(|s| s.to_string()); + let desktop_code_options = DesktopCodeOptions { + extensions_dir: get_first_arg_value("extensions-dir"), + user_data_dir: get_first_arg_value("user-data-dir"), + use_version: None, + }; - // Now translate them to subcommands. - // --list-extensions -> ext list - // --install-extension=id -> ext install - // --uninstall-extension=id -> ext uninstall - // --status -> status + // Now translate them to subcommands. + // --list-extensions -> ext list + // --install-extension=id -> ext install + // --uninstall-extension=id -> ext uninstall + // --status -> status - if args.contains_key("list-extensions") { - Some(Cli { - subcommand: Some(Commands::Extension(ExtensionArgs { - subcommand: ExtensionSubcommand::List(ListExtensionArgs { - category: get_first_arg_value("category"), - show_versions: args.contains_key("show-versions"), - }), - desktop_code_options, - })), - ..Default::default() - }) - } else if let Some(exts) = args.remove("install-extension") { - Some(Cli { - subcommand: Some(Commands::Extension(ExtensionArgs { - subcommand: ExtensionSubcommand::Install(InstallExtensionArgs { - id_or_path: exts, - pre_release: args.contains_key("pre-release"), - force: args.contains_key("force"), - }), - desktop_code_options, - })), - ..Default::default() - }) - } else if let Some(exts) = args.remove("uninstall-extension") { - Some(Cli { - subcommand: Some(Commands::Extension(ExtensionArgs { - subcommand: ExtensionSubcommand::Uninstall(UninstallExtensionArgs { id: exts }), - desktop_code_options, - })), - ..Default::default() - }) - } else if args.contains_key("status") { - Some(Cli { - subcommand: Some(Commands::Status), - ..Default::default() - }) - } else { - None - } + if args.contains_key("list-extensions") { + Some(Cli { + subcommand: Some(Commands::Extension(ExtensionArgs { + subcommand: ExtensionSubcommand::List(ListExtensionArgs { + category: get_first_arg_value("category"), + show_versions: args.contains_key("show-versions"), + }), + desktop_code_options, + })), + ..Default::default() + }) + } else if let Some(exts) = args.remove("install-extension") { + Some(Cli { + subcommand: Some(Commands::Extension(ExtensionArgs { + subcommand: ExtensionSubcommand::Install(InstallExtensionArgs { + id_or_path: exts, + pre_release: args.contains_key("pre-release"), + force: args.contains_key("force"), + }), + desktop_code_options, + })), + ..Default::default() + }) + } else if let Some(exts) = args.remove("uninstall-extension") { + Some(Cli { + subcommand: Some(Commands::Extension(ExtensionArgs { + subcommand: ExtensionSubcommand::Uninstall(UninstallExtensionArgs { id: exts }), + desktop_code_options, + })), + ..Default::default() + }) + } else if args.contains_key("status") { + Some(Cli { + subcommand: Some(Commands::Status), + ..Default::default() + }) + } else { + None + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_parses_list_extensions() { - let args = vec![ - "code", - "--list-extensions", - "--category", - "themes", - "--show-versions", - ]; - let cli = try_parse_legacy(args.into_iter()).unwrap(); + #[test] + fn test_parses_list_extensions() { + let args = vec![ + "code", + "--list-extensions", + "--category", + "themes", + "--show-versions", + ]; + let cli = try_parse_legacy(args.into_iter()).unwrap(); - if let Some(Commands::Extension(extension_args)) = cli.subcommand { - if let ExtensionSubcommand::List(list_args) = extension_args.subcommand { - assert_eq!(list_args.category, Some("themes".to_string())); - assert!(list_args.show_versions); - } else { - panic!( - "Expected list subcommand, got {:?}", - extension_args.subcommand - ); - } - } else { - panic!("Expected extension subcommand, got {:?}", cli.subcommand); - } - } + if let Some(Commands::Extension(extension_args)) = cli.subcommand { + if let ExtensionSubcommand::List(list_args) = extension_args.subcommand { + assert_eq!(list_args.category, Some("themes".to_string())); + assert!(list_args.show_versions); + } else { + panic!( + "Expected list subcommand, got {:?}", + extension_args.subcommand + ); + } + } else { + panic!("Expected extension subcommand, got {:?}", cli.subcommand); + } + } - #[test] - fn test_parses_install_extension() { - let args = vec![ - "code", - "--install-extension", - "connor4312.codesong", - "connor4312.hello-world", - "--pre-release", - "--force", - ]; - let cli = try_parse_legacy(args.into_iter()).unwrap(); + #[test] + fn test_parses_install_extension() { + let args = vec![ + "code", + "--install-extension", + "connor4312.codesong", + "connor4312.hello-world", + "--pre-release", + "--force", + ]; + let cli = try_parse_legacy(args.into_iter()).unwrap(); - if let Some(Commands::Extension(extension_args)) = cli.subcommand { - if let ExtensionSubcommand::Install(install_args) = extension_args.subcommand { - assert_eq!( - install_args.id_or_path, - vec!["connor4312.codesong", "connor4312.hello-world"] - ); - assert!(install_args.pre_release); - assert!(install_args.force); - } else { - panic!( - "Expected install subcommand, got {:?}", - extension_args.subcommand - ); - } - } else { - panic!("Expected extension subcommand, got {:?}", cli.subcommand); - } - } + if let Some(Commands::Extension(extension_args)) = cli.subcommand { + if let ExtensionSubcommand::Install(install_args) = extension_args.subcommand { + assert_eq!( + install_args.id_or_path, + vec!["connor4312.codesong", "connor4312.hello-world"] + ); + assert!(install_args.pre_release); + assert!(install_args.force); + } else { + panic!( + "Expected install subcommand, got {:?}", + extension_args.subcommand + ); + } + } else { + panic!("Expected extension subcommand, got {:?}", cli.subcommand); + } + } - #[test] - fn test_parses_uninstall_extension() { - let args = vec!["code", "--uninstall-extension", "connor4312.codesong"]; - let cli = try_parse_legacy(args.into_iter()).unwrap(); + #[test] + fn test_parses_uninstall_extension() { + let args = vec!["code", "--uninstall-extension", "connor4312.codesong"]; + let cli = try_parse_legacy(args.into_iter()).unwrap(); - if let Some(Commands::Extension(extension_args)) = cli.subcommand { - if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand { - assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]); - } else { - panic!( - "Expected uninstall subcommand, got {:?}", - extension_args.subcommand - ); - } - } else { - panic!("Expected extension subcommand, got {:?}", cli.subcommand); - } - } + if let Some(Commands::Extension(extension_args)) = cli.subcommand { + if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand { + assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]); + } else { + panic!( + "Expected uninstall subcommand, got {:?}", + extension_args.subcommand + ); + } + } else { + panic!("Expected extension subcommand, got {:?}", cli.subcommand); + } + } - #[test] - fn test_parses_user_data_dir_and_extensions_dir() { - let args = vec![ - "code", - "--uninstall-extension", - "connor4312.codesong", - "--user-data-dir", - "foo", - "--extensions-dir", - "bar", - ]; - let cli = try_parse_legacy(args.into_iter()).unwrap(); + #[test] + fn test_parses_user_data_dir_and_extensions_dir() { + let args = vec![ + "code", + "--uninstall-extension", + "connor4312.codesong", + "--user-data-dir", + "foo", + "--extensions-dir", + "bar", + ]; + let cli = try_parse_legacy(args.into_iter()).unwrap(); - if let Some(Commands::Extension(extension_args)) = cli.subcommand { - assert_eq!( - extension_args.desktop_code_options.user_data_dir, - Some("foo".to_string()) - ); - assert_eq!( - extension_args.desktop_code_options.extensions_dir, - Some("bar".to_string()) - ); - if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand { - assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]); - } else { - panic!( - "Expected uninstall subcommand, got {:?}", - extension_args.subcommand - ); - } - } else { - panic!("Expected extension subcommand, got {:?}", cli.subcommand); - } - } + if let Some(Commands::Extension(extension_args)) = cli.subcommand { + assert_eq!( + extension_args.desktop_code_options.user_data_dir, + Some("foo".to_string()) + ); + assert_eq!( + extension_args.desktop_code_options.extensions_dir, + Some("bar".to_string()) + ); + if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand { + assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]); + } else { + panic!( + "Expected uninstall subcommand, got {:?}", + extension_args.subcommand + ); + } + } else { + panic!("Expected extension subcommand, got {:?}", cli.subcommand); + } + } - #[test] - fn test_status() { - let args = vec!["code", "--status"]; - let cli = try_parse_legacy(args.into_iter()).unwrap(); + #[test] + fn test_status() { + let args = vec!["code", "--status"]; + let cli = try_parse_legacy(args.into_iter()).unwrap(); - if let Some(Commands::Status) = cli.subcommand { - // no-op - } else { - panic!("Expected extension subcommand, got {:?}", cli.subcommand); - } - } + if let Some(Commands::Status) = cli.subcommand { + // no-op + } else { + panic!("Expected extension subcommand, got {:?}", cli.subcommand); + } + } } diff --git a/cli/src/bin/code/main.rs b/cli/src/bin/code/main.rs index d05db59b48e..b483ad17cb8 100644 --- a/cli/src/bin/code/main.rs +++ b/cli/src/bin/code/main.rs @@ -8,14 +8,14 @@ use std::process::Command; use clap::Parser; use cli::{ - commands::{args, tunnels, version, CommandContext}, - desktop, log as own_log, - state::LauncherPaths, - update_service::UpdateService, - util::{ - errors::{wrap, AnyError}, - prereqs::PreReqChecker, - }, + commands::{args, tunnels, version, CommandContext}, + desktop, log as own_log, + state::LauncherPaths, + update_service::UpdateService, + util::{ + errors::{wrap, AnyError}, + prereqs::PreReqChecker, + }, }; use legacy_args::try_parse_legacy; use opentelemetry::sdk::trace::TracerProvider as SdkTracerProvider; @@ -25,110 +25,110 @@ use log::{Level, Metadata, Record}; #[tokio::main] async fn main() -> Result<(), std::convert::Infallible> { - let raw_args = std::env::args_os().collect::>(); - let parsed = try_parse_legacy(&raw_args).unwrap_or_else(|| args::Cli::parse_from(&raw_args)); - let context = CommandContext { - http: reqwest::Client::new(), - paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(), - log: own_log::Logger::new( - SdkTracerProvider::builder().build().tracer("codecli"), - if parsed.global_options.verbose { - own_log::Level::Trace - } else { - parsed.global_options.log.unwrap_or(own_log::Level::Info) - }, - ), - args: parsed, - }; + let raw_args = std::env::args_os().collect::>(); + let parsed = try_parse_legacy(&raw_args).unwrap_or_else(|| args::Cli::parse_from(&raw_args)); + let context = CommandContext { + http: reqwest::Client::new(), + paths: LauncherPaths::new(&parsed.global_options.cli_data_dir).unwrap(), + log: own_log::Logger::new( + SdkTracerProvider::builder().build().tracer("codecli"), + if parsed.global_options.verbose { + own_log::Level::Trace + } else { + parsed.global_options.log.unwrap_or(own_log::Level::Info) + }, + ), + args: parsed, + }; - log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone())))) - .map(|()| log::set_max_level(log::LevelFilter::Debug)) - .expect("expected to make logger"); + log::set_logger(Box::leak(Box::new(RustyLogger(context.log.clone())))) + .map(|()| log::set_max_level(log::LevelFilter::Debug)) + .expect("expected to make logger"); - let result = match context.args.subcommand.clone() { - None => { - let ca = context.args.get_base_code_args(); - start_code(context, ca).await - } + let result = match context.args.subcommand.clone() { + None => { + let ca = context.args.get_base_code_args(); + start_code(context, ca).await + } - Some(args::Commands::Extension(extension_args)) => { - let mut ca = context.args.get_base_code_args(); - extension_args.add_code_args(&mut ca); - start_code(context, ca).await - } + Some(args::Commands::Extension(extension_args)) => { + let mut ca = context.args.get_base_code_args(); + extension_args.add_code_args(&mut ca); + start_code(context, ca).await + } - Some(args::Commands::Status) => { - let mut ca = context.args.get_base_code_args(); - ca.push("--status".to_string()); - start_code(context, ca).await - } + Some(args::Commands::Status) => { + let mut ca = context.args.get_base_code_args(); + ca.push("--status".to_string()); + start_code(context, ca).await + } - Some(args::Commands::Version(version_args)) => match version_args.subcommand { - args::VersionSubcommand::Use(use_version_args) => { - version::switch_to(context, use_version_args).await - } - args::VersionSubcommand::Uninstall(uninstall_version_args) => { - version::uninstall(context, uninstall_version_args).await - } - args::VersionSubcommand::List(list_version_args) => { - version::list(context, list_version_args).await - } - }, + Some(args::Commands::Version(version_args)) => match version_args.subcommand { + args::VersionSubcommand::Use(use_version_args) => { + version::switch_to(context, use_version_args).await + } + args::VersionSubcommand::Uninstall(uninstall_version_args) => { + version::uninstall(context, uninstall_version_args).await + } + args::VersionSubcommand::List(list_version_args) => { + version::list(context, list_version_args).await + } + }, - Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand { - Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await, - Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await, - Some(args::TunnelSubcommand::Rename(rename_args)) => { - tunnels::rename(context, rename_args).await - } - Some(args::TunnelSubcommand::User(user_command)) => { - tunnels::user(context, user_command).await - } - Some(args::TunnelSubcommand::Service(service_args)) => { - tunnels::service(context, service_args).await - } - None => tunnels::serve(context, tunnel_args.serve_args).await, - }, - }; + Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand { + Some(args::TunnelSubcommand::Prune) => tunnels::prune(context).await, + Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context).await, + Some(args::TunnelSubcommand::Rename(rename_args)) => { + tunnels::rename(context, rename_args).await + } + Some(args::TunnelSubcommand::User(user_command)) => { + tunnels::user(context, user_command).await + } + Some(args::TunnelSubcommand::Service(service_args)) => { + tunnels::service(context, service_args).await + } + None => tunnels::serve(context, tunnel_args.serve_args).await, + }, + }; - match result { - Err(e) => print_and_exit(e), - Ok(code) => std::process::exit(code), - } + match result { + Err(e) => print_and_exit(e), + Ok(code) => std::process::exit(code), + } } fn print_and_exit(err: E) -> ! where - E: std::fmt::Display, + E: std::fmt::Display, { - own_log::emit(own_log::Level::Error, "", &format!("{}", err)); - std::process::exit(1); + own_log::emit(own_log::Level::Error, "", &format!("{}", err)); + std::process::exit(1); } async fn start_code(context: CommandContext, args: Vec) -> Result { - let platform = PreReqChecker::new().verify().await?; - let version_manager = desktop::CodeVersionManager::new(&context.paths, platform); - let update_service = UpdateService::new(context.log.clone(), context.http.clone()); - let version = match &context.args.editor_options.code_options.use_version { - Some(v) => desktop::RequestedVersion::try_from(v.as_str())?, - None => version_manager.get_preferred_version(), - }; + let platform = PreReqChecker::new().verify().await?; + let version_manager = desktop::CodeVersionManager::new(&context.paths, platform); + let update_service = UpdateService::new(context.log.clone(), context.http.clone()); + let version = match &context.args.editor_options.code_options.use_version { + Some(v) => desktop::RequestedVersion::try_from(v.as_str())?, + None => version_manager.get_preferred_version(), + }; - let binary = match version_manager.try_get_entrypoint(&version).await { - Some(ep) => ep, - None => { - desktop::prompt_to_install(&version)?; - version_manager.install(&update_service, &version).await? - } - }; + let binary = match version_manager.try_get_entrypoint(&version).await { + Some(ep) => ep, + None => { + desktop::prompt_to_install(&version)?; + version_manager.install(&update_service, &version).await? + } + }; - let code = Command::new(binary) - .args(args) - .status() - .map(|s| s.code().unwrap_or(1)) - .map_err(|e| wrap(e, "error running VS Code"))?; + let code = Command::new(binary) + .args(args) + .status() + .map(|s| s.code().unwrap_or(1)) + .map_err(|e| wrap(e, "error running VS Code"))?; - Ok(code) + Ok(code) } /// Logger that uses the common rust "log" crate and directs back to one of @@ -136,34 +136,34 @@ async fn start_code(context: CommandContext, args: Vec) -> Result bool { - metadata.level() <= Level::Debug - } + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= Level::Debug + } - fn log(&self, record: &Record) { - if !self.enabled(record.metadata()) { - return; - } + fn log(&self, record: &Record) { + if !self.enabled(record.metadata()) { + return; + } - // exclude noisy log modules: - let src = match record.module_path() { - Some("russh::cipher") => return, - Some("russh::negotiation") => return, - Some(s) => s, - None => "", - }; + // exclude noisy log modules: + let src = match record.module_path() { + Some("russh::cipher") => return, + Some("russh::negotiation") => return, + Some(s) => s, + None => "", + }; - self.0.emit( - match record.level() { - log::Level::Debug => own_log::Level::Debug, - log::Level::Error => own_log::Level::Error, - log::Level::Info => own_log::Level::Info, - log::Level::Trace => own_log::Level::Trace, - log::Level::Warn => own_log::Level::Warn, - }, - &format!("[{}] {}", src, record.args()), - ); - } + self.0.emit( + match record.level() { + log::Level::Debug => own_log::Level::Debug, + log::Level::Error => own_log::Level::Error, + log::Level::Info => own_log::Level::Info, + log::Level::Trace => own_log::Level::Trace, + log::Level::Warn => own_log::Level::Warn, + }, + &format!("[{}] {}", src, record.args()), + ); + } - fn flush(&self) {} + fn flush(&self) {} } diff --git a/cli/src/commands/args.rs b/cli/src/commands/args.rs index 09e9090bdf2..8b4c1923048 100644 --- a/cli/src/commands/args.rs +++ b/cli/src/commands/args.rs @@ -25,566 +25,566 @@ const TEMPLATE: &str = " version = match constants::VSCODE_CLI_VERSION { Some(v) => v, None => "dev" }, )] pub struct Cli { - /// One or more files, folders, or URIs to open. - #[clap(name = "paths")] - pub open_paths: Vec, + /// One or more files, folders, or URIs to open. + #[clap(name = "paths")] + pub open_paths: Vec, - #[clap(flatten, next_help_heading = Some("EDITOR OPTIONS"))] - pub editor_options: EditorOptions, + #[clap(flatten, next_help_heading = Some("EDITOR OPTIONS"))] + pub editor_options: EditorOptions, - #[clap(flatten, next_help_heading = Some("EDITOR TROUBLESHOOTING"))] - pub troubleshooting: EditorTroubleshooting, + #[clap(flatten, next_help_heading = Some("EDITOR TROUBLESHOOTING"))] + pub troubleshooting: EditorTroubleshooting, - #[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))] - pub global_options: GlobalOptions, + #[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))] + pub global_options: GlobalOptions, - #[clap(subcommand)] - pub subcommand: Option, + #[clap(subcommand)] + pub subcommand: Option, } impl Cli { - pub fn get_base_code_args(&self) -> Vec { - let mut args = self.open_paths.clone(); - self.editor_options.add_code_args(&mut args); - self.troubleshooting.add_code_args(&mut args); - self.global_options.add_code_args(&mut args); - args - } + pub fn get_base_code_args(&self) -> Vec { + let mut args = self.open_paths.clone(); + self.editor_options.add_code_args(&mut args); + self.troubleshooting.add_code_args(&mut args); + self.global_options.add_code_args(&mut args); + args + } } impl<'a> From<&'a Cli> for CodeServerArgs { - fn from(cli: &'a Cli) -> Self { - let mut args = CodeServerArgs { - log: cli.global_options.log, - accept_server_license_terms: true, - ..Default::default() - }; + fn from(cli: &'a Cli) -> Self { + let mut args = CodeServerArgs { + log: cli.global_options.log, + accept_server_license_terms: true, + ..Default::default() + }; - args.log = cli.global_options.log; - args.accept_server_license_terms = true; + args.log = cli.global_options.log; + args.accept_server_license_terms = true; - if cli.global_options.verbose { - args.verbose = true; - } + if cli.global_options.verbose { + args.verbose = true; + } - if cli.global_options.disable_telemetry { - args.telemetry_level = Some(options::TelemetryLevel::Off); - } else if cli.global_options.telemetry_level.is_some() { - args.telemetry_level = cli.global_options.telemetry_level; - } + if cli.global_options.disable_telemetry { + args.telemetry_level = Some(options::TelemetryLevel::Off); + } else if cli.global_options.telemetry_level.is_some() { + args.telemetry_level = cli.global_options.telemetry_level; + } - args - } + args + } } #[derive(Subcommand, Debug, Clone)] pub enum Commands { - /// Create a tunnel that's accessible on vscode.dev from anywhere. - /// Run `code tunnel --help` for more usage info. - Tunnel(TunnelArgs), + /// Create a tunnel that's accessible on vscode.dev from anywhere. + /// Run `code tunnel --help` for more usage info. + Tunnel(TunnelArgs), - /// Manage VS Code extensions. - #[clap(name = "ext")] - Extension(ExtensionArgs), + /// Manage VS Code extensions. + #[clap(name = "ext")] + Extension(ExtensionArgs), - /// Print process usage and diagnostics information. - Status, + /// Print process usage and diagnostics information. + Status, - /// Changes the version of VS Code you're using. - Version(VersionArgs), + /// Changes the version of VS Code you're using. + Version(VersionArgs), } #[derive(Args, Debug, Clone)] pub struct ExtensionArgs { - #[clap(subcommand)] - pub subcommand: ExtensionSubcommand, + #[clap(subcommand)] + pub subcommand: ExtensionSubcommand, - #[clap(flatten)] - pub desktop_code_options: DesktopCodeOptions, + #[clap(flatten)] + pub desktop_code_options: DesktopCodeOptions, } impl ExtensionArgs { - pub fn add_code_args(&self, target: &mut Vec) { - if let Some(ed) = &self.desktop_code_options.extensions_dir { - target.push(ed.to_string()); - } + pub fn add_code_args(&self, target: &mut Vec) { + if let Some(ed) = &self.desktop_code_options.extensions_dir { + target.push(ed.to_string()); + } - self.subcommand.add_code_args(target); - } + self.subcommand.add_code_args(target); + } } #[derive(Subcommand, Debug, Clone)] pub enum ExtensionSubcommand { - /// List installed extensions. - List(ListExtensionArgs), - /// Install an extension. - Install(InstallExtensionArgs), - /// Uninstall an extension. - Uninstall(UninstallExtensionArgs), + /// List installed extensions. + List(ListExtensionArgs), + /// Install an extension. + Install(InstallExtensionArgs), + /// Uninstall an extension. + Uninstall(UninstallExtensionArgs), } impl ExtensionSubcommand { - pub fn add_code_args(&self, target: &mut Vec) { - match self { - ExtensionSubcommand::List(args) => { - target.push("--list-extensions".to_string()); - if args.show_versions { - target.push("--show-versions".to_string()); - } - if let Some(category) = &args.category { - target.push(format!("--category={}", category)); - } - } - ExtensionSubcommand::Install(args) => { - for id in args.id_or_path.iter() { - target.push(format!("--install-extension={}", id)); - } - if args.pre_release { - target.push("--pre-release".to_string()); - } - if args.force { - target.push("--force".to_string()); - } - } - ExtensionSubcommand::Uninstall(args) => { - for id in args.id.iter() { - target.push(format!("--uninstall-extension={}", id)); - } - } - } - } + pub fn add_code_args(&self, target: &mut Vec) { + match self { + ExtensionSubcommand::List(args) => { + target.push("--list-extensions".to_string()); + if args.show_versions { + target.push("--show-versions".to_string()); + } + if let Some(category) = &args.category { + target.push(format!("--category={}", category)); + } + } + ExtensionSubcommand::Install(args) => { + for id in args.id_or_path.iter() { + target.push(format!("--install-extension={}", id)); + } + if args.pre_release { + target.push("--pre-release".to_string()); + } + if args.force { + target.push("--force".to_string()); + } + } + ExtensionSubcommand::Uninstall(args) => { + for id in args.id.iter() { + target.push(format!("--uninstall-extension={}", id)); + } + } + } + } } #[derive(Args, Debug, Clone)] pub struct ListExtensionArgs { - /// Filters installed extensions by provided category, when using --list-extensions. - #[clap(long, value_name = "category")] - pub category: Option, + /// Filters installed extensions by provided category, when using --list-extensions. + #[clap(long, value_name = "category")] + pub category: Option, - /// Show versions of installed extensions, when using --list-extensions. - #[clap(long)] - pub show_versions: bool, + /// Show versions of installed extensions, when using --list-extensions. + #[clap(long)] + pub show_versions: bool, } #[derive(Args, Debug, Clone)] pub struct InstallExtensionArgs { - /// Either an extension id or a path to a VSIX. The identifier of an - /// extension is '${publisher}.${name}'. Use '--force' argument to update - /// to latest version. To install a specific version provide '@${version}'. - /// For example: 'vscode.csharp@1.2.3'. - #[clap(name = "ext-id | id")] - pub id_or_path: Vec, + /// Either an extension id or a path to a VSIX. The identifier of an + /// extension is '${publisher}.${name}'. Use '--force' argument to update + /// to latest version. To install a specific version provide '@${version}'. + /// For example: 'vscode.csharp@1.2.3'. + #[clap(name = "ext-id | id")] + pub id_or_path: Vec, - /// Installs the pre-release version of the extension - #[clap(long)] - pub pre_release: bool, + /// Installs the pre-release version of the extension + #[clap(long)] + pub pre_release: bool, - /// Update to the latest version of the extension if it's already installed. - #[clap(long)] - pub force: bool, + /// Update to the latest version of the extension if it's already installed. + #[clap(long)] + pub force: bool, } #[derive(Args, Debug, Clone)] pub struct UninstallExtensionArgs { - /// One or more extension identifiers to uninstall. The identifier of an - /// extension is '${publisher}.${name}'. Use '--force' argument to update - /// to latest version. - #[clap(name = "ext-id")] - pub id: Vec, + /// One or more extension identifiers to uninstall. The identifier of an + /// extension is '${publisher}.${name}'. Use '--force' argument to update + /// to latest version. + #[clap(name = "ext-id")] + pub id: Vec, } #[derive(Args, Debug, Clone)] pub struct VersionArgs { - #[clap(subcommand)] - pub subcommand: VersionSubcommand, + #[clap(subcommand)] + pub subcommand: VersionSubcommand, } #[derive(Subcommand, Debug, Clone)] pub enum VersionSubcommand { - /// Switches the instance of VS Code in use. - Use(UseVersionArgs), - /// Uninstalls a instance of VS Code. - Uninstall(UninstallVersionArgs), - /// Lists installed VS Code instances. - List(OutputFormatOptions), + /// Switches the instance of VS Code in use. + Use(UseVersionArgs), + /// Uninstalls a instance of VS Code. + Uninstall(UninstallVersionArgs), + /// Lists installed VS Code instances. + List(OutputFormatOptions), } #[derive(Args, Debug, Clone)] pub struct UseVersionArgs { - /// The version of VS Code you want to use. Can be "stable", "insiders", - /// a version number, or an absolute path to an existing install. - #[clap(value_name = "stable | insiders | x.y.z | path")] - pub name: String, + /// The version of VS Code you want to use. Can be "stable", "insiders", + /// a version number, or an absolute path to an existing install. + #[clap(value_name = "stable | insiders | x.y.z | path")] + pub name: String, - /// The directory the version should be installed into, if it's not already installed. - #[clap(long, value_name = "path")] - pub install_dir: Option, + /// The directory the version should be installed into, if it's not already installed. + #[clap(long, value_name = "path")] + pub install_dir: Option, - /// Reinstall the version even if it's already installed. - #[clap(long)] - pub reinstall: bool, + /// Reinstall the version even if it's already installed. + #[clap(long)] + pub reinstall: bool, } #[derive(Args, Debug, Clone)] pub struct UninstallVersionArgs { - /// The version of VS Code to uninstall. Can be "stable", "insiders", or a - /// version number previous passed to `code version use `. - #[clap(value_name = "stable | insiders | x.y.z")] - pub name: String, + /// The version of VS Code to uninstall. Can be "stable", "insiders", or a + /// version number previous passed to `code version use `. + #[clap(value_name = "stable | insiders | x.y.z")] + pub name: String, } #[derive(Args, Debug, Default)] pub struct EditorOptions { - /// Compare two files with each other. - #[clap(short, long, value_names = &["file", "file"])] - pub diff: Vec, + /// Compare two files with each other. + #[clap(short, long, value_names = &["file", "file"])] + pub diff: Vec, - /// Add folder(s) to the last active window. - #[clap(short, long, value_name = "folder")] - pub add: Option, + /// Add folder(s) to the last active window. + #[clap(short, long, value_name = "folder")] + pub add: Option, - /// Open a file at the path on the specified line and character position. - #[clap(short, long, value_name = "file:line[:character]")] - pub goto: Option, + /// Open a file at the path on the specified line and character position. + #[clap(short, long, value_name = "file:line[:character]")] + pub goto: Option, - /// Force to open a new window. - #[clap(short, long)] - pub new_window: bool, + /// Force to open a new window. + #[clap(short, long)] + pub new_window: bool, - /// Force to open a file or folder in an - #[clap(short, long)] - pub reuse_window: bool, + /// Force to open a file or folder in an + #[clap(short, long)] + pub reuse_window: bool, - /// Wait for the files to be closed before returning. - #[clap(short, long)] - pub wait: bool, + /// Wait for the files to be closed before returning. + #[clap(short, long)] + pub wait: bool, - /// The locale to use (e.g. en-US or zh-TW). - #[clap(long, value_name = "locale")] - pub locale: Option, + /// The locale to use (e.g. en-US or zh-TW). + #[clap(long, value_name = "locale")] + pub locale: Option, - /// Enables proposed API features for extensions. Can receive one or - /// more extension IDs to enable individually. - #[clap(long, value_name = "ext-id")] - pub enable_proposed_api: Vec, + /// Enables proposed API features for extensions. Can receive one or + /// more extension IDs to enable individually. + #[clap(long, value_name = "ext-id")] + pub enable_proposed_api: Vec, - #[clap(flatten)] - pub code_options: DesktopCodeOptions, + #[clap(flatten)] + pub code_options: DesktopCodeOptions, } impl EditorOptions { - pub fn add_code_args(&self, target: &mut Vec) { - if !self.diff.is_empty() { - target.push("--diff".to_string()); - for file in self.diff.iter() { - target.push(file.clone()); - } - } - if let Some(add) = &self.add { - target.push("--add".to_string()); - target.push(add.clone()); - } - if let Some(goto) = &self.goto { - target.push("--goto".to_string()); - target.push(goto.clone()); - } - if self.new_window { - target.push("--new-window".to_string()); - } - if self.reuse_window { - target.push("--reuse-window".to_string()); - } - if self.wait { - target.push("--wait".to_string()); - } - if let Some(locale) = &self.locale { - target.push(format!("--locale={}", locale)); - } - if !self.enable_proposed_api.is_empty() { - for id in self.enable_proposed_api.iter() { - target.push(format!("--enable-proposed-api={}", id)); - } - } - self.code_options.add_code_args(target); - } + pub fn add_code_args(&self, target: &mut Vec) { + if !self.diff.is_empty() { + target.push("--diff".to_string()); + for file in self.diff.iter() { + target.push(file.clone()); + } + } + if let Some(add) = &self.add { + target.push("--add".to_string()); + target.push(add.clone()); + } + if let Some(goto) = &self.goto { + target.push("--goto".to_string()); + target.push(goto.clone()); + } + if self.new_window { + target.push("--new-window".to_string()); + } + if self.reuse_window { + target.push("--reuse-window".to_string()); + } + if self.wait { + target.push("--wait".to_string()); + } + if let Some(locale) = &self.locale { + target.push(format!("--locale={}", locale)); + } + if !self.enable_proposed_api.is_empty() { + for id in self.enable_proposed_api.iter() { + target.push(format!("--enable-proposed-api={}", id)); + } + } + self.code_options.add_code_args(target); + } } /// Arguments applicable whenever VS Code desktop is launched #[derive(Args, Debug, Default, Clone)] pub struct DesktopCodeOptions { - /// Set the root path for extensions. - #[clap(long, value_name = "dir")] - pub extensions_dir: Option, + /// Set the root path for extensions. + #[clap(long, value_name = "dir")] + pub extensions_dir: Option, - /// Specifies the directory that user data is kept in. Can be used to - /// open multiple distinct instances of Code. - #[clap(long, value_name = "dir")] - pub user_data_dir: Option, + /// Specifies the directory that user data is kept in. Can be used to + /// open multiple distinct instances of Code. + #[clap(long, value_name = "dir")] + pub user_data_dir: Option, - /// Sets the VS Code version to use for this command. The preferred version - /// can be persisted with `code version use `. Can be "stable", - /// "insiders", a version number, or an absolute path to an existing install. - #[clap(long, value_name = "stable | insiders | x.y.z | path")] - pub use_version: Option, + /// Sets the VS Code version to use for this command. The preferred version + /// can be persisted with `code version use `. Can be "stable", + /// "insiders", a version number, or an absolute path to an existing install. + #[clap(long, value_name = "stable | insiders | x.y.z | path")] + pub use_version: Option, } /// Argument specifying the output format. #[derive(Args, Debug, Clone)] pub struct OutputFormatOptions { - /// Set the data output formats. - #[clap(arg_enum, long, value_name = "format", default_value_t = OutputFormat::Text)] - pub format: OutputFormat, + /// Set the data output formats. + #[clap(arg_enum, long, value_name = "format", default_value_t = OutputFormat::Text)] + pub format: OutputFormat, } impl DesktopCodeOptions { - pub fn add_code_args(&self, target: &mut Vec) { - if let Some(extensions_dir) = &self.extensions_dir { - target.push(format!("--extensions-dir={}", extensions_dir)); - } - if let Some(user_data_dir) = &self.user_data_dir { - target.push(format!("--user-data-dir={}", user_data_dir)); - } - } + pub fn add_code_args(&self, target: &mut Vec) { + if let Some(extensions_dir) = &self.extensions_dir { + target.push(format!("--extensions-dir={}", extensions_dir)); + } + if let Some(user_data_dir) = &self.user_data_dir { + target.push(format!("--user-data-dir={}", user_data_dir)); + } + } } #[derive(Args, Debug, Default)] pub struct GlobalOptions { - /// Directory where CLI metadata, such as VS Code installations, should be stored. - #[clap(long, env = "VSCODE_CLI_DATA_DIR", global = true)] - pub cli_data_dir: Option, + /// Directory where CLI metadata, such as VS Code installations, should be stored. + #[clap(long, env = "VSCODE_CLI_DATA_DIR", global = true)] + pub cli_data_dir: Option, - /// Print verbose output (implies --wait). - #[clap(long, global = true)] - pub verbose: bool, + /// Print verbose output (implies --wait). + #[clap(long, global = true)] + pub verbose: bool, - /// Log level to use. - #[clap(long, arg_enum, value_name = "level", global = true)] - pub log: Option, + /// Log level to use. + #[clap(long, arg_enum, value_name = "level", global = true)] + pub log: Option, - /// Disable telemetry for the current command, even if it was previously - /// accepted as part of the license prompt or specified in '--telemetry-level' - #[clap(long, global = true, hide = true)] - pub disable_telemetry: bool, + /// Disable telemetry for the current command, even if it was previously + /// accepted as part of the license prompt or specified in '--telemetry-level' + #[clap(long, global = true, hide = true)] + pub disable_telemetry: bool, - /// Sets the initial telemetry level - #[clap(arg_enum, long, global = true, hide = true)] - pub telemetry_level: Option, + /// Sets the initial telemetry level + #[clap(arg_enum, long, global = true, hide = true)] + pub telemetry_level: Option, } impl GlobalOptions { - pub fn add_code_args(&self, target: &mut Vec) { - if self.verbose { - target.push("--verbose".to_string()); - } - if let Some(log) = self.log { - target.push(format!("--log={}", log)); - } - if self.disable_telemetry { - target.push("--disable-telemetry".to_string()); - } - if let Some(telemetry_level) = &self.telemetry_level { - target.push(format!("--telemetry-level={}", telemetry_level)); - } - } + pub fn add_code_args(&self, target: &mut Vec) { + if self.verbose { + target.push("--verbose".to_string()); + } + if let Some(log) = self.log { + target.push(format!("--log={}", log)); + } + if self.disable_telemetry { + target.push("--disable-telemetry".to_string()); + } + if let Some(telemetry_level) = &self.telemetry_level { + target.push(format!("--telemetry-level={}", telemetry_level)); + } + } } #[derive(Args, Debug, Default)] pub struct EditorTroubleshooting { - /// Run CPU profiler during startup. - #[clap(long)] - pub prof_startup: bool, + /// Run CPU profiler during startup. + #[clap(long)] + pub prof_startup: bool, - /// Disable all installed extensions. - #[clap(long)] - pub disable_extensions: bool, + /// Disable all installed extensions. + #[clap(long)] + pub disable_extensions: bool, - /// Disable an extension. - #[clap(long, value_name = "ext-id")] - pub disable_extension: Vec, + /// Disable an extension. + #[clap(long, value_name = "ext-id")] + pub disable_extension: Vec, - /// Turn sync on or off. - #[clap(arg_enum, long, value_name = "on | off")] - pub sync: Option, + /// Turn sync on or off. + #[clap(arg_enum, long, value_name = "on | off")] + pub sync: Option, - /// Allow debugging and profiling of extensions. Check the developer tools for the connection URI. - #[clap(long, value_name = "port")] - pub inspect_extensions: Option, + /// Allow debugging and profiling of extensions. Check the developer tools for the connection URI. + #[clap(long, value_name = "port")] + pub inspect_extensions: Option, - /// Allow debugging and profiling of extensions with the extension host - /// being paused after start. Check the developer tools for the connection URI. - #[clap(long, value_name = "port")] - pub inspect_brk_extensions: Option, + /// Allow debugging and profiling of extensions with the extension host + /// being paused after start. Check the developer tools for the connection URI. + #[clap(long, value_name = "port")] + pub inspect_brk_extensions: Option, - /// Disable GPU hardware acceleration. - #[clap(long)] - pub disable_gpu: bool, + /// Disable GPU hardware acceleration. + #[clap(long)] + pub disable_gpu: bool, - /// Max memory size for a window (in Mbytes). - #[clap(long, value_name = "memory")] - pub max_memory: Option, + /// Max memory size for a window (in Mbytes). + #[clap(long, value_name = "memory")] + pub max_memory: Option, - /// Shows all telemetry events which VS code collects. - #[clap(long)] - pub telemetry: bool, + /// Shows all telemetry events which VS code collects. + #[clap(long)] + pub telemetry: bool, } impl EditorTroubleshooting { - pub fn add_code_args(&self, target: &mut Vec) { - if self.prof_startup { - target.push("--prof-startup".to_string()); - } - if self.disable_extensions { - target.push("--disable-extensions".to_string()); - } - for id in self.disable_extension.iter() { - target.push(format!("--disable-extension={}", id)); - } - if let Some(sync) = &self.sync { - target.push(format!("--sync={}", sync)); - } - if let Some(port) = &self.inspect_extensions { - target.push(format!("--inspect-extensions={}", port)); - } - if let Some(port) = &self.inspect_brk_extensions { - target.push(format!("--inspect-brk-extensions={}", port)); - } - if self.disable_gpu { - target.push("--disable-gpu".to_string()); - } - if let Some(memory) = &self.max_memory { - target.push(format!("--max-memory={}", memory)); - } - if self.telemetry { - target.push("--telemetry".to_string()); - } - } + pub fn add_code_args(&self, target: &mut Vec) { + if self.prof_startup { + target.push("--prof-startup".to_string()); + } + if self.disable_extensions { + target.push("--disable-extensions".to_string()); + } + for id in self.disable_extension.iter() { + target.push(format!("--disable-extension={}", id)); + } + if let Some(sync) = &self.sync { + target.push(format!("--sync={}", sync)); + } + if let Some(port) = &self.inspect_extensions { + target.push(format!("--inspect-extensions={}", port)); + } + if let Some(port) = &self.inspect_brk_extensions { + target.push(format!("--inspect-brk-extensions={}", port)); + } + if self.disable_gpu { + target.push("--disable-gpu".to_string()); + } + if let Some(memory) = &self.max_memory { + target.push(format!("--max-memory={}", memory)); + } + if self.telemetry { + target.push("--telemetry".to_string()); + } + } } #[derive(ArgEnum, Clone, Copy, Debug)] pub enum SyncState { - On, - Off, + On, + Off, } impl fmt::Display for SyncState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - SyncState::Off => write!(f, "off"), - SyncState::On => write!(f, "on"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SyncState::Off => write!(f, "off"), + SyncState::On => write!(f, "on"), + } + } } #[derive(ArgEnum, Clone, Copy, Debug)] pub enum OutputFormat { - Json, - Text, + Json, + Text, } #[derive(Args, Clone, Debug, Default)] pub struct ExistingTunnelArgs { - /// Name you'd like to assign preexisting tunnel to use to connect the tunnel - #[clap(long, hide = true)] - pub tunnel_name: Option, + /// Name you'd like to assign preexisting tunnel to use to connect the tunnel + #[clap(long, hide = true)] + pub tunnel_name: Option, - /// Token to authenticate and use preexisting tunnel - #[clap(long, hide = true)] - pub host_token: Option, + /// Token to authenticate and use preexisting tunnel + #[clap(long, hide = true)] + pub host_token: Option, - /// ID of preexisting tunnel to use to connect the tunnel - #[clap(long, hide = true)] - pub tunnel_id: Option, + /// ID of preexisting tunnel to use to connect the tunnel + #[clap(long, hide = true)] + pub tunnel_id: Option, - /// Cluster of preexisting tunnel to use to connect the tunnel - #[clap(long, hide = true)] - pub cluster: Option, + /// Cluster of preexisting tunnel to use to connect the tunnel + #[clap(long, hide = true)] + pub cluster: Option, } #[derive(Args, Debug, Clone, Default)] pub struct TunnelServeArgs { - /// Optional details to connect to an existing tunnel - #[clap(flatten, next_help_heading = Some("ADVANCED OPTIONS"))] - pub tunnel: ExistingTunnelArgs, + /// Optional details to connect to an existing tunnel + #[clap(flatten, next_help_heading = Some("ADVANCED OPTIONS"))] + pub tunnel: ExistingTunnelArgs, - /// Randomly name machine for port forwarding service - #[clap(long)] - pub random_name: bool, + /// Randomly name machine for port forwarding service + #[clap(long)] + pub random_name: bool, } #[derive(Args, Debug, Clone)] pub struct TunnelArgs { - #[clap(subcommand)] - pub subcommand: Option, + #[clap(subcommand)] + pub subcommand: Option, - #[clap(flatten)] - pub serve_args: TunnelServeArgs, + #[clap(flatten)] + pub serve_args: TunnelServeArgs, } #[derive(Subcommand, Debug, Clone)] pub enum TunnelSubcommand { - /// Delete all servers which are currently not running. - Prune, + /// Delete all servers which are currently not running. + Prune, - /// Rename the name of this machine associated with port forwarding service. - Rename(TunnelRenameArgs), + /// Rename the name of this machine associated with port forwarding service. + Rename(TunnelRenameArgs), - /// Remove this machine's association with the port forwarding service. - Unregister, + /// Remove this machine's association with the port forwarding service. + Unregister, - #[clap(subcommand)] - User(TunnelUserSubCommands), + #[clap(subcommand)] + User(TunnelUserSubCommands), - /// Manages the tunnel when installed as a system service, - #[clap(subcommand)] - Service(TunnelServiceSubCommands), + /// Manages the tunnel when installed as a system service, + #[clap(subcommand)] + Service(TunnelServiceSubCommands), } #[derive(Subcommand, Debug, Clone)] pub enum TunnelServiceSubCommands { - /// Installs or re-installs the tunnel service on the machine. - Install, + /// Installs or re-installs the tunnel service on the machine. + Install, - /// Uninstalls and stops the tunnel service. - Uninstall, + /// Uninstalls and stops the tunnel service. + Uninstall, - /// Internal command for running the service - #[clap(hide = true)] - InternalRun, + /// Internal command for running the service + #[clap(hide = true)] + InternalRun, } #[derive(Args, Debug, Clone)] pub struct TunnelRenameArgs { - /// The name you'd like to rename your machine to. - pub name: String, + /// The name you'd like to rename your machine to. + pub name: String, } #[derive(Subcommand, Debug, Clone)] pub enum TunnelUserSubCommands { - /// Log in to port forwarding service - Login(LoginArgs), + /// Log in to port forwarding service + Login(LoginArgs), - /// Log out of port forwarding service - Logout, + /// Log out of port forwarding service + Logout, - /// Show the account that's logged into port forwarding service - Show, + /// Show the account that's logged into port forwarding service + Show, } #[derive(Args, Debug, Clone)] pub struct LoginArgs { - /// An access token to store for authentication. Note: this will not be - /// refreshed if it expires! - #[clap(long, requires = "provider")] - pub access_token: Option, + /// An access token to store for authentication. Note: this will not be + /// refreshed if it expires! + #[clap(long, requires = "provider")] + pub access_token: Option, - /// The auth provider to use. If not provided, a prompt will be shown. - #[clap(arg_enum, long)] - pub provider: Option, + /// The auth provider to use. If not provided, a prompt will be shown. + #[clap(arg_enum, long)] + pub provider: Option, } #[derive(clap::ArgEnum, Debug, Clone, Copy)] pub enum AuthProvider { - Microsoft, - Github, + Microsoft, + Github, } diff --git a/cli/src/commands/context.rs b/cli/src/commands/context.rs index 9b45f253ee9..506ee0f57d0 100644 --- a/cli/src/commands/context.rs +++ b/cli/src/commands/context.rs @@ -8,8 +8,8 @@ use crate::{log, state::LauncherPaths}; use super::args::Cli; pub struct CommandContext { - pub log: log::Logger, - pub paths: LauncherPaths, - pub args: Cli, - pub http: reqwest::Client, + pub log: log::Logger, + pub paths: LauncherPaths, + pub args: Cli, + pub http: reqwest::Client, } diff --git a/cli/src/commands/output.rs b/cli/src/commands/output.rs index d66ec130f88..8747457889b 100644 --- a/cli/src/commands/output.rs +++ b/cli/src/commands/output.rs @@ -10,126 +10,126 @@ use std::io::{BufWriter, Write}; use super::args::OutputFormat; pub struct Column { - max_width: usize, - heading: &'static str, - data: Vec, + max_width: usize, + heading: &'static str, + data: Vec, } impl Column { - pub fn new(heading: &'static str) -> Self { - Column { - max_width: heading.len(), - heading, - data: vec![], - } - } + pub fn new(heading: &'static str) -> Self { + Column { + max_width: heading.len(), + heading, + data: vec![], + } + } - pub fn add_row(&mut self, row: String) { - self.max_width = std::cmp::max(self.max_width, row.len()); - self.data.push(row); - } + pub fn add_row(&mut self, row: String) { + self.max_width = std::cmp::max(self.max_width, row.len()); + self.data.push(row); + } } impl OutputFormat { - pub fn print_table(&self, table: OutputTable) -> Result<(), std::io::Error> { - match *self { - OutputFormat::Json => JsonTablePrinter().print(table, &mut std::io::stdout()), - OutputFormat::Text => TextTablePrinter().print(table, &mut std::io::stdout()), - } - } + pub fn print_table(&self, table: OutputTable) -> Result<(), std::io::Error> { + match *self { + OutputFormat::Json => JsonTablePrinter().print(table, &mut std::io::stdout()), + OutputFormat::Text => TextTablePrinter().print(table, &mut std::io::stdout()), + } + } } pub struct OutputTable { - cols: Vec, + cols: Vec, } impl OutputTable { - pub fn new(cols: Vec) -> Self { - OutputTable { cols } - } + pub fn new(cols: Vec) -> Self { + OutputTable { cols } + } } trait TablePrinter { - fn print(&self, table: OutputTable, out: &mut dyn std::io::Write) - -> Result<(), std::io::Error>; + fn print(&self, table: OutputTable, out: &mut dyn std::io::Write) + -> Result<(), std::io::Error>; } pub struct JsonTablePrinter(); impl TablePrinter for JsonTablePrinter { - fn print( - &self, - table: OutputTable, - out: &mut dyn std::io::Write, - ) -> Result<(), std::io::Error> { - let mut bw = BufWriter::new(out); - bw.write_all(b"[")?; + fn print( + &self, + table: OutputTable, + out: &mut dyn std::io::Write, + ) -> Result<(), std::io::Error> { + let mut bw = BufWriter::new(out); + bw.write_all(b"[")?; - if !table.cols.is_empty() { - let data_len = table.cols[0].data.len(); - for i in 0..data_len { - if i > 0 { - bw.write_all(b",{")?; - } else { - bw.write_all(b"{")?; - } - for col in &table.cols { - serde_json::to_writer(&mut bw, col.heading)?; - bw.write_all(b":")?; - serde_json::to_writer(&mut bw, &col.data[i])?; - } - } - } + if !table.cols.is_empty() { + let data_len = table.cols[0].data.len(); + for i in 0..data_len { + if i > 0 { + bw.write_all(b",{")?; + } else { + bw.write_all(b"{")?; + } + for col in &table.cols { + serde_json::to_writer(&mut bw, col.heading)?; + bw.write_all(b":")?; + serde_json::to_writer(&mut bw, &col.data[i])?; + } + } + } - bw.write_all(b"]")?; - bw.flush() - } + bw.write_all(b"]")?; + bw.flush() + } } /// Type that prints the output as an ASCII, markdown-style table. pub struct TextTablePrinter(); impl TablePrinter for TextTablePrinter { - fn print( - &self, - table: OutputTable, - out: &mut dyn std::io::Write, - ) -> Result<(), std::io::Error> { - let mut bw = BufWriter::new(out); + fn print( + &self, + table: OutputTable, + out: &mut dyn std::io::Write, + ) -> Result<(), std::io::Error> { + let mut bw = BufWriter::new(out); - let sizes = table.cols.iter().map(|c| c.max_width).collect::>(); + let sizes = table.cols.iter().map(|c| c.max_width).collect::>(); - // print headers - write_columns(&mut bw, table.cols.iter().map(|c| c.heading), &sizes)?; - // print --- separators - write_columns( - &mut bw, - table.cols.iter().map(|c| "-".repeat(c.max_width)), - &sizes, - )?; - // print each column - if !table.cols.is_empty() { - let data_len = table.cols[0].data.len(); - for i in 0..data_len { - write_columns(&mut bw, table.cols.iter().map(|c| &c.data[i]), &sizes)?; - } - } + // print headers + write_columns(&mut bw, table.cols.iter().map(|c| c.heading), &sizes)?; + // print --- separators + write_columns( + &mut bw, + table.cols.iter().map(|c| "-".repeat(c.max_width)), + &sizes, + )?; + // print each column + if !table.cols.is_empty() { + let data_len = table.cols[0].data.len(); + for i in 0..data_len { + write_columns(&mut bw, table.cols.iter().map(|c| &c.data[i]), &sizes)?; + } + } - bw.flush() - } + bw.flush() + } } fn write_columns( - mut w: impl Write, - cols: impl Iterator, - sizes: &[usize], + mut w: impl Write, + cols: impl Iterator, + sizes: &[usize], ) -> Result<(), std::io::Error> where - T: Display, + T: Display, { - w.write_all(b"|")?; - for (i, col) in cols.enumerate() { - write!(w, " {:width$} |", col, width = sizes[i])?; - } - w.write_all(b"\r\n") + w.write_all(b"|")?; + for (i, col) in cols.enumerate() { + write!(w, " {:width$} |", col, width = sizes[i])?; + } + w.write_all(b"\r\n") } diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 287322e9667..0f75688791b 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -9,253 +9,253 @@ use async_trait::async_trait; use tokio::sync::oneshot; use super::{ - args::{ - AuthProvider, Cli, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs, - TunnelServiceSubCommands, TunnelUserSubCommands, - }, - CommandContext, + args::{ + AuthProvider, Cli, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs, + TunnelServiceSubCommands, TunnelUserSubCommands, + }, + CommandContext, }; use crate::{ - auth::Auth, - log::{self, Logger}, - state::LauncherPaths, - tunnels::{ - code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal, - paths::get_all_servers, ServiceContainer, ServiceManager, - }, - util::{ - errors::{wrap, AnyError}, - prereqs::PreReqChecker, - }, + auth::Auth, + log::{self, Logger}, + state::LauncherPaths, + tunnels::{ + code_server::CodeServerArgs, create_service_manager, dev_tunnels, legal, + paths::get_all_servers, ServiceContainer, ServiceManager, + }, + util::{ + errors::{wrap, AnyError}, + prereqs::PreReqChecker, + }, }; impl From for crate::auth::AuthProvider { - fn from(auth_provider: AuthProvider) -> Self { - match auth_provider { - AuthProvider::Github => crate::auth::AuthProvider::Github, - AuthProvider::Microsoft => crate::auth::AuthProvider::Microsoft, - } - } + fn from(auth_provider: AuthProvider) -> Self { + match auth_provider { + AuthProvider::Github => crate::auth::AuthProvider::Github, + AuthProvider::Microsoft => crate::auth::AuthProvider::Microsoft, + } + } } impl From for Option { - fn from(d: ExistingTunnelArgs) -> Option { - if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) = - (d.tunnel_id, d.tunnel_name, d.cluster, d.host_token) - { - Some(dev_tunnels::ExistingTunnel { - tunnel_id, - tunnel_name, - host_token, - cluster, - }) - } else { - None - } - } + fn from(d: ExistingTunnelArgs) -> Option { + if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) = + (d.tunnel_id, d.tunnel_name, d.cluster, d.host_token) + { + Some(dev_tunnels::ExistingTunnel { + tunnel_id, + tunnel_name, + host_token, + cluster, + }) + } else { + None + } + } } struct TunnelServiceContainer { - args: Cli, + args: Cli, } impl TunnelServiceContainer { - fn new(args: Cli) -> Self { - Self { args } - } + fn new(args: Cli) -> Self { + Self { args } + } } #[async_trait] impl ServiceContainer for TunnelServiceContainer { - async fn run_service( - &mut self, - log: log::Logger, - launcher_paths: LauncherPaths, - shutdown_rx: oneshot::Receiver<()>, - ) -> Result<(), AnyError> { - let csa = (&self.args).into(); - serve_with_csa( - launcher_paths, - log, - TunnelServeArgs { - random_name: true, // avoid prompting - ..Default::default() - }, - csa, - Some(shutdown_rx), - ) - .await?; - Ok(()) - } + async fn run_service( + &mut self, + log: log::Logger, + launcher_paths: LauncherPaths, + shutdown_rx: oneshot::Receiver<()>, + ) -> Result<(), AnyError> { + let csa = (&self.args).into(); + serve_with_csa( + launcher_paths, + log, + TunnelServeArgs { + random_name: true, // avoid prompting + ..Default::default() + }, + csa, + Some(shutdown_rx), + ) + .await?; + Ok(()) + } } pub async fn service( - ctx: CommandContext, - service_args: TunnelServiceSubCommands, + ctx: CommandContext, + service_args: TunnelServiceSubCommands, ) -> Result { - let manager = create_service_manager(ctx.log.clone()); - match service_args { - TunnelServiceSubCommands::Install => { - // ensure logged in, otherwise subsequent serving will fail - Auth::new(&ctx.paths, ctx.log.clone()) - .get_credential() - .await?; + let manager = create_service_manager(ctx.log.clone()); + match service_args { + TunnelServiceSubCommands::Install => { + // ensure logged in, otherwise subsequent serving will fail + Auth::new(&ctx.paths, ctx.log.clone()) + .get_credential() + .await?; - // likewise for license consent - legal::require_consent(&ctx.paths)?; + // likewise for license consent + legal::require_consent(&ctx.paths)?; - let current_exe = - std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?; + let current_exe = + std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?; - manager.register( - current_exe, - &[ - "--cli-data-dir", - ctx.paths.root().as_os_str().to_string_lossy().as_ref(), - "tunnel", - "service", - "internal-run", - ], - )?; - ctx.log.result("Service successfully installed! You can use `code tunnel service log` to monitor it, and `code tunnel service uninstall` to remove it."); - } - TunnelServiceSubCommands::Uninstall => { - manager.unregister()?; - } - TunnelServiceSubCommands::InternalRun => { - manager.run(ctx.paths.clone(), TunnelServiceContainer::new(ctx.args))?; - } - } + manager.register( + current_exe, + &[ + "--cli-data-dir", + ctx.paths.root().as_os_str().to_string_lossy().as_ref(), + "tunnel", + "service", + "internal-run", + ], + )?; + ctx.log.result("Service successfully installed! You can use `code tunnel service log` to monitor it, and `code tunnel service uninstall` to remove it."); + } + TunnelServiceSubCommands::Uninstall => { + manager.unregister()?; + } + TunnelServiceSubCommands::InternalRun => { + manager.run(ctx.paths.clone(), TunnelServiceContainer::new(ctx.args))?; + } + } - Ok(0) + Ok(0) } pub async fn user(ctx: CommandContext, user_args: TunnelUserSubCommands) -> Result { - let auth = Auth::new(&ctx.paths, ctx.log.clone()); - match user_args { - TunnelUserSubCommands::Login(login_args) => { - auth.login( - login_args.provider.map(|p| p.into()), - login_args.access_token.to_owned(), - ) - .await?; - } - TunnelUserSubCommands::Logout => { - auth.clear_credentials()?; - } - TunnelUserSubCommands::Show => { - if let Ok(Some(_)) = auth.get_current_credential() { - ctx.log.result("logged in"); - } else { - ctx.log.result("not logged in"); - return Ok(1); - } - } - } + let auth = Auth::new(&ctx.paths, ctx.log.clone()); + match user_args { + TunnelUserSubCommands::Login(login_args) => { + auth.login( + login_args.provider.map(|p| p.into()), + login_args.access_token.to_owned(), + ) + .await?; + } + TunnelUserSubCommands::Logout => { + auth.clear_credentials()?; + } + TunnelUserSubCommands::Show => { + if let Ok(Some(_)) = auth.get_current_credential() { + ctx.log.result("logged in"); + } else { + ctx.log.result("not logged in"); + return Ok(1); + } + } + } - Ok(0) + Ok(0) } /// Remove the tunnel used by this gateway, if any. pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Result { - let auth = Auth::new(&ctx.paths, ctx.log.clone()); - let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths); - dt.rename_tunnel(&rename_args.name).await?; - ctx.log.result(&format!( - "Successfully renamed this gateway to {}", - &rename_args.name - )); + let auth = Auth::new(&ctx.paths, ctx.log.clone()); + let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths); + dt.rename_tunnel(&rename_args.name).await?; + ctx.log.result(&format!( + "Successfully renamed this gateway to {}", + &rename_args.name + )); - Ok(0) + Ok(0) } /// Remove the tunnel used by this gateway, if any. pub async fn unregister(ctx: CommandContext) -> Result { - let auth = Auth::new(&ctx.paths, ctx.log.clone()); - let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths); - dt.remove_tunnel().await?; - Ok(0) + let auth = Auth::new(&ctx.paths, ctx.log.clone()); + let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths); + dt.remove_tunnel().await?; + Ok(0) } /// Removes unused servers. pub async fn prune(ctx: CommandContext) -> Result { - get_all_servers(&ctx.paths) - .into_iter() - .map(|s| s.server_paths(&ctx.paths)) - .filter(|s| s.get_running_pid().is_none()) - .try_for_each(|s| { - ctx.log - .result(&format!("Deleted {}", s.server_dir.display())); - s.delete() - }) - .map_err(AnyError::from)?; + get_all_servers(&ctx.paths) + .into_iter() + .map(|s| s.server_paths(&ctx.paths)) + .filter(|s| s.get_running_pid().is_none()) + .try_for_each(|s| { + ctx.log + .result(&format!("Deleted {}", s.server_dir.display())); + s.delete() + }) + .map_err(AnyError::from)?; - ctx.log.result("Successfully removed all unused servers"); + ctx.log.result("Successfully removed all unused servers"); - Ok(0) + Ok(0) } /// Starts the gateway server. pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result { - let CommandContext { - log, paths, args, .. - } = ctx; + let CommandContext { + log, paths, args, .. + } = ctx; - legal::require_consent(&paths)?; + legal::require_consent(&paths)?; - let csa = (&args).into(); - serve_with_csa(paths, log, gateway_args, csa, None).await + let csa = (&args).into(); + serve_with_csa(paths, log, gateway_args, csa, None).await } async fn serve_with_csa( - paths: LauncherPaths, - log: Logger, - gateway_args: TunnelServeArgs, - csa: CodeServerArgs, - shutdown_rx: Option>, + paths: LauncherPaths, + log: Logger, + gateway_args: TunnelServeArgs, + csa: CodeServerArgs, + shutdown_rx: Option>, ) -> Result { - let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?; + let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?; - let auth = Auth::new(&paths, log.clone()); - let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths); - let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() { - dt.start_existing_tunnel(d).await - } else { - dt.start_new_launcher_tunnel(gateway_args.random_name).await - }?; + let auth = Auth::new(&paths, log.clone()); + let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths); + let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() { + dt.start_existing_tunnel(d).await + } else { + dt.start_new_launcher_tunnel(gateway_args.random_name).await + }?; - let shutdown_tx = if let Some(tx) = shutdown_rx { - tx - } else { - let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - tx.send(()).ok(); - }); - rx - }; + let shutdown_tx = if let Some(tx) = shutdown_rx { + tx + } else { + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + tx.send(()).ok(); + }); + rx + }; - let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?; - r.tunnel.close().await.ok(); + let mut r = crate::tunnels::serve(&log, tunnel, &paths, &csa, platform, shutdown_tx).await?; + r.tunnel.close().await.ok(); - if r.respawn { - warning!(log, "respawn requested, starting new server"); - // reuse current args, but specify no-forward since tunnels will - // already be running in this process, and we cannot do a login - let args = std::env::args().skip(1).collect::>(); - let exit = std::process::Command::new(std::env::current_exe().unwrap()) - .args(args) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()) - .stdin(Stdio::inherit()) - .spawn() - .map_err(|e| wrap(e, "error respawning after update"))? - .wait() - .map_err(|e| wrap(e, "error waiting for child"))?; + if r.respawn { + warning!(log, "respawn requested, starting new server"); + // reuse current args, but specify no-forward since tunnels will + // already be running in this process, and we cannot do a login + let args = std::env::args().skip(1).collect::>(); + let exit = std::process::Command::new(std::env::current_exe().unwrap()) + .args(args) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .stdin(Stdio::inherit()) + .spawn() + .map_err(|e| wrap(e, "error respawning after update"))? + .wait() + .map_err(|e| wrap(e, "error waiting for child"))?; - return Ok(exit.code().unwrap_or(1)); - } + return Ok(exit.code().unwrap_or(1)); + } - Ok(0) + Ok(0) } diff --git a/cli/src/commands/version.rs b/cli/src/commands/version.rs index 3b29c382278..07568a15e66 100644 --- a/cli/src/commands/version.rs +++ b/cli/src/commands/version.rs @@ -4,63 +4,63 @@ *--------------------------------------------------------------------------------------------*/ use crate::{ - desktop::{CodeVersionManager, RequestedVersion}, - log, - update_service::UpdateService, - util::{errors::AnyError, prereqs::PreReqChecker}, + desktop::{CodeVersionManager, RequestedVersion}, + log, + update_service::UpdateService, + util::{errors::AnyError, prereqs::PreReqChecker}, }; use super::{ - args::{OutputFormatOptions, UninstallVersionArgs, UseVersionArgs}, - output::{Column, OutputTable}, - CommandContext, + args::{OutputFormatOptions, UninstallVersionArgs, UseVersionArgs}, + output::{Column, OutputTable}, + CommandContext, }; pub async fn switch_to(ctx: CommandContext, args: UseVersionArgs) -> Result { - let platform = PreReqChecker::new().verify().await?; - let vm = CodeVersionManager::new(&ctx.paths, platform); - let version = RequestedVersion::try_from(args.name.as_str())?; + let platform = PreReqChecker::new().verify().await?; + let vm = CodeVersionManager::new(&ctx.paths, platform); + let version = RequestedVersion::try_from(args.name.as_str())?; - if !args.reinstall && vm.try_get_entrypoint(&version).await.is_some() { - vm.set_preferred_version(&version)?; - print_now_using(&ctx.log, &version); - return Ok(0); - } + if !args.reinstall && vm.try_get_entrypoint(&version).await.is_some() { + vm.set_preferred_version(&version)?; + print_now_using(&ctx.log, &version); + return Ok(0); + } - let update_service = UpdateService::new(ctx.log.clone(), ctx.http.clone()); - vm.install(&update_service, &version).await?; - vm.set_preferred_version(&version)?; - print_now_using(&ctx.log, &version); - Ok(0) + let update_service = UpdateService::new(ctx.log.clone(), ctx.http.clone()); + vm.install(&update_service, &version).await?; + vm.set_preferred_version(&version)?; + print_now_using(&ctx.log, &version); + Ok(0) } pub async fn list(ctx: CommandContext, args: OutputFormatOptions) -> Result { - let platform = PreReqChecker::new().verify().await?; - let vm = CodeVersionManager::new(&ctx.paths, platform); + let platform = PreReqChecker::new().verify().await?; + let vm = CodeVersionManager::new(&ctx.paths, platform); - let mut name = Column::new("Installation"); - let mut command = Column::new("Command"); - for version in vm.list() { - name.add_row(version.to_string()); - command.add_row(version.get_command()); - } - args.format - .print_table(OutputTable::new(vec![name, command])) - .ok(); + let mut name = Column::new("Installation"); + let mut command = Column::new("Command"); + for version in vm.list() { + name.add_row(version.to_string()); + command.add_row(version.get_command()); + } + args.format + .print_table(OutputTable::new(vec![name, command])) + .ok(); - Ok(0) + Ok(0) } pub async fn uninstall(ctx: CommandContext, args: UninstallVersionArgs) -> Result { - let platform = PreReqChecker::new().verify().await?; - let vm = CodeVersionManager::new(&ctx.paths, platform); - let version = RequestedVersion::try_from(args.name.as_str())?; - vm.uninstall(&version).await?; - ctx.log - .result(&format!("VS Code {} uninstalled successfully", version)); - Ok(0) + let platform = PreReqChecker::new().verify().await?; + let vm = CodeVersionManager::new(&ctx.paths, platform); + let version = RequestedVersion::try_from(args.name.as_str())?; + vm.uninstall(&version).await?; + ctx.log + .result(&format!("VS Code {} uninstalled successfully", version)); + Ok(0) } fn print_now_using(log: &log::Logger, version: &RequestedVersion) { - log.result(&format!("Now using VS Code {}", version)); + log.result(&format!("Now using VS Code {}", version)); } diff --git a/cli/src/constants.rs b/cli/src/constants.rs index a194f9ab9a1..6f357bc5d69 100644 --- a/cli/src/constants.rs +++ b/cli/src/constants.rs @@ -13,21 +13,21 @@ pub const VSCODE_CLI_ASSET_NAME: Option<&'static str> = option_env!("VSCODE_CLI_ pub const VSCODE_CLI_AI_KEY: Option<&'static str> = option_env!("VSCODE_CLI_AI_KEY"); pub const VSCODE_CLI_AI_ENDPOINT: Option<&'static str> = option_env!("VSCODE_CLI_AI_ENDPOINT"); pub const VSCODE_CLI_UPDATE_ENDPOINT: Option<&'static str> = - option_env!("VSCODE_CLI_UPDATE_ENDPOINT"); + option_env!("VSCODE_CLI_UPDATE_ENDPOINT"); pub const TUNNEL_SERVICE_USER_AGENT_ENV_VAR: &str = "TUNNEL_SERVICE_USER_AGENT"; pub fn get_default_user_agent() -> String { - format!( - "vscode-server-launcher/{}", - VSCODE_CLI_VERSION.unwrap_or("dev") - ) + format!( + "vscode-server-launcher/{}", + VSCODE_CLI_VERSION.unwrap_or("dev") + ) } lazy_static! { - pub static ref TUNNEL_SERVICE_USER_AGENT: String = - match std::env::var(TUNNEL_SERVICE_USER_AGENT_ENV_VAR) { - Ok(ua) if !ua.is_empty() => format!("{} {}", ua, get_default_user_agent()), - _ => get_default_user_agent(), - }; + pub static ref TUNNEL_SERVICE_USER_AGENT: String = + match std::env::var(TUNNEL_SERVICE_USER_AGENT_ENV_VAR) { + Ok(ua) if !ua.is_empty() => format!("{} {}", ua, get_default_user_agent()), + _ => get_default_user_agent(), + }; } diff --git a/cli/src/desktop/version_manager.rs b/cli/src/desktop/version_manager.rs index 7138adfbc72..7dc32bf7ded 100644 --- a/cli/src/desktop/version_manager.rs +++ b/cli/src/desktop/version_manager.rs @@ -4,8 +4,8 @@ *--------------------------------------------------------------------------------------------*/ use std::{ - fmt, - path::{Path, PathBuf}, + fmt, + path::{Path, PathBuf}, }; use indicatif::ProgressBar; @@ -15,478 +15,478 @@ use serde::{Deserialize, Serialize}; use tokio::fs::remove_dir_all; use crate::{ - options, - state::{LauncherPaths, PersistedState}, - update_service::{unzip_downloaded_release, Platform, Release, TargetKind, UpdateService}, - util::{ - errors::{ - wrap, AnyError, InvalidRequestedVersion, MissingEntrypointError, - NoInstallInUserProvidedPath, UserCancelledInstallation, WrappedError, - }, - http, - input::{prompt_yn, ProgressBarReporter}, - }, + options, + state::{LauncherPaths, PersistedState}, + update_service::{unzip_downloaded_release, Platform, Release, TargetKind, UpdateService}, + util::{ + errors::{ + wrap, AnyError, InvalidRequestedVersion, MissingEntrypointError, + NoInstallInUserProvidedPath, UserCancelledInstallation, WrappedError, + }, + http, + input::{prompt_yn, ProgressBarReporter}, + }, }; /// Parsed instance that a user can request. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "t", content = "c")] pub enum RequestedVersion { - Quality(options::Quality), - Version { - version: String, - quality: options::Quality, - }, - Commit { - commit: String, - quality: options::Quality, - }, - Path(String), + Quality(options::Quality), + Version { + version: String, + quality: options::Quality, + }, + Commit { + commit: String, + quality: options::Quality, + }, + Path(String), } lazy_static! { - static ref SEMVER_RE: Regex = Regex::new(r"^\d+\.\d+\.\d+(-insider)?$").unwrap(); - static ref COMMIT_RE: Regex = Regex::new(r"^[a-z]+/[a-e0-f]{40}$").unwrap(); + static ref SEMVER_RE: Regex = Regex::new(r"^\d+\.\d+\.\d+(-insider)?$").unwrap(); + static ref COMMIT_RE: Regex = Regex::new(r"^[a-z]+/[a-e0-f]{40}$").unwrap(); } impl RequestedVersion { - pub fn get_command(&self) -> String { - match self { - RequestedVersion::Quality(quality) => { - format!("code version use {}", quality.get_machine_name()) - } - RequestedVersion::Version { version, .. } => { - format!("code version use {}", version) - } - RequestedVersion::Commit { commit, quality } => { - format!("code version use {}/{}", quality.get_machine_name(), commit) - } - RequestedVersion::Path(path) => { - format!("code version use {}", path) - } - } - } + pub fn get_command(&self) -> String { + match self { + RequestedVersion::Quality(quality) => { + format!("code version use {}", quality.get_machine_name()) + } + RequestedVersion::Version { version, .. } => { + format!("code version use {}", version) + } + RequestedVersion::Commit { commit, quality } => { + format!("code version use {}/{}", quality.get_machine_name(), commit) + } + RequestedVersion::Path(path) => { + format!("code version use {}", path) + } + } + } } impl std::fmt::Display for RequestedVersion { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - RequestedVersion::Quality(quality) => write!(f, "{}", quality.get_capitalized_name()), - RequestedVersion::Version { version, .. } => { - write!(f, "{}", version) - } - RequestedVersion::Commit { commit, quality } => { - write!(f, "{}/{}", quality, commit) - } - RequestedVersion::Path(path) => write!(f, "{}", path), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + RequestedVersion::Quality(quality) => write!(f, "{}", quality.get_capitalized_name()), + RequestedVersion::Version { version, .. } => { + write!(f, "{}", version) + } + RequestedVersion::Commit { commit, quality } => { + write!(f, "{}/{}", quality, commit) + } + RequestedVersion::Path(path) => write!(f, "{}", path), + } + } } impl TryFrom<&str> for RequestedVersion { - type Error = InvalidRequestedVersion; + type Error = InvalidRequestedVersion; - fn try_from(s: &str) -> Result { - if let Ok(quality) = options::Quality::try_from(s) { - return Ok(RequestedVersion::Quality(quality)); - } + fn try_from(s: &str) -> Result { + if let Ok(quality) = options::Quality::try_from(s) { + return Ok(RequestedVersion::Quality(quality)); + } - if SEMVER_RE.is_match(s) { - return Ok(RequestedVersion::Version { - quality: if s.ends_with("-insider") { - options::Quality::Insiders - } else { - options::Quality::Stable - }, - version: s.to_string(), - }); - } + if SEMVER_RE.is_match(s) { + return Ok(RequestedVersion::Version { + quality: if s.ends_with("-insider") { + options::Quality::Insiders + } else { + options::Quality::Stable + }, + version: s.to_string(), + }); + } - if Path::is_absolute(&PathBuf::from(s)) { - return Ok(RequestedVersion::Path(s.to_string())); - } + if Path::is_absolute(&PathBuf::from(s)) { + return Ok(RequestedVersion::Path(s.to_string())); + } - if COMMIT_RE.is_match(s) { - let idx = s.find('/').expect("expected a /"); - if let Ok(quality) = options::Quality::try_from(&s[0..idx]) { - return Ok(RequestedVersion::Commit { - commit: s[idx + 1..].to_string(), - quality, - }); - } - } + if COMMIT_RE.is_match(s) { + let idx = s.find('/').expect("expected a /"); + if let Ok(quality) = options::Quality::try_from(&s[0..idx]) { + return Ok(RequestedVersion::Commit { + commit: s[idx + 1..].to_string(), + quality, + }); + } + } - Err(InvalidRequestedVersion()) - } + Err(InvalidRequestedVersion()) + } } #[derive(Serialize, Deserialize, Clone, Default)] struct Stored { - versions: Vec, - current: usize, + versions: Vec, + current: usize, } pub struct CodeVersionManager { - state: PersistedState, - platform: Platform, - storage_dir: PathBuf, + state: PersistedState, + platform: Platform, + storage_dir: PathBuf, } impl CodeVersionManager { - pub fn new(lp: &LauncherPaths, platform: Platform) -> Self { - CodeVersionManager { - state: PersistedState::new(lp.root().join("versions.json")), - storage_dir: lp.root().join("desktop"), - platform, - } - } + pub fn new(lp: &LauncherPaths, platform: Platform) -> Self { + CodeVersionManager { + state: PersistedState::new(lp.root().join("versions.json")), + storage_dir: lp.root().join("desktop"), + platform, + } + } - /// Sets the "version" as the persisted one for the user. - pub fn set_preferred_version(&self, version: &RequestedVersion) -> Result<(), AnyError> { - let mut stored = self.state.load(); - if let Some(i) = stored.versions.iter().position(|v| v == version) { - stored.current = i; - } else { - stored.current = stored.versions.len(); - stored.versions.push(version.clone()); - } + /// Sets the "version" as the persisted one for the user. + pub fn set_preferred_version(&self, version: &RequestedVersion) -> Result<(), AnyError> { + let mut stored = self.state.load(); + if let Some(i) = stored.versions.iter().position(|v| v == version) { + stored.current = i; + } else { + stored.current = stored.versions.len(); + stored.versions.push(version.clone()); + } - self.state.save(stored)?; + self.state.save(stored)?; - Ok(()) - } + Ok(()) + } - /// Lists installed versions. - pub fn list(&self) -> Vec { - self.state.load().versions - } + /// Lists installed versions. + pub fn list(&self) -> Vec { + self.state.load().versions + } - /// Uninstalls a previously installed version. - pub async fn uninstall(&self, version: &RequestedVersion) -> Result<(), AnyError> { - let mut stored = self.state.load(); - if let Some(i) = stored.versions.iter().position(|v| v == version) { - if i > stored.current && i > 0 { - stored.current -= 1; - } - stored.versions.remove(i); - self.state.save(stored)?; - } + /// Uninstalls a previously installed version. + pub async fn uninstall(&self, version: &RequestedVersion) -> Result<(), AnyError> { + let mut stored = self.state.load(); + if let Some(i) = stored.versions.iter().position(|v| v == version) { + if i > stored.current && i > 0 { + stored.current -= 1; + } + stored.versions.remove(i); + self.state.save(stored)?; + } - remove_dir_all(self.get_install_dir(version)) - .await - .map_err(|e| wrap(e, "error deleting vscode directory"))?; + remove_dir_all(self.get_install_dir(version)) + .await + .map_err(|e| wrap(e, "error deleting vscode directory"))?; - Ok(()) - } + Ok(()) + } - pub fn get_preferred_version(&self) -> RequestedVersion { - let stored = self.state.load(); - stored - .versions - .get(stored.current) - .unwrap_or(&RequestedVersion::Quality(options::Quality::Stable)) - .clone() - } + pub fn get_preferred_version(&self) -> RequestedVersion { + let stored = self.state.load(); + stored + .versions + .get(stored.current) + .unwrap_or(&RequestedVersion::Quality(options::Quality::Stable)) + .clone() + } - /// Installs the release for the given request. This always runs and does not - /// prompt, so you may want to use `try_get_entrypoint` first. - pub async fn install( - &self, - update_service: &UpdateService, - version: &RequestedVersion, - ) -> Result { - let target_dir = self.get_install_dir(version); - let release = get_release_for_request(update_service, version, self.platform).await?; - install_release_into(update_service, &target_dir, &release).await?; + /// Installs the release for the given request. This always runs and does not + /// prompt, so you may want to use `try_get_entrypoint` first. + pub async fn install( + &self, + update_service: &UpdateService, + version: &RequestedVersion, + ) -> Result { + let target_dir = self.get_install_dir(version); + let release = get_release_for_request(update_service, version, self.platform).await?; + install_release_into(update_service, &target_dir, &release).await?; - if let Some(p) = try_get_entrypoint(&target_dir).await { - return Ok(p); - } + if let Some(p) = try_get_entrypoint(&target_dir).await { + return Ok(p); + } - Err(MissingEntrypointError().into()) - } + Err(MissingEntrypointError().into()) + } - /// Tries to get the entrypoint in the installed version, if one exists. - pub async fn try_get_entrypoint(&self, version: &RequestedVersion) -> Option { - try_get_entrypoint(&self.get_install_dir(version)).await - } + /// Tries to get the entrypoint in the installed version, if one exists. + pub async fn try_get_entrypoint(&self, version: &RequestedVersion) -> Option { + try_get_entrypoint(&self.get_install_dir(version)).await + } - fn get_install_dir(&self, version: &RequestedVersion) -> PathBuf { - let (name, quality) = match version { - RequestedVersion::Path(path) => return PathBuf::from(path), - RequestedVersion::Quality(quality) => (quality.get_machine_name(), quality), - RequestedVersion::Version { - quality, - version: number, - } => (number.as_str(), quality), - RequestedVersion::Commit { commit, quality } => (commit.as_str(), quality), - }; + fn get_install_dir(&self, version: &RequestedVersion) -> PathBuf { + let (name, quality) = match version { + RequestedVersion::Path(path) => return PathBuf::from(path), + RequestedVersion::Quality(quality) => (quality.get_machine_name(), quality), + RequestedVersion::Version { + quality, + version: number, + } => (number.as_str(), quality), + RequestedVersion::Commit { commit, quality } => (commit.as_str(), quality), + }; - let mut dir = self.storage_dir.join(name); - if cfg!(target_os = "macos") { - dir.push(format!("{}.app", quality.get_app_name())) - } + let mut dir = self.storage_dir.join(name); + if cfg!(target_os = "macos") { + dir.push(format!("{}.app", quality.get_app_name())) + } - dir - } + dir + } } /// Shows a nice UI prompt to users asking them if they want to install the /// requested version. pub fn prompt_to_install(version: &RequestedVersion) -> Result<(), AnyError> { - if let RequestedVersion::Path(path) = version { - return Err(NoInstallInUserProvidedPath(path.clone()).into()); - } + if let RequestedVersion::Path(path) = version { + return Err(NoInstallInUserProvidedPath(path.clone()).into()); + } - if !prompt_yn(&format!( - "VS Code {} is not installed yet, install it now?", - version - ))? { - return Err(UserCancelledInstallation().into()); - } + if !prompt_yn(&format!( + "VS Code {} is not installed yet, install it now?", + version + ))? { + return Err(UserCancelledInstallation().into()); + } - Ok(()) + Ok(()) } async fn get_release_for_request( - update_service: &UpdateService, - request: &RequestedVersion, - platform: Platform, + update_service: &UpdateService, + request: &RequestedVersion, + platform: Platform, ) -> Result { - match request { - RequestedVersion::Version { - quality, - version: number, - } => update_service - .get_release_by_semver_version(platform, TargetKind::Archive, *quality, number) - .await - .map_err(|e| wrap(e, "Could not get release")), - RequestedVersion::Commit { commit, quality } => Ok(Release { - platform, - commit: commit.clone(), - quality: *quality, - target: TargetKind::Archive, - }), - RequestedVersion::Quality(quality) => update_service - .get_latest_commit(platform, TargetKind::Archive, *quality) - .await - .map_err(|e| wrap(e, "Could not get release")), - _ => panic!("cannot get release info for a path"), - } + match request { + RequestedVersion::Version { + quality, + version: number, + } => update_service + .get_release_by_semver_version(platform, TargetKind::Archive, *quality, number) + .await + .map_err(|e| wrap(e, "Could not get release")), + RequestedVersion::Commit { commit, quality } => Ok(Release { + platform, + commit: commit.clone(), + quality: *quality, + target: TargetKind::Archive, + }), + RequestedVersion::Quality(quality) => update_service + .get_latest_commit(platform, TargetKind::Archive, *quality) + .await + .map_err(|e| wrap(e, "Could not get release")), + _ => panic!("cannot get release info for a path"), + } } async fn install_release_into( - update_service: &UpdateService, - path: &Path, - release: &Release, + update_service: &UpdateService, + path: &Path, + release: &Release, ) -> Result<(), AnyError> { - let tempdir = - tempfile::tempdir().map_err(|e| wrap(e, "error creating temporary download dir"))?; - let save_path = tempdir.path().join("vscode"); + let tempdir = + tempfile::tempdir().map_err(|e| wrap(e, "error creating temporary download dir"))?; + let save_path = tempdir.path().join("vscode"); - let stream = update_service.get_download_stream(release).await?; - let pb = ProgressBar::new(1); - pb.set_message("Downloading..."); - let progress = ProgressBarReporter::from(pb); - http::download_into_file(&save_path, progress, stream).await?; + let stream = update_service.get_download_stream(release).await?; + let pb = ProgressBar::new(1); + pb.set_message("Downloading..."); + let progress = ProgressBarReporter::from(pb); + http::download_into_file(&save_path, progress, stream).await?; - let pb = ProgressBar::new(1); - pb.set_message("Unzipping..."); - let progress = ProgressBarReporter::from(pb); - unzip_downloaded_release(&save_path, path, progress)?; + let pb = ProgressBar::new(1); + pb.set_message("Unzipping..."); + let progress = ProgressBarReporter::from(pb); + unzip_downloaded_release(&save_path, path, progress)?; - drop(tempdir); + drop(tempdir); - Ok(()) + Ok(()) } /// Tries to find the binary entrypoint for VS Code installed in the path. async fn try_get_entrypoint(path: &Path) -> Option { - use tokio::sync::mpsc; + use tokio::sync::mpsc; - let (tx, mut rx) = mpsc::channel(1); + let (tx, mut rx) = mpsc::channel(1); - // Look for all the possible paths in parallel - for entry in DESKTOP_CLI_RELATIVE_PATH.split(',') { - let my_path = path.join(entry); - let my_tx = tx.clone(); - tokio::spawn(async move { - if tokio::fs::metadata(&my_path).await.is_ok() { - my_tx.send(my_path).await.ok(); - } - }); - } + // Look for all the possible paths in parallel + for entry in DESKTOP_CLI_RELATIVE_PATH.split(',') { + let my_path = path.join(entry); + let my_tx = tx.clone(); + tokio::spawn(async move { + if tokio::fs::metadata(&my_path).await.is_ok() { + my_tx.send(my_path).await.ok(); + } + }); + } - drop(tx); // drop so rx gets None if no sender emits + drop(tx); // drop so rx gets None if no sender emits - rx.recv().await + rx.recv().await } const DESKTOP_CLI_RELATIVE_PATH: &str = if cfg!(target_os = "macos") { - "Contents/Resources/app/bin/code" + "Contents/Resources/app/bin/code" } else if cfg!(target_os = "windows") { - "bin/code.cmd,bin/code-insiders.cmd,bin/code-exploration.cmd" + "bin/code.cmd,bin/code-insiders.cmd,bin/code-exploration.cmd" } else { - "bin/code,bin/code-insiders,bin/code-exploration" + "bin/code,bin/code-insiders,bin/code-exploration" }; #[cfg(test)] mod tests { - use std::{ - fs::{create_dir_all, File}, - io::Write, - }; + use std::{ + fs::{create_dir_all, File}, + io::Write, + }; - use super::*; + use super::*; - fn make_fake_vscode_install(path: &Path, quality: options::Quality) { - let bin = DESKTOP_CLI_RELATIVE_PATH - .split(',') - .next() - .expect("expected exe path"); + fn make_fake_vscode_install(path: &Path, quality: options::Quality) { + let bin = DESKTOP_CLI_RELATIVE_PATH + .split(',') + .next() + .expect("expected exe path"); - let binary_file_path = if cfg!(target_os = "macos") { - path.join(format!("{}.app/{}", quality.get_app_name(), bin)) - } else { - path.join(bin) - }; + let binary_file_path = if cfg!(target_os = "macos") { + path.join(format!("{}.app/{}", quality.get_app_name(), bin)) + } else { + path.join(bin) + }; - let parent_dir_path = binary_file_path.parent().expect("expected parent path"); + let parent_dir_path = binary_file_path.parent().expect("expected parent path"); - create_dir_all(parent_dir_path).expect("expected to create parent dir"); + create_dir_all(parent_dir_path).expect("expected to create parent dir"); - let mut binary_file = File::create(binary_file_path).expect("expected to make file"); - binary_file - .write_all(b"") - .expect("expected to write binary"); - } + let mut binary_file = File::create(binary_file_path).expect("expected to make file"); + binary_file + .write_all(b"") + .expect("expected to write binary"); + } - fn make_multiple_vscode_install() -> tempfile::TempDir { - let dir = tempfile::tempdir().expect("expected to make temp dir"); - make_fake_vscode_install(&dir.path().join("desktop/stable"), options::Quality::Stable); - make_fake_vscode_install(&dir.path().join("desktop/1.68.2"), options::Quality::Stable); - dir - } + fn make_multiple_vscode_install() -> tempfile::TempDir { + let dir = tempfile::tempdir().expect("expected to make temp dir"); + make_fake_vscode_install(&dir.path().join("desktop/stable"), options::Quality::Stable); + make_fake_vscode_install(&dir.path().join("desktop/1.68.2"), options::Quality::Stable); + dir + } - #[test] - fn test_requested_version_parses() { - assert_eq!( - RequestedVersion::try_from("1.2.3").unwrap(), - RequestedVersion::Version { - quality: options::Quality::Stable, - version: "1.2.3".to_string(), - } - ); + #[test] + fn test_requested_version_parses() { + assert_eq!( + RequestedVersion::try_from("1.2.3").unwrap(), + RequestedVersion::Version { + quality: options::Quality::Stable, + version: "1.2.3".to_string(), + } + ); - assert_eq!( - RequestedVersion::try_from("1.2.3-insider").unwrap(), - RequestedVersion::Version { - quality: options::Quality::Insiders, - version: "1.2.3-insider".to_string(), - } - ); + assert_eq!( + RequestedVersion::try_from("1.2.3-insider").unwrap(), + RequestedVersion::Version { + quality: options::Quality::Insiders, + version: "1.2.3-insider".to_string(), + } + ); - assert_eq!( - RequestedVersion::try_from("stable").unwrap(), - RequestedVersion::Quality(options::Quality::Stable) - ); + assert_eq!( + RequestedVersion::try_from("stable").unwrap(), + RequestedVersion::Quality(options::Quality::Stable) + ); - assert_eq!( - RequestedVersion::try_from("insiders").unwrap(), - RequestedVersion::Quality(options::Quality::Insiders) - ); + assert_eq!( + RequestedVersion::try_from("insiders").unwrap(), + RequestedVersion::Quality(options::Quality::Insiders) + ); - assert_eq!( - RequestedVersion::try_from("insiders/92fd228156aafeb326b23f6604028d342152313b") - .unwrap(), - RequestedVersion::Commit { - commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(), - quality: options::Quality::Insiders - } - ); + assert_eq!( + RequestedVersion::try_from("insiders/92fd228156aafeb326b23f6604028d342152313b") + .unwrap(), + RequestedVersion::Commit { + commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(), + quality: options::Quality::Insiders + } + ); - assert_eq!( - RequestedVersion::try_from("stable/92fd228156aafeb326b23f6604028d342152313b").unwrap(), - RequestedVersion::Commit { - commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(), - quality: options::Quality::Stable - } - ); + assert_eq!( + RequestedVersion::try_from("stable/92fd228156aafeb326b23f6604028d342152313b").unwrap(), + RequestedVersion::Commit { + commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(), + quality: options::Quality::Stable + } + ); - let exe = std::env::current_exe() - .expect("expected to get exe") - .to_string_lossy() - .to_string(); - assert_eq!( - RequestedVersion::try_from((&exe).as_str()).unwrap(), - RequestedVersion::Path(exe), - ); - } + let exe = std::env::current_exe() + .expect("expected to get exe") + .to_string_lossy() + .to_string(); + assert_eq!( + RequestedVersion::try_from((&exe).as_str()).unwrap(), + RequestedVersion::Path(exe), + ); + } - #[test] - fn test_set_preferred_version() { - let dir = make_multiple_vscode_install(); - let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); - let vm1 = CodeVersionManager::new(&lp, Platform::LinuxARM64); + #[test] + fn test_set_preferred_version() { + let dir = make_multiple_vscode_install(); + let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); + let vm1 = CodeVersionManager::new(&lp, Platform::LinuxARM64); - assert_eq!( - vm1.get_preferred_version(), - RequestedVersion::Quality(options::Quality::Stable) - ); - vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Exploration)) - .expect("expected to store"); - vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Insiders)) - .expect("expected to store"); - assert_eq!( - vm1.get_preferred_version(), - RequestedVersion::Quality(options::Quality::Insiders) - ); + assert_eq!( + vm1.get_preferred_version(), + RequestedVersion::Quality(options::Quality::Stable) + ); + vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Exploration)) + .expect("expected to store"); + vm1.set_preferred_version(&RequestedVersion::Quality(options::Quality::Insiders)) + .expect("expected to store"); + assert_eq!( + vm1.get_preferred_version(), + RequestedVersion::Quality(options::Quality::Insiders) + ); - let vm2 = CodeVersionManager::new(&lp, Platform::LinuxARM64); - assert_eq!( - vm2.get_preferred_version(), - RequestedVersion::Quality(options::Quality::Insiders) - ); + let vm2 = CodeVersionManager::new(&lp, Platform::LinuxARM64); + assert_eq!( + vm2.get_preferred_version(), + RequestedVersion::Quality(options::Quality::Insiders) + ); - assert_eq!( - vm2.list(), - vec![ - RequestedVersion::Quality(options::Quality::Exploration), - RequestedVersion::Quality(options::Quality::Insiders) - ] - ); - } + assert_eq!( + vm2.list(), + vec![ + RequestedVersion::Quality(options::Quality::Exploration), + RequestedVersion::Quality(options::Quality::Insiders) + ] + ); + } - #[tokio::test] - async fn test_gets_entrypoint() { - let dir = make_multiple_vscode_install(); - let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); - let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64); + #[tokio::test] + async fn test_gets_entrypoint() { + let dir = make_multiple_vscode_install(); + let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); + let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64); - assert!(vm - .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable)) - .await - .is_some()); + assert!(vm + .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable)) + .await + .is_some()); - assert!(vm - .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Exploration)) - .await - .is_none()); - } + assert!(vm + .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Exploration)) + .await + .is_none()); + } - #[tokio::test] - async fn test_uninstall() { - let dir = make_multiple_vscode_install(); - let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); - let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64); + #[tokio::test] + async fn test_uninstall() { + let dir = make_multiple_vscode_install(); + let lp = LauncherPaths::new_without_replacements(dir.path().to_owned()); + let vm = CodeVersionManager::new(&lp, Platform::LinuxARM64); - vm.uninstall(&RequestedVersion::Quality(options::Quality::Stable)) - .await - .expect("expected to uninsetall"); + vm.uninstall(&RequestedVersion::Quality(options::Quality::Stable)) + .await + .expect("expected to uninsetall"); - assert!(vm - .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable)) - .await - .is_none()); - } + assert!(vm + .try_get_entrypoint(&RequestedVersion::Quality(options::Quality::Stable)) + .await + .is_none()); + } } diff --git a/cli/src/lib.rs b/cli/src/lib.rs index a4b11bece9e..98b8c4f9755 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -12,8 +12,8 @@ pub mod log; pub mod commands; pub mod desktop; pub mod options; -pub mod tunnels; pub mod state; +pub mod tunnels; pub mod update; pub mod update_service; pub mod util; diff --git a/cli/src/log.rs b/cli/src/log.rs index 64118e426ed..d7a250b45ac 100644 --- a/cli/src/log.rs +++ b/cli/src/log.rs @@ -5,14 +5,14 @@ use chrono::Local; use opentelemetry::{ - sdk::trace::Tracer, - trace::{SpanBuilder, Tracer as TraitTracer}, + sdk::trace::Tracer, + trace::{SpanBuilder, Tracer as TraitTracer}, }; use std::fmt; use std::{env, path::Path, sync::Arc}; use std::{ - io::Write, - sync::atomic::{AtomicU32, Ordering}, + io::Write, + sync::atomic::{AtomicU32, Ordering}, }; const NO_COLOR_ENV: &str = "NO_COLOR"; @@ -21,282 +21,282 @@ static INSTANCE_COUNTER: AtomicU32 = AtomicU32::new(0); // Gets a next incrementing number that can be used in logs pub fn next_counter() -> u32 { - INSTANCE_COUNTER.fetch_add(1, Ordering::SeqCst) + INSTANCE_COUNTER.fetch_add(1, Ordering::SeqCst) } // Log level #[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug)] pub enum Level { - Trace = 0, - Debug, - Info, - Warn, - Error, - Critical, - Off, + Trace = 0, + Debug, + Info, + Warn, + Error, + Critical, + Off, } impl Default for Level { - fn default() -> Self { - Level::Info - } + fn default() -> Self { + Level::Info + } } impl fmt::Display for Level { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Level::Critical => write!(f, "critical"), - Level::Debug => write!(f, "debug"), - Level::Error => write!(f, "error"), - Level::Info => write!(f, "info"), - Level::Off => write!(f, "off"), - Level::Trace => write!(f, "trace"), - Level::Warn => write!(f, "warn"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Level::Critical => write!(f, "critical"), + Level::Debug => write!(f, "debug"), + Level::Error => write!(f, "error"), + Level::Info => write!(f, "info"), + Level::Off => write!(f, "off"), + Level::Trace => write!(f, "trace"), + Level::Warn => write!(f, "warn"), + } + } } impl Level { - pub fn name(&self) -> Option<&str> { - match self { - Level::Trace => Some("trace"), - Level::Debug => Some("debug"), - Level::Info => Some("info"), - Level::Warn => Some("warn"), - Level::Error => Some("error"), - Level::Critical => Some("critical"), - Level::Off => None, - } - } + pub fn name(&self) -> Option<&str> { + match self { + Level::Trace => Some("trace"), + Level::Debug => Some("debug"), + Level::Info => Some("info"), + Level::Warn => Some("warn"), + Level::Error => Some("error"), + Level::Critical => Some("critical"), + Level::Off => None, + } + } - pub fn color_code(&self) -> Option<&str> { - if env::var(NO_COLOR_ENV).is_ok() || !atty::is(atty::Stream::Stdout) { - return None; - } + pub fn color_code(&self) -> Option<&str> { + if env::var(NO_COLOR_ENV).is_ok() || !atty::is(atty::Stream::Stdout) { + return None; + } - match self { - Level::Trace => None, - Level::Debug => Some("\x1b[36m"), - Level::Info => Some("\x1b[35m"), - Level::Warn => Some("\x1b[33m"), - Level::Error => Some("\x1b[31m"), - Level::Critical => Some("\x1b[31m"), - Level::Off => None, - } - } + match self { + Level::Trace => None, + Level::Debug => Some("\x1b[36m"), + Level::Info => Some("\x1b[35m"), + Level::Warn => Some("\x1b[33m"), + Level::Error => Some("\x1b[31m"), + Level::Critical => Some("\x1b[31m"), + Level::Off => None, + } + } - pub fn to_u8(self) -> u8 { - self as u8 - } + pub fn to_u8(self) -> u8 { + self as u8 + } } pub fn new_tunnel_prefix() -> String { - format!("[tunnel.{}]", next_counter()) + format!("[tunnel.{}]", next_counter()) } pub fn new_code_server_prefix() -> String { - format!("[codeserver.{}]", next_counter()) + format!("[codeserver.{}]", next_counter()) } pub fn new_rpc_prefix() -> String { - format!("[rpc.{}]", next_counter()) + format!("[rpc.{}]", next_counter()) } // Base logger implementation #[derive(Clone)] pub struct Logger { - tracer: Tracer, - sink: Vec>, - prefix: Option, + tracer: Tracer, + sink: Vec>, + prefix: Option, } // Copy trick from https://stackoverflow.com/a/30353928 pub trait LogSinkClone { - fn clone_box(&self) -> Box; + fn clone_box(&self) -> Box; } impl LogSinkClone for T where - T: 'static + LogSink + Clone, + T: 'static + LogSink + Clone, { - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } } pub trait LogSink: LogSinkClone + Sync + Send { - fn write_log(&self, level: Level, prefix: &str, message: &str); - fn write_result(&self, message: &str); + fn write_log(&self, level: Level, prefix: &str, message: &str); + fn write_result(&self, message: &str); } impl Clone for Box { - fn clone(&self) -> Box { - self.clone_box() - } + fn clone(&self) -> Box { + self.clone_box() + } } #[derive(Clone)] pub struct StdioLogSink { - level: Level, + level: Level, } impl LogSink for StdioLogSink { - fn write_log(&self, level: Level, prefix: &str, message: &str) { - if level < self.level { - return; - } + fn write_log(&self, level: Level, prefix: &str, message: &str) { + if level < self.level { + return; + } - emit(level, prefix, message); - } + emit(level, prefix, message); + } - fn write_result(&self, message: &str) { - println!("{}", message); - } + fn write_result(&self, message: &str) { + println!("{}", message); + } } #[derive(Clone)] pub struct FileLogSink { - level: Level, - file: Arc>, + level: Level, + file: Arc>, } impl FileLogSink { - pub fn new(level: Level, path: &Path) -> std::io::Result { - let file = std::fs::File::create(path)?; - Ok(Self { - level, - file: Arc::new(std::sync::Mutex::new(file)), - }) - } + pub fn new(level: Level, path: &Path) -> std::io::Result { + let file = std::fs::File::create(path)?; + Ok(Self { + level, + file: Arc::new(std::sync::Mutex::new(file)), + }) + } } impl LogSink for FileLogSink { - fn write_log(&self, level: Level, prefix: &str, message: &str) { - if level < self.level { - return; - } + fn write_log(&self, level: Level, prefix: &str, message: &str) { + if level < self.level { + return; + } - let line = format(level, prefix, message); + let line = format(level, prefix, message); - // ignore any errors, not much we can do if logging fails... - self.file.lock().unwrap().write_all(line.as_bytes()).ok(); - } + // ignore any errors, not much we can do if logging fails... + self.file.lock().unwrap().write_all(line.as_bytes()).ok(); + } - fn write_result(&self, _message: &str) {} + fn write_result(&self, _message: &str) {} } impl Logger { - pub fn new(tracer: Tracer, level: Level) -> Self { - Self { - tracer, - sink: vec![Box::new(StdioLogSink { level })], - prefix: None, - } - } + pub fn new(tracer: Tracer, level: Level) -> Self { + Self { + tracer, + sink: vec![Box::new(StdioLogSink { level })], + prefix: None, + } + } - pub fn span(&self, name: &str) -> SpanBuilder { - self.tracer.span_builder(format!("serverlauncher/{}", name)) - } + pub fn span(&self, name: &str) -> SpanBuilder { + self.tracer.span_builder(format!("serverlauncher/{}", name)) + } - pub fn tracer(&self) -> &Tracer { - &self.tracer - } + pub fn tracer(&self) -> &Tracer { + &self.tracer + } - pub fn emit(&self, level: Level, message: &str) { - let prefix = self.prefix.as_deref().unwrap_or(""); - for sink in &self.sink { - sink.write_log(level, prefix, message); - } - } + pub fn emit(&self, level: Level, message: &str) { + let prefix = self.prefix.as_deref().unwrap_or(""); + for sink in &self.sink { + sink.write_log(level, prefix, message); + } + } - pub fn result(&self, message: &str) { - for sink in &self.sink { - sink.write_result(message); - } - } + pub fn result(&self, message: &str) { + for sink in &self.sink { + sink.write_result(message); + } + } - pub fn prefixed(&self, prefix: &str) -> Logger { - Logger { - prefix: Some(match &self.prefix { - Some(p) => format!("{}{} ", p, prefix), - None => format!("{} ", prefix), - }), - ..self.clone() - } - } + pub fn prefixed(&self, prefix: &str) -> Logger { + Logger { + prefix: Some(match &self.prefix { + Some(p) => format!("{}{} ", p, prefix), + None => format!("{} ", prefix), + }), + ..self.clone() + } + } - /// Creates a new logger with the additional log sink added. - pub fn tee(&self, sink: T) -> Logger - where - T: LogSink + 'static, - { - let mut new_sinks = self.sink.clone(); - new_sinks.push(Box::new(sink)); + /// Creates a new logger with the additional log sink added. + pub fn tee(&self, sink: T) -> Logger + where + T: LogSink + 'static, + { + let mut new_sinks = self.sink.clone(); + new_sinks.push(Box::new(sink)); - Logger { - sink: new_sinks, - ..self.clone() - } - } + Logger { + sink: new_sinks, + ..self.clone() + } + } - pub fn get_download_logger<'a>(&'a self, prefix: &'static str) -> DownloadLogger<'a> { - DownloadLogger { - prefix, - logger: self, - } - } + pub fn get_download_logger<'a>(&'a self, prefix: &'static str) -> DownloadLogger<'a> { + DownloadLogger { + prefix, + logger: self, + } + } } pub struct DownloadLogger<'a> { - prefix: &'static str, - logger: &'a Logger, + prefix: &'static str, + logger: &'a Logger, } impl<'a> crate::util::io::ReportCopyProgress for DownloadLogger<'a> { - fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) { - if total_bytes > 0 { - self.logger.emit( - Level::Trace, - &format!( - "{} {}/{} ({:.0}%)", - self.prefix, - bytes_so_far, - total_bytes, - (bytes_so_far as f64 / total_bytes as f64) * 100.0, - ), - ); - } else { - self.logger.emit( - Level::Trace, - &format!("{} {}/{}", self.prefix, bytes_so_far, total_bytes,), - ); - } - } + fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) { + if total_bytes > 0 { + self.logger.emit( + Level::Trace, + &format!( + "{} {}/{} ({:.0}%)", + self.prefix, + bytes_so_far, + total_bytes, + (bytes_so_far as f64 / total_bytes as f64) * 100.0, + ), + ); + } else { + self.logger.emit( + Level::Trace, + &format!("{} {}/{}", self.prefix, bytes_so_far, total_bytes,), + ); + } + } } pub fn format(level: Level, prefix: &str, message: &str) -> String { - let current = Local::now(); - let timestamp = current.format("%Y-%m-%d %H:%M:%S").to_string(); + let current = Local::now(); + let timestamp = current.format("%Y-%m-%d %H:%M:%S").to_string(); - let name = level.name().unwrap(); + let name = level.name().unwrap(); - if let Some(c) = level.color_code() { - format!( - "\x1b[2m[{}]\x1b[0m {}{}\x1b[0m {}{}\n", - timestamp, c, name, prefix, message - ) - } else { - format!("[{}] {} {}{}\n", timestamp, name, prefix, message) - } + if let Some(c) = level.color_code() { + format!( + "\x1b[2m[{}]\x1b[0m {}{}\x1b[0m {}{}\n", + timestamp, c, name, prefix, message + ) + } else { + format!("[{}] {} {}{}\n", timestamp, name, prefix, message) + } } pub fn emit(level: Level, prefix: &str, message: &str) { - let line = format(level, prefix, message); - if level == Level::Trace { - print!("\x1b[2m{}\x1b[0m", line); - } else { - print!("{}", line); - } + let line = format(level, prefix, message); + if level == Level::Trace { + print!("\x1b[2m{}\x1b[0m", line); + } else { + print!("{}", line); + } } #[macro_export] @@ -351,39 +351,39 @@ macro_rules! warning { #[macro_export] macro_rules! span { - ($logger:expr, $span:expr, $func:expr) => {{ - use opentelemetry::trace::TraceContextExt; + ($logger:expr, $span:expr, $func:expr) => {{ + use opentelemetry::trace::TraceContextExt; - let span = $span.start($logger.tracer()); - let cx = opentelemetry::Context::current_with_span(span); - let guard = cx.clone().attach(); - let t = $func; + let span = $span.start($logger.tracer()); + let cx = opentelemetry::Context::current_with_span(span); + let guard = cx.clone().attach(); + let t = $func; - if let Err(e) = &t { - cx.span().record_error(e); - } + if let Err(e) = &t { + cx.span().record_error(e); + } - std::mem::drop(guard); + std::mem::drop(guard); - t - }}; + t + }}; } #[macro_export] macro_rules! spanf { - ($logger:expr, $span:expr, $func:expr) => {{ - use opentelemetry::trace::{FutureExt, TraceContextExt}; + ($logger:expr, $span:expr, $func:expr) => {{ + use opentelemetry::trace::{FutureExt, TraceContextExt}; - let span = $span.start($logger.tracer()); - let cx = opentelemetry::Context::current_with_span(span); - let t = $func.with_context(cx.clone()).await; + let span = $span.start($logger.tracer()); + let cx = opentelemetry::Context::current_with_span(span); + let t = $func.with_context(cx.clone()).await; - if let Err(e) = &t { - cx.span().record_error(e); - } + if let Err(e) = &t { + cx.span().record_error(e); + } - cx.span().end(); + cx.span().end(); - t - }}; + t + }}; } diff --git a/cli/src/options.rs b/cli/src/options.rs index 7da2636e582..81009f72029 100644 --- a/cli/src/options.rs +++ b/cli/src/options.rs @@ -9,96 +9,96 @@ use serde::{Deserialize, Serialize}; #[derive(clap::ArgEnum, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Quality { - #[serde(rename = "stable")] - Stable, - #[serde(rename = "exploration")] - Exploration, - #[serde(other)] - Insiders, + #[serde(rename = "stable")] + Stable, + #[serde(rename = "exploration")] + Exploration, + #[serde(other)] + Insiders, } impl Quality { - /// Lowercased name in paths and protocol - pub fn get_machine_name(&self) -> &'static str { - match self { - Quality::Insiders => "insiders", - Quality::Exploration => "exploration", - Quality::Stable => "stable", - } - } + /// Lowercased name in paths and protocol + pub fn get_machine_name(&self) -> &'static str { + match self { + Quality::Insiders => "insiders", + Quality::Exploration => "exploration", + Quality::Stable => "stable", + } + } - /// Uppercased display name for humans - pub fn get_capitalized_name(&self) -> &'static str { - match self { - Quality::Insiders => "Insiders", - Quality::Exploration => "Exploration", - Quality::Stable => "Stable", - } - } + /// Uppercased display name for humans + pub fn get_capitalized_name(&self) -> &'static str { + match self { + Quality::Insiders => "Insiders", + Quality::Exploration => "Exploration", + Quality::Stable => "Stable", + } + } - pub fn get_app_name(&self) -> &'static str { - match self { - Quality::Insiders => "Visual Studio Code Insiders", - Quality::Exploration => "Visual Studio Code Exploration", - Quality::Stable => "Visual Studio Code", - } - } + pub fn get_app_name(&self) -> &'static str { + match self { + Quality::Insiders => "Visual Studio Code Insiders", + Quality::Exploration => "Visual Studio Code Exploration", + Quality::Stable => "Visual Studio Code", + } + } - #[cfg(target_os = "windows")] - pub fn server_entrypoint(&self) -> &'static str { - match self { - Quality::Insiders => "code-server-insiders.cmd", - Quality::Exploration => "code-server-exploration.cmd", - Quality::Stable => "code-server.cmd", - } - } - #[cfg(not(target_os = "windows"))] - pub fn server_entrypoint(&self) -> &'static str { - match self { - Quality::Insiders => "code-server-insiders", - Quality::Exploration => "code-server-exploration", - Quality::Stable => "code-server", - } - } + #[cfg(target_os = "windows")] + pub fn server_entrypoint(&self) -> &'static str { + match self { + Quality::Insiders => "code-server-insiders.cmd", + Quality::Exploration => "code-server-exploration.cmd", + Quality::Stable => "code-server.cmd", + } + } + #[cfg(not(target_os = "windows"))] + pub fn server_entrypoint(&self) -> &'static str { + match self { + Quality::Insiders => "code-server-insiders", + Quality::Exploration => "code-server-exploration", + Quality::Stable => "code-server", + } + } } impl fmt::Display for Quality { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.get_capitalized_name()) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.get_capitalized_name()) + } } impl TryFrom<&str> for Quality { - type Error = String; + type Error = String; - fn try_from(s: &str) -> Result { - match s { - "stable" => Ok(Quality::Stable), - "insiders" => Ok(Quality::Insiders), - "exploration" => Ok(Quality::Exploration), - _ => Err(format!( - "Unknown quality: {}. Must be one of stable, insiders, or exploration.", - s - )), - } - } + fn try_from(s: &str) -> Result { + match s { + "stable" => Ok(Quality::Stable), + "insiders" => Ok(Quality::Insiders), + "exploration" => Ok(Quality::Exploration), + _ => Err(format!( + "Unknown quality: {}. Must be one of stable, insiders, or exploration.", + s + )), + } + } } #[derive(clap::ArgEnum, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum TelemetryLevel { - Off, - Crash, - Error, - All, + Off, + Crash, + Error, + All, } impl fmt::Display for TelemetryLevel { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - TelemetryLevel::Off => write!(f, "off"), - TelemetryLevel::Crash => write!(f, "crash"), - TelemetryLevel::Error => write!(f, "error"), - TelemetryLevel::All => write!(f, "all"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TelemetryLevel::Off => write!(f, "off"), + TelemetryLevel::Crash => write!(f, "crash"), + TelemetryLevel::Error => write!(f, "error"), + TelemetryLevel::All => write!(f, "all"), + } + } } diff --git a/cli/src/state.rs b/cli/src/state.rs index e3f0510c9d2..1a33ea755f8 100644 --- a/cli/src/state.rs +++ b/cli/src/state.rs @@ -6,9 +6,9 @@ extern crate dirs; use std::{ - fs::{create_dir, read_to_string, remove_dir_all, write}, - path::{Path, PathBuf}, - sync::{Arc, Mutex}, + fs::{create_dir, read_to_string, remove_dir_all, write}, + path::{Path, PathBuf}, + sync::{Arc, Mutex}, }; use serde::{de::DeserializeOwned, Serialize}; @@ -19,134 +19,134 @@ const HOME_DIR_ALTS: [&str; 2] = ["$HOME", "~"]; #[derive(Clone)] pub struct LauncherPaths { - root: PathBuf, + root: PathBuf, } struct PersistedStateContainer where - T: Clone + Serialize + DeserializeOwned + Default, + T: Clone + Serialize + DeserializeOwned + Default, { - path: PathBuf, - state: Option, + path: PathBuf, + state: Option, } impl PersistedStateContainer where - T: Clone + Serialize + DeserializeOwned + Default, + T: Clone + Serialize + DeserializeOwned + Default, { - fn load_or_get(&mut self) -> T { - if let Some(state) = &self.state { - return state.clone(); - } + fn load_or_get(&mut self) -> T { + if let Some(state) = &self.state { + return state.clone(); + } - let state = if let Ok(s) = read_to_string(&self.path) { - serde_json::from_str::(&s).unwrap_or_default() - } else { - T::default() - }; + let state = if let Ok(s) = read_to_string(&self.path) { + serde_json::from_str::(&s).unwrap_or_default() + } else { + T::default() + }; - self.state = Some(state.clone()); - state - } + self.state = Some(state.clone()); + state + } - fn save(&mut self, state: T) -> Result<(), WrappedError> { - let s = serde_json::to_string(&state).unwrap(); - self.state = Some(state); - write(&self.path, s).map_err(|e| { - wrap( - e, - format!("error saving launcher state into {}", self.path.display()), - ) - }) - } + fn save(&mut self, state: T) -> Result<(), WrappedError> { + let s = serde_json::to_string(&state).unwrap(); + self.state = Some(state); + write(&self.path, s).map_err(|e| { + wrap( + e, + format!("error saving launcher state into {}", self.path.display()), + ) + }) + } } /// Container that holds some state value that is persisted to disk. #[derive(Clone)] pub struct PersistedState where - T: Clone + Serialize + DeserializeOwned + Default, + T: Clone + Serialize + DeserializeOwned + Default, { - container: Arc>>, + container: Arc>>, } impl PersistedState where - T: Clone + Serialize + DeserializeOwned + Default, + T: Clone + Serialize + DeserializeOwned + Default, { - /// Creates a new state container that persists to the given path. - pub fn new(path: PathBuf) -> PersistedState { - PersistedState { - container: Arc::new(Mutex::new(PersistedStateContainer { path, state: None })), - } - } + /// Creates a new state container that persists to the given path. + pub fn new(path: PathBuf) -> PersistedState { + PersistedState { + container: Arc::new(Mutex::new(PersistedStateContainer { path, state: None })), + } + } - /// Loads persisted state. - pub fn load(&self) -> T { - self.container.lock().unwrap().load_or_get() - } + /// Loads persisted state. + pub fn load(&self) -> T { + self.container.lock().unwrap().load_or_get() + } - /// Saves persisted state. - pub fn save(&self, state: T) -> Result<(), WrappedError> { - self.container.lock().unwrap().save(state) - } + /// Saves persisted state. + pub fn save(&self, state: T) -> Result<(), WrappedError> { + self.container.lock().unwrap().save(state) + } - /// Mutates persisted state. - pub fn update_with( - &self, - v: V, - mutator: fn(v: V, state: &mut T) -> R, - ) -> Result { - let mut container = self.container.lock().unwrap(); - let mut state = container.load_or_get(); - let r = mutator(v, &mut state); - container.save(state).map(|_| r) - } + /// Mutates persisted state. + pub fn update_with( + &self, + v: V, + mutator: fn(v: V, state: &mut T) -> R, + ) -> Result { + let mut container = self.container.lock().unwrap(); + let mut state = container.load_or_get(); + let r = mutator(v, &mut state); + container.save(state).map(|_| r) + } } impl LauncherPaths { - pub fn new(root: &Option) -> Result { - let root = root.as_deref().unwrap_or("~/.vscode-cli"); - let mut replaced = root.to_owned(); - for token in HOME_DIR_ALTS { - if root.contains(token) { - if let Some(home) = dirs::home_dir() { - replaced = root.replace(token, &home.to_string_lossy()) - } else { - return Err(AnyError::from(NoHomeForLauncherError())); - } - } - } + pub fn new(root: &Option) -> Result { + let root = root.as_deref().unwrap_or("~/.vscode-cli"); + let mut replaced = root.to_owned(); + for token in HOME_DIR_ALTS { + if root.contains(token) { + if let Some(home) = dirs::home_dir() { + replaced = root.replace(token, &home.to_string_lossy()) + } else { + return Err(AnyError::from(NoHomeForLauncherError())); + } + } + } - if !Path::new(&replaced).exists() { - create_dir(&replaced) - .map_err(|e| wrap(e, format!("error creating directory {}", &replaced)))?; - } + if !Path::new(&replaced).exists() { + create_dir(&replaced) + .map_err(|e| wrap(e, format!("error creating directory {}", &replaced)))?; + } - Ok(LauncherPaths::new_without_replacements(PathBuf::from( - replaced, - ))) - } + Ok(LauncherPaths::new_without_replacements(PathBuf::from( + replaced, + ))) + } - pub fn new_without_replacements(root: PathBuf) -> LauncherPaths { - LauncherPaths { root } - } + pub fn new_without_replacements(root: PathBuf) -> LauncherPaths { + LauncherPaths { root } + } - /// Root directory for the server launcher - pub fn root(&self) -> &Path { - &self.root - } + /// Root directory for the server launcher + pub fn root(&self) -> &Path { + &self.root + } - /// Removes the launcher data directory. - pub fn remove(&self) -> Result<(), WrappedError> { - remove_dir_all(&self.root).map_err(|e| { - wrap( - e, - format!( - "error removing launcher data directory {}", - self.root.display() - ), - ) - }) - } + /// Removes the launcher data directory. + pub fn remove(&self) -> Result<(), WrappedError> { + remove_dir_all(&self.root).map_err(|e| { + wrap( + e, + format!( + "error removing launcher data directory {}", + self.root.display() + ), + ) + }) + } } diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index f2e93da6a25..d94e47addf3 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -21,5 +21,5 @@ mod service_windows; pub use control_server::serve; pub use service::{ - create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, + create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, }; diff --git a/cli/src/tunnels/code_server.rs b/cli/src/tunnels/code_server.rs index 98e008dd11d..989b6b5f40d 100644 --- a/cli/src/tunnels/code_server.rs +++ b/cli/src/tunnels/code_server.rs @@ -6,11 +6,11 @@ use super::paths::{InstalledServer, LastUsedServers, ServerPaths}; use crate::options::{Quality, TelemetryLevel}; use crate::state::LauncherPaths; use crate::update_service::{ - unzip_downloaded_release, Platform, Release, TargetKind, UpdateService, + unzip_downloaded_release, Platform, Release, TargetKind, UpdateService, }; use crate::util::command::{capture_command, kill_tree}; use crate::util::errors::{ - wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError, + wrap, AnyError, ExtensionInstallFailed, MissingEntrypointError, WrappedError, }; use crate::util::http; use crate::util::io::SilentCopyProgress; @@ -33,198 +33,198 @@ use tokio::time::{interval, timeout}; use uuid::Uuid; lazy_static! { - static ref LISTENING_PORT_RE: Regex = - Regex::new(r"Extension host agent listening on (.+)").unwrap(); - static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap(); + static ref LISTENING_PORT_RE: Regex = + Regex::new(r"Extension host agent listening on (.+)").unwrap(); + static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap(); } const MAX_RETAINED_SERVERS: usize = 5; #[derive(Clone, Debug, Default)] pub struct CodeServerArgs { - pub host: Option, - pub port: Option, - pub socket_path: Option, + pub host: Option, + pub port: Option, + pub socket_path: Option, - // common argument - pub telemetry_level: Option, - pub log: Option, - pub accept_server_license_terms: bool, - pub verbose: bool, - // extension management - pub install_extensions: Vec, - pub uninstall_extensions: Vec, - pub list_extensions: bool, - pub show_versions: bool, - pub category: Option, - pub pre_release: bool, - pub force: bool, - pub start_server: bool, - // connection tokens - pub connection_token: Option, - pub connection_token_file: Option, - pub without_connection_token: bool, + // common argument + pub telemetry_level: Option, + pub log: Option, + pub accept_server_license_terms: bool, + pub verbose: bool, + // extension management + pub install_extensions: Vec, + pub uninstall_extensions: Vec, + pub list_extensions: bool, + pub show_versions: bool, + pub category: Option, + pub pre_release: bool, + pub force: bool, + pub start_server: bool, + // connection tokens + pub connection_token: Option, + pub connection_token_file: Option, + pub without_connection_token: bool, } impl CodeServerArgs { - pub fn log_level(&self) -> log::Level { - if self.verbose { - log::Level::Trace - } else { - self.log.unwrap_or(log::Level::Info) - } - } + pub fn log_level(&self) -> log::Level { + if self.verbose { + log::Level::Trace + } else { + self.log.unwrap_or(log::Level::Info) + } + } - pub fn telemetry_disabled(&self) -> bool { - self.telemetry_level == Some(TelemetryLevel::Off) - } + pub fn telemetry_disabled(&self) -> bool { + self.telemetry_level == Some(TelemetryLevel::Off) + } - pub fn command_arguments(&self) -> Vec { - let mut args = Vec::new(); - if let Some(i) = &self.socket_path { - args.push(format!("--socket-path={}", i)); - } else { - if let Some(i) = &self.host { - args.push(format!("--host={}", i)); - } - if let Some(i) = &self.port { - args.push(format!("--port={}", i)); - } - } + pub fn command_arguments(&self) -> Vec { + let mut args = Vec::new(); + if let Some(i) = &self.socket_path { + args.push(format!("--socket-path={}", i)); + } else { + if let Some(i) = &self.host { + args.push(format!("--host={}", i)); + } + if let Some(i) = &self.port { + args.push(format!("--port={}", i)); + } + } - if let Some(i) = &self.connection_token { - args.push(format!("--connection-token={}", i)); - } - if let Some(i) = &self.connection_token_file { - args.push(format!("--connection-token-file={}", i)); - } - if self.without_connection_token { - args.push(String::from("--without-connection-token")); - } - if self.accept_server_license_terms { - args.push(String::from("--accept-server-license-terms")); - } - if let Some(i) = self.telemetry_level { - args.push(format!("--telemetry-level={}", i)); - } - if let Some(i) = self.log { - args.push(format!("--log={}", i)); - } + if let Some(i) = &self.connection_token { + args.push(format!("--connection-token={}", i)); + } + if let Some(i) = &self.connection_token_file { + args.push(format!("--connection-token-file={}", i)); + } + if self.without_connection_token { + args.push(String::from("--without-connection-token")); + } + if self.accept_server_license_terms { + args.push(String::from("--accept-server-license-terms")); + } + if let Some(i) = self.telemetry_level { + args.push(format!("--telemetry-level={}", i)); + } + if let Some(i) = self.log { + args.push(format!("--log={}", i)); + } - for extension in &self.install_extensions { - args.push(format!("--install-extension={}", extension)); - } - if !&self.install_extensions.is_empty() { - if self.pre_release { - args.push(String::from("--pre-release")); - } - if self.force { - args.push(String::from("--force")); - } - } - for extension in &self.uninstall_extensions { - args.push(format!("--uninstall-extension={}", extension)); - } - if self.list_extensions { - args.push(String::from("--list-extensions")); - if self.show_versions { - args.push(String::from("--show-versions")); - } - if let Some(i) = &self.category { - args.push(format!("--category={}", i)); - } - } - if self.start_server { - args.push(String::from("--start-server")); - } - args - } + for extension in &self.install_extensions { + args.push(format!("--install-extension={}", extension)); + } + if !&self.install_extensions.is_empty() { + if self.pre_release { + args.push(String::from("--pre-release")); + } + if self.force { + args.push(String::from("--force")); + } + } + for extension in &self.uninstall_extensions { + args.push(format!("--uninstall-extension={}", extension)); + } + if self.list_extensions { + args.push(String::from("--list-extensions")); + if self.show_versions { + args.push(String::from("--show-versions")); + } + if let Some(i) = &self.category { + args.push(format!("--category={}", i)); + } + } + if self.start_server { + args.push(String::from("--start-server")); + } + args + } } /// Base server params that can be `resolve()`d to a `ResolvedServerParams`. /// Doing so fetches additional information like a commit ID if previously /// unspecified. pub struct ServerParamsRaw { - pub commit_id: Option, - pub quality: Quality, - pub code_server_args: CodeServerArgs, - pub headless: bool, - pub platform: Platform, + pub commit_id: Option, + pub quality: Quality, + pub code_server_args: CodeServerArgs, + pub headless: bool, + pub platform: Platform, } /// Server params that can be used to start a VS Code server. pub struct ResolvedServerParams { - pub release: Release, - pub code_server_args: CodeServerArgs, + pub release: Release, + pub code_server_args: CodeServerArgs, } impl ResolvedServerParams { - fn as_installed_server(&self) -> InstalledServer { - InstalledServer { - commit: self.release.commit.clone(), - quality: self.release.quality, - headless: self.release.target == TargetKind::Server, - } - } + fn as_installed_server(&self) -> InstalledServer { + InstalledServer { + commit: self.release.commit.clone(), + quality: self.release.quality, + headless: self.release.target == TargetKind::Server, + } + } } impl ServerParamsRaw { - pub async fn resolve(self, log: &log::Logger) -> Result { - Ok(ResolvedServerParams { - release: self.get_or_fetch_commit_id(log).await?, - code_server_args: self.code_server_args, - }) - } + pub async fn resolve(self, log: &log::Logger) -> Result { + Ok(ResolvedServerParams { + release: self.get_or_fetch_commit_id(log).await?, + code_server_args: self.code_server_args, + }) + } - async fn get_or_fetch_commit_id(&self, log: &log::Logger) -> Result { - let target = match self.headless { - true => TargetKind::Server, - false => TargetKind::Web, - }; + async fn get_or_fetch_commit_id(&self, log: &log::Logger) -> Result { + let target = match self.headless { + true => TargetKind::Server, + false => TargetKind::Web, + }; - if let Some(c) = &self.commit_id { - return Ok(Release { - commit: c.clone(), - quality: self.quality, - target, - platform: self.platform, - }); - } + if let Some(c) = &self.commit_id { + return Ok(Release { + commit: c.clone(), + quality: self.quality, + target, + platform: self.platform, + }); + } - UpdateService::new(log.clone(), reqwest::Client::new()) - .get_latest_commit(self.platform, target, self.quality) - .await - } + UpdateService::new(log.clone(), reqwest::Client::new()) + .get_latest_commit(self.platform, target, self.quality) + .await + } } #[derive(Deserialize)] #[serde(rename_all = "camelCase")] #[allow(dead_code)] struct UpdateServerVersion { - pub name: String, - pub version: String, - pub product_version: String, - pub timestamp: i64, + pub name: String, + pub version: String, + pub product_version: String, + pub timestamp: i64, } /// Code server listening on a port address. pub struct SocketCodeServer { - pub commit_id: String, - pub socket: PathBuf, - pub origin: CodeServerOrigin, + pub commit_id: String, + pub socket: PathBuf, + pub origin: CodeServerOrigin, } /// Code server listening on a socket address. pub struct PortCodeServer { - pub commit_id: String, - pub port: u16, - pub origin: CodeServerOrigin, + pub commit_id: String, + pub port: u16, + pub origin: CodeServerOrigin, } /// A server listening on any address/location. pub enum AnyCodeServer { - Socket(SocketCodeServer), - Port(PortCodeServer), + Socket(SocketCodeServer), + Port(PortCodeServer), } // impl AnyCodeServer { @@ -237,520 +237,520 @@ pub enum AnyCodeServer { // } pub enum CodeServerOrigin { - /// A new code server, that opens the barrier when it exits. - New(Child), - /// An existing code server with a PID. - Existing(u32), + /// A new code server, that opens the barrier when it exits. + New(Child), + /// An existing code server with a PID. + Existing(u32), } impl CodeServerOrigin { - pub async fn wait_for_exit(&mut self) { - match self { - CodeServerOrigin::New(child) => { - child.wait().await.ok(); - } - CodeServerOrigin::Existing(pid) => { - let mut interval = interval(Duration::from_secs(30)); - while process_exists(*pid) { - interval.tick().await; - } - } - } - } + pub async fn wait_for_exit(&mut self) { + match self { + CodeServerOrigin::New(child) => { + child.wait().await.ok(); + } + CodeServerOrigin::Existing(pid) => { + let mut interval = interval(Duration::from_secs(30)); + while process_exists(*pid) { + interval.tick().await; + } + } + } + } - pub async fn kill(&mut self) { - match self { - CodeServerOrigin::New(child) => { - child.kill().await.ok(); - } - CodeServerOrigin::Existing(pid) => { - kill_tree(*pid).await.ok(); - } - } - } + pub async fn kill(&mut self) { + match self { + CodeServerOrigin::New(child) => { + child.kill().await.ok(); + } + CodeServerOrigin::Existing(pid) => { + kill_tree(*pid).await.ok(); + } + } + } } async fn check_and_create_dir(path: &Path) -> Result<(), WrappedError> { - tokio::fs::create_dir_all(path) - .await - .map_err(|e| wrap(e, "error creating server directory"))?; - Ok(()) + tokio::fs::create_dir_all(path) + .await + .map_err(|e| wrap(e, "error creating server directory"))?; + Ok(()) } async fn install_server_if_needed( - log: &log::Logger, - paths: &ServerPaths, - release: &Release, + log: &log::Logger, + paths: &ServerPaths, + release: &Release, ) -> Result<(), AnyError> { - if paths.executable.exists() { - info!( - log, - "Found existing installation at {}", - paths.server_dir.display() - ); - return Ok(()); - } + if paths.executable.exists() { + info!( + log, + "Found existing installation at {}", + paths.server_dir.display() + ); + return Ok(()); + } - let tar_file_path = spanf!( - log, - log.span("server.download"), - download_server(&paths.server_dir, release, log) - )?; + let tar_file_path = spanf!( + log, + log.span("server.download"), + download_server(&paths.server_dir, release, log) + )?; - span!( - log, - log.span("server.extract"), - install_server(&tar_file_path, paths, log) - )?; + span!( + log, + log.span("server.extract"), + install_server(&tar_file_path, paths, log) + )?; - Ok(()) + Ok(()) } async fn download_server( - path: &Path, - release: &Release, - log: &log::Logger, + path: &Path, + release: &Release, + log: &log::Logger, ) -> Result { - let response = UpdateService::new(log.clone(), reqwest::Client::new()) - .get_download_stream(release) - .await?; + let response = UpdateService::new(log.clone(), reqwest::Client::new()) + .get_download_stream(release) + .await?; - let mut save_path = path.to_owned(); + let mut save_path = path.to_owned(); - let fname = response - .url() - .path_segments() - .and_then(|segments| segments.last()) - .and_then(|name| if name.is_empty() { None } else { Some(name) }) - .unwrap_or("tmp.zip"); + let fname = response + .url() + .path_segments() + .and_then(|segments| segments.last()) + .and_then(|name| if name.is_empty() { None } else { Some(name) }) + .unwrap_or("tmp.zip"); - info!( - log, - "Downloading VS Code server {} -> {}", - response.url(), - save_path.display() - ); + info!( + log, + "Downloading VS Code server {} -> {}", + response.url(), + save_path.display() + ); - save_path.push(fname); - http::download_into_file( - &save_path, - log.get_download_logger("server download progress:"), - response, - ) - .await?; + save_path.push(fname); + http::download_into_file( + &save_path, + log.get_download_logger("server download progress:"), + response, + ) + .await?; - Ok(save_path) + Ok(save_path) } fn install_server( - compressed_file: &Path, - paths: &ServerPaths, - log: &log::Logger, + compressed_file: &Path, + paths: &ServerPaths, + log: &log::Logger, ) -> Result<(), AnyError> { - info!(log, "Setting up server..."); + info!(log, "Setting up server..."); - unzip_downloaded_release(compressed_file, &paths.server_dir, SilentCopyProgress())?; + unzip_downloaded_release(compressed_file, &paths.server_dir, SilentCopyProgress())?; - match fs::remove_file(&compressed_file) { - Ok(()) => {} - Err(e) => { - if e.kind() != ErrorKind::NotFound { - return Err(AnyError::from(wrap(e, "error removing downloaded file"))); - } - } - } + match fs::remove_file(&compressed_file) { + Ok(()) => {} + Err(e) => { + if e.kind() != ErrorKind::NotFound { + return Err(AnyError::from(wrap(e, "error removing downloaded file"))); + } + } + } - if !paths.executable.exists() { - return Err(AnyError::from(MissingEntrypointError())); - } + if !paths.executable.exists() { + return Err(AnyError::from(MissingEntrypointError())); + } - Ok(()) + Ok(()) } /// Ensures the given list of extensions are installed on the running server. async fn do_extension_install_on_running_server( - start_script_path: &Path, - extensions: &[String], - log: &log::Logger, + start_script_path: &Path, + extensions: &[String], + log: &log::Logger, ) -> Result<(), AnyError> { - if extensions.is_empty() { - return Ok(()); - } + if extensions.is_empty() { + return Ok(()); + } - debug!(log, "Installing extensions..."); - let command = format!( - "{} {}", - start_script_path.display(), - extensions - .iter() - .map(|s| get_extensions_flag(s)) - .collect::>() - .join(" ") - ); + debug!(log, "Installing extensions..."); + let command = format!( + "{} {}", + start_script_path.display(), + extensions + .iter() + .map(|s| get_extensions_flag(s)) + .collect::>() + .join(" ") + ); - let result = capture_command("bash", &["-c", &command]).await?; - if !result.status.success() { - Err(AnyError::from(ExtensionInstallFailed( - String::from_utf8_lossy(&result.stderr).to_string(), - ))) - } else { - Ok(()) - } + let result = capture_command("bash", &["-c", &command]).await?; + if !result.status.success() { + Err(AnyError::from(ExtensionInstallFailed( + String::from_utf8_lossy(&result.stderr).to_string(), + ))) + } else { + Ok(()) + } } pub struct ServerBuilder<'a> { - logger: &'a log::Logger, - server_params: &'a ResolvedServerParams, - last_used: LastUsedServers<'a>, - server_paths: ServerPaths, + logger: &'a log::Logger, + server_params: &'a ResolvedServerParams, + last_used: LastUsedServers<'a>, + server_paths: ServerPaths, } impl<'a> ServerBuilder<'a> { - pub fn new( - logger: &'a log::Logger, - server_params: &'a ResolvedServerParams, - launcher_paths: &'a LauncherPaths, - ) -> Self { - Self { - logger, - server_params, - last_used: LastUsedServers::new(launcher_paths), - server_paths: server_params - .as_installed_server() - .server_paths(launcher_paths), - } - } + pub fn new( + logger: &'a log::Logger, + server_params: &'a ResolvedServerParams, + launcher_paths: &'a LauncherPaths, + ) -> Self { + Self { + logger, + server_params, + last_used: LastUsedServers::new(launcher_paths), + server_paths: server_params + .as_installed_server() + .server_paths(launcher_paths), + } + } - /// Gets any already-running server from this directory. - pub async fn get_running(&self) -> Result, AnyError> { - info!( - self.logger, - "Checking {} and {} for a running server...", - self.server_paths.logfile.display(), - self.server_paths.pidfile.display() - ); + /// Gets any already-running server from this directory. + pub async fn get_running(&self) -> Result, AnyError> { + info!( + self.logger, + "Checking {} and {} for a running server...", + self.server_paths.logfile.display(), + self.server_paths.pidfile.display() + ); - let pid = match self.server_paths.get_running_pid() { - Some(pid) => pid, - None => return Ok(None), - }; - info!(self.logger, "Found running server (pid={})", pid); - if !Path::new(&self.server_paths.logfile).exists() { - warning!(self.logger, "VS Code Server is running but its logfile is missing. Don't delete the VS Code Server manually, run the command 'code-server prune'."); - return Ok(None); - } + let pid = match self.server_paths.get_running_pid() { + Some(pid) => pid, + None => return Ok(None), + }; + info!(self.logger, "Found running server (pid={})", pid); + if !Path::new(&self.server_paths.logfile).exists() { + warning!(self.logger, "VS Code Server is running but its logfile is missing. Don't delete the VS Code Server manually, run the command 'code-server prune'."); + return Ok(None); + } - do_extension_install_on_running_server( - &self.server_paths.executable, - &self.server_params.code_server_args.install_extensions, - self.logger, - ) - .await?; + do_extension_install_on_running_server( + &self.server_paths.executable, + &self.server_params.code_server_args.install_extensions, + self.logger, + ) + .await?; - let origin = CodeServerOrigin::Existing(pid); - let contents = fs::read_to_string(&self.server_paths.logfile) - .expect("Something went wrong reading log file"); + let origin = CodeServerOrigin::Existing(pid); + let contents = fs::read_to_string(&self.server_paths.logfile) + .expect("Something went wrong reading log file"); - if let Some(port) = parse_port_from(&contents) { - Ok(Some(AnyCodeServer::Port(PortCodeServer { - commit_id: self.server_params.release.commit.to_owned(), - port, - origin, - }))) - } else if let Some(socket) = parse_socket_from(&contents) { - Ok(Some(AnyCodeServer::Socket(SocketCodeServer { - commit_id: self.server_params.release.commit.to_owned(), - socket, - origin, - }))) - } else { - Ok(None) - } - } + if let Some(port) = parse_port_from(&contents) { + Ok(Some(AnyCodeServer::Port(PortCodeServer { + commit_id: self.server_params.release.commit.to_owned(), + port, + origin, + }))) + } else if let Some(socket) = parse_socket_from(&contents) { + Ok(Some(AnyCodeServer::Socket(SocketCodeServer { + commit_id: self.server_params.release.commit.to_owned(), + socket, + origin, + }))) + } else { + Ok(None) + } + } - /// Ensures the server is set up in the configured directory. - pub async fn setup(&self) -> Result<(), AnyError> { - debug!(self.logger, "Installing and setting up VS Code Server..."); - check_and_create_dir(&self.server_paths.server_dir).await?; - install_server_if_needed(self.logger, &self.server_paths, &self.server_params.release) - .await?; - debug!(self.logger, "Server setup complete"); + /// Ensures the server is set up in the configured directory. + pub async fn setup(&self) -> Result<(), AnyError> { + debug!(self.logger, "Installing and setting up VS Code Server..."); + check_and_create_dir(&self.server_paths.server_dir).await?; + install_server_if_needed(self.logger, &self.server_paths, &self.server_params.release) + .await?; + debug!(self.logger, "Server setup complete"); - match self.last_used.add(self.server_params.as_installed_server()) { - Err(e) => warning!(self.logger, "Error adding server to last used: {}", e), - Ok(count) if count > MAX_RETAINED_SERVERS => { - if let Err(e) = self.last_used.trim(self.logger, MAX_RETAINED_SERVERS) { - warning!(self.logger, "Error trimming old servers: {}", e); - } - } - Ok(_) => {} - } + match self.last_used.add(self.server_params.as_installed_server()) { + Err(e) => warning!(self.logger, "Error adding server to last used: {}", e), + Ok(count) if count > MAX_RETAINED_SERVERS => { + if let Err(e) = self.last_used.trim(self.logger, MAX_RETAINED_SERVERS) { + warning!(self.logger, "Error trimming old servers: {}", e); + } + } + Ok(_) => {} + } - Ok(()) - } + Ok(()) + } - pub async fn listen_on_default_socket(&self) -> Result { - let requested_file = if cfg!(target_os = "windows") { - PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4())) - } else { - std::env::temp_dir().join(format!("vscode-server-{}", Uuid::new_v4())) - }; + pub async fn listen_on_default_socket(&self) -> Result { + let requested_file = if cfg!(target_os = "windows") { + PathBuf::from(format!(r"\\.\pipe\vscode-server-{}", Uuid::new_v4())) + } else { + std::env::temp_dir().join(format!("vscode-server-{}", Uuid::new_v4())) + }; - self.listen_on_socket(&requested_file).await - } + self.listen_on_socket(&requested_file).await + } - pub async fn listen_on_socket(&self, socket: &Path) -> Result { - Ok(spanf!( - self.logger, - self.logger.span("server.start").with_attributes(vec! { - KeyValue::new("commit_id", self.server_params.release.commit.to_string()), - KeyValue::new("quality", format!("{}", self.server_params.release.quality)), - }), - self._listen_on_socket(socket) - )?) - } + pub async fn listen_on_socket(&self, socket: &Path) -> Result { + Ok(spanf!( + self.logger, + self.logger.span("server.start").with_attributes(vec! { + KeyValue::new("commit_id", self.server_params.release.commit.to_string()), + KeyValue::new("quality", format!("{}", self.server_params.release.quality)), + }), + self._listen_on_socket(socket) + )?) + } - async fn _listen_on_socket(&self, socket: &Path) -> Result { - remove_file(&socket).await.ok(); // ignore any error if it doesn't exist + async fn _listen_on_socket(&self, socket: &Path) -> Result { + remove_file(&socket).await.ok(); // ignore any error if it doesn't exist - let mut cmd = self.get_base_command(); - cmd.arg("--start-server") - .arg("--without-connection-token") - .arg("--enable-remote-auto-shutdown") - .arg(format!("--socket-path={}", socket.display())); + let mut cmd = self.get_base_command(); + cmd.arg("--start-server") + .arg("--without-connection-token") + .arg("--enable-remote-auto-shutdown") + .arg(format!("--socket-path={}", socket.display())); - let child = self.spawn_server_process(cmd)?; - let log_file = self.get_logfile()?; - let plog = self.logger.prefixed(&log::new_code_server_prefix()); + let child = self.spawn_server_process(cmd)?; + let log_file = self.get_logfile()?; + let plog = self.logger.prefixed(&log::new_code_server_prefix()); - let (mut origin, listen_rx) = - monitor_server::(child, Some(log_file), plog, false); + let (mut origin, listen_rx) = + monitor_server::(child, Some(log_file), plog, false); - let socket = match timeout(Duration::from_secs(8), listen_rx).await { - Err(e) => { - origin.kill().await; - Err(wrap(e, "timed out looking for socket")) - } - Ok(Err(e)) => { - origin.kill().await; - Err(wrap(e, "server exited without writing socket")) - } - Ok(Ok(socket)) => Ok(socket), - }?; + let socket = match timeout(Duration::from_secs(8), listen_rx).await { + Err(e) => { + origin.kill().await; + Err(wrap(e, "timed out looking for socket")) + } + Ok(Err(e)) => { + origin.kill().await; + Err(wrap(e, "server exited without writing socket")) + } + Ok(Ok(socket)) => Ok(socket), + }?; - info!(self.logger, "Server started"); + info!(self.logger, "Server started"); - Ok(SocketCodeServer { - commit_id: self.server_params.release.commit.to_owned(), - socket, - origin, - }) - } + Ok(SocketCodeServer { + commit_id: self.server_params.release.commit.to_owned(), + socket, + origin, + }) + } - /// Starts with a given opaque set of args. Does not set up any port or - /// socket, but does return one if present, in the form of a channel. - pub async fn start_opaque_with_args( - &self, - args: &[String], - ) -> Result<(CodeServerOrigin, Receiver), AnyError> - where - M: ServerOutputMatcher, - R: 'static + Send + std::fmt::Debug, - { - let mut cmd = self.get_base_command(); - cmd.args(args); + /// Starts with a given opaque set of args. Does not set up any port or + /// socket, but does return one if present, in the form of a channel. + pub async fn start_opaque_with_args( + &self, + args: &[String], + ) -> Result<(CodeServerOrigin, Receiver), AnyError> + where + M: ServerOutputMatcher, + R: 'static + Send + std::fmt::Debug, + { + let mut cmd = self.get_base_command(); + cmd.args(args); - let child = self.spawn_server_process(cmd)?; - let plog = self.logger.prefixed(&log::new_code_server_prefix()); + let child = self.spawn_server_process(cmd)?; + let plog = self.logger.prefixed(&log::new_code_server_prefix()); - Ok(monitor_server::(child, None, plog, true)) - } + Ok(monitor_server::(child, None, plog, true)) + } - fn spawn_server_process(&self, mut cmd: Command) -> Result { - info!(self.logger, "Starting server..."); + fn spawn_server_process(&self, mut cmd: Command) -> Result { + info!(self.logger, "Starting server..."); - debug!(self.logger, "Starting server with command... {:?}", cmd); + debug!(self.logger, "Starting server with command... {:?}", cmd); - let child = cmd - .stderr(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn() - .map_err(|e| wrap(e, "error spawning server"))?; + let child = cmd + .stderr(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn() + .map_err(|e| wrap(e, "error spawning server"))?; - self.server_paths - .write_pid(child.id().expect("expected server to have pid"))?; + self.server_paths + .write_pid(child.id().expect("expected server to have pid"))?; - Ok(child) - } + Ok(child) + } - fn get_logfile(&self) -> Result { - File::create(&self.server_paths.logfile).map_err(|e| { - wrap( - e, - format!( - "error creating log file {}", - self.server_paths.logfile.display() - ), - ) - }) - } + fn get_logfile(&self) -> Result { + File::create(&self.server_paths.logfile).map_err(|e| { + wrap( + e, + format!( + "error creating log file {}", + self.server_paths.logfile.display() + ), + ) + }) + } - fn get_base_command(&self) -> Command { - let mut cmd = Command::new(&self.server_paths.executable); - cmd.stdin(std::process::Stdio::null()) - .args(self.server_params.code_server_args.command_arguments()); - cmd - } + fn get_base_command(&self) -> Command { + let mut cmd = Command::new(&self.server_paths.executable); + cmd.stdin(std::process::Stdio::null()) + .args(self.server_params.code_server_args.command_arguments()); + cmd + } } fn monitor_server( - mut child: Child, - log_file: Option, - plog: log::Logger, - write_directly: bool, + mut child: Child, + log_file: Option, + plog: log::Logger, + write_directly: bool, ) -> (CodeServerOrigin, Receiver) where - M: ServerOutputMatcher, - R: 'static + Send + std::fmt::Debug, + M: ServerOutputMatcher, + R: 'static + Send + std::fmt::Debug, { - let stdout = child - .stdout - .take() - .expect("child did not have a handle to stdout"); + let stdout = child + .stdout + .take() + .expect("child did not have a handle to stdout"); - let stderr = child - .stderr - .take() - .expect("child did not have a handle to stdout"); + let stderr = child + .stderr + .take() + .expect("child did not have a handle to stdout"); - let (listen_tx, listen_rx) = tokio::sync::oneshot::channel(); + let (listen_tx, listen_rx) = tokio::sync::oneshot::channel(); - // Handle stderr and stdout in a separate task. Initially scan lines looking - // for the listening port. Afterwards, just scan and write out to the file. - tokio::spawn(async move { - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - let write_line = |line: &str| -> std::io::Result<()> { - if let Some(mut f) = log_file.as_ref() { - f.write_all(line.as_bytes())?; - f.write_all(&[b'\n'])?; - } - if write_directly { - println!("{}", line); - } else { - trace!(plog, line); - } - Ok(()) - }; + // Handle stderr and stdout in a separate task. Initially scan lines looking + // for the listening port. Afterwards, just scan and write out to the file. + tokio::spawn(async move { + let mut stdout_reader = BufReader::new(stdout).lines(); + let mut stderr_reader = BufReader::new(stderr).lines(); + let write_line = |line: &str| -> std::io::Result<()> { + if let Some(mut f) = log_file.as_ref() { + f.write_all(line.as_bytes())?; + f.write_all(&[b'\n'])?; + } + if write_directly { + println!("{}", line); + } else { + trace!(plog, line); + } + Ok(()) + }; - loop { - let line = tokio::select! { - l = stderr_reader.next_line() => l, - l = stdout_reader.next_line() => l, - }; + loop { + let line = tokio::select! { + l = stderr_reader.next_line() => l, + l = stdout_reader.next_line() => l, + }; - match line { - Err(e) => { - trace!(plog, "error reading from stdout/stderr: {}", e); - return; - } - Ok(None) => break, - Ok(Some(l)) => { - write_line(&l).ok(); + match line { + Err(e) => { + trace!(plog, "error reading from stdout/stderr: {}", e); + return; + } + Ok(None) => break, + Ok(Some(l)) => { + write_line(&l).ok(); - if let Some(listen_on) = M::match_line(&l) { - trace!(plog, "parsed location: {:?}", listen_on); - listen_tx.send(listen_on).ok(); - break; - } - } - } - } + if let Some(listen_on) = M::match_line(&l) { + trace!(plog, "parsed location: {:?}", listen_on); + listen_tx.send(listen_on).ok(); + break; + } + } + } + } - loop { - let line = tokio::select! { - l = stderr_reader.next_line() => l, - l = stdout_reader.next_line() => l, - }; + loop { + let line = tokio::select! { + l = stderr_reader.next_line() => l, + l = stdout_reader.next_line() => l, + }; - match line { - Err(e) => { - trace!(plog, "error reading from stdout/stderr: {}", e); - break; - } - Ok(None) => break, - Ok(Some(l)) => { - write_line(&l).ok(); - } - } - } - }); + match line { + Err(e) => { + trace!(plog, "error reading from stdout/stderr: {}", e); + break; + } + Ok(None) => break, + Ok(Some(l)) => { + write_line(&l).ok(); + } + } + } + }); - let origin = CodeServerOrigin::New(child); - (origin, listen_rx) + let origin = CodeServerOrigin::New(child); + (origin, listen_rx) } fn get_extensions_flag(extension_id: &str) -> String { - format!("--install-extension={}", extension_id) + format!("--install-extension={}", extension_id) } /// A type that can be used to scan stdout from the VS Code server. Returns /// some other type that, in turn, is returned from starting the server. pub trait ServerOutputMatcher where - R: Send, + R: Send, { - fn match_line(line: &str) -> Option; + fn match_line(line: &str) -> Option; } /// Parses a line like "Extension host agent listening on /tmp/foo.sock" struct SocketMatcher(); impl ServerOutputMatcher for SocketMatcher { - fn match_line(line: &str) -> Option { - parse_socket_from(line) - } + fn match_line(line: &str) -> Option { + parse_socket_from(line) + } } /// Parses a line like "Extension host agent listening on 9000" pub struct PortMatcher(); impl ServerOutputMatcher for PortMatcher { - fn match_line(line: &str) -> Option { - parse_port_from(line) - } + fn match_line(line: &str) -> Option { + parse_port_from(line) + } } /// Parses a line like "Web UI available at http://localhost:9000/?tkn=..." pub struct WebUiMatcher(); impl ServerOutputMatcher for WebUiMatcher { - fn match_line(line: &str) -> Option { - WEB_UI_RE.captures(line).and_then(|cap| { - cap.get(1) - .and_then(|uri| reqwest::Url::parse(uri.as_str()).ok()) - }) - } + fn match_line(line: &str) -> Option { + WEB_UI_RE.captures(line).and_then(|cap| { + cap.get(1) + .and_then(|uri| reqwest::Url::parse(uri.as_str()).ok()) + }) + } } /// Does not do any parsing and just immediately returns an empty result. pub struct NoOpMatcher(); impl ServerOutputMatcher<()> for NoOpMatcher { - fn match_line(_: &str) -> Option<()> { - Some(()) - } + fn match_line(_: &str) -> Option<()> { + Some(()) + } } fn parse_socket_from(text: &str) -> Option { - LISTENING_PORT_RE - .captures(text) - .and_then(|cap| cap.get(1).map(|path| PathBuf::from(path.as_str()))) + LISTENING_PORT_RE + .captures(text) + .and_then(|cap| cap.get(1).map(|path| PathBuf::from(path.as_str()))) } fn parse_port_from(text: &str) -> Option { - LISTENING_PORT_RE.captures(text).and_then(|cap| { - cap.get(1) - .and_then(|path| path.as_str().parse::().ok()) - }) + LISTENING_PORT_RE.captures(text).and_then(|cap| { + cap.get(1) + .and_then(|path| path.as_str().parse::().ok()) + }) } diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index 3122b479593..6a021a72de3 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -2,13 +2,13 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -use crate::constants::{CONTROL_PORT, VSCODE_CLI_VERSION, PROTOCOL_VERSION}; +use crate::constants::{CONTROL_PORT, PROTOCOL_VERSION, VSCODE_CLI_VERSION}; use crate::log; use crate::state::LauncherPaths; use crate::update::Update; use crate::update_service::Platform; use crate::util::errors::{ - wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError, + wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError, }; use crate::util::sync::{new_barrier, Barrier}; use opentelemetry::trace::SpanKind; @@ -25,16 +25,16 @@ use tokio::pin; use tokio::sync::{mpsc, oneshot, Mutex}; use super::code_server::{ - AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, SocketCodeServer, + AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw, SocketCodeServer, }; use super::dev_tunnels::ActiveTunnel; use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ - CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse, - ForwardParams, ForwardResult, GetHostnameResponse, RefServerMessageParams, ResponseError, - ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod, SuccessResponse, - ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult, VersionParams, + CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse, + ForwardParams, ForwardResult, GetHostnameResponse, RefServerMessageParams, ResponseError, + ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod, SuccessResponse, + ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult, VersionParams, }; use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge}; @@ -42,685 +42,685 @@ type ServerBridgeList = Option>; type ServerBridgeListLock = Arc>; struct HandlerContext { - /// Exit barrier for the socket. - closer: Barrier<()>, - /// Log handle for the server - log: log::Logger, - /// A loopback channel to talk to the TCP server task. - server_tx: mpsc::Sender, - /// A loopback channel to talk to the socket server task. - socket_tx: mpsc::Sender, - /// Configured launcher paths. - launcher_paths: LauncherPaths, - /// Connected VS Code Server - code_server: Option, - /// Potentially many "websocket" connections to client - server_bridges: ServerBridgeListLock, - // the cli arguments used to start the code server - code_server_args: CodeServerArgs, - /// counter for the number of bytes received from the socket - rx_counter: Arc, - /// port forwarding functionality - port_forwarding: PortForwarding, - /// install platform for the VS Code server - platform: Platform, + /// Exit barrier for the socket. + closer: Barrier<()>, + /// Log handle for the server + log: log::Logger, + /// A loopback channel to talk to the TCP server task. + server_tx: mpsc::Sender, + /// A loopback channel to talk to the socket server task. + socket_tx: mpsc::Sender, + /// Configured launcher paths. + launcher_paths: LauncherPaths, + /// Connected VS Code Server + code_server: Option, + /// Potentially many "websocket" connections to client + server_bridges: ServerBridgeListLock, + // the cli arguments used to start the code server + code_server_args: CodeServerArgs, + /// counter for the number of bytes received from the socket + rx_counter: Arc, + /// port forwarding functionality + port_forwarding: PortForwarding, + /// install platform for the VS Code server + platform: Platform, } impl HandlerContext { - async fn dispose(self) { - let bridges: ServerBridgeList = { - let mut lock = self.server_bridges.lock().await; - let bridges = lock.take(); - *lock = None; - bridges - }; + async fn dispose(self) { + let bridges: ServerBridgeList = { + let mut lock = self.server_bridges.lock().await; + let bridges = lock.take(); + *lock = None; + bridges + }; - if let Some(b) = bridges { - for (_, bridge) in b { - if let Err(e) = bridge.close().await { - warning!( - self.log, - "Could not properly dispose of connection context: {}", - e - ) - } else { - debug!(self.log, "Closed server bridge."); - } - } - } + if let Some(b) = bridges { + for (_, bridge) in b { + if let Err(e) = bridge.close().await { + warning!( + self.log, + "Could not properly dispose of connection context: {}", + e + ) + } else { + debug!(self.log, "Closed server bridge."); + } + } + } - info!(self.log, "Disposed of connection to running server."); - } + info!(self.log, "Disposed of connection to running server."); + } } enum ServerSignal { - /// Signalled when the server has been updated and we want to respawn. - /// We'd generally need to stop and then restart the launcher, but the - /// program might be managed by a supervisor like systemd. Instead, we - /// will stop the TCP listener and spawn the launcher again as a subprocess - /// with the same arguments we used. - Respawn, + /// Signalled when the server has been updated and we want to respawn. + /// We'd generally need to stop and then restart the launcher, but the + /// program might be managed by a supervisor like systemd. Instead, we + /// will stop the TCP listener and spawn the launcher again as a subprocess + /// with the same arguments we used. + Respawn, } struct CloseReason(String); enum SocketSignal { - /// Signals bytes to send to the socket. - Send(Vec), - /// Closes the socket (e.g. as a result of an error) - CloseWith(CloseReason), - /// Disposes ServerBridge corresponding to an ID - CloseServerBridge(u16), + /// Signals bytes to send to the socket. + Send(Vec), + /// Closes the socket (e.g. as a result of an error) + CloseWith(CloseReason), + /// Disposes ServerBridge corresponding to an ID + CloseServerBridge(u16), } impl SocketSignal { - fn from_message(msg: &T) -> Self - where - T: Serialize + ?Sized, - { - SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap()) - } + fn from_message(msg: &T) -> Self + where + T: Serialize + ?Sized, + { + SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap()) + } } impl FromServerMessage for SocketSignal { - fn from_server_message(i: u16, body: &[u8]) -> Self { - SocketSignal::from_message(&ToClientRequest { - id: None, - params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }), - }) - } + fn from_server_message(i: u16, body: &[u8]) -> Self { + SocketSignal::from_message(&ToClientRequest { + id: None, + params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }), + }) + } - fn from_closed_server_bridge(i: u16) -> Self { - SocketSignal::CloseServerBridge(i) - } + fn from_closed_server_bridge(i: u16) -> Self { + SocketSignal::CloseServerBridge(i) + } } pub struct ServerTermination { - /// Whether the server should be respawned in a new binary (see ServerSignal.Respawn). - pub respawn: bool, - pub tunnel: ActiveTunnel, + /// Whether the server should be respawned in a new binary (see ServerSignal.Respawn). + pub respawn: bool, + pub tunnel: ActiveTunnel, } fn print_listening(log: &log::Logger, tunnel_name: &str) { - debug!(log, "VS Code Server is listening for incoming connections"); + debug!(log, "VS Code Server is listening for incoming connections"); - let extension_name = "+ms-vscode.remote-server"; - let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("")); - let current_dir = env::current_dir().unwrap_or_else(|_| PathBuf::from("")); + let extension_name = "+ms-vscode.remote-server"; + let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("")); + let current_dir = env::current_dir().unwrap_or_else(|_| PathBuf::from("")); - let dir = if home_dir == current_dir { - PathBuf::from("") - } else { - current_dir - }; + let dir = if home_dir == current_dir { + PathBuf::from("") + } else { + current_dir + }; - let mut addr = url::Url::parse("https://insiders.vscode.dev").unwrap(); - { - let mut ps = addr.path_segments_mut().unwrap(); - ps.push(extension_name); - ps.push(tunnel_name); - for segment in &dir { - let as_str = segment.to_string_lossy(); - if !(as_str.len() == 1 && as_str.starts_with(std::path::MAIN_SEPARATOR)) { - ps.push(as_str.as_ref()); - } - } - } + let mut addr = url::Url::parse("https://insiders.vscode.dev").unwrap(); + { + let mut ps = addr.path_segments_mut().unwrap(); + ps.push(extension_name); + ps.push(tunnel_name); + for segment in &dir { + let as_str = segment.to_string_lossy(); + if !(as_str.len() == 1 && as_str.starts_with(std::path::MAIN_SEPARATOR)) { + ps.push(as_str.as_ref()); + } + } + } - let message = &format!("\nOpen this link in your browser {}\n", addr); - log.result(message); + let message = &format!("\nOpen this link in your browser {}\n", addr); + log.result(message); } // Runs the launcher server. Exits on a ctrl+c or when requested by a user. // Note that client connections may not be closed when this returns; use // `close_all_clients()` on the ServerTermination to make this happen. pub async fn serve( - log: &log::Logger, - mut tunnel: ActiveTunnel, - launcher_paths: &LauncherPaths, - code_server_args: &CodeServerArgs, - platform: Platform, - shutdown_rx: oneshot::Receiver<()>, + log: &log::Logger, + mut tunnel: ActiveTunnel, + launcher_paths: &LauncherPaths, + code_server_args: &CodeServerArgs, + platform: Platform, + shutdown_rx: oneshot::Receiver<()>, ) -> Result { - let mut port = tunnel.add_port_direct(CONTROL_PORT).await?; - print_listening(log, &tunnel.name); + let mut port = tunnel.add_port_direct(CONTROL_PORT).await?; + print_listening(log, &tunnel.name); - let mut forwarding = PortForwardingProcessor::new(); - let (tx, mut rx) = mpsc::channel::(4); - let (exit_barrier, signal_exit) = new_barrier(); + let mut forwarding = PortForwardingProcessor::new(); + let (tx, mut rx) = mpsc::channel::(4); + let (exit_barrier, signal_exit) = new_barrier(); - pin!(shutdown_rx); + pin!(shutdown_rx); - loop { - tokio::select! { - _ = &mut shutdown_rx => { - info!(log, "Received interrupt, shutting down..."); - drop(signal_exit); - return Ok(ServerTermination { - respawn: false, - tunnel, - }); - }, - c = rx.recv() => { - if let Some(ServerSignal::Respawn) = c { - drop(signal_exit); - return Ok(ServerTermination { - respawn: true, - tunnel, - }); - } - }, - Some(w) = forwarding.recv() => { - forwarding.process(w, &mut tunnel).await; - }, - l = port.recv() => { - let socket = match l { - Some(p) => p, - None => { - warning!(log, "ssh tunnel disposed, tearing down"); - return Ok(ServerTermination { - respawn: false, - tunnel, - }); - } - }; + loop { + tokio::select! { + _ = &mut shutdown_rx => { + info!(log, "Received interrupt, shutting down..."); + drop(signal_exit); + return Ok(ServerTermination { + respawn: false, + tunnel, + }); + }, + c = rx.recv() => { + if let Some(ServerSignal::Respawn) = c { + drop(signal_exit); + return Ok(ServerTermination { + respawn: true, + tunnel, + }); + } + }, + Some(w) = forwarding.recv() => { + forwarding.process(w, &mut tunnel).await; + }, + l = port.recv() => { + let socket = match l { + Some(p) => p, + None => { + warning!(log, "ssh tunnel disposed, tearing down"); + return Ok(ServerTermination { + respawn: false, + tunnel, + }); + } + }; - let own_log = log.prefixed(&log::new_rpc_prefix()); - let own_tx = tx.clone(); - let own_paths = launcher_paths.clone(); - let own_exit = exit_barrier.clone(); - let own_code_server_args = code_server_args.clone(); - let own_forwarding = forwarding.handle(); + let own_log = log.prefixed(&log::new_rpc_prefix()); + let own_tx = tx.clone(); + let own_paths = launcher_paths.clone(); + let own_exit = exit_barrier.clone(); + let own_code_server_args = code_server_args.clone(); + let own_forwarding = forwarding.handle(); - tokio::spawn(async move { - use opentelemetry::trace::{FutureExt, TraceContextExt}; + tokio::spawn(async move { + use opentelemetry::trace::{FutureExt, TraceContextExt}; - let span = own_log.span("server.socket").with_kind(SpanKind::Consumer).start(own_log.tracer()); - let cx = opentelemetry::Context::current_with_span(span); - let serve_at = Instant::now(); + let span = own_log.span("server.socket").with_kind(SpanKind::Consumer).start(own_log.tracer()); + let cx = opentelemetry::Context::current_with_span(span); + let serve_at = Instant::now(); - debug!(own_log, "Serving new connection"); + debug!(own_log, "Serving new connection"); - let (writehalf, readhalf) = socket.into_split(); - let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await; + let (writehalf, readhalf) = socket.into_split(); + let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await; - cx.span().add_event( - "socket.bandwidth", - vec![ - KeyValue::new("tx", stats.tx as f64), - KeyValue::new("rx", stats.rx as f64), - KeyValue::new("duration_ms", serve_at.elapsed().as_millis() as f64), - ], - ); - cx.span().end(); - }); - } - } - } + cx.span().add_event( + "socket.bandwidth", + vec![ + KeyValue::new("tx", stats.tx as f64), + KeyValue::new("rx", stats.rx as f64), + KeyValue::new("duration_ms", serve_at.elapsed().as_millis() as f64), + ], + ); + cx.span().end(); + }); + } + } + } } struct SocketStats { - rx: usize, - tx: usize, + rx: usize, + tx: usize, } #[allow(clippy::too_many_arguments)] // necessary here async fn process_socket( - mut exit_barrier: Barrier<()>, - readhalf: impl AsyncRead + Send + Unpin + 'static, - mut writehalf: impl AsyncWrite + Unpin, - log: log::Logger, - server_tx: mpsc::Sender, - launcher_paths: LauncherPaths, - code_server_args: CodeServerArgs, - port_forwarding: PortForwarding, - platform: Platform, + mut exit_barrier: Barrier<()>, + readhalf: impl AsyncRead + Send + Unpin + 'static, + mut writehalf: impl AsyncWrite + Unpin, + log: log::Logger, + server_tx: mpsc::Sender, + launcher_paths: LauncherPaths, + code_server_args: CodeServerArgs, + port_forwarding: PortForwarding, + platform: Platform, ) -> SocketStats { - let (socket_tx, mut socket_rx) = mpsc::channel(4); + let (socket_tx, mut socket_rx) = mpsc::channel(4); - let rx_counter = Arc::new(AtomicUsize::new(0)); + let rx_counter = Arc::new(AtomicUsize::new(0)); - let server_bridges: ServerBridgeListLock = Arc::new(Mutex::new(Some(vec![]))); - let server_bridges_lock = Arc::clone(&server_bridges); - let barrier_ctx = exit_barrier.clone(); - let log_ctx = log.clone(); - let rx_counter_ctx = rx_counter.clone(); + let server_bridges: ServerBridgeListLock = Arc::new(Mutex::new(Some(vec![]))); + let server_bridges_lock = Arc::clone(&server_bridges); + let barrier_ctx = exit_barrier.clone(); + let log_ctx = log.clone(); + let rx_counter_ctx = rx_counter.clone(); - tokio::spawn(async move { - let mut ctx = HandlerContext { - closer: barrier_ctx, - server_tx, - socket_tx, - log: log_ctx, - launcher_paths, - code_server_args, - rx_counter: rx_counter_ctx, - code_server: None, - server_bridges: server_bridges_lock, - port_forwarding, - platform, - }; + tokio::spawn(async move { + let mut ctx = HandlerContext { + closer: barrier_ctx, + server_tx, + socket_tx, + log: log_ctx, + launcher_paths, + code_server_args, + rx_counter: rx_counter_ctx, + code_server: None, + server_bridges: server_bridges_lock, + port_forwarding, + platform, + }; - send_version(&ctx.socket_tx).await; + send_version(&ctx.socket_tx).await; - if let Err(e) = handle_socket_read(readhalf, &mut ctx).await { - debug!(ctx.log, "closing socket reader: {}", e); - ctx.socket_tx - .send(SocketSignal::CloseWith(CloseReason(format!("{}", e)))) - .await - .ok(); - } + if let Err(e) = handle_socket_read(readhalf, &mut ctx).await { + debug!(ctx.log, "closing socket reader: {}", e); + ctx.socket_tx + .send(SocketSignal::CloseWith(CloseReason(format!("{}", e)))) + .await + .ok(); + } - ctx.dispose().await; - }); + ctx.dispose().await; + }); - let mut tx_counter = 0; + let mut tx_counter = 0; - loop { - tokio::select! { - _ = exit_barrier.wait() => { - writehalf.shutdown().await.ok(); - break; - }, - recv = socket_rx.recv() => match recv { - None => break, - Some(message) => match message { - SocketSignal::Send(bytes) => { - tx_counter += bytes.len(); - if let Err(e) = writehalf.write_all(&bytes).await { - debug!(log, "Closing connection: {}", e); - break; - } - } - SocketSignal::CloseWith(reason) => { - debug!(log, "Closing connection: {}", reason.0); - break; - } - SocketSignal::CloseServerBridge(id) => { - let mut lock = server_bridges.lock().await; - match &mut *lock { - Some(bridges) => { - if let Some(index) = bridges.iter().position(|(i, _)| *i == id) { - (*bridges).remove(index as usize); - } - }, - None => {} - } - } - } - } - } - } + loop { + tokio::select! { + _ = exit_barrier.wait() => { + writehalf.shutdown().await.ok(); + break; + }, + recv = socket_rx.recv() => match recv { + None => break, + Some(message) => match message { + SocketSignal::Send(bytes) => { + tx_counter += bytes.len(); + if let Err(e) = writehalf.write_all(&bytes).await { + debug!(log, "Closing connection: {}", e); + break; + } + } + SocketSignal::CloseWith(reason) => { + debug!(log, "Closing connection: {}", reason.0); + break; + } + SocketSignal::CloseServerBridge(id) => { + let mut lock = server_bridges.lock().await; + match &mut *lock { + Some(bridges) => { + if let Some(index) = bridges.iter().position(|(i, _)| *i == id) { + (*bridges).remove(index as usize); + } + }, + None => {} + } + } + } + } + } + } - SocketStats { - tx: tx_counter, - rx: rx_counter.load(Ordering::Acquire), - } + SocketStats { + tx: tx_counter, + rx: rx_counter.load(Ordering::Acquire), + } } async fn send_version(tx: &mpsc::Sender) { - tx.send(SocketSignal::from_message(&ToClientRequest { - id: None, - params: ClientRequestMethod::version(VersionParams { - version: VSCODE_CLI_VERSION.unwrap_or("dev"), - protocol_version: PROTOCOL_VERSION, - }), - })) - .await - .ok(); + tx.send(SocketSignal::from_message(&ToClientRequest { + id: None, + params: ClientRequestMethod::version(VersionParams { + version: VSCODE_CLI_VERSION.unwrap_or("dev"), + protocol_version: PROTOCOL_VERSION, + }), + })) + .await + .ok(); } async fn handle_socket_read( - readhalf: impl AsyncRead + Unpin, - ctx: &mut HandlerContext, + readhalf: impl AsyncRead + Unpin, + ctx: &mut HandlerContext, ) -> Result<(), std::io::Error> { - let mut socket_reader = BufReader::new(readhalf); - let mut decode_buf = vec![]; - let mut did_update = false; + let mut socket_reader = BufReader::new(readhalf); + let mut decode_buf = vec![]; + let mut did_update = false; - let result = loop { - match read_next(&mut socket_reader, ctx, &mut decode_buf, &mut did_update).await { - Ok(false) => break Ok(()), - Ok(true) => { /* continue */ } - Err(e) => break Err(e), - } - }; + let result = loop { + match read_next(&mut socket_reader, ctx, &mut decode_buf, &mut did_update).await { + Ok(false) => break Ok(()), + Ok(true) => { /* continue */ } + Err(e) => break Err(e), + } + }; - // The connection is now closed, asked to respawn if needed - if did_update { - ctx.server_tx.send(ServerSignal::Respawn).await.ok(); - } + // The connection is now closed, asked to respawn if needed + if did_update { + ctx.server_tx.send(ServerSignal::Respawn).await.ok(); + } - result + result } /// Reads and handles the next data packet, returns true if the read loop should continue. async fn read_next( - socket_reader: &mut BufReader, - ctx: &mut HandlerContext, - decode_buf: &mut Vec, - did_update: &mut bool, + socket_reader: &mut BufReader, + ctx: &mut HandlerContext, + decode_buf: &mut Vec, + did_update: &mut bool, ) -> Result { - let msg_length = tokio::select! { - u = socket_reader.read_u32() => u? as usize, - _ = ctx.closer.wait() => return Ok(false), - }; - decode_buf.resize(msg_length, 0); - ctx.rx_counter - .fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed); + let msg_length = tokio::select! { + u = socket_reader.read_u32() => u? as usize, + _ = ctx.closer.wait() => return Ok(false), + }; + decode_buf.resize(msg_length, 0); + ctx.rx_counter + .fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed); - tokio::select! { - r = socket_reader.read_exact(decode_buf) => r?, - _ = ctx.closer.wait() => return Ok(false), - }; + tokio::select! { + r = socket_reader.read_exact(decode_buf) => r?, + _ = ctx.closer.wait() => return Ok(false), + }; - let req = match rmp_serde::from_slice::(decode_buf) { - Ok(req) => req, - Err(e) => { - warning!(ctx.log, "Error decoding message: {}", e); - return Ok(true); // not fatal - } - }; + let req = match rmp_serde::from_slice::(decode_buf) { + Ok(req) => req, + Err(e) => { + warning!(ctx.log, "Error decoding message: {}", e); + return Ok(true); // not fatal + } + }; - let log = ctx.log.prefixed( - req.id - .map(|id| format!("[call.{}]", id)) - .as_deref() - .unwrap_or("notify"), - ); + let log = ctx.log.prefixed( + req.id + .map(|id| format!("[call.{}]", id)) + .as_deref() + .unwrap_or("notify"), + ); - macro_rules! success { - ($r:expr) => { - req.id - .map(|id| rmp_serde::to_vec_named(&SuccessResponse { id, result: &$r })) - }; - } + macro_rules! success { + ($r:expr) => { + req.id + .map(|id| rmp_serde::to_vec_named(&SuccessResponse { id, result: &$r })) + }; + } - macro_rules! tj { - ($name:expr, $e:expr) => { - match (spanf!( - log, - log.span(&format!("call.{}", $name)) - .with_kind(opentelemetry::trace::SpanKind::Server), - $e - )) { - Ok(r) => success!(r), - Err(e) => { - warning!(log, "error handling call: {:?}", e); - req.id.map(|id| { - rmp_serde::to_vec_named(&ErrorResponse { - id, - error: ResponseError { - code: -1, - message: format!("{:?}", e), - }, - }) - }) - } - } - }; - } + macro_rules! tj { + ($name:expr, $e:expr) => { + match (spanf!( + log, + log.span(&format!("call.{}", $name)) + .with_kind(opentelemetry::trace::SpanKind::Server), + $e + )) { + Ok(r) => success!(r), + Err(e) => { + warning!(log, "error handling call: {:?}", e); + req.id.map(|id| { + rmp_serde::to_vec_named(&ErrorResponse { + id, + error: ResponseError { + code: -1, + message: format!("{:?}", e), + }, + }) + }) + } + } + }; + } - let response = match req.params { - ServerRequestMethod::ping(_) => success!(EmptyResult {}), - ServerRequestMethod::serve(p) => tj!("serve", handle_serve(ctx, &log, p)), - ServerRequestMethod::prune => tj!("prune", handle_prune(ctx)), - ServerRequestMethod::gethostname(_) => tj!("gethostname", handle_get_hostname()), - ServerRequestMethod::update(p) => tj!("update", async { - let r = handle_update(ctx, &p).await; - if matches!(&r, Ok(u) if u.did_update) { - *did_update = true; - } - r - }), - ServerRequestMethod::servermsg(m) => { - if let Err(e) = handle_server_message(ctx, m).await { - warning!(log, "error handling call: {:?}", e); - } - None - } - ServerRequestMethod::callserverhttp(p) => { - tj!("callserverhttp", handle_call_server_http(ctx, p)) - } - ServerRequestMethod::forward(p) => tj!("forward", handle_forward(ctx, p)), - ServerRequestMethod::unforward(p) => tj!("unforward", handle_unforward(ctx, p)), - }; + let response = match req.params { + ServerRequestMethod::ping(_) => success!(EmptyResult {}), + ServerRequestMethod::serve(p) => tj!("serve", handle_serve(ctx, &log, p)), + ServerRequestMethod::prune => tj!("prune", handle_prune(ctx)), + ServerRequestMethod::gethostname(_) => tj!("gethostname", handle_get_hostname()), + ServerRequestMethod::update(p) => tj!("update", async { + let r = handle_update(ctx, &p).await; + if matches!(&r, Ok(u) if u.did_update) { + *did_update = true; + } + r + }), + ServerRequestMethod::servermsg(m) => { + if let Err(e) = handle_server_message(ctx, m).await { + warning!(log, "error handling call: {:?}", e); + } + None + } + ServerRequestMethod::callserverhttp(p) => { + tj!("callserverhttp", handle_call_server_http(ctx, p)) + } + ServerRequestMethod::forward(p) => tj!("forward", handle_forward(ctx, p)), + ServerRequestMethod::unforward(p) => tj!("unforward", handle_unforward(ctx, p)), + }; - if let Some(Ok(res)) = response { - if ctx.socket_tx.send(SocketSignal::Send(res)).await.is_err() { - return Ok(false); - } - } + if let Some(Ok(res)) = response { + if ctx.socket_tx.send(SocketSignal::Send(res)).await.is_err() { + return Ok(false); + } + } - Ok(true) + Ok(true) } #[derive(Clone)] struct ServerOutputSink { - tx: mpsc::Sender, + tx: mpsc::Sender, } impl log::LogSink for ServerOutputSink { - fn write_log(&self, level: log::Level, _prefix: &str, message: &str) { - let s = SocketSignal::from_message(&ToClientRequest { - id: None, - params: ClientRequestMethod::serverlog(ServerLog { - line: message, - level: level.to_u8(), - }), - }); + fn write_log(&self, level: log::Level, _prefix: &str, message: &str) { + let s = SocketSignal::from_message(&ToClientRequest { + id: None, + params: ClientRequestMethod::serverlog(ServerLog { + line: message, + level: level.to_u8(), + }), + }); - self.tx.try_send(s).ok(); - } + self.tx.try_send(s).ok(); + } - fn write_result(&self, _message: &str) {} + fn write_result(&self, _message: &str) {} } async fn handle_serve( - ctx: &mut HandlerContext, - log: &log::Logger, - params: ServeParams, + ctx: &mut HandlerContext, + log: &log::Logger, + params: ServeParams, ) -> Result { - let mut code_server_args = ctx.code_server_args.clone(); + let mut code_server_args = ctx.code_server_args.clone(); - // fill params.extensions into code_server_args.install_extensions - code_server_args - .install_extensions - .extend(params.extensions.into_iter()); + // fill params.extensions into code_server_args.install_extensions + code_server_args + .install_extensions + .extend(params.extensions.into_iter()); - let resolved = ServerParamsRaw { - commit_id: params.commit_id, - quality: params.quality, - code_server_args, - headless: true, - platform: ctx.platform, - } - .resolve(log) - .await?; + let resolved = ServerParamsRaw { + commit_id: params.commit_id, + quality: params.quality, + code_server_args, + headless: true, + platform: ctx.platform, + } + .resolve(log) + .await?; - if ctx.code_server.is_none() { - let install_log = log.tee(ServerOutputSink { - tx: ctx.socket_tx.clone(), - }); - let sb = ServerBuilder::new(&install_log, &resolved, &ctx.launcher_paths); + if ctx.code_server.is_none() { + let install_log = log.tee(ServerOutputSink { + tx: ctx.socket_tx.clone(), + }); + let sb = ServerBuilder::new(&install_log, &resolved, &ctx.launcher_paths); - let server = match sb.get_running().await? { - Some(AnyCodeServer::Socket(s)) => s, - Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())), - None => { - sb.setup().await?; - sb.listen_on_default_socket().await? - } - }; + let server = match sb.get_running().await? { + Some(AnyCodeServer::Socket(s)) => s, + Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())), + None => { + sb.setup().await?; + sb.listen_on_default_socket().await? + } + }; - ctx.code_server = Some(server); - } + ctx.code_server = Some(server); + } - attach_server_bridge(ctx, params.socket_id).await?; - Ok(EmptyResult {}) + attach_server_bridge(ctx, params.socket_id).await?; + Ok(EmptyResult {}) } async fn attach_server_bridge(ctx: &mut HandlerContext, socket_id: u16) -> Result { - let attached_fut = ServerBridge::new( - &ctx.code_server.as_ref().unwrap().socket, - socket_id, - &ctx.socket_tx, - ) - .await; + let attached_fut = ServerBridge::new( + &ctx.code_server.as_ref().unwrap().socket, + socket_id, + &ctx.socket_tx, + ) + .await; - match attached_fut { - Ok(a) => { - let mut lock = ctx.server_bridges.lock().await; - match &mut *lock { - Some(server_bridges) => (*server_bridges).push((socket_id, a)), - None => *lock = Some(vec![(socket_id, a)]), - } - trace!(ctx.log, "Attached to server"); - Ok(socket_id) - } - Err(e) => Err(e), - } + match attached_fut { + Ok(a) => { + let mut lock = ctx.server_bridges.lock().await; + match &mut *lock { + Some(server_bridges) => (*server_bridges).push((socket_id, a)), + None => *lock = Some(vec![(socket_id, a)]), + } + trace!(ctx.log, "Attached to server"); + Ok(socket_id) + } + Err(e) => Err(e), + } } async fn handle_server_message( - ctx: &mut HandlerContext, - params: ServerMessageParams, + ctx: &mut HandlerContext, + params: ServerMessageParams, ) -> Result { - let mut lock = ctx.server_bridges.lock().await; + let mut lock = ctx.server_bridges.lock().await; - match &mut *lock { - Some(server_bridges) => { - let matched_bridge = server_bridges.iter_mut().find(|(id, _)| *id == params.i); + match &mut *lock { + Some(server_bridges) => { + let matched_bridge = server_bridges.iter_mut().find(|(id, _)| *id == params.i); - match matched_bridge { - Some((_, sb)) => sb - .write(params.body) - .await - .map_err(|_| AnyError::from(ServerWriteError()))?, - None => return Err(AnyError::from(NoAttachedServerError())), - } - } - None => return Err(AnyError::from(NoAttachedServerError())), - } + match matched_bridge { + Some((_, sb)) => sb + .write(params.body) + .await + .map_err(|_| AnyError::from(ServerWriteError()))?, + None => return Err(AnyError::from(NoAttachedServerError())), + } + } + None => return Err(AnyError::from(NoAttachedServerError())), + } - Ok(EmptyResult {}) + Ok(EmptyResult {}) } async fn handle_prune(ctx: &HandlerContext) -> Result, AnyError> { - prune_stopped_servers(&ctx.launcher_paths).map(|v| { - v.iter() - .map(|p| p.server_dir.display().to_string()) - .collect() - }) + prune_stopped_servers(&ctx.launcher_paths).map(|v| { + v.iter() + .map(|p| p.server_dir.display().to_string()) + .collect() + }) } async fn handle_update( - ctx: &HandlerContext, - params: &UpdateParams, + ctx: &HandlerContext, + params: &UpdateParams, ) -> Result { - let updater = Update::new(); - let latest_release = updater.get_latest_release().await?; + let updater = Update::new(); + let latest_release = updater.get_latest_release().await?; - let up_to_date = match VSCODE_CLI_VERSION { - Some(v) => v == latest_release.version, - None => true, - }; + let up_to_date = match VSCODE_CLI_VERSION { + Some(v) => v == latest_release.version, + None => true, + }; - if !params.do_update || up_to_date { - return Ok(UpdateResult { - up_to_date, - did_update: false, - }); - } + if !params.do_update || up_to_date { + return Ok(UpdateResult { + up_to_date, + did_update: false, + }); + } - info!(ctx.log, "Updating CLI from {}", latest_release.version); + info!(ctx.log, "Updating CLI from {}", latest_release.version); - let current_exe = std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?; + let current_exe = std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?; - updater - .switch_to_release(&latest_release, ¤t_exe) - .await?; + updater + .switch_to_release(&latest_release, ¤t_exe) + .await?; - Ok(UpdateResult { - up_to_date: true, - did_update: true, - }) + Ok(UpdateResult { + up_to_date: true, + did_update: true, + }) } async fn handle_get_hostname() -> Result { - Ok(GetHostnameResponse { - value: gethostname::gethostname().to_string_lossy().into_owned(), - }) + Ok(GetHostnameResponse { + value: gethostname::gethostname().to_string_lossy().into_owned(), + }) } async fn handle_forward( - ctx: &HandlerContext, - params: ForwardParams, + ctx: &HandlerContext, + params: ForwardParams, ) -> Result { - info!(ctx.log, "Forwarding port {}", params.port); - let uri = ctx.port_forwarding.forward(params.port).await?; - Ok(ForwardResult { uri }) + info!(ctx.log, "Forwarding port {}", params.port); + let uri = ctx.port_forwarding.forward(params.port).await?; + Ok(ForwardResult { uri }) } async fn handle_unforward( - ctx: &HandlerContext, - params: UnforwardParams, + ctx: &HandlerContext, + params: UnforwardParams, ) -> Result { - info!(ctx.log, "Unforwarding port {}", params.port); - ctx.port_forwarding.unforward(params.port).await?; - Ok(EmptyResult {}) + info!(ctx.log, "Unforwarding port {}", params.port); + ctx.port_forwarding.unforward(params.port).await?; + Ok(EmptyResult {}) } async fn handle_call_server_http( - ctx: &HandlerContext, - params: CallServerHttpParams, + ctx: &HandlerContext, + params: CallServerHttpParams, ) -> Result { - use hyper::{body, client::conn::Builder, Body, Request}; + use hyper::{body, client::conn::Builder, Body, Request}; - // We use Hyper directly here since reqwest doesn't support sockets/pipes. - // See https://github.com/seanmonstar/reqwest/issues/39 + // We use Hyper directly here since reqwest doesn't support sockets/pipes. + // See https://github.com/seanmonstar/reqwest/issues/39 - let socket = match &ctx.code_server { - Some(cs) => &cs.socket, - None => return Err(AnyError::from(NoAttachedServerError())), - }; + let socket = match &ctx.code_server { + Some(cs) => &cs.socket, + None => return Err(AnyError::from(NoAttachedServerError())), + }; - let rw = get_socket_rw_stream(socket).await?; + let rw = get_socket_rw_stream(socket).await?; - let (mut request_sender, connection) = Builder::new() - .handshake(rw) - .await - .map_err(|e| wrap(e, "error establishing connection"))?; + let (mut request_sender, connection) = Builder::new() + .handshake(rw) + .await + .map_err(|e| wrap(e, "error establishing connection"))?; - // start the connection processing; it's shut down when the sender is dropped - tokio::spawn(connection); + // start the connection processing; it's shut down when the sender is dropped + tokio::spawn(connection); - let mut request_builder = Request::builder() - .method::<&str>(params.method.as_ref()) - .uri(format!("http://127.0.0.1{}", params.path)) - .header("Host", "127.0.0.1"); + let mut request_builder = Request::builder() + .method::<&str>(params.method.as_ref()) + .uri(format!("http://127.0.0.1{}", params.path)) + .header("Host", "127.0.0.1"); - for (k, v) in params.headers { - request_builder = request_builder.header(k, v); - } - let request = request_builder - .body(Body::from(params.body.unwrap_or_default())) - .map_err(|e| wrap(e, "invalid request"))?; + for (k, v) in params.headers { + request_builder = request_builder.header(k, v); + } + let request = request_builder + .body(Body::from(params.body.unwrap_or_default())) + .map_err(|e| wrap(e, "invalid request"))?; - let response = request_sender - .send_request(request) - .await - .map_err(|e| wrap(e, "error sending request"))?; + let response = request_sender + .send_request(request) + .await + .map_err(|e| wrap(e, "error sending request"))?; - Ok(CallServerHttpResult { - status: response.status().as_u16(), - headers: response - .headers() - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(), - body: body::to_bytes(response) - .await - .map_err(|e| wrap(e, "error reading response body"))? - .to_vec(), - }) + Ok(CallServerHttpResult { + status: response.status().as_u16(), + headers: response + .headers() + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(), + body: body::to_bytes(response) + .await + .map_err(|e| wrap(e, "error reading response body"))? + .to_vec(), + }) } diff --git a/cli/src/tunnels/dev_tunnels.rs b/cli/src/tunnels/dev_tunnels.rs index 55604108788..224e0a43b62 100644 --- a/cli/src/tunnels/dev_tunnels.rs +++ b/cli/src/tunnels/dev_tunnels.rs @@ -6,7 +6,7 @@ use crate::auth; use crate::constants::{CONTROL_PORT, TUNNEL_SERVICE_USER_AGENT}; use crate::state::{LauncherPaths, PersistedState}; use crate::util::errors::{ - wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError, + wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError, }; use crate::util::input::prompt_placeholder; use crate::{debug, info, log, spanf, trace, warning}; @@ -21,802 +21,801 @@ use std::time::Duration; use tokio::sync::{mpsc, watch}; use tunnels::connections::{ForwardedPortConnection, HostRelay}; use tunnels::contracts::{ - Tunnel, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN, TUNNEL_PROTOCOL_AUTO, + Tunnel, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN, TUNNEL_PROTOCOL_AUTO, }; use tunnels::management::{ - new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions, - NO_REQUEST_OPTIONS, + new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions, + NO_REQUEST_OPTIONS, }; use super::name_generator; #[derive(Clone, Serialize, Deserialize)] pub struct PersistedTunnel { - pub name: String, - pub id: String, - pub cluster: String, + pub name: String, + pub id: String, + pub cluster: String, } impl PersistedTunnel { - pub fn into_locator(self) -> TunnelLocator { - TunnelLocator::ID { - cluster: self.cluster, - id: self.id, - } - } - pub fn locator(&self) -> TunnelLocator { - TunnelLocator::ID { - cluster: self.cluster.clone(), - id: self.id.clone(), - } - } + pub fn into_locator(self) -> TunnelLocator { + TunnelLocator::ID { + cluster: self.cluster, + id: self.id, + } + } + pub fn locator(&self) -> TunnelLocator { + TunnelLocator::ID { + cluster: self.cluster.clone(), + id: self.id.clone(), + } + } } #[async_trait] trait AccessTokenProvider: Send + Sync { - /// Gets the current access token. - async fn refresh_token(&self) -> Result; + /// Gets the current access token. + async fn refresh_token(&self) -> Result; } /// Access token provider that provides a fixed token without refreshing. struct StaticAccessTokenProvider(String); impl StaticAccessTokenProvider { - pub fn new(token: String) -> Self { - Self(token) - } + pub fn new(token: String) -> Self { + Self(token) + } } #[async_trait] impl AccessTokenProvider for StaticAccessTokenProvider { - async fn refresh_token(&self) -> Result { - Ok(self.0.clone()) - } + async fn refresh_token(&self) -> Result { + Ok(self.0.clone()) + } } /// Access token provider that looks up the token from the tunnels API. struct LookupAccessTokenProvider { - client: TunnelManagementClient, - locator: TunnelLocator, - log: log::Logger, - initial_token: Arc>>, + client: TunnelManagementClient, + locator: TunnelLocator, + log: log::Logger, + initial_token: Arc>>, } impl LookupAccessTokenProvider { - pub fn new( - client: TunnelManagementClient, - locator: TunnelLocator, - log: log::Logger, - initial_token: Option, - ) -> Self { - Self { - client, - locator, - log, - initial_token: Arc::new(Mutex::new(initial_token)), - } - } + pub fn new( + client: TunnelManagementClient, + locator: TunnelLocator, + log: log::Logger, + initial_token: Option, + ) -> Self { + Self { + client, + locator, + log, + initial_token: Arc::new(Mutex::new(initial_token)), + } + } } #[async_trait] impl AccessTokenProvider for LookupAccessTokenProvider { - async fn refresh_token(&self) -> Result { - if let Some(token) = self.initial_token.lock().unwrap().take() { - return Ok(token); - } + async fn refresh_token(&self) -> Result { + if let Some(token) = self.initial_token.lock().unwrap().take() { + return Ok(token); + } - let tunnel_lookup = spanf!( - self.log, - self.log.span("dev-tunnel.tag.get"), - self.client.get_tunnel( - &self.locator, - &TunnelRequestOptions { - token_scopes: vec!["host".to_string()], - ..Default::default() - } - ) - ); + let tunnel_lookup = spanf!( + self.log, + self.log.span("dev-tunnel.tag.get"), + self.client.get_tunnel( + &self.locator, + &TunnelRequestOptions { + token_scopes: vec!["host".to_string()], + ..Default::default() + } + ) + ); - trace!(self.log, "Successfully refreshed access token"); + trace!(self.log, "Successfully refreshed access token"); - match tunnel_lookup { - Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)), - Err(e) => Err(wrap(e, "failed to lookup tunnel")), - } - } + match tunnel_lookup { + Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)), + Err(e) => Err(wrap(e, "failed to lookup tunnel")), + } + } } #[derive(Clone)] pub struct DevTunnels { - log: log::Logger, - launcher_tunnel: PersistedState>, - client: TunnelManagementClient, + log: log::Logger, + launcher_tunnel: PersistedState>, + client: TunnelManagementClient, } /// Representation of a tunnel returned from the `start` methods. pub struct ActiveTunnel { - /// Name of the tunnel - pub name: String, - manager: ActiveTunnelManager, + /// Name of the tunnel + pub name: String, + manager: ActiveTunnelManager, } impl ActiveTunnel { - /// Closes and unregisters the tunnel. - pub async fn close(&mut self) -> Result<(), AnyError> { - self.manager.kill().await?; - Ok(()) - } + /// Closes and unregisters the tunnel. + pub async fn close(&mut self) -> Result<(), AnyError> { + self.manager.kill().await?; + Ok(()) + } - /// Forwards a port to local connections. - pub async fn add_port_direct( - &mut self, - port_number: u16, - ) -> Result, AnyError> { - let port = self.manager.add_port_direct(port_number).await?; - Ok(port) - } + /// Forwards a port to local connections. + pub async fn add_port_direct( + &mut self, + port_number: u16, + ) -> Result, AnyError> { + let port = self.manager.add_port_direct(port_number).await?; + Ok(port) + } - /// Forwards a port over TCP. - pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> { - self.manager.add_port_tcp(port_number).await?; - Ok(()) - } + /// Forwards a port over TCP. + pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> { + self.manager.add_port_tcp(port_number).await?; + Ok(()) + } - /// Removes a forwarded port TCP. - pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> { - self.manager.remove_port(port_number).await?; - Ok(()) - } + /// Removes a forwarded port TCP. + pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> { + self.manager.remove_port(port_number).await?; + Ok(()) + } - /// Gets the public URI on which a forwarded port can be access in browser. - pub async fn get_port_uri(&mut self, port: u16) -> Result { - let endpoint = self.manager.get_endpoint().await?; - let format = endpoint - .base - .port_uri_format - .expect("expected to have port format"); + /// Gets the public URI on which a forwarded port can be access in browser. + pub async fn get_port_uri(&mut self, port: u16) -> Result { + let endpoint = self.manager.get_endpoint().await?; + let format = endpoint + .base + .port_uri_format + .expect("expected to have port format"); - Ok(format.replace(PORT_TOKEN, &port.to_string())) - } + Ok(format.replace(PORT_TOKEN, &port.to_string())) + } } const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher"; const MAX_TUNNEL_NAME_LENGTH: usize = 20; fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String { - tunnel - .access_tokens - .as_ref() - .expect("expected to have access tokens") - .get("host") - .expect("expected to have host token") - .to_string() + tunnel + .access_tokens + .as_ref() + .expect("expected to have access tokens") + .get("host") + .expect("expected to have host token") + .to_string() } fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> { - if name.len() > MAX_TUNNEL_NAME_LENGTH { - return Err(InvalidTunnelName(format!( - "Names cannot be longer than {} characters. Please try a different name.", - MAX_TUNNEL_NAME_LENGTH - ))); - } + if name.len() > MAX_TUNNEL_NAME_LENGTH { + return Err(InvalidTunnelName(format!( + "Names cannot be longer than {} characters. Please try a different name.", + MAX_TUNNEL_NAME_LENGTH + ))); + } - let re = Regex::new(r"^([\w-]+)$").unwrap(); + let re = Regex::new(r"^([\w-]+)$").unwrap(); - if !re.is_match(name) { - return Err(InvalidTunnelName( + if !re.is_match(name) { + return Err(InvalidTunnelName( "Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string() )); - } + } - Ok(()) + Ok(()) } /// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel. #[derive(Clone, Debug)] pub struct ExistingTunnel { - /// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server - pub tunnel_name: String, + /// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server + pub tunnel_name: String, - /// Token to authenticate and use preexisting tunnel - pub host_token: String, + /// Token to authenticate and use preexisting tunnel + pub host_token: String, - /// Id of preexisting tunnel to use to connect to the VS Code Server - pub tunnel_id: String, + /// Id of preexisting tunnel to use to connect to the VS Code Server + pub tunnel_id: String, - /// Cluster of preexisting tunnel to use to connect to the VS Code Server - pub cluster: String, + /// Cluster of preexisting tunnel to use to connect to the VS Code Server + pub cluster: String, } impl DevTunnels { - pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels { - let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT); - client.authorization_provider(auth); + pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels { + let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT); + client.authorization_provider(auth); - DevTunnels { - log: log.clone(), - client: client.into(), - launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")), - } - } + DevTunnels { + log: log.clone(), + client: client.into(), + launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")), + } + } - pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> { - let tunnel = match self.launcher_tunnel.load() { - Some(t) => t, - None => { - return Ok(()); - } - }; + pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> { + let tunnel = match self.launcher_tunnel.load() { + Some(t) => t, + None => { + return Ok(()); + } + }; - spanf!( - self.log, - self.log.span("dev-tunnel.delete"), - self.client - .delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS) - ) - .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?; + spanf!( + self.log, + self.log.span("dev-tunnel.delete"), + self.client + .delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS) + ) + .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?; - self.launcher_tunnel.save(None)?; - Ok(()) - } + self.launcher_tunnel.save(None)?; + Ok(()) + } - pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> { - is_valid_name(name)?; + pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> { + is_valid_name(name)?; - let existing = spanf!( - self.log, - self.log.span("dev-tunnel.rename.search"), - self.client.list_all_tunnels(&TunnelRequestOptions { - tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()], - require_all_tags: true, - ..Default::default() - }) - ) - .map_err(|e| wrap(e, "failed to list existing tunnels"))?; + let existing = spanf!( + self.log, + self.log.span("dev-tunnel.rename.search"), + self.client.list_all_tunnels(&TunnelRequestOptions { + tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()], + require_all_tags: true, + ..Default::default() + }) + ) + .map_err(|e| wrap(e, "failed to list existing tunnels"))?; - if !existing.is_empty() { - return Err(AnyError::from(TunnelCreationFailed( - name.to_string(), - "tunnel name already in use".to_string(), - ))); - } + if !existing.is_empty() { + return Err(AnyError::from(TunnelCreationFailed( + name.to_string(), + "tunnel name already in use".to_string(), + ))); + } - let mut tunnel = match self.launcher_tunnel.load() { - Some(t) => t, - None => { - debug!(self.log, "No code server tunnel found, creating new one"); - let (persisted, _) = self.create_tunnel(name).await?; - self.launcher_tunnel.save(Some(persisted))?; - return Ok(()); - } - }; + let mut tunnel = match self.launcher_tunnel.load() { + Some(t) => t, + None => { + debug!(self.log, "No code server tunnel found, creating new one"); + let (persisted, _) = self.create_tunnel(name).await?; + self.launcher_tunnel.save(Some(persisted))?; + return Ok(()); + } + }; - let locator = tunnel.locator(); + let locator = tunnel.locator(); - let mut full_tunnel = spanf!( - self.log, - self.log.span("dev-tunnel.tag.get"), - self.client.get_tunnel(&locator, NO_REQUEST_OPTIONS) - ) - .map_err(|e| wrap(e, "failed to lookup tunnel"))?; + let mut full_tunnel = spanf!( + self.log, + self.log.span("dev-tunnel.tag.get"), + self.client.get_tunnel(&locator, NO_REQUEST_OPTIONS) + ) + .map_err(|e| wrap(e, "failed to lookup tunnel"))?; - full_tunnel.tags = vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()]; - spanf!( - self.log, - self.log.span("dev-tunnel.tag.update"), - self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS) - ) - .map_err(|e| wrap(e, "failed to update tunnel tags"))?; + full_tunnel.tags = vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()]; + spanf!( + self.log, + self.log.span("dev-tunnel.tag.update"), + self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS) + ) + .map_err(|e| wrap(e, "failed to update tunnel tags"))?; - tunnel.name = name.to_string(); - self.launcher_tunnel.save(Some(tunnel.clone()))?; - Ok(()) - } + tunnel.name = name.to_string(); + self.launcher_tunnel.save(Some(tunnel.clone()))?; + Ok(()) + } - /// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`, - /// this attempts to reuse or generate a friendly tunnel name. - pub async fn start_new_launcher_tunnel( - &mut self, - use_random_name: bool, - ) -> Result { - let (tunnel, persisted) = match self.launcher_tunnel.load() { - Some(persisted) => { - let tunnel_lookup = spanf!( - self.log, - self.log.span("dev-tunnel.tag.get"), - self.client.get_tunnel( - &persisted.locator(), - &TunnelRequestOptions { - include_ports: true, - token_scopes: vec!["host".to_string()], - ..Default::default() - } - ) - ); + /// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`, + /// this attempts to reuse or generate a friendly tunnel name. + pub async fn start_new_launcher_tunnel( + &mut self, + use_random_name: bool, + ) -> Result { + let (tunnel, persisted) = match self.launcher_tunnel.load() { + Some(persisted) => { + let tunnel_lookup = spanf!( + self.log, + self.log.span("dev-tunnel.tag.get"), + self.client.get_tunnel( + &persisted.locator(), + &TunnelRequestOptions { + include_ports: true, + token_scopes: vec!["host".to_string()], + ..Default::default() + } + ) + ); - match tunnel_lookup { - Ok(ft) => (ft, persisted), - Err(HttpError::ResponseError(e)) - if e.status_code == StatusCode::NOT_FOUND - || e.status_code == StatusCode::FORBIDDEN => - { - let (persisted, tunnel) = self.create_tunnel(&persisted.name).await?; - self.launcher_tunnel.save(Some(persisted.clone()))?; - (tunnel, persisted) - } - Err(e) => return Err(AnyError::from(wrap(e, "failed to lookup tunnel"))), - } - } - None => { - debug!(self.log, "No code server tunnel found, creating new one"); - let name = self.get_name_for_tunnel(use_random_name).await?; - let (persisted, full_tunnel) = self.create_tunnel(&name).await?; - self.launcher_tunnel.save(Some(persisted.clone()))?; - (full_tunnel, persisted) - } - }; + match tunnel_lookup { + Ok(ft) => (ft, persisted), + Err(HttpError::ResponseError(e)) + if e.status_code == StatusCode::NOT_FOUND + || e.status_code == StatusCode::FORBIDDEN => + { + let (persisted, tunnel) = self.create_tunnel(&persisted.name).await?; + self.launcher_tunnel.save(Some(persisted.clone()))?; + (tunnel, persisted) + } + Err(e) => return Err(AnyError::from(wrap(e, "failed to lookup tunnel"))), + } + } + None => { + debug!(self.log, "No code server tunnel found, creating new one"); + let name = self.get_name_for_tunnel(use_random_name).await?; + let (persisted, full_tunnel) = self.create_tunnel(&name).await?; + self.launcher_tunnel.save(Some(persisted.clone()))?; + (full_tunnel, persisted) + } + }; - let locator = TunnelLocator::try_from(&tunnel).unwrap(); - let host_token = get_host_token_from_tunnel(&tunnel); + let locator = TunnelLocator::try_from(&tunnel).unwrap(); + let host_token = get_host_token_from_tunnel(&tunnel); - for port_to_delete in tunnel - .ports - .iter() - .filter(|p| p.port_number != CONTROL_PORT) - { - let output_fut = self.client.delete_tunnel_port( - &locator, - port_to_delete.port_number, - NO_REQUEST_OPTIONS, - ); - spanf!( - self.log, - self.log.span("dev-tunnel.port.delete"), - output_fut - ) - .map_err(|e| wrap(e, "failed to delete port"))?; - } + for port_to_delete in tunnel + .ports + .iter() + .filter(|p| p.port_number != CONTROL_PORT) + { + let output_fut = self.client.delete_tunnel_port( + &locator, + port_to_delete.port_number, + NO_REQUEST_OPTIONS, + ); + spanf!( + self.log, + self.log.span("dev-tunnel.port.delete"), + output_fut + ) + .map_err(|e| wrap(e, "failed to delete port"))?; + } - // cleanup any old trailing tunnel endpoints - for endpoint in tunnel.endpoints { - let fut = self.client.delete_tunnel_endpoints( - &locator, - &endpoint.host_id, - None, - NO_REQUEST_OPTIONS, - ); + // cleanup any old trailing tunnel endpoints + for endpoint in tunnel.endpoints { + let fut = self.client.delete_tunnel_endpoints( + &locator, + &endpoint.host_id, + None, + NO_REQUEST_OPTIONS, + ); - spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut) - .map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?; - } + spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut) + .map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?; + } - self.start_tunnel( - locator.clone(), - &persisted, - self.client.clone(), - LookupAccessTokenProvider::new( - self.client.clone(), - locator, - self.log.clone(), - Some(host_token), - ), - ) - .await - } + self.start_tunnel( + locator.clone(), + &persisted, + self.client.clone(), + LookupAccessTokenProvider::new( + self.client.clone(), + locator, + self.log.clone(), + Some(host_token), + ), + ) + .await + } - async fn create_tunnel(&mut self, name: &str) -> Result<(PersistedTunnel, Tunnel), AnyError> { - info!(self.log, "Creating tunnel with the name: {}", name); + async fn create_tunnel(&mut self, name: &str) -> Result<(PersistedTunnel, Tunnel), AnyError> { + info!(self.log, "Creating tunnel with the name: {}", name); - let mut tried_recycle = false; + let mut tried_recycle = false; - let new_tunnel = Tunnel { - tags: vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()], - ..Default::default() - }; + let new_tunnel = Tunnel { + tags: vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()], + ..Default::default() + }; - loop { - let result = spanf!( - self.log, - self.log.span("dev-tunnel.create"), - self.client.create_tunnel(&new_tunnel, NO_REQUEST_OPTIONS) - ); + loop { + let result = spanf!( + self.log, + self.log.span("dev-tunnel.create"), + self.client.create_tunnel(&new_tunnel, NO_REQUEST_OPTIONS) + ); - match result { - Err(HttpError::ResponseError(e)) - if e.status_code == StatusCode::TOO_MANY_REQUESTS => - { - if !tried_recycle && self.try_recycle_tunnel().await? { - tried_recycle = true; - continue; - } + match result { + Err(HttpError::ResponseError(e)) + if e.status_code == StatusCode::TOO_MANY_REQUESTS => + { + if !tried_recycle && self.try_recycle_tunnel().await? { + tried_recycle = true; + continue; + } - return Err(AnyError::from(TunnelCreationFailed( + return Err(AnyError::from(TunnelCreationFailed( name.to_string(), "You've exceeded the 10 machine limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(), ))); - } - Err(e) => { - return Err(AnyError::from(TunnelCreationFailed( - name.to_string(), - format!("{:?}", e), - ))) - } - Ok(t) => { - return Ok(( - PersistedTunnel { - cluster: t.cluster_id.clone().unwrap(), - id: t.tunnel_id.clone().unwrap(), - name: name.to_string(), - }, - t, - )) - } - } - } - } + } + Err(e) => { + return Err(AnyError::from(TunnelCreationFailed( + name.to_string(), + format!("{:?}", e), + ))) + } + Ok(t) => { + return Ok(( + PersistedTunnel { + cluster: t.cluster_id.clone().unwrap(), + id: t.tunnel_id.clone().unwrap(), + name: name.to_string(), + }, + t, + )) + } + } + } + } - /// Tries to delete an unused tunnel, and then creates a tunnel with the - /// given `new_name`. - async fn try_recycle_tunnel(&mut self) -> Result { - trace!( - self.log, - "Tunnel limit hit, trying to recycle an old tunnel" - ); + /// Tries to delete an unused tunnel, and then creates a tunnel with the + /// given `new_name`. + async fn try_recycle_tunnel(&mut self) -> Result { + trace!( + self.log, + "Tunnel limit hit, trying to recycle an old tunnel" + ); - let existing_tunnels = self.list_all_server_tunnels().await?; + let existing_tunnels = self.list_all_server_tunnels().await?; - let recyclable = existing_tunnels - .iter() - .filter(|t| { - t.status - .as_ref() - .and_then(|s| s.host_connection_count.as_ref()) - .map(|c| c.get_count()) - .unwrap_or(0) - == 0 - }) - .choose(&mut rand::thread_rng()); + let recyclable = existing_tunnels + .iter() + .filter(|t| { + t.status + .as_ref() + .and_then(|s| s.host_connection_count.as_ref()) + .map(|c| c.get_count()) + .unwrap_or(0) == 0 + }) + .choose(&mut rand::thread_rng()); - match recyclable { - Some(tunnel) => { - trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id); - spanf!( - self.log, - self.log.span("dev-tunnel.delete"), - self.client - .delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS) - ) - .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?; - Ok(true) - } - None => { - trace!(self.log, "No tunnels available to recycle"); - Ok(false) - } - } - } + match recyclable { + Some(tunnel) => { + trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id); + spanf!( + self.log, + self.log.span("dev-tunnel.delete"), + self.client + .delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS) + ) + .map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?; + Ok(true) + } + None => { + trace!(self.log, "No tunnels available to recycle"); + Ok(false) + } + } + } - async fn list_all_server_tunnels(&mut self) -> Result, AnyError> { - let tunnels = spanf!( - self.log, - self.log.span("dev-tunnel.listall"), - self.client.list_all_tunnels(&TunnelRequestOptions { - tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()], - require_all_tags: true, - ..Default::default() - }) - ) - .map_err(|e| wrap(e, "error listing current tunnels"))?; + async fn list_all_server_tunnels(&mut self) -> Result, AnyError> { + let tunnels = spanf!( + self.log, + self.log.span("dev-tunnel.listall"), + self.client.list_all_tunnels(&TunnelRequestOptions { + tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()], + require_all_tags: true, + ..Default::default() + }) + ) + .map_err(|e| wrap(e, "error listing current tunnels"))?; - Ok(tunnels) - } + Ok(tunnels) + } - async fn get_name_for_tunnel(&mut self, use_random_name: bool) -> Result { - let mut placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH); + async fn get_name_for_tunnel(&mut self, use_random_name: bool) -> Result { + let mut placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH); - let existing_tunnels = self.list_all_server_tunnels().await?; - let is_name_free = |n: &str| { - !existing_tunnels - .iter() - .any(|v| v.tags.iter().any(|t| t == n)) - }; + let existing_tunnels = self.list_all_server_tunnels().await?; + let is_name_free = |n: &str| { + !existing_tunnels + .iter() + .any(|v| v.tags.iter().any(|t| t == n)) + }; - if use_random_name { - while !is_name_free(&placeholder_name) { - placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH); - } - return Ok(placeholder_name); - } + if use_random_name { + while !is_name_free(&placeholder_name) { + placeholder_name = name_generator::generate_name(MAX_TUNNEL_NAME_LENGTH); + } + return Ok(placeholder_name); + } - loop { - let name = prompt_placeholder( - "What would you like to call this machine?", - &placeholder_name, - )?; + loop { + let name = prompt_placeholder( + "What would you like to call this machine?", + &placeholder_name, + )?; - if let Err(e) = is_valid_name(&name) { - info!(self.log, "{}", e); - continue; - } + if let Err(e) = is_valid_name(&name) { + info!(self.log, "{}", e); + continue; + } - if is_name_free(&name) { - return Ok(name); - } + if is_name_free(&name) { + return Ok(name); + } - info!(self.log, "The name {} is already in use", name); - } - } + info!(self.log, "The name {} is already in use", name); + } + } - /// Hosts an existing tunnel, where the tunnel ID and host token are given. - pub async fn start_existing_tunnel( - &mut self, - tunnel: ExistingTunnel, - ) -> Result { - let tunnel_details = PersistedTunnel { - name: tunnel.tunnel_name, - id: tunnel.tunnel_id, - cluster: tunnel.cluster, - }; + /// Hosts an existing tunnel, where the tunnel ID and host token are given. + pub async fn start_existing_tunnel( + &mut self, + tunnel: ExistingTunnel, + ) -> Result { + let tunnel_details = PersistedTunnel { + name: tunnel.tunnel_name, + id: tunnel.tunnel_id, + cluster: tunnel.cluster, + }; - let mut mgmt = self.client.build(); - mgmt.authorization(tunnels::management::Authorization::Tunnel( - tunnel.host_token.clone(), - )); + let mut mgmt = self.client.build(); + mgmt.authorization(tunnels::management::Authorization::Tunnel( + tunnel.host_token.clone(), + )); - self.start_tunnel( - tunnel_details.locator(), - &tunnel_details, - mgmt.into(), - StaticAccessTokenProvider::new(tunnel.host_token), - ) - .await - } + self.start_tunnel( + tunnel_details.locator(), + &tunnel_details, + mgmt.into(), + StaticAccessTokenProvider::new(tunnel.host_token), + ) + .await + } - async fn start_tunnel( - &mut self, - locator: TunnelLocator, - tunnel_details: &PersistedTunnel, - client: TunnelManagementClient, - access_token: impl AccessTokenProvider + 'static, - ) -> Result { - let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token); + async fn start_tunnel( + &mut self, + locator: TunnelLocator, + tunnel_details: &PersistedTunnel, + client: TunnelManagementClient, + access_token: impl AccessTokenProvider + 'static, + ) -> Result { + let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token); - let endpoint_result = spanf!( - self.log, - self.log.span("dev-tunnel.serve.callback"), - manager.get_endpoint() - ); + let endpoint_result = spanf!( + self.log, + self.log.span("dev-tunnel.serve.callback"), + manager.get_endpoint() + ); - let endpoint = match endpoint_result { - Ok(endpoint) => endpoint, - Err(e) => { - error!(self.log, "Error connecting to tunnel endpoint: {}", e); - manager.kill().await.ok(); - return Err(e); - } - }; + let endpoint = match endpoint_result { + Ok(endpoint) => endpoint, + Err(e) => { + error!(self.log, "Error connecting to tunnel endpoint: {}", e); + manager.kill().await.ok(); + return Err(e); + } + }; - debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint); + debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint); - Ok(ActiveTunnel { - name: tunnel_details.name.clone(), - manager, - }) - } + Ok(ActiveTunnel { + name: tunnel_details.name.clone(), + manager, + }) + } } struct ActiveTunnelManager { - close_tx: Option>, - endpoint_rx: watch::Receiver>>, - relay: Arc>, + close_tx: Option>, + endpoint_rx: watch::Receiver>>, + relay: Arc>, } impl ActiveTunnelManager { - pub fn new( - log: log::Logger, - mgmt: TunnelManagementClient, - locator: TunnelLocator, - access_token: impl AccessTokenProvider + 'static, - ) -> ActiveTunnelManager { - let (endpoint_tx, endpoint_rx) = watch::channel(None); - let (close_tx, close_rx) = mpsc::channel(1); + pub fn new( + log: log::Logger, + mgmt: TunnelManagementClient, + locator: TunnelLocator, + access_token: impl AccessTokenProvider + 'static, + ) -> ActiveTunnelManager { + let (endpoint_tx, endpoint_rx) = watch::channel(None); + let (close_tx, close_rx) = mpsc::channel(1); - let relay = Arc::new(tokio::sync::Mutex::new(HostRelay::new(locator, mgmt))); - let relay_spawned = relay.clone(); + let relay = Arc::new(tokio::sync::Mutex::new(HostRelay::new(locator, mgmt))); + let relay_spawned = relay.clone(); - tokio::spawn(async move { - ActiveTunnelManager::spawn_tunnel( - log, - relay_spawned, - close_rx, - endpoint_tx, - access_token, - ) - .await; - }); + tokio::spawn(async move { + ActiveTunnelManager::spawn_tunnel( + log, + relay_spawned, + close_rx, + endpoint_tx, + access_token, + ) + .await; + }); - ActiveTunnelManager { - endpoint_rx, - relay, - close_tx: Some(close_tx), - } - } + ActiveTunnelManager { + endpoint_rx, + relay, + close_tx: Some(close_tx), + } + } - /// Adds a port for TCP/IP forwarding. - #[allow(dead_code)] // todo: port forwarding - pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), WrappedError> { - self.relay - .lock() - .await - .add_port(&TunnelPort { - port_number, - protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()), - ..Default::default() - }) - .await - .map_err(|e| wrap(e, "error adding port to relay"))?; - Ok(()) - } + /// Adds a port for TCP/IP forwarding. + #[allow(dead_code)] // todo: port forwarding + pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), WrappedError> { + self.relay + .lock() + .await + .add_port(&TunnelPort { + port_number, + protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()), + ..Default::default() + }) + .await + .map_err(|e| wrap(e, "error adding port to relay"))?; + Ok(()) + } - /// Adds a port for TCP/IP forwarding. - pub async fn add_port_direct( - &self, - port_number: u16, - ) -> Result, WrappedError> { - self.relay - .lock() - .await - .add_port_raw(&TunnelPort { - port_number, - protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()), - ..Default::default() - }) - .await - .map_err(|e| wrap(e, "error adding port to relay")) - } + /// Adds a port for TCP/IP forwarding. + pub async fn add_port_direct( + &self, + port_number: u16, + ) -> Result, WrappedError> { + self.relay + .lock() + .await + .add_port_raw(&TunnelPort { + port_number, + protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()), + ..Default::default() + }) + .await + .map_err(|e| wrap(e, "error adding port to relay")) + } - /// Removes a port from TCP/IP forwarding. - pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> { - self.relay - .lock() - .await - .remove_port(port_number) - .await - .map_err(|e| wrap(e, "error remove port from relay")) - } + /// Removes a port from TCP/IP forwarding. + pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> { + self.relay + .lock() + .await + .remove_port(port_number) + .await + .map_err(|e| wrap(e, "error remove port from relay")) + } - /// Gets the most recent details from the tunnel process. Returns None if - /// the process exited before providing details. - pub async fn get_endpoint(&mut self) -> Result { - loop { - if let Some(details) = &*self.endpoint_rx.borrow() { - return details.clone().map_err(AnyError::from); - } + /// Gets the most recent details from the tunnel process. Returns None if + /// the process exited before providing details. + pub async fn get_endpoint(&mut self) -> Result { + loop { + if let Some(details) = &*self.endpoint_rx.borrow() { + return details.clone().map_err(AnyError::from); + } - if self.endpoint_rx.changed().await.is_err() { - return Err(DevTunnelError("tunnel creation cancelled".to_string()).into()); - } - } - } + if self.endpoint_rx.changed().await.is_err() { + return Err(DevTunnelError("tunnel creation cancelled".to_string()).into()); + } + } + } - /// Kills the process, and waits for it to exit. - /// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works - pub async fn kill(&mut self) -> Result<(), AnyError> { - if let Some(tx) = self.close_tx.take() { - drop(tx); - } + /// Kills the process, and waits for it to exit. + /// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works + pub async fn kill(&mut self) -> Result<(), AnyError> { + if let Some(tx) = self.close_tx.take() { + drop(tx); + } - self.relay - .lock() - .await - .unregister() - .await - .map_err(|e| wrap(e, "error unregistering relay"))?; + self.relay + .lock() + .await + .unregister() + .await + .map_err(|e| wrap(e, "error unregistering relay"))?; - while self.endpoint_rx.changed().await.is_ok() {} + while self.endpoint_rx.changed().await.is_ok() {} - Ok(()) - } + Ok(()) + } - async fn spawn_tunnel( - log: log::Logger, - relay: Arc>, - mut close_rx: mpsc::Receiver<()>, - endpoint_tx: watch::Sender>>, - access_token_provider: impl AccessTokenProvider + 'static, - ) { - let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120)); + async fn spawn_tunnel( + log: log::Logger, + relay: Arc>, + mut close_rx: mpsc::Receiver<()>, + endpoint_tx: watch::Sender>>, + access_token_provider: impl AccessTokenProvider + 'static, + ) { + let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120)); - macro_rules! fail { - ($e: expr, $msg: expr) => { - warning!(log, "{}: {}", $msg, $e); - endpoint_tx.send(Some(Err($e))).ok(); - backoff.delay().await; - }; - } + macro_rules! fail { + ($e: expr, $msg: expr) => { + warning!(log, "{}: {}", $msg, $e); + endpoint_tx.send(Some(Err($e))).ok(); + backoff.delay().await; + }; + } - loop { - debug!(log, "Starting tunnel to server..."); + loop { + debug!(log, "Starting tunnel to server..."); - let access_token = match access_token_provider.refresh_token().await { - Ok(t) => t, - Err(e) => { - fail!(e, "Error refreshing access token, will retry"); - continue; - } - }; + let access_token = match access_token_provider.refresh_token().await { + Ok(t) => t, + Err(e) => { + fail!(e, "Error refreshing access token, will retry"); + continue; + } + }; - // we don't bother making a client that can refresh the token, since - // the tunnel won't be able to host as soon as the access token expires. - let handle_res = { - let mut relay = relay.lock().await; - relay - .connect(&access_token) - .await - .map_err(|e| wrap(e, "error connecting to tunnel")) - }; + // we don't bother making a client that can refresh the token, since + // the tunnel won't be able to host as soon as the access token expires. + let handle_res = { + let mut relay = relay.lock().await; + relay + .connect(&access_token) + .await + .map_err(|e| wrap(e, "error connecting to tunnel")) + }; - let mut handle = match handle_res { - Ok(handle) => handle, - Err(e) => { - fail!(e, "Error connecting to relay, will retry"); - continue; - } - }; + let mut handle = match handle_res { + Ok(handle) => handle, + Err(e) => { + fail!(e, "Error connecting to relay, will retry"); + continue; + } + }; - backoff.reset(); - endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok(); + backoff.reset(); + endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok(); - tokio::select! { - // error is mapped like this prevent it being used across an await, - // which Rust dislikes since there's a non-sendable dyn Error in there - res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => { - if let Err(e) = res { - fail!(e, "Tunnel exited unexpectedly, reconnecting"); - } else { - warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting"); - backoff.delay().await; - } - }, - _ = close_rx.recv() => { - trace!(log, "Tunnel closing gracefully"); - trace!(log, "Tunnel closed with result: {:?}", handle.close().await); - break; - } - } - } - } + tokio::select! { + // error is mapped like this prevent it being used across an await, + // which Rust dislikes since there's a non-sendable dyn Error in there + res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => { + if let Err(e) = res { + fail!(e, "Tunnel exited unexpectedly, reconnecting"); + } else { + warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting"); + backoff.delay().await; + } + }, + _ = close_rx.recv() => { + trace!(log, "Tunnel closing gracefully"); + trace!(log, "Tunnel closed with result: {:?}", handle.close().await); + break; + } + } + } + } } struct Backoff { - failures: u32, - base_duration: Duration, - max_duration: Duration, + failures: u32, + base_duration: Duration, + max_duration: Duration, } impl Backoff { - pub fn new(base_duration: Duration, max_duration: Duration) -> Self { - Self { - failures: 0, - base_duration, - max_duration, - } - } + pub fn new(base_duration: Duration, max_duration: Duration) -> Self { + Self { + failures: 0, + base_duration, + max_duration, + } + } - pub async fn delay(&mut self) { - tokio::time::sleep(self.next()).await - } + pub async fn delay(&mut self) { + tokio::time::sleep(self.next()).await + } - pub fn next(&mut self) -> Duration { - self.failures += 1; - let duration = self - .base_duration - .checked_mul(self.failures) - .unwrap_or(self.max_duration); - std::cmp::min(duration, self.max_duration) - } + pub fn next(&mut self) -> Duration { + self.failures += 1; + let duration = self + .base_duration + .checked_mul(self.failures) + .unwrap_or(self.max_duration); + std::cmp::min(duration, self.max_duration) + } - pub fn reset(&mut self) { - self.failures = 0; - } + pub fn reset(&mut self) { + self.failures = 0; + } } diff --git a/cli/src/tunnels/legal.rs b/cli/src/tunnels/legal.rs index 6e0fb0f1cf5..0e11a8bacd4 100644 --- a/cli/src/tunnels/legal.rs +++ b/cli/src/tunnels/legal.rs @@ -12,45 +12,45 @@ const LICENSE_PROMPT: Option<&'static str> = option_env!("VSCODE_CLI_REMOTE_LICE #[derive(Clone, Default, Serialize, Deserialize)] struct PersistedConsent { - pub consented: Option, + pub consented: Option, } pub fn require_consent(paths: &LauncherPaths) -> Result<(), AnyError> { - match LICENSE_TEXT { - Some(t) => println!("{}", t), - None => return Ok(()), - } + match LICENSE_TEXT { + Some(t) => println!("{}", t), + None => return Ok(()), + } - let prompt = match LICENSE_PROMPT { - Some(p) => p, - None => return Ok(()), - }; + let prompt = match LICENSE_PROMPT { + Some(p) => p, + None => return Ok(()), + }; - let license: PersistedState = - PersistedState::new(paths.root().join("license_consent.json")); + let license: PersistedState = + PersistedState::new(paths.root().join("license_consent.json")); - let mut save = false; - let mut load = license.load(); + let mut save = false; + let mut load = license.load(); - if !load.consented.unwrap_or(false) { - match prompt_yn(prompt) { - Ok(true) => { - save = true; - load.consented = Some(true); - } - Ok(false) => { - return Err(AnyError::from(MissingLegalConsent( - "Sorry you cannot use VS Code Server CLI without accepting the terms." - .to_string(), - ))) - } - Err(e) => return Err(AnyError::from(MissingLegalConsent(e.to_string()))), - } - } + if !load.consented.unwrap_or(false) { + match prompt_yn(prompt) { + Ok(true) => { + save = true; + load.consented = Some(true); + } + Ok(false) => { + return Err(AnyError::from(MissingLegalConsent( + "Sorry you cannot use VS Code Server CLI without accepting the terms." + .to_string(), + ))) + } + Err(e) => return Err(AnyError::from(MissingLegalConsent(e.to_string()))), + } + } - if save { - license.save(load)?; - } + if save { + license.save(load)?; + } - Ok(()) + Ok(()) } diff --git a/cli/src/tunnels/name_generator.rs b/cli/src/tunnels/name_generator.rs index 17df9f01533..f7c8cc92441 100644 --- a/cli/src/tunnels/name_generator.rs +++ b/cli/src/tunnels/name_generator.rs @@ -6,213 +6,213 @@ use rand::prelude::*; // Adjectives in LEFT from Moby : static LEFT: &[&str] = &[ - "admiring", - "adoring", - "affectionate", - "agitated", - "amazing", - "angry", - "awesome", - "beautiful", - "blissful", - "bold", - "boring", - "brave", - "busy", - "charming", - "clever", - "cool", - "compassionate", - "competent", - "condescending", - "confident", - "cranky", - "crazy", - "dazzling", - "determined", - "distracted", - "dreamy", - "eager", - "ecstatic", - "elastic", - "elated", - "elegant", - "eloquent", - "epic", - "exciting", - "fervent", - "festive", - "flamboyant", - "focused", - "friendly", - "frosty", - "funny", - "gallant", - "gifted", - "goofy", - "gracious", - "great", - "happy", - "hardcore", - "heuristic", - "hopeful", - "hungry", - "infallible", - "inspiring", - "interesting", - "intelligent", - "jolly", - "jovial", - "keen", - "kind", - "laughing", - "loving", - "lucid", - "magical", - "mystifying", - "modest", - "musing", - "naughty", - "nervous", - "nice", - "nifty", - "nostalgic", - "objective", - "optimistic", - "peaceful", - "pedantic", - "pensive", - "practical", - "priceless", - "quirky", - "quizzical", - "recursing", - "relaxed", - "reverent", - "romantic", - "sad", - "serene", - "sharp", - "silly", - "sleepy", - "stoic", - "strange", - "stupefied", - "suspicious", - "sweet", - "tender", - "thirsty", - "trusting", - "unruffled", - "upbeat", - "vibrant", - "vigilant", - "vigorous", - "wizardly", - "wonderful", - "xenodochial", - "youthful", - "zealous", - "zen", + "admiring", + "adoring", + "affectionate", + "agitated", + "amazing", + "angry", + "awesome", + "beautiful", + "blissful", + "bold", + "boring", + "brave", + "busy", + "charming", + "clever", + "cool", + "compassionate", + "competent", + "condescending", + "confident", + "cranky", + "crazy", + "dazzling", + "determined", + "distracted", + "dreamy", + "eager", + "ecstatic", + "elastic", + "elated", + "elegant", + "eloquent", + "epic", + "exciting", + "fervent", + "festive", + "flamboyant", + "focused", + "friendly", + "frosty", + "funny", + "gallant", + "gifted", + "goofy", + "gracious", + "great", + "happy", + "hardcore", + "heuristic", + "hopeful", + "hungry", + "infallible", + "inspiring", + "interesting", + "intelligent", + "jolly", + "jovial", + "keen", + "kind", + "laughing", + "loving", + "lucid", + "magical", + "mystifying", + "modest", + "musing", + "naughty", + "nervous", + "nice", + "nifty", + "nostalgic", + "objective", + "optimistic", + "peaceful", + "pedantic", + "pensive", + "practical", + "priceless", + "quirky", + "quizzical", + "recursing", + "relaxed", + "reverent", + "romantic", + "sad", + "serene", + "sharp", + "silly", + "sleepy", + "stoic", + "strange", + "stupefied", + "suspicious", + "sweet", + "tender", + "thirsty", + "trusting", + "unruffled", + "upbeat", + "vibrant", + "vigilant", + "vigorous", + "wizardly", + "wonderful", + "xenodochial", + "youthful", + "zealous", + "zen", ]; static RIGHT: &[&str] = &[ - "albatross", - "antbird", - "antpitta", - "antshrike", - "antwren", - "babbler", - "barbet", - "blackbird", - "brushfinch", - "bulbul", - "bunting", - "cisticola", - "cormorant", - "crow", - "cuckoo", - "dove", - "drongo", - "duck", - "eagle", - "falcon", - "fantail", - "finch", - "flowerpecker", - "flycatcher", - "goose", - "goshawk", - "greenbul", - "grosbeak", - "gull", - "hawk", - "heron", - "honeyeater", - "hornbill", - "hummingbird", - "ibis", - "jay", - "kestrel", - "kingfisher", - "kite", - "lark", - "lorikeet", - "magpie", - "mockingbird", - "monarch", - "nightjar", - "oriole", - "owl", - "parakeet", - "parrot", - "partridge", - "penguin", - "petrel", - "pheasant", - "piculet", - "pigeon", - "pitta", - "prinia", - "puffin", - "quail", - "robin", - "sandpiper", - "seedeater", - "shearwater", - "sparrow", - "spinetail", - "starling", - "sunbird", - "swallow", - "swift", - "swiftlet", - "tanager", - "tapaculo", - "tern", - "thornbill", - "tinamou", - "trogon", - "tyrannulet", - "vireo", - "warbler", - "waxbill", - "weaver", - "whistler", - "woodpecker", - "wren", + "albatross", + "antbird", + "antpitta", + "antshrike", + "antwren", + "babbler", + "barbet", + "blackbird", + "brushfinch", + "bulbul", + "bunting", + "cisticola", + "cormorant", + "crow", + "cuckoo", + "dove", + "drongo", + "duck", + "eagle", + "falcon", + "fantail", + "finch", + "flowerpecker", + "flycatcher", + "goose", + "goshawk", + "greenbul", + "grosbeak", + "gull", + "hawk", + "heron", + "honeyeater", + "hornbill", + "hummingbird", + "ibis", + "jay", + "kestrel", + "kingfisher", + "kite", + "lark", + "lorikeet", + "magpie", + "mockingbird", + "monarch", + "nightjar", + "oriole", + "owl", + "parakeet", + "parrot", + "partridge", + "penguin", + "petrel", + "pheasant", + "piculet", + "pigeon", + "pitta", + "prinia", + "puffin", + "quail", + "robin", + "sandpiper", + "seedeater", + "shearwater", + "sparrow", + "spinetail", + "starling", + "sunbird", + "swallow", + "swift", + "swiftlet", + "tanager", + "tapaculo", + "tern", + "thornbill", + "tinamou", + "trogon", + "tyrannulet", + "vireo", + "warbler", + "waxbill", + "weaver", + "whistler", + "woodpecker", + "wren", ]; /// Generates a random avian name, with the optional extra_random_length added /// to reduce chance of in-flight collisions. pub fn generate_name(max_length: usize) -> String { - let mut rng = rand::thread_rng(); - loop { - let left = LEFT[rng.gen_range(0..LEFT.len())]; - let right = RIGHT[rng.gen_range(0..RIGHT.len())]; - let s = format!("{}-{}", left, right); - if s.len() < max_length { - return s; - } - } + let mut rng = rand::thread_rng(); + loop { + let left = LEFT[rng.gen_range(0..LEFT.len())]; + let right = RIGHT[rng.gen_range(0..RIGHT.len())]; + let s = format!("{}-{}", left, right); + if s.len() < max_length { + return s; + } + } } diff --git a/cli/src/tunnels/paths.rs b/cli/src/tunnels/paths.rs index 0d4982346e5..3c47b2575d7 100644 --- a/cli/src/tunnels/paths.rs +++ b/cli/src/tunnels/paths.rs @@ -4,19 +4,19 @@ *--------------------------------------------------------------------------------------------*/ use std::{ - fs::{read_dir, read_to_string, remove_dir_all, write}, - path::PathBuf, + fs::{read_dir, read_to_string, remove_dir_all, write}, + path::PathBuf, }; use serde::{Deserialize, Serialize}; use crate::{ - log, options, - state::{LauncherPaths, PersistedState}, - util::{ - errors::{wrap, AnyError, WrappedError}, - machine, - }, + log, options, + state::{LauncherPaths, PersistedState}, + util::{ + errors::{wrap, AnyError, WrappedError}, + machine, + }, }; const INSIDERS_INSTALL_FOLDER: &str = "server-insiders"; @@ -26,191 +26,191 @@ const PIDFILE_SUFFIX: &str = ".pid"; const LOGFILE_SUFFIX: &str = ".log"; pub struct ServerPaths { - // Directory into which the server is downloaded - pub server_dir: PathBuf, - // Executable path, within the server_id - pub executable: PathBuf, - // File where logs for the server should be written. - pub logfile: PathBuf, - // File where the process ID for the server should be written. - pub pidfile: PathBuf, + // Directory into which the server is downloaded + pub server_dir: PathBuf, + // Executable path, within the server_id + pub executable: PathBuf, + // File where logs for the server should be written. + pub logfile: PathBuf, + // File where the process ID for the server should be written. + pub pidfile: PathBuf, } impl ServerPaths { - // Queries the system to determine the process ID of the running server. - // Returns the process ID, if the server is running. - pub fn get_running_pid(&self) -> Option { - if let Some(pid) = self.read_pid() { - return match machine::process_at_path_exists(pid, &self.executable) { - true => Some(pid), - false => None, - }; - } + // Queries the system to determine the process ID of the running server. + // Returns the process ID, if the server is running. + pub fn get_running_pid(&self) -> Option { + if let Some(pid) = self.read_pid() { + return match machine::process_at_path_exists(pid, &self.executable) { + true => Some(pid), + false => None, + }; + } - if let Some(pid) = machine::find_running_process(&self.executable) { - // attempt to backfill process ID: - self.write_pid(pid).ok(); - return Some(pid); - } + if let Some(pid) = machine::find_running_process(&self.executable) { + // attempt to backfill process ID: + self.write_pid(pid).ok(); + return Some(pid); + } - None - } + None + } - /// Delete the server directory - pub fn delete(&self) -> Result<(), WrappedError> { - remove_dir_all(&self.server_dir).map_err(|e| { - wrap( - e, - format!("error deleting server dir {}", self.server_dir.display()), - ) - }) - } + /// Delete the server directory + pub fn delete(&self) -> Result<(), WrappedError> { + remove_dir_all(&self.server_dir).map_err(|e| { + wrap( + e, + format!("error deleting server dir {}", self.server_dir.display()), + ) + }) + } - // VS Code Server pid - pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> { - write(&self.pidfile, &format!("{}", pid)).map_err(|e| { - wrap( - e, - format!("error writing process id into {}", self.pidfile.display()), - ) - }) - } + // VS Code Server pid + pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> { + write(&self.pidfile, &format!("{}", pid)).map_err(|e| { + wrap( + e, + format!("error writing process id into {}", self.pidfile.display()), + ) + }) + } - fn read_pid(&self) -> Option { - read_to_string(&self.pidfile) - .ok() - .and_then(|s| s.parse::().ok()) - } + fn read_pid(&self) -> Option { + read_to_string(&self.pidfile) + .ok() + .and_then(|s| s.parse::().ok()) + } } #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct InstalledServer { - pub quality: options::Quality, - pub commit: String, - pub headless: bool, + pub quality: options::Quality, + pub commit: String, + pub headless: bool, } impl InstalledServer { - /// Gets path information about where a specific server should be stored. - pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths { - let base_folder = self.get_install_folder(p); - let server_dir = base_folder.join("bin").join(&self.commit); - ServerPaths { - executable: server_dir - .join("bin") - .join(self.quality.server_entrypoint()), - server_dir, - logfile: base_folder.join(format!(".{}{}", self.commit, LOGFILE_SUFFIX)), - pidfile: base_folder.join(format!(".{}{}", self.commit, PIDFILE_SUFFIX)), - } - } + /// Gets path information about where a specific server should be stored. + pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths { + let base_folder = self.get_install_folder(p); + let server_dir = base_folder.join("bin").join(&self.commit); + ServerPaths { + executable: server_dir + .join("bin") + .join(self.quality.server_entrypoint()), + server_dir, + logfile: base_folder.join(format!(".{}{}", self.commit, LOGFILE_SUFFIX)), + pidfile: base_folder.join(format!(".{}{}", self.commit, PIDFILE_SUFFIX)), + } + } - fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf { - let name = match self.quality { - options::Quality::Insiders => INSIDERS_INSTALL_FOLDER, - options::Quality::Exploration => EXPLORATION_INSTALL_FOLDER, - options::Quality::Stable => STABLE_INSTALL_FOLDER, - }; + fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf { + let name = match self.quality { + options::Quality::Insiders => INSIDERS_INSTALL_FOLDER, + options::Quality::Exploration => EXPLORATION_INSTALL_FOLDER, + options::Quality::Stable => STABLE_INSTALL_FOLDER, + }; - p.root().join(if !self.headless { - format!("{}-web", name) - } else { - name.to_string() - }) - } + p.root().join(if !self.headless { + format!("{}-web", name) + } else { + name.to_string() + }) + } } pub struct LastUsedServers<'a> { - state: PersistedState>, - paths: &'a LauncherPaths, + state: PersistedState>, + paths: &'a LauncherPaths, } impl<'a> LastUsedServers<'a> { - pub fn new(paths: &'a LauncherPaths) -> LastUsedServers { - LastUsedServers { - state: PersistedState::new(paths.root().join("last-used-servers.json")), - paths, - } - } + pub fn new(paths: &'a LauncherPaths) -> LastUsedServers { + LastUsedServers { + state: PersistedState::new(paths.root().join("last-used-servers.json")), + paths, + } + } - /// Adds a server as having been used most recently. Returns the number of retained server. - pub fn add(&self, server: InstalledServer) -> Result { - self.state.update_with(server, |server, l| { - if let Some(index) = l.iter().position(|s| s == &server) { - l.remove(index); - } - l.insert(0, server); - l.len() - }) - } + /// Adds a server as having been used most recently. Returns the number of retained server. + pub fn add(&self, server: InstalledServer) -> Result { + self.state.update_with(server, |server, l| { + if let Some(index) = l.iter().position(|s| s == &server) { + l.remove(index); + } + l.insert(0, server); + l.len() + }) + } - /// Trims so that at most `max_servers` are saved on disk. - pub fn trim(&self, log: &log::Logger, max_servers: usize) -> Result<(), WrappedError> { - let mut servers = self.state.load(); - while servers.len() > max_servers { - let server = servers.pop().unwrap(); - debug!( - log, - "Removing old server {}/{}", - server.quality.get_machine_name(), - server.commit - ); - let server_paths = server.server_paths(self.paths); - server_paths.delete()?; - } - self.state.save(servers)?; - Ok(()) - } + /// Trims so that at most `max_servers` are saved on disk. + pub fn trim(&self, log: &log::Logger, max_servers: usize) -> Result<(), WrappedError> { + let mut servers = self.state.load(); + while servers.len() > max_servers { + let server = servers.pop().unwrap(); + debug!( + log, + "Removing old server {}/{}", + server.quality.get_machine_name(), + server.commit + ); + let server_paths = server.server_paths(self.paths); + server_paths.delete()?; + } + self.state.save(servers)?; + Ok(()) + } } /// Prunes servers not currently running, and returns the deleted servers. pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result, AnyError> { - get_all_servers(launcher_paths) - .into_iter() - .map(|s| s.server_paths(launcher_paths)) - .filter(|s| s.get_running_pid().is_none()) - .map(|s| s.delete().map(|_| s)) - .collect::>() - .map_err(AnyError::from) + get_all_servers(launcher_paths) + .into_iter() + .map(|s| s.server_paths(launcher_paths)) + .filter(|s| s.get_running_pid().is_none()) + .map(|s| s.delete().map(|_| s)) + .collect::>() + .map_err(AnyError::from) } // Gets a list of all servers which look like they might be running. pub fn get_all_servers(lp: &LauncherPaths) -> Vec { - let mut servers: Vec = vec![]; - let mut server = InstalledServer { - commit: "".to_owned(), - headless: false, - quality: options::Quality::Stable, - }; + let mut servers: Vec = vec![]; + let mut server = InstalledServer { + commit: "".to_owned(), + headless: false, + quality: options::Quality::Stable, + }; - add_server_paths_in_folder(lp, &server, &mut servers); + add_server_paths_in_folder(lp, &server, &mut servers); - server.headless = true; - add_server_paths_in_folder(lp, &server, &mut servers); + server.headless = true; + add_server_paths_in_folder(lp, &server, &mut servers); - server.headless = false; - server.quality = options::Quality::Insiders; - add_server_paths_in_folder(lp, &server, &mut servers); + server.headless = false; + server.quality = options::Quality::Insiders; + add_server_paths_in_folder(lp, &server, &mut servers); - server.headless = true; - add_server_paths_in_folder(lp, &server, &mut servers); + server.headless = true; + add_server_paths_in_folder(lp, &server, &mut servers); - servers + servers } fn add_server_paths_in_folder( - lp: &LauncherPaths, - server: &InstalledServer, - servers: &mut Vec, + lp: &LauncherPaths, + server: &InstalledServer, + servers: &mut Vec, ) { - let dir = server.get_install_folder(lp).join("bin"); - if let Ok(children) = read_dir(dir) { - for bin in children.flatten() { - servers.push(InstalledServer { - quality: server.quality, - headless: server.headless, - commit: bin.file_name().to_string_lossy().into(), - }); - } - } + let dir = server.get_install_folder(lp).join("bin"); + if let Ok(children) = read_dir(dir) { + for bin in children.flatten() { + servers.push(InstalledServer { + quality: server.quality, + headless: server.headless, + commit: bin.file_name().to_string_lossy().into(), + }); + } + } } diff --git a/cli/src/tunnels/port_forwarder.rs b/cli/src/tunnels/port_forwarder.rs index 627c1f77a73..9c79bebd22f 100644 --- a/cli/src/tunnels/port_forwarder.rs +++ b/cli/src/tunnels/port_forwarder.rs @@ -8,123 +8,123 @@ use std::collections::HashSet; use tokio::sync::{mpsc, oneshot}; use crate::{ - constants::CONTROL_PORT, - util::errors::{AnyError, CannotForwardControlPort, ServerHasClosed}, + constants::CONTROL_PORT, + util::errors::{AnyError, CannotForwardControlPort, ServerHasClosed}, }; use super::dev_tunnels::ActiveTunnel; pub enum PortForwardingRec { - Forward(u16, oneshot::Sender>), - Unforward(u16, oneshot::Sender>), + Forward(u16, oneshot::Sender>), + Unforward(u16, oneshot::Sender>), } /// Provides a port forwarding service for connected clients. Clients can make /// requests on it, which are (and *must be*) processed by calling the `.process()` /// method on the forwarder. pub struct PortForwardingProcessor { - tx: mpsc::Sender, - rx: mpsc::Receiver, - forwarded: HashSet, + tx: mpsc::Sender, + rx: mpsc::Receiver, + forwarded: HashSet, } impl PortForwardingProcessor { - pub fn new() -> Self { - let (tx, rx) = mpsc::channel(8); - Self { - tx, - rx, - forwarded: HashSet::new(), - } - } + pub fn new() -> Self { + let (tx, rx) = mpsc::channel(8); + Self { + tx, + rx, + forwarded: HashSet::new(), + } + } - /// Gets a handle that can be passed off to consumers of port forwarding. - pub fn handle(&self) -> PortForwarding { - PortForwarding { - tx: self.tx.clone(), - } - } + /// Gets a handle that can be passed off to consumers of port forwarding. + pub fn handle(&self) -> PortForwarding { + PortForwarding { + tx: self.tx.clone(), + } + } - /// Receives port forwarding requests. Consumers MUST call `process()` - /// with the received requests. - pub async fn recv(&mut self) -> Option { - self.rx.recv().await - } + /// Receives port forwarding requests. Consumers MUST call `process()` + /// with the received requests. + pub async fn recv(&mut self) -> Option { + self.rx.recv().await + } - /// Processes the incoming forwarding request. - pub async fn process(&mut self, req: PortForwardingRec, tunnel: &mut ActiveTunnel) { - match req { - PortForwardingRec::Forward(port, tx) => { - tx.send(self.process_forward(port, tunnel).await).ok(); - } - PortForwardingRec::Unforward(port, tx) => { - tx.send(self.process_unforward(port, tunnel).await).ok(); - } - } - } + /// Processes the incoming forwarding request. + pub async fn process(&mut self, req: PortForwardingRec, tunnel: &mut ActiveTunnel) { + match req { + PortForwardingRec::Forward(port, tx) => { + tx.send(self.process_forward(port, tunnel).await).ok(); + } + PortForwardingRec::Unforward(port, tx) => { + tx.send(self.process_unforward(port, tunnel).await).ok(); + } + } + } - async fn process_unforward( - &mut self, - port: u16, - tunnel: &mut ActiveTunnel, - ) -> Result<(), AnyError> { - if port == CONTROL_PORT { - return Err(CannotForwardControlPort().into()); - } + async fn process_unforward( + &mut self, + port: u16, + tunnel: &mut ActiveTunnel, + ) -> Result<(), AnyError> { + if port == CONTROL_PORT { + return Err(CannotForwardControlPort().into()); + } - tunnel.remove_port(port).await?; - self.forwarded.remove(&port); - Ok(()) - } + tunnel.remove_port(port).await?; + self.forwarded.remove(&port); + Ok(()) + } - async fn process_forward( - &mut self, - port: u16, - tunnel: &mut ActiveTunnel, - ) -> Result { - if port == CONTROL_PORT { - return Err(CannotForwardControlPort().into()); - } + async fn process_forward( + &mut self, + port: u16, + tunnel: &mut ActiveTunnel, + ) -> Result { + if port == CONTROL_PORT { + return Err(CannotForwardControlPort().into()); + } - if !self.forwarded.contains(&port) { - tunnel.add_port_tcp(port).await?; - self.forwarded.insert(port); - } + if !self.forwarded.contains(&port) { + tunnel.add_port_tcp(port).await?; + self.forwarded.insert(port); + } - tunnel.get_port_uri(port).await - } + tunnel.get_port_uri(port).await + } } pub struct PortForwarding { - tx: mpsc::Sender, + tx: mpsc::Sender, } impl PortForwarding { - pub async fn forward(&self, port: u16) -> Result { - let (tx, rx) = oneshot::channel(); - let req = PortForwardingRec::Forward(port, tx); + pub async fn forward(&self, port: u16) -> Result { + let (tx, rx) = oneshot::channel(); + let req = PortForwardingRec::Forward(port, tx); - if self.tx.send(req).await.is_err() { - return Err(ServerHasClosed().into()); - } + if self.tx.send(req).await.is_err() { + return Err(ServerHasClosed().into()); + } - match rx.await { - Ok(r) => r, - Err(_) => Err(ServerHasClosed().into()), - } - } + match rx.await { + Ok(r) => r, + Err(_) => Err(ServerHasClosed().into()), + } + } - pub async fn unforward(&self, port: u16) -> Result<(), AnyError> { - let (tx, rx) = oneshot::channel(); - let req = PortForwardingRec::Unforward(port, tx); + pub async fn unforward(&self, port: u16) -> Result<(), AnyError> { + let (tx, rx) = oneshot::channel(); + let req = PortForwardingRec::Unforward(port, tx); - if self.tx.send(req).await.is_err() { - return Err(ServerHasClosed().into()); - } + if self.tx.send(req).await.is_err() { + return Err(ServerHasClosed().into()); + } - match rx.await { - Ok(r) => r, - Err(_) => Err(ServerHasClosed().into()), - } - } + match rx.await { + Ok(r) => r, + Err(_) => Err(ServerHasClosed().into()), + } + } } diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index 94d28232193..e4751ffb45c 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -11,47 +11,47 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "method", content = "params")] #[allow(non_camel_case_types)] pub enum ServerRequestMethod { - serve(ServeParams), - prune, - ping(EmptyResult), - forward(ForwardParams), - unforward(UnforwardParams), - gethostname(EmptyResult), - update(UpdateParams), - servermsg(ServerMessageParams), - callserverhttp(CallServerHttpParams), + serve(ServeParams), + prune, + ping(EmptyResult), + forward(ForwardParams), + unforward(UnforwardParams), + gethostname(EmptyResult), + update(UpdateParams), + servermsg(ServerMessageParams), + callserverhttp(CallServerHttpParams), } #[derive(Serialize, Debug)] #[serde(tag = "method", content = "params", rename_all = "camelCase")] #[allow(non_camel_case_types)] pub enum ClientRequestMethod<'a> { - servermsg(RefServerMessageParams<'a>), - serverlog(ServerLog<'a>), - version(VersionParams), + servermsg(RefServerMessageParams<'a>), + serverlog(ServerLog<'a>), + version(VersionParams), } #[derive(Deserialize, Debug)] pub struct ForwardParams { - pub port: u16, + pub port: u16, } #[derive(Deserialize, Debug)] pub struct UnforwardParams { - pub port: u16, + pub port: u16, } #[derive(Serialize)] pub struct ForwardResult { - pub uri: String, + pub uri: String, } #[derive(Deserialize, Debug)] pub struct ServeParams { - pub socket_id: u16, - pub commit_id: Option, - pub quality: Quality, - pub extensions: Vec, + pub socket_id: u16, + pub commit_id: Option, + pub quality: Quality, + pub extensions: Vec, } #[derive(Deserialize, Serialize, Debug)] @@ -59,93 +59,93 @@ pub struct EmptyResult {} #[derive(Serialize, Deserialize, Debug)] pub struct UpdateParams { - pub do_update: bool, + pub do_update: bool, } #[derive(Deserialize, Debug)] pub struct ServerMessageParams { - pub i: u16, - #[serde(with = "serde_bytes")] - pub body: Vec, + pub i: u16, + #[serde(with = "serde_bytes")] + pub body: Vec, } #[derive(Serialize, Debug)] pub struct RefServerMessageParams<'a> { - pub i: u16, - #[serde(with = "serde_bytes")] - pub body: &'a [u8], + pub i: u16, + #[serde(with = "serde_bytes")] + pub body: &'a [u8], } #[derive(Serialize)] pub struct UpdateResult { - pub up_to_date: bool, - pub did_update: bool, + pub up_to_date: bool, + pub did_update: bool, } #[derive(Deserialize, Debug)] pub struct ToServerRequest { - pub id: Option, - #[serde(flatten)] - pub params: ServerRequestMethod, + pub id: Option, + #[serde(flatten)] + pub params: ServerRequestMethod, } #[derive(Serialize, Debug)] pub struct ToClientRequest<'a> { - pub id: Option, - #[serde(flatten)] - pub params: ClientRequestMethod<'a>, + pub id: Option, + #[serde(flatten)] + pub params: ClientRequestMethod<'a>, } #[derive(Serialize, Deserialize)] pub struct SuccessResponse where - T: Serialize, + T: Serialize, { - pub id: u8, - pub result: T, + pub id: u8, + pub result: T, } #[derive(Serialize, Deserialize)] pub struct ErrorResponse { - pub id: u8, - pub error: ResponseError, + pub id: u8, + pub error: ResponseError, } #[derive(Serialize, Deserialize)] pub struct ResponseError { - pub code: i32, - pub message: String, + pub code: i32, + pub message: String, } #[derive(Debug, Default, Serialize)] pub struct ServerLog<'a> { - pub line: &'a str, - pub level: u8, + pub line: &'a str, + pub level: u8, } #[derive(Serialize)] pub struct GetHostnameResponse { - pub value: String, + pub value: String, } #[derive(Deserialize, Debug)] pub struct CallServerHttpParams { - pub path: String, - pub method: String, - pub headers: HashMap, - pub body: Option>, + pub path: String, + pub method: String, + pub headers: HashMap, + pub body: Option>, } #[derive(Serialize)] pub struct CallServerHttpResult { - pub status: u16, - #[serde(with = "serde_bytes")] - pub body: Vec, - pub headers: HashMap, + pub status: u16, + #[serde(with = "serde_bytes")] + pub body: Vec, + pub headers: HashMap, } #[derive(Serialize, Debug)] pub struct VersionParams { - pub version: &'static str, - pub protocol_version: u32, + pub version: &'static str, + pub protocol_version: u32, } diff --git a/cli/src/tunnels/server_bridge_unix.rs b/cli/src/tunnels/server_bridge_unix.rs index 38da8229c17..f584ddfddc9 100644 --- a/cli/src/tunnels/server_bridge_unix.rs +++ b/cli/src/tunnels/server_bridge_unix.rs @@ -5,76 +5,76 @@ use std::path::Path; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{unix::OwnedWriteHalf, UnixStream}, - sync::mpsc::Sender, + io::{AsyncReadExt, AsyncWriteExt}, + net::{unix::OwnedWriteHalf, UnixStream}, + sync::mpsc::Sender, }; use crate::util::errors::{wrap, AnyError}; pub struct ServerBridge { - write: OwnedWriteHalf, + write: OwnedWriteHalf, } pub trait FromServerMessage { - fn from_server_message(index: u16, message: &[u8]) -> Self; - fn from_closed_server_bridge(i: u16) -> Self; + fn from_server_message(index: u16, message: &[u8]) -> Self; + fn from_closed_server_bridge(i: u16) -> Self; } pub async fn get_socket_rw_stream(path: &Path) -> Result { - let s = UnixStream::connect(path).await.map_err(|e| { - wrap( - e, - format!( - "error connecting to vscode server socket in {}", - path.display() - ), - ) - })?; + let s = UnixStream::connect(path).await.map_err(|e| { + wrap( + e, + format!( + "error connecting to vscode server socket in {}", + path.display() + ), + ) + })?; - Ok(s) + Ok(s) } const BUFFER_SIZE: usize = 65536; impl ServerBridge { - pub async fn new(path: &Path, index: u16, target: &Sender) -> Result - where - T: 'static + FromServerMessage + Send, - { - let stream = get_socket_rw_stream(path).await?; - let (mut read, write) = stream.into_split(); + pub async fn new(path: &Path, index: u16, target: &Sender) -> Result + where + T: 'static + FromServerMessage + Send, + { + let stream = get_socket_rw_stream(path).await?; + let (mut read, write) = stream.into_split(); - let tx = target.clone(); - tokio::spawn(async move { - let mut read_buf = vec![0; BUFFER_SIZE]; - loop { - match read.read(&mut read_buf).await { - Err(_) => return, - Ok(0) => { - let _ = tx.send(T::from_closed_server_bridge(index)).await; - return; // EOF - } - Ok(s) => { - let send = tx.send(T::from_server_message(index, &read_buf[..s])).await; - if send.is_err() { - return; - } - } - } - } - }); + let tx = target.clone(); + tokio::spawn(async move { + let mut read_buf = vec![0; BUFFER_SIZE]; + loop { + match read.read(&mut read_buf).await { + Err(_) => return, + Ok(0) => { + let _ = tx.send(T::from_closed_server_bridge(index)).await; + return; // EOF + } + Ok(s) => { + let send = tx.send(T::from_server_message(index, &read_buf[..s])).await; + if send.is_err() { + return; + } + } + } + } + }); - Ok(ServerBridge { write }) - } + Ok(ServerBridge { write }) + } - pub async fn write(&mut self, b: Vec) -> std::io::Result<()> { - self.write.write_all(&b).await?; - Ok(()) - } + pub async fn write(&mut self, b: Vec) -> std::io::Result<()> { + self.write.write_all(&b).await?; + Ok(()) + } - pub async fn close(mut self) -> std::io::Result<()> { - self.write.shutdown().await?; - Ok(()) - } + pub async fn close(mut self) -> std::io::Result<()> { + self.write.shutdown().await?; + Ok(()) + } } diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs index 5116b61af4d..fb4b2b321f0 100644 --- a/cli/src/tunnels/server_bridge_windows.rs +++ b/cli/src/tunnels/server_bridge_windows.rs @@ -6,128 +6,128 @@ use std::{path::Path, time::Duration}; use tokio::{ - io::{self, Interest}, - net::windows::named_pipe::{ClientOptions, NamedPipeClient}, - sync::mpsc, - time::sleep, + io::{self, Interest}, + net::windows::named_pipe::{ClientOptions, NamedPipeClient}, + sync::mpsc, + time::sleep, }; use crate::util::errors::{wrap, AnyError}; pub struct ServerBridge { - write_tx: mpsc::Sender>, + write_tx: mpsc::Sender>, } pub trait FromServerMessage { - fn from_server_message(index: u16, message: &[u8]) -> Self; - fn from_closed_server_bridge(i: u16) -> Self; + fn from_server_message(index: u16, message: &[u8]) -> Self; + fn from_closed_server_bridge(i: u16) -> Self; } const BUFFER_SIZE: usize = 65536; pub async fn get_socket_rw_stream(path: &Path) -> Result { - // Tokio says we can need to try in a loop. Do so. - // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html - let client = loop { - match ClientOptions::new().open(path) { - Ok(client) => break client, - // ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- - Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await, - Err(e) => { - return Err(AnyError::WrappedError(wrap( - e, - format!( - "error connecting to vscode server socket in {}", - path.display() - ), - ))) - } - } - }; + // Tokio says we can need to try in a loop. Do so. + // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html + let client = loop { + match ClientOptions::new().open(path) { + Ok(client) => break client, + // ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- + Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await, + Err(e) => { + return Err(AnyError::WrappedError(wrap( + e, + format!( + "error connecting to vscode server socket in {}", + path.display() + ), + ))) + } + } + }; - Ok(client) + Ok(client) } impl ServerBridge { - pub async fn new(path: &Path, index: u16, target: &mpsc::Sender) -> Result - where - T: 'static + FromServerMessage + Send, - { - let client = get_socket_rw_stream(path).await?; - let (write_tx, mut write_rx) = mpsc::channel(4); - let read_tx = target.clone(); - tokio::spawn(async move { - let mut read_buf = vec![0; BUFFER_SIZE]; - let mut pending_recv: Option> = None; + pub async fn new(path: &Path, index: u16, target: &mpsc::Sender) -> Result + where + T: 'static + FromServerMessage + Send, + { + let client = get_socket_rw_stream(path).await?; + let (write_tx, mut write_rx) = mpsc::channel(4); + let read_tx = target.clone(); + tokio::spawn(async move { + let mut read_buf = vec![0; BUFFER_SIZE]; + let mut pending_recv: Option> = None; - // See https://docs.rs/tokio/1.17.0/tokio/net/windows/named_pipe/struct.NamedPipeClient.html#method.ready - // With additional complications. If there's nothing queued to write, we wait for the - // pipe to be readable, or for something to come in. If there is something to - // write, wait until the pipe is either readable or writable. - loop { - let ready_result = if pending_recv.is_none() { - tokio::select! { - msg = write_rx.recv() => match msg { - Some(msg) => { - pending_recv = Some(msg); - client.ready(Interest::READABLE | Interest::WRITABLE).await - }, - None => return - }, - r = client.ready(Interest::READABLE) => r, - } - } else { - client.ready(Interest::READABLE | Interest::WRITABLE).await - }; + // See https://docs.rs/tokio/1.17.0/tokio/net/windows/named_pipe/struct.NamedPipeClient.html#method.ready + // With additional complications. If there's nothing queued to write, we wait for the + // pipe to be readable, or for something to come in. If there is something to + // write, wait until the pipe is either readable or writable. + loop { + let ready_result = if pending_recv.is_none() { + tokio::select! { + msg = write_rx.recv() => match msg { + Some(msg) => { + pending_recv = Some(msg); + client.ready(Interest::READABLE | Interest::WRITABLE).await + }, + None => return + }, + r = client.ready(Interest::READABLE) => r, + } + } else { + client.ready(Interest::READABLE | Interest::WRITABLE).await + }; - let ready = match ready_result { - Ok(r) => r, - Err(_) => return, - }; + let ready = match ready_result { + Ok(r) => r, + Err(_) => return, + }; - if ready.is_readable() { - match client.try_read(&mut read_buf) { - Ok(0) => return, // EOF - Ok(s) => { - let send = read_tx - .send(T::from_server_message(index, &read_buf[..s])) - .await; - if send.is_err() { - return; - } - } - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - continue; - } - Err(_) => return, - } - } + if ready.is_readable() { + match client.try_read(&mut read_buf) { + Ok(0) => return, // EOF + Ok(s) => { + let send = read_tx + .send(T::from_server_message(index, &read_buf[..s])) + .await; + if send.is_err() { + return; + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + Err(_) => return, + } + } - if let Some(msg) = &pending_recv { - if ready.is_writable() { - match client.try_write(msg) { - Ok(n) if n == msg.len() => pending_recv = None, - Ok(n) => pending_recv = Some(msg[n..].to_vec()), - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - continue; - } - Err(_) => return, - } - } - } - } - }); + if let Some(msg) = &pending_recv { + if ready.is_writable() { + match client.try_write(msg) { + Ok(n) if n == msg.len() => pending_recv = None, + Ok(n) => pending_recv = Some(msg[n..].to_vec()), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(_) => return, + } + } + } + } + }); - Ok(ServerBridge { write_tx }) - } + Ok(ServerBridge { write_tx }) + } - pub async fn write(&self, b: Vec) -> std::io::Result<()> { - self.write_tx.send(b).await.ok(); - Ok(()) - } + pub async fn write(&self, b: Vec) -> std::io::Result<()> { + self.write_tx.send(b).await.ok(); + Ok(()) + } - pub async fn close(self) -> std::io::Result<()> { - drop(self.write_tx); - Ok(()) - } + pub async fn close(self) -> std::io::Result<()> { + drop(self.write_tx); + Ok(()) + } } diff --git a/cli/src/tunnels/service.rs b/cli/src/tunnels/service.rs index 2cc3b9c4c89..85a26d64c01 100644 --- a/cli/src/tunnels/service.rs +++ b/cli/src/tunnels/service.rs @@ -16,30 +16,30 @@ pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log"; #[async_trait] pub trait ServiceContainer: Send { - async fn run_service( - &mut self, - log: log::Logger, - launcher_paths: LauncherPaths, - shutdown_rx: oneshot::Receiver<()>, - ) -> Result<(), AnyError>; + async fn run_service( + &mut self, + log: log::Logger, + launcher_paths: LauncherPaths, + shutdown_rx: oneshot::Receiver<()>, + ) -> Result<(), AnyError>; } pub trait ServiceManager { - /// Registers the current executable as a service to run with the given set - /// of arguments. - fn register(&self, exe: PathBuf, args: &[&str]) -> Result<(), AnyError>; + /// Registers the current executable as a service to run with the given set + /// of arguments. + fn register(&self, exe: PathBuf, args: &[&str]) -> Result<(), AnyError>; - /// Runs the service using the given handle. The executable *must not* take - /// any action which may fail prior to calling this to ensure service - /// states may update. - fn run( - &self, - launcher_paths: LauncherPaths, - handle: impl 'static + ServiceContainer, - ) -> Result<(), AnyError>; + /// Runs the service using the given handle. The executable *must not* take + /// any action which may fail prior to calling this to ensure service + /// states may update. + fn run( + &self, + launcher_paths: LauncherPaths, + handle: impl 'static + ServiceContainer, + ) -> Result<(), AnyError>; - /// Unregisters the current executable as a service. - fn unregister(&self) -> Result<(), AnyError>; + /// Unregisters the current executable as a service. + fn unregister(&self) -> Result<(), AnyError>; } #[cfg(target_os = "windows")] @@ -50,32 +50,32 @@ pub type ServiceManagerImpl = UnimplementedServiceManager; #[allow(unreachable_code)] pub fn create_service_manager(log: log::Logger) -> ServiceManagerImpl { - ServiceManagerImpl::new(log) + ServiceManagerImpl::new(log) } pub struct UnimplementedServiceManager(); #[allow(dead_code)] impl UnimplementedServiceManager { - fn new(_log: log::Logger) -> Self { - Self() - } + fn new(_log: log::Logger) -> Self { + Self() + } } impl ServiceManager for UnimplementedServiceManager { - fn register(&self, _exe: PathBuf, _args: &[&str]) -> Result<(), AnyError> { - unimplemented!("Service management is not supported on this platform"); - } + fn register(&self, _exe: PathBuf, _args: &[&str]) -> Result<(), AnyError> { + unimplemented!("Service management is not supported on this platform"); + } - fn run( - &self, - _launcher_paths: LauncherPaths, - _handle: impl 'static + ServiceContainer, - ) -> Result<(), AnyError> { - unimplemented!("Service management is not supported on this platform"); - } + fn run( + &self, + _launcher_paths: LauncherPaths, + _handle: impl 'static + ServiceContainer, + ) -> Result<(), AnyError> { + unimplemented!("Service management is not supported on this platform"); + } - fn unregister(&self) -> Result<(), AnyError> { - unimplemented!("Service management is not supported on this platform"); - } + fn unregister(&self) -> Result<(), AnyError> { + unimplemented!("Service management is not supported on this platform"); + } } diff --git a/cli/src/tunnels/service_windows.rs b/cli/src/tunnels/service_windows.rs index bc118cef5dc..ee467bdccb9 100644 --- a/cli/src/tunnels/service_windows.rs +++ b/cli/src/tunnels/service_windows.rs @@ -8,271 +8,271 @@ use lazy_static::lazy_static; use std::{ffi::OsString, sync::Mutex, thread, time::Duration}; use tokio::sync::oneshot; use windows_service::{ - define_windows_service, - service::{ - ServiceAccess, ServiceControl, ServiceControlAccept, ServiceErrorControl, ServiceExitCode, - ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType, - }, - service_control_handler::{self, ServiceControlHandlerResult}, - service_dispatcher, - service_manager::{ServiceManager, ServiceManagerAccess}, + define_windows_service, + service::{ + ServiceAccess, ServiceControl, ServiceControlAccept, ServiceErrorControl, ServiceExitCode, + ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType, + }, + service_control_handler::{self, ServiceControlHandlerResult}, + service_dispatcher, + service_manager::{ServiceManager, ServiceManagerAccess}, }; use crate::util::errors::{wrap, AnyError, WindowsNeedsElevation}; use crate::{ - log::{self, FileLogSink}, - state::LauncherPaths, + log::{self, FileLogSink}, + state::LauncherPaths, }; use super::service::{ - ServiceContainer, ServiceManager as CliServiceManager, SERVICE_LOG_FILE_NAME, + ServiceContainer, ServiceManager as CliServiceManager, SERVICE_LOG_FILE_NAME, }; pub struct WindowsService { - log: log::Logger, + log: log::Logger, } const SERVICE_NAME: &str = "code_tunnel"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; impl WindowsService { - pub fn new(log: log::Logger) -> Self { - Self { log } - } + pub fn new(log: log::Logger) -> Self { + Self { log } + } } impl CliServiceManager for WindowsService { - fn register(&self, exe: std::path::PathBuf, args: &[&str]) -> Result<(), AnyError> { - let service_manager = ServiceManager::local_computer( - None::<&str>, - ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE, - ) - .map_err(|e| WindowsNeedsElevation(format!("error getting service manager: {}", e)))?; + fn register(&self, exe: std::path::PathBuf, args: &[&str]) -> Result<(), AnyError> { + let service_manager = ServiceManager::local_computer( + None::<&str>, + ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE, + ) + .map_err(|e| WindowsNeedsElevation(format!("error getting service manager: {}", e)))?; - let mut service_info = ServiceInfo { - name: OsString::from(SERVICE_NAME), - display_name: OsString::from("VS Code Tunnel"), - service_type: SERVICE_TYPE, - start_type: ServiceStartType::AutoStart, - error_control: ServiceErrorControl::Normal, - executable_path: exe, - launch_arguments: args.iter().map(OsString::from).collect(), - dependencies: vec![], - account_name: None, - account_password: None, - }; + let mut service_info = ServiceInfo { + name: OsString::from(SERVICE_NAME), + display_name: OsString::from("VS Code Tunnel"), + service_type: SERVICE_TYPE, + start_type: ServiceStartType::AutoStart, + error_control: ServiceErrorControl::Normal, + executable_path: exe, + launch_arguments: args.iter().map(OsString::from).collect(), + dependencies: vec![], + account_name: None, + account_password: None, + }; - let existing_service = service_manager.open_service( - SERVICE_NAME, - ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::CHANGE_CONFIG, - ); - let service = if let Ok(service) = existing_service { - service - .change_config(&service_info) - .map_err(|e| wrap(e, "error updating existing service"))?; - service - } else { - loop { - let (username, password) = prompt_credentials()?; - service_info.account_name = Some(format!(".\\{}", username).into()); - service_info.account_password = Some(password.into()); + let existing_service = service_manager.open_service( + SERVICE_NAME, + ServiceAccess::QUERY_STATUS | ServiceAccess::START | ServiceAccess::CHANGE_CONFIG, + ); + let service = if let Ok(service) = existing_service { + service + .change_config(&service_info) + .map_err(|e| wrap(e, "error updating existing service"))?; + service + } else { + loop { + let (username, password) = prompt_credentials()?; + service_info.account_name = Some(format!(".\\{}", username).into()); + service_info.account_password = Some(password.into()); - match service_manager.create_service( - &service_info, - ServiceAccess::CHANGE_CONFIG | ServiceAccess::START, - ) { - Ok(service) => break service, - Err(windows_service::Error::Winapi(e)) if Some(1057) == e.raw_os_error() => { - error!( - self.log, - "Invalid username or password, please try again..." - ); - } - Err(e) => return Err(wrap(e, "error registering service").into()), - } - } - }; + match service_manager.create_service( + &service_info, + ServiceAccess::CHANGE_CONFIG | ServiceAccess::START, + ) { + Ok(service) => break service, + Err(windows_service::Error::Winapi(e)) if Some(1057) == e.raw_os_error() => { + error!( + self.log, + "Invalid username or password, please try again..." + ); + } + Err(e) => return Err(wrap(e, "error registering service").into()), + } + } + }; - service - .set_description("Service that runs `code tunnel` for access on vscode.dev") - .ok(); + service + .set_description("Service that runs `code tunnel` for access on vscode.dev") + .ok(); - info!(self.log, "Successfully registered service..."); + info!(self.log, "Successfully registered service..."); - let status = service - .query_status() - .map(|s| s.current_state) - .unwrap_or(ServiceState::Stopped); + let status = service + .query_status() + .map(|s| s.current_state) + .unwrap_or(ServiceState::Stopped); - if status == ServiceState::Stopped { - service - .start::<&str>(&[]) - .map_err(|e| wrap(e, "error starting service"))?; - } + if status == ServiceState::Stopped { + service + .start::<&str>(&[]) + .map_err(|e| wrap(e, "error starting service"))?; + } - info!(self.log, "Tunnel service successfully started"); - Ok(()) - } + info!(self.log, "Tunnel service successfully started"); + Ok(()) + } - #[allow(unused_must_use)] // triggers incorrectly on `define_windows_service!` - fn run( - &self, - launcher_paths: LauncherPaths, - handle: impl 'static + ServiceContainer, - ) -> Result<(), AnyError> { - let log = match FileLogSink::new( - log::Level::Debug, - &launcher_paths.root().join(SERVICE_LOG_FILE_NAME), - ) { - Ok(sink) => self.log.tee(sink), - Err(e) => { - warning!(self.log, "Failed to create service log file: {}", e); - self.log.clone() - } - }; + #[allow(unused_must_use)] // triggers incorrectly on `define_windows_service!` + fn run( + &self, + launcher_paths: LauncherPaths, + handle: impl 'static + ServiceContainer, + ) -> Result<(), AnyError> { + let log = match FileLogSink::new( + log::Level::Debug, + &launcher_paths.root().join(SERVICE_LOG_FILE_NAME), + ) { + Ok(sink) => self.log.tee(sink), + Err(e) => { + warning!(self.log, "Failed to create service log file: {}", e); + self.log.clone() + } + }; - // We put the handle into the global "impl" type and then take it out in - // my_service_main. This is needed just since we have to have that - // function at the root level, but need to pass in data later here... - SERVICE_IMPL.lock().unwrap().replace(ServiceImpl { - container: Box::new(handle), - launcher_paths, - log, - }); + // We put the handle into the global "impl" type and then take it out in + // my_service_main. This is needed just since we have to have that + // function at the root level, but need to pass in data later here... + SERVICE_IMPL.lock().unwrap().replace(ServiceImpl { + container: Box::new(handle), + launcher_paths, + log, + }); - define_windows_service!(ffi_service_main, service_main); + define_windows_service!(ffi_service_main, service_main); - service_dispatcher::start(SERVICE_NAME, ffi_service_main) - .map_err(|e| wrap(e, "error starting service dispatcher").into()) - } + service_dispatcher::start(SERVICE_NAME, ffi_service_main) + .map_err(|e| wrap(e, "error starting service dispatcher").into()) + } - fn unregister(&self) -> Result<(), AnyError> { - let service_manager = - ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) - .map_err(|e| wrap(e, "error getting service manager"))?; + fn unregister(&self) -> Result<(), AnyError> { + let service_manager = + ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) + .map_err(|e| wrap(e, "error getting service manager"))?; - let service = service_manager.open_service( - SERVICE_NAME, - ServiceAccess::QUERY_STATUS | ServiceAccess::STOP | ServiceAccess::DELETE, - ); + let service = service_manager.open_service( + SERVICE_NAME, + ServiceAccess::QUERY_STATUS | ServiceAccess::STOP | ServiceAccess::DELETE, + ); - let service = match service { - Ok(service) => service, - // Service does not exist: - Err(windows_service::Error::Winapi(e)) if Some(1060) == e.raw_os_error() => { - return Ok(()) - } - Err(e) => return Err(wrap(e, "error getting service handle").into()), - }; + let service = match service { + Ok(service) => service, + // Service does not exist: + Err(windows_service::Error::Winapi(e)) if Some(1060) == e.raw_os_error() => { + return Ok(()) + } + Err(e) => return Err(wrap(e, "error getting service handle").into()), + }; - let service_status = service - .query_status() - .map_err(|e| wrap(e, "error getting service status"))?; + let service_status = service + .query_status() + .map_err(|e| wrap(e, "error getting service status"))?; - if service_status.current_state != ServiceState::Stopped { - service - .stop() - .map_err(|e| wrap(e, "error getting stopping service"))?; + if service_status.current_state != ServiceState::Stopped { + service + .stop() + .map_err(|e| wrap(e, "error getting stopping service"))?; - while let Ok(ServiceState::Stopped) = service.query_status().map(|s| s.current_state) { - info!(self.log, "Polling for service to stop..."); - thread::sleep(Duration::from_secs(1)); - } - } + while let Ok(ServiceState::Stopped) = service.query_status().map(|s| s.current_state) { + info!(self.log, "Polling for service to stop..."); + thread::sleep(Duration::from_secs(1)); + } + } - service - .delete() - .map_err(|e| wrap(e, "error deleting service"))?; + service + .delete() + .map_err(|e| wrap(e, "error deleting service"))?; - Ok(()) - } + Ok(()) + } } struct ServiceImpl { - container: Box, - launcher_paths: LauncherPaths, - log: log::Logger, + container: Box, + launcher_paths: LauncherPaths, + log: log::Logger, } lazy_static! { - static ref SERVICE_IMPL: Mutex> = Mutex::new(None); + static ref SERVICE_IMPL: Mutex> = Mutex::new(None); } /// "main" function that the service calls in its own thread. fn service_main(_arguments: Vec) -> Result<(), AnyError> { - let mut service = SERVICE_IMPL.lock().unwrap().take().unwrap(); + let mut service = SERVICE_IMPL.lock().unwrap().take().unwrap(); - // Create a channel to be able to poll a stop event from the service worker loop. - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let mut shutdown_tx = Some(shutdown_tx); + // Create a channel to be able to poll a stop event from the service worker loop. + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let mut shutdown_tx = Some(shutdown_tx); - // Define system service event handler that will be receiving service events. - let event_handler = move |control_event| -> ServiceControlHandlerResult { - match control_event { - ServiceControl::Interrogate => ServiceControlHandlerResult::NoError, - ServiceControl::Stop => { - shutdown_tx.take().and_then(|tx| tx.send(()).ok()); - ServiceControlHandlerResult::NoError - } + // Define system service event handler that will be receiving service events. + let event_handler = move |control_event| -> ServiceControlHandlerResult { + match control_event { + ServiceControl::Interrogate => ServiceControlHandlerResult::NoError, + ServiceControl::Stop => { + shutdown_tx.take().and_then(|tx| tx.send(()).ok()); + ServiceControlHandlerResult::NoError + } - _ => ServiceControlHandlerResult::NotImplemented, - } - }; + _ => ServiceControlHandlerResult::NotImplemented, + } + }; - let status_handle = service_control_handler::register(SERVICE_NAME, event_handler) - .map_err(|e| wrap(e, "error registering service event handler"))?; + let status_handle = service_control_handler::register(SERVICE_NAME, event_handler) + .map_err(|e| wrap(e, "error registering service event handler"))?; - // Tell the system that service is running - status_handle - .set_service_status(ServiceStatus { - service_type: SERVICE_TYPE, - current_state: ServiceState::Running, - controls_accepted: ServiceControlAccept::STOP, - exit_code: ServiceExitCode::Win32(0), - checkpoint: 0, - wait_hint: Duration::default(), - process_id: None, - }) - .map_err(|e| wrap(e, "error marking service as running"))?; + // Tell the system that service is running + status_handle + .set_service_status(ServiceStatus { + service_type: SERVICE_TYPE, + current_state: ServiceState::Running, + controls_accepted: ServiceControlAccept::STOP, + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + }) + .map_err(|e| wrap(e, "error marking service as running"))?; - let result = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on( - service - .container - .run_service(service.log, service.launcher_paths, shutdown_rx), - ); + let result = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on( + service + .container + .run_service(service.log, service.launcher_paths, shutdown_rx), + ); - status_handle - .set_service_status(ServiceStatus { - service_type: SERVICE_TYPE, - current_state: ServiceState::Stopped, - controls_accepted: ServiceControlAccept::empty(), - exit_code: ServiceExitCode::Win32(0), - checkpoint: 0, - wait_hint: Duration::default(), - process_id: None, - }) - .map_err(|e| wrap(e, "error marking service as stopped"))?; + status_handle + .set_service_status(ServiceStatus { + service_type: SERVICE_TYPE, + current_state: ServiceState::Stopped, + controls_accepted: ServiceControlAccept::empty(), + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + }) + .map_err(|e| wrap(e, "error marking service as stopped"))?; - result + result } fn prompt_credentials() -> Result<(String, String), AnyError> { - println!("Running a Windows service under your user requires your username and password."); - println!("These are sent to the Windows Service Manager and are not stored by VS Code."); + println!("Running a Windows service under your user requires your username and password."); + println!("These are sent to the Windows Service Manager and are not stored by VS Code."); - let username: String = Input::with_theme(&ColorfulTheme::default()) - .with_prompt("Windows username:") - .interact_text() - .map_err(|e| wrap(e, "Failed to read username"))?; + let username: String = Input::with_theme(&ColorfulTheme::default()) + .with_prompt("Windows username:") + .interact_text() + .map_err(|e| wrap(e, "Failed to read username"))?; - let password = Password::with_theme(&ColorfulTheme::default()) - .with_prompt("Windows password:") - .interact() - .map_err(|e| wrap(e, "Failed to read password"))?; + let password = Password::with_theme(&ColorfulTheme::default()) + .with_prompt("Windows password:") + .interact() + .map_err(|e| wrap(e, "Failed to read password"))?; - Ok((username, password)) + Ok((username, password)) } diff --git a/cli/src/update.rs b/cli/src/update.rs index b99b0eb44c7..2863b632490 100644 --- a/cli/src/update.rs +++ b/cli/src/update.rs @@ -7,114 +7,114 @@ use crate::constants::{VSCODE_CLI_ASSET_NAME, VSCODE_CLI_VERSION}; use crate::util::{errors, http, io::SilentCopyProgress}; use serde::Deserialize; use std::{ - fs::{rename, set_permissions}, - path::Path, + fs::{rename, set_permissions}, + path::Path, }; pub struct Update { - client: reqwest::Client, + client: reqwest::Client, } const LATEST_URL: &str = "https://aka.ms/vscode-server-launcher/update"; impl Default for Update { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl Update { - // Creates a new Update instance without authentication - pub fn new() -> Update { - Update { - client: reqwest::Client::new(), - } - } + // Creates a new Update instance without authentication + pub fn new() -> Update { + Update { + client: reqwest::Client::new(), + } + } - // Gets the asset to update to, or None if the current launcher is up to date. - pub async fn get_latest_release(&self) -> Result { - let res = self - .client - .get(LATEST_URL) - .header( - "User-Agent", - format!( - "vscode-server-launcher/{}", - VSCODE_CLI_VERSION.unwrap_or("dev") - ), - ) - .send() - .await?; + // Gets the asset to update to, or None if the current launcher is up to date. + pub async fn get_latest_release(&self) -> Result { + let res = self + .client + .get(LATEST_URL) + .header( + "User-Agent", + format!( + "vscode-server-launcher/{}", + VSCODE_CLI_VERSION.unwrap_or("dev") + ), + ) + .send() + .await?; - if !res.status().is_success() { - return Err(errors::StatusError::from_res(res).await?.into()); - } + if !res.status().is_success() { + return Err(errors::StatusError::from_res(res).await?.into()); + } - Ok(res.json::().await?) - } + Ok(res.json::().await?) + } - pub async fn switch_to_release( - &self, - update: &LauncherRelease, - target_path: &Path, - ) -> Result<(), errors::AnyError> { - let mut staging_path = target_path.to_owned(); - staging_path.set_file_name(format!( - "{}.next", - target_path.file_name().unwrap().to_string_lossy() - )); + pub async fn switch_to_release( + &self, + update: &LauncherRelease, + target_path: &Path, + ) -> Result<(), errors::AnyError> { + let mut staging_path = target_path.to_owned(); + staging_path.set_file_name(format!( + "{}.next", + target_path.file_name().unwrap().to_string_lossy() + )); - let an = VSCODE_CLI_ASSET_NAME.unwrap(); - let mut url = format!("{}/{}/{}", update.url, an, an); - if cfg!(target_os = "windows") { - url += ".exe"; - } + let an = VSCODE_CLI_ASSET_NAME.unwrap(); + let mut url = format!("{}/{}/{}", update.url, an, an); + if cfg!(target_os = "windows") { + url += ".exe"; + } - let res = self.client.get(url).send().await?; + let res = self.client.get(url).send().await?; - if !res.status().is_success() { - return Err(errors::StatusError::from_res(res).await?.into()); - } + if !res.status().is_success() { + return Err(errors::StatusError::from_res(res).await?.into()); + } - http::download_into_file(&staging_path, SilentCopyProgress(), res).await?; + http::download_into_file(&staging_path, SilentCopyProgress(), res).await?; - copy_file_metadata(target_path, &staging_path) - .map_err(|e| errors::wrap(e, "failed to set file permissions"))?; + copy_file_metadata(target_path, &staging_path) + .map_err(|e| errors::wrap(e, "failed to set file permissions"))?; - rename(&staging_path, &target_path) - .map_err(|e| errors::wrap(e, "failed to copy new launcher version"))?; + rename(&staging_path, &target_path) + .map_err(|e| errors::wrap(e, "failed to copy new launcher version"))?; - Ok(()) - } + Ok(()) + } } #[derive(Deserialize, Clone)] pub struct LauncherRelease { - pub version: String, - pub url: String, - pub released_at: u64, + pub version: String, + pub url: String, + pub released_at: u64, } #[cfg(target_os = "windows")] fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> { - let permissions = from.metadata()?.permissions(); - set_permissions(&to, permissions)?; - Ok(()) + let permissions = from.metadata()?.permissions(); + set_permissions(&to, permissions)?; + Ok(()) } #[cfg(not(target_os = "windows"))] fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> { - use std::os::unix::ffi::OsStrExt; - use std::os::unix::fs::MetadataExt; + use std::os::unix::ffi::OsStrExt; + use std::os::unix::fs::MetadataExt; - let metadata = from.metadata()?; - set_permissions(&to, metadata.permissions())?; + let metadata = from.metadata()?; + set_permissions(&to, metadata.permissions())?; - // based on coreutils' chown https://github.com/uutils/coreutils/blob/72b4629916abe0852ad27286f4e307fbca546b6e/src/chown/chown.rs#L266-L281 - let s = std::ffi::CString::new(to.as_os_str().as_bytes()).unwrap(); - let ret = unsafe { libc::chown(s.as_ptr(), metadata.uid(), metadata.gid()) }; - if ret != 0 { - return Err(std::io::Error::last_os_error()); - } + // based on coreutils' chown https://github.com/uutils/coreutils/blob/72b4629916abe0852ad27286f4e307fbca546b6e/src/chown/chown.rs#L266-L281 + let s = std::ffi::CString::new(to.as_os_str().as_bytes()).unwrap(); + let ret = unsafe { libc::chown(s.as_ptr(), metadata.uid(), metadata.gid()) }; + if ret != 0 { + return Err(std::io::Error::last_os_error()); + } - Ok(()) + Ok(()) } diff --git a/cli/src/update_service.rs b/cli/src/update_service.rs index e2513f6ab80..f0d0acfeb6c 100644 --- a/cli/src/update_service.rs +++ b/cli/src/update_service.rs @@ -8,266 +8,266 @@ use std::path::Path; use serde::Deserialize; use crate::{ - constants::VSCODE_CLI_UPDATE_ENDPOINT, - debug, log, options, spanf, - util::{ - errors::{ - AnyError, StatusError, UnsupportedPlatformError, UpdatesNotConfigured, WrappedError, - }, - io::ReportCopyProgress, - }, + constants::VSCODE_CLI_UPDATE_ENDPOINT, + debug, log, options, spanf, + util::{ + errors::{ + AnyError, StatusError, UnsupportedPlatformError, UpdatesNotConfigured, WrappedError, + }, + io::ReportCopyProgress, + }, }; /// Implementation of the VS Code Update service for use in the CLI. pub struct UpdateService { - client: reqwest::Client, - log: log::Logger, + client: reqwest::Client, + log: log::Logger, } /// Describes a specific release, can be created manually or returned from the update service. pub struct Release { - pub platform: Platform, - pub target: TargetKind, - pub quality: options::Quality, - pub commit: String, + pub platform: Platform, + pub target: TargetKind, + pub quality: options::Quality, + pub commit: String, } #[derive(Deserialize)] struct UpdateServerVersion { - pub version: String, + pub version: String, } fn quality_download_segment(quality: options::Quality) -> &'static str { - match quality { - options::Quality::Stable => "stable", - options::Quality::Insiders => "insider", - options::Quality::Exploration => "exploration", - } + match quality { + options::Quality::Stable => "stable", + options::Quality::Insiders => "insider", + options::Quality::Exploration => "exploration", + } } impl UpdateService { - pub fn new(log: log::Logger, client: reqwest::Client) -> Self { - UpdateService { client, log } - } + pub fn new(log: log::Logger, client: reqwest::Client) -> Self { + UpdateService { client, log } + } - pub async fn get_release_by_semver_version( - &self, - platform: Platform, - target: TargetKind, - quality: options::Quality, - version: &str, - ) -> Result { - let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; - let download_segment = target - .download_segment(platform) - .ok_or(UnsupportedPlatformError())?; - let download_url = format!( - "{}/api/versions/{}/{}/{}", - update_endpoint, - version, - download_segment, - quality_download_segment(quality), - ); + pub async fn get_release_by_semver_version( + &self, + platform: Platform, + target: TargetKind, + quality: options::Quality, + version: &str, + ) -> Result { + let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; + let download_segment = target + .download_segment(platform) + .ok_or(UnsupportedPlatformError())?; + let download_url = format!( + "{}/api/versions/{}/{}/{}", + update_endpoint, + version, + download_segment, + quality_download_segment(quality), + ); - let response = spanf!( - self.log, - self.log.span("server.version.resolve"), - self.client.get(download_url).send() - )?; + let response = spanf!( + self.log, + self.log.span("server.version.resolve"), + self.client.get(download_url).send() + )?; - if !response.status().is_success() { - return Err(StatusError::from_res(response).await?.into()); - } + if !response.status().is_success() { + return Err(StatusError::from_res(response).await?.into()); + } - let res = response.json::().await?; - debug!(self.log, "Resolved version {} to {}", version, res.version); + let res = response.json::().await?; + debug!(self.log, "Resolved version {} to {}", version, res.version); - Ok(Release { - target, - platform, - quality, - commit: res.version, - }) - } + Ok(Release { + target, + platform, + quality, + commit: res.version, + }) + } - /// Gets the latest commit for the target of the given quality. - pub async fn get_latest_commit( - &self, - platform: Platform, - target: TargetKind, - quality: options::Quality, - ) -> Result { - let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; - let download_segment = target - .download_segment(platform) - .ok_or(UnsupportedPlatformError())?; - let download_url = format!( - "{}/api/latest/{}/{}", - update_endpoint, - download_segment, - quality_download_segment(quality), - ); + /// Gets the latest commit for the target of the given quality. + pub async fn get_latest_commit( + &self, + platform: Platform, + target: TargetKind, + quality: options::Quality, + ) -> Result { + let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; + let download_segment = target + .download_segment(platform) + .ok_or(UnsupportedPlatformError())?; + let download_url = format!( + "{}/api/latest/{}/{}", + update_endpoint, + download_segment, + quality_download_segment(quality), + ); - let response = spanf!( - self.log, - self.log.span("server.version.resolve"), - self.client.get(download_url).send() - )?; + let response = spanf!( + self.log, + self.log.span("server.version.resolve"), + self.client.get(download_url).send() + )?; - if !response.status().is_success() { - return Err(StatusError::from_res(response).await?.into()); - } + if !response.status().is_success() { + return Err(StatusError::from_res(response).await?.into()); + } - let res = response.json::().await?; - debug!(self.log, "Resolved quality {} to {}", quality, res.version); + let res = response.json::().await?; + debug!(self.log, "Resolved quality {} to {}", quality, res.version); - Ok(Release { - target, - platform, - quality, - commit: res.version, - }) - } + Ok(Release { + target, + platform, + quality, + commit: res.version, + }) + } - /// Gets the download stream for the release. - pub async fn get_download_stream( - &self, - release: &Release, - ) -> Result { - let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; - let download_segment = release - .target - .download_segment(release.platform) - .ok_or(UnsupportedPlatformError())?; + /// Gets the download stream for the release. + pub async fn get_download_stream( + &self, + release: &Release, + ) -> Result { + let update_endpoint = VSCODE_CLI_UPDATE_ENDPOINT.ok_or(UpdatesNotConfigured())?; + let download_segment = release + .target + .download_segment(release.platform) + .ok_or(UnsupportedPlatformError())?; - let download_url = format!( - "{}/commit:{}/{}/{}", - update_endpoint, - release.commit, - download_segment, - quality_download_segment(release.quality), - ); + let download_url = format!( + "{}/commit:{}/{}/{}", + update_endpoint, + release.commit, + download_segment, + quality_download_segment(release.quality), + ); - let response = reqwest::get(&download_url).await?; - if !response.status().is_success() { - return Err(StatusError::from_res(response).await?.into()); - } + let response = reqwest::get(&download_url).await?; + if !response.status().is_success() { + return Err(StatusError::from_res(response).await?.into()); + } - Ok(response) - } + Ok(response) + } } pub fn unzip_downloaded_release( - compressed_file: &Path, - target_dir: &Path, - reporter: T, + compressed_file: &Path, + target_dir: &Path, + reporter: T, ) -> Result<(), WrappedError> where - T: ReportCopyProgress, + T: ReportCopyProgress, { - #[cfg(any(target_os = "windows", target_os = "macos"))] - { - use crate::util::zipper; - zipper::unzip_file(compressed_file, target_dir, reporter) - } - #[cfg(target_os = "linux")] - { - use crate::util::tar; - tar::decompress_tarball(compressed_file, target_dir, reporter) - } + #[cfg(any(target_os = "windows", target_os = "macos"))] + { + use crate::util::zipper; + zipper::unzip_file(compressed_file, target_dir, reporter) + } + #[cfg(target_os = "linux")] + { + use crate::util::tar; + tar::decompress_tarball(compressed_file, target_dir, reporter) + } } #[derive(Eq, PartialEq, Copy, Clone)] pub enum TargetKind { - Server, - Archive, - Web, + Server, + Archive, + Web, } impl TargetKind { - fn download_segment(&self, platform: Platform) -> Option { - match *self { - TargetKind::Server => Some(platform.headless()), - TargetKind::Archive => platform.archive(), - TargetKind::Web => Some(platform.web()), - } - } + fn download_segment(&self, platform: Platform) -> Option { + match *self { + TargetKind::Server => Some(platform.headless()), + TargetKind::Archive => platform.archive(), + TargetKind::Web => Some(platform.web()), + } + } } #[derive(Debug, Copy, Clone)] pub enum Platform { - LinuxAlpineX64, - LinuxAlpineARM64, - LinuxX64, - LinuxARM64, - LinuxARM32, - DarwinX64, - DarwinARM64, - WindowsX64, - WindowsX86, + LinuxAlpineX64, + LinuxAlpineARM64, + LinuxX64, + LinuxARM64, + LinuxARM32, + DarwinX64, + DarwinARM64, + WindowsX64, + WindowsX86, } impl Platform { - pub fn archive(&self) -> Option { - match self { - Platform::LinuxX64 => Some("linux-x64".to_owned()), - Platform::LinuxARM64 => Some("linux-arm64".to_owned()), - Platform::LinuxARM32 => Some("linux-armhf".to_owned()), - Platform::DarwinX64 => Some("darwin".to_owned()), - Platform::DarwinARM64 => Some("darwin-arm64".to_owned()), - Platform::WindowsX64 => Some("win32-x64-archive".to_owned()), - Platform::WindowsX86 => Some("win32-archive".to_owned()), - _ => None, - } - } - pub fn headless(&self) -> String { - match self { - Platform::LinuxAlpineARM64 => "server-alpine-arm64", - Platform::LinuxAlpineX64 => "server-linux-alpine", - Platform::LinuxX64 => "server-linux-x64", - Platform::LinuxARM64 => "server-linux-arm64", - Platform::LinuxARM32 => "server-linux-armhf", - Platform::DarwinX64 => "server-darwin", - Platform::DarwinARM64 => "server-darwin-arm64", - Platform::WindowsX64 => "server-win32-x64", - Platform::WindowsX86 => "server-win32", - } - .to_owned() - } + pub fn archive(&self) -> Option { + match self { + Platform::LinuxX64 => Some("linux-x64".to_owned()), + Platform::LinuxARM64 => Some("linux-arm64".to_owned()), + Platform::LinuxARM32 => Some("linux-armhf".to_owned()), + Platform::DarwinX64 => Some("darwin".to_owned()), + Platform::DarwinARM64 => Some("darwin-arm64".to_owned()), + Platform::WindowsX64 => Some("win32-x64-archive".to_owned()), + Platform::WindowsX86 => Some("win32-archive".to_owned()), + _ => None, + } + } + pub fn headless(&self) -> String { + match self { + Platform::LinuxAlpineARM64 => "server-alpine-arm64", + Platform::LinuxAlpineX64 => "server-linux-alpine", + Platform::LinuxX64 => "server-linux-x64", + Platform::LinuxARM64 => "server-linux-arm64", + Platform::LinuxARM32 => "server-linux-armhf", + Platform::DarwinX64 => "server-darwin", + Platform::DarwinARM64 => "server-darwin-arm64", + Platform::WindowsX64 => "server-win32-x64", + Platform::WindowsX86 => "server-win32", + } + .to_owned() + } - pub fn web(&self) -> String { - format!("{}-web", self.headless()) - } + pub fn web(&self) -> String { + format!("{}-web", self.headless()) + } - pub fn env_default() -> Option { - if cfg!(all( - target_os = "linux", - target_arch = "x86_64", - target_env = "musl" - )) { - Some(Platform::LinuxAlpineX64) - } else if cfg!(all( - target_os = "linux", - target_arch = "aarch64", - target_env = "musl" - )) { - Some(Platform::LinuxAlpineARM64) - } else if cfg!(all(target_os = "linux", target_arch = "x86_64")) { - Some(Platform::LinuxX64) - } else if cfg!(all(target_os = "linux", target_arch = "armhf")) { - Some(Platform::LinuxARM32) - } else if cfg!(all(target_os = "linux", target_arch = "aarch64")) { - Some(Platform::LinuxARM64) - } else if cfg!(all(target_os = "macos", target_arch = "x86_64")) { - Some(Platform::DarwinX64) - } else if cfg!(all(target_os = "macos", target_arch = "aarch64")) { - Some(Platform::DarwinARM64) - } else if cfg!(all(target_os = "windows", target_arch = "x86_64")) { - Some(Platform::WindowsX64) - } else if cfg!(all(target_os = "windows", target_arch = "x86")) { - Some(Platform::WindowsX86) - } else { - None - } - } + pub fn env_default() -> Option { + if cfg!(all( + target_os = "linux", + target_arch = "x86_64", + target_env = "musl" + )) { + Some(Platform::LinuxAlpineX64) + } else if cfg!(all( + target_os = "linux", + target_arch = "aarch64", + target_env = "musl" + )) { + Some(Platform::LinuxAlpineARM64) + } else if cfg!(all(target_os = "linux", target_arch = "x86_64")) { + Some(Platform::LinuxX64) + } else if cfg!(all(target_os = "linux", target_arch = "armhf")) { + Some(Platform::LinuxARM32) + } else if cfg!(all(target_os = "linux", target_arch = "aarch64")) { + Some(Platform::LinuxARM64) + } else if cfg!(all(target_os = "macos", target_arch = "x86_64")) { + Some(Platform::DarwinX64) + } else if cfg!(all(target_os = "macos", target_arch = "aarch64")) { + Some(Platform::DarwinARM64) + } else if cfg!(all(target_os = "windows", target_arch = "x86_64")) { + Some(Platform::WindowsX64) + } else if cfg!(all(target_os = "windows", target_arch = "x86")) { + Some(Platform::WindowsX86) + } else { + None + } + } } diff --git a/cli/src/util/command.rs b/cli/src/util/command.rs index 61ead045a20..25b2116268b 100644 --- a/cli/src/util/command.rs +++ b/cli/src/util/command.rs @@ -7,71 +7,71 @@ use std::{ffi::OsStr, process::Stdio}; use tokio::process::Command; pub async fn capture_command( - command_str: A, - args: I, + command_str: A, + args: I, ) -> Result where - A: AsRef, - I: IntoIterator, - S: AsRef, + A: AsRef, + I: IntoIterator, + S: AsRef, { - Command::new(&command_str) - .args(args) - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .output() - .await - .map_err(|e| { - wrap( - e, - format!( - "failed to execute command '{}'", - (&command_str).as_ref().to_string_lossy() - ), - ) - }) + Command::new(&command_str) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .output() + .await + .map_err(|e| { + wrap( + e, + format!( + "failed to execute command '{}'", + (&command_str).as_ref().to_string_lossy() + ), + ) + }) } /// Kills and processes and all of its children. #[cfg(target_os = "windows")] pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> { - capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?; - Ok(()) + capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?; + Ok(()) } /// Kills and processes and all of its children. #[cfg(not(target_os = "windows"))] pub async fn kill_tree(process_id: u32) -> Result<(), WrappedError> { - use futures::future::join_all; - use tokio::io::{AsyncBufReadExt, BufReader}; + use futures::future::join_all; + use tokio::io::{AsyncBufReadExt, BufReader}; - async fn kill_single_pid(process_id_str: String) { - capture_command("kill", &[&process_id_str]).await.ok(); - } + async fn kill_single_pid(process_id_str: String) { + capture_command("kill", &[&process_id_str]).await.ok(); + } - // Rusty version of https://github.com/microsoft/vscode-js-debug/blob/main/src/targets/node/terminateProcess.sh + // Rusty version of https://github.com/microsoft/vscode-js-debug/blob/main/src/targets/node/terminateProcess.sh - let parent_id = process_id.to_string(); - let mut prgrep_cmd = Command::new("pgrep") - .arg("-P") - .arg(&parent_id) - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .spawn() - .map_err(|e| wrap(e, "error enumerating process tree"))?; + let parent_id = process_id.to_string(); + let mut prgrep_cmd = Command::new("pgrep") + .arg("-P") + .arg(&parent_id) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .spawn() + .map_err(|e| wrap(e, "error enumerating process tree"))?; - let mut kill_futures = vec![tokio::spawn( - async move { kill_single_pid(parent_id).await }, - )]; + let mut kill_futures = vec![tokio::spawn( + async move { kill_single_pid(parent_id).await }, + )]; - if let Some(stdout) = prgrep_cmd.stdout.take() { - let mut reader = BufReader::new(stdout).lines(); - while let Some(line) = reader.next_line().await.unwrap_or(None) { - kill_futures.push(tokio::spawn(async move { kill_single_pid(line).await })) - } - } + if let Some(stdout) = prgrep_cmd.stdout.take() { + let mut reader = BufReader::new(stdout).lines(); + while let Some(line) = reader.next_line().await.unwrap_or(None) { + kill_futures.push(tokio::spawn(async move { kill_single_pid(line).await })) + } + } - join_all(kill_futures).await; - prgrep_cmd.kill().await.ok(); - Ok(()) + join_all(kill_futures).await; + prgrep_cmd.kill().await.ok(); + Ok(()) } diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index 47f6a64e359..e7abac0d224 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -9,89 +9,89 @@ use crate::constants::CONTROL_PORT; // Wraps another error with additional info. #[derive(Debug, Clone)] pub struct WrappedError { - message: String, - original: String, + message: String, + original: String, } impl std::fmt::Display for WrappedError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}: {}", self.message, self.original) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {}", self.message, self.original) + } } impl std::error::Error for WrappedError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - None - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } } impl WrappedError { - // fn new(original: Box, message: String) -> WrappedError { - // WrappedError { message, original } - // } + // fn new(original: Box, message: String) -> WrappedError { + // WrappedError { message, original } + // } } impl From for WrappedError { - fn from(e: reqwest::Error) -> WrappedError { - WrappedError { - message: format!( - "error requesting {}", - e.url().map_or("", |u| u.as_str()) - ), - original: format!("{}", e), - } - } + fn from(e: reqwest::Error) -> WrappedError { + WrappedError { + message: format!( + "error requesting {}", + e.url().map_or("", |u| u.as_str()) + ), + original: format!("{}", e), + } + } } pub fn wrap(original: T, message: S) -> WrappedError where - T: Display, - S: Into, + T: Display, + S: Into, { - WrappedError { - message: message.into(), - original: format!("{}", original), - } + WrappedError { + message: message.into(), + original: format!("{}", original), + } } // Error generated by an unsuccessful HTTP response #[derive(Debug)] pub struct StatusError { - url: String, - status_code: u16, - body: String, + url: String, + status_code: u16, + body: String, } impl std::fmt::Display for StatusError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "error requesting {}: {} {}", - self.url, self.status_code, self.body - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "error requesting {}: {} {}", + self.url, self.status_code, self.body + ) + } } impl StatusError { - pub async fn from_res(res: reqwest::Response) -> Result { - let status_code = res.status().as_u16(); - let url = res.url().to_string(); - let body = res.text().await.map_err(|e| { - wrap( - e, - format!( - "failed to read response body on {} code from {}", - status_code, url - ), - ) - })?; + pub async fn from_res(res: reqwest::Response) -> Result { + let status_code = res.status().as_u16(); + let url = res.url().to_string(); + let body = res.text().await.map_err(|e| { + wrap( + e, + format!( + "failed to read response body on {} code from {}", + status_code, url + ), + ) + })?; - Ok(StatusError { - url, - status_code, - body, - }) - } + Ok(StatusError { + url, + status_code, + body, + }) + } } // When the user has not consented to the licensing terms in using the Launcher @@ -99,9 +99,9 @@ impl StatusError { pub struct MissingLegalConsent(pub String); impl std::fmt::Display for MissingLegalConsent { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } } // When the provided connection token doesn't match the one used to set up the original VS Code Server @@ -110,9 +110,9 @@ impl std::fmt::Display for MissingLegalConsent { pub struct MismatchConnectionToken(pub String); impl std::fmt::Display for MismatchConnectionToken { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } } // When the VS Code server has an unrecognized extension (rather than zip or gz) @@ -120,9 +120,9 @@ impl std::fmt::Display for MismatchConnectionToken { pub struct InvalidServerExtensionError(pub String); impl std::fmt::Display for InvalidServerExtensionError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "invalid server extension '{}'", self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "invalid server extension '{}'", self.0) + } } // When the tunnel fails to open @@ -130,15 +130,15 @@ impl std::fmt::Display for InvalidServerExtensionError { pub struct DevTunnelError(pub String); impl std::fmt::Display for DevTunnelError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "could not open tunnel: {}", self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "could not open tunnel: {}", self.0) + } } impl std::error::Error for DevTunnelError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - None - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } } // When the server was downloaded, but the entrypoint scripts don't exist. @@ -146,214 +146,214 @@ impl std::error::Error for DevTunnelError { pub struct MissingEntrypointError(); impl std::fmt::Display for MissingEntrypointError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Missing entrypoints in server download. Most likely this is a corrupted download. Please retry") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Missing entrypoints in server download. Most likely this is a corrupted download. Please retry") + } } #[derive(Debug)] pub struct SetupError(pub String); impl std::fmt::Display for SetupError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "{}\r\n\r\nMore info at https://code.visualstudio.com/docs/remote/linux", - self.0 - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}\r\n\r\nMore info at https://code.visualstudio.com/docs/remote/linux", + self.0 + ) + } } #[derive(Debug)] pub struct NoHomeForLauncherError(); impl std::fmt::Display for NoHomeForLauncherError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( f, "No $HOME variable was found in your environment. Either set it, or specify a `--data-dir` manually when invoking the launcher.", ) - } + } } #[derive(Debug)] pub struct InvalidTunnelName(pub String); impl std::fmt::Display for InvalidTunnelName { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", &self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", &self.0) + } } #[derive(Debug)] pub struct TunnelCreationFailed(pub String, pub String); impl std::fmt::Display for TunnelCreationFailed { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "Could not create tunnel with name: {}\nReason: {}", - &self.0, &self.1 - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Could not create tunnel with name: {}\nReason: {}", + &self.0, &self.1 + ) + } } #[derive(Debug)] pub struct TunnelHostFailed(pub String); impl std::fmt::Display for TunnelHostFailed { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", &self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", &self.0) + } } #[derive(Debug)] pub struct ExtensionInstallFailed(pub String); impl std::fmt::Display for ExtensionInstallFailed { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Extension install failed: {}", &self.0) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Extension install failed: {}", &self.0) + } } #[derive(Debug)] pub struct MismatchedLaunchModeError(); impl std::fmt::Display for MismatchedLaunchModeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "A server is already running, but it was not launched in the same listening mode (port vs. socket) as this request") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "A server is already running, but it was not launched in the same listening mode (port vs. socket) as this request") + } } #[derive(Debug)] pub struct NoAttachedServerError(); impl std::fmt::Display for NoAttachedServerError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "No server is running") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "No server is running") + } } #[derive(Debug)] pub struct ServerWriteError(); impl std::fmt::Display for ServerWriteError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Error writing to the server, it should be restarted") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Error writing to the server, it should be restarted") + } } #[derive(Debug)] pub struct RefreshTokenNotAvailableError(); impl std::fmt::Display for RefreshTokenNotAvailableError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Refresh token not available, authentication is required") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Refresh token not available, authentication is required") + } } #[derive(Debug)] pub struct UnsupportedPlatformError(); impl std::fmt::Display for UnsupportedPlatformError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "This operation is not supported on your current platform" - ) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "This operation is not supported on your current platform" + ) + } } #[derive(Debug)] pub struct NoInstallInUserProvidedPath(pub String); impl std::fmt::Display for NoInstallInUserProvidedPath { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( f, "No VS Code installation could be found in {}. You can run `code --use-quality=stable` to switch to the latest stable version of VS Code.", self.0 ) - } + } } #[derive(Debug)] pub struct InvalidRequestedVersion(); impl std::fmt::Display for InvalidRequestedVersion { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( f, "The reqested version is invalid, expected one of 'stable', 'insiders', version number (x.y.z), or absolute path.", ) - } + } } #[derive(Debug)] pub struct UserCancelledInstallation(); impl std::fmt::Display for UserCancelledInstallation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Installation aborted.") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Installation aborted.") + } } #[derive(Debug)] pub struct CannotForwardControlPort(); impl std::fmt::Display for CannotForwardControlPort { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Cannot forward or unforward port {}.", CONTROL_PORT) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Cannot forward or unforward port {}.", CONTROL_PORT) + } } #[derive(Debug)] pub struct ServerHasClosed(); impl std::fmt::Display for ServerHasClosed { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Request cancelled because the server has closed") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Request cancelled because the server has closed") + } } #[derive(Debug)] pub struct UpdatesNotConfigured(); impl std::fmt::Display for UpdatesNotConfigured { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Update service is not configured") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Update service is not configured") + } } #[derive(Debug)] pub struct ServiceAlreadyRegistered(); impl std::fmt::Display for ServiceAlreadyRegistered { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Already registered the service. Run `code tunnel service uninstall` to unregister it first") - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Already registered the service. Run `code tunnel service uninstall` to unregister it first") + } } #[derive(Debug)] pub struct WindowsNeedsElevation(pub String); impl std::fmt::Display for WindowsNeedsElevation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - writeln!(f, "{}", self.0)?; - writeln!(f)?; - writeln!(f, "You may need to run this command as an administrator:")?; - writeln!(f, " 1. Open the start menu and search for Powershell")?; - writeln!(f, " 2. Right click and 'Run as administrator'")?; - if let Ok(exe) = std::env::current_exe() { - writeln!( - f, - " 3. Run &'{}' '{}'", - exe.display(), - std::env::args().skip(1).collect::>().join("' '") - ) - } else { - writeln!(f, " 3. Run the same command again",) - } - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "{}", self.0)?; + writeln!(f)?; + writeln!(f, "You may need to run this command as an administrator:")?; + writeln!(f, " 1. Open the start menu and search for Powershell")?; + writeln!(f, " 2. Right click and 'Run as administrator'")?; + if let Ok(exe) = std::env::current_exe() { + writeln!( + f, + " 3. Run &'{}' '{}'", + exe.display(), + std::env::args().skip(1).collect::>().join("' '") + ) + } else { + writeln!(f, " 3. Run the same command again",) + } + } } // Makes an "AnyError" enum that contains any of the given errors, in the form @@ -392,36 +392,36 @@ macro_rules! makeAnyError { } makeAnyError!( - MissingLegalConsent, - MismatchConnectionToken, - DevTunnelError, - StatusError, - WrappedError, - InvalidServerExtensionError, - MissingEntrypointError, - SetupError, - NoHomeForLauncherError, - TunnelCreationFailed, - TunnelHostFailed, - InvalidTunnelName, - ExtensionInstallFailed, - MismatchedLaunchModeError, - NoAttachedServerError, - ServerWriteError, - UnsupportedPlatformError, - RefreshTokenNotAvailableError, - NoInstallInUserProvidedPath, - UserCancelledInstallation, - InvalidRequestedVersion, - CannotForwardControlPort, - ServerHasClosed, - ServiceAlreadyRegistered, - WindowsNeedsElevation, - UpdatesNotConfigured + MissingLegalConsent, + MismatchConnectionToken, + DevTunnelError, + StatusError, + WrappedError, + InvalidServerExtensionError, + MissingEntrypointError, + SetupError, + NoHomeForLauncherError, + TunnelCreationFailed, + TunnelHostFailed, + InvalidTunnelName, + ExtensionInstallFailed, + MismatchedLaunchModeError, + NoAttachedServerError, + ServerWriteError, + UnsupportedPlatformError, + RefreshTokenNotAvailableError, + NoInstallInUserProvidedPath, + UserCancelledInstallation, + InvalidRequestedVersion, + CannotForwardControlPort, + ServerHasClosed, + ServiceAlreadyRegistered, + WindowsNeedsElevation, + UpdatesNotConfigured ); impl From for AnyError { - fn from(e: reqwest::Error) -> AnyError { - AnyError::WrappedError(WrappedError::from(e)) - } + fn from(e: reqwest::Error) -> AnyError { + AnyError::WrappedError(WrappedError::from(e)) + } } diff --git a/cli/src/util/http.rs b/cli/src/util/http.rs index 91573221fce..2dfd5cbe65a 100644 --- a/cli/src/util/http.rs +++ b/cli/src/util/http.rs @@ -10,27 +10,27 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; use super::io::{copy_async_progress, ReportCopyProgress}; pub async fn download_into_file( - filename: &std::path::Path, - progress: T, - res: reqwest::Response, + filename: &std::path::Path, + progress: T, + res: reqwest::Response, ) -> Result where - T: ReportCopyProgress, + T: ReportCopyProgress, { - let mut file = fs::File::create(filename) - .await - .map_err(|e| errors::wrap(e, "failed to create file"))?; + let mut file = fs::File::create(filename) + .await + .map_err(|e| errors::wrap(e, "failed to create file"))?; - let content_length = res.content_length().unwrap_or(0); - let mut read = res - .bytes_stream() - .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) - .into_async_read() - .compat(); + let content_length = res.content_length().unwrap_or(0); + let mut read = res + .bytes_stream() + .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) + .into_async_read() + .compat(); - copy_async_progress(progress, &mut read, &mut file, content_length) - .await - .map_err(|e| errors::wrap(e, "failed to download file"))?; + copy_async_progress(progress, &mut read, &mut file, content_length) + .await + .map_err(|e| errors::wrap(e, "failed to download file"))?; - Ok(file) + Ok(file) } diff --git a/cli/src/util/input.rs b/cli/src/util/input.rs index ee25286f802..5f1acd7a266 100644 --- a/cli/src/util/input.rs +++ b/cli/src/util/input.rs @@ -11,59 +11,59 @@ use super::{errors::WrappedError, io::ReportCopyProgress}; /// Wrapper around indicatif::ProgressBar that implements ReportCopyProgress. pub struct ProgressBarReporter { - bar: ProgressBar, - has_set_total: bool, + bar: ProgressBar, + has_set_total: bool, } impl From for ProgressBarReporter { - fn from(bar: ProgressBar) -> Self { - ProgressBarReporter { - bar, - has_set_total: false, - } - } + fn from(bar: ProgressBar) -> Self { + ProgressBarReporter { + bar, + has_set_total: false, + } + } } impl ReportCopyProgress for ProgressBarReporter { - fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) { - if !self.has_set_total { - self.bar.set_length(total_bytes); - } + fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) { + if !self.has_set_total { + self.bar.set_length(total_bytes); + } - if bytes_so_far == total_bytes { - self.bar.finish_and_clear(); - } else { - self.bar.set_position(bytes_so_far); - } - } + if bytes_so_far == total_bytes { + self.bar.finish_and_clear(); + } else { + self.bar.set_position(bytes_so_far); + } + } } pub fn prompt_yn(text: &str) -> Result { - Confirm::with_theme(&ColorfulTheme::default()) - .with_prompt(text) - .default(true) - .interact() - .map_err(|e| wrap(e, "Failed to read confirm input")) + Confirm::with_theme(&ColorfulTheme::default()) + .with_prompt(text) + .default(true) + .interact() + .map_err(|e| wrap(e, "Failed to read confirm input")) } pub fn prompt_options(text: &str, options: &[T]) -> Result where - T: Display + Copy, + T: Display + Copy, { - let chosen = Select::with_theme(&ColorfulTheme::default()) - .with_prompt(text) - .items(options) - .default(0) - .interact() - .map_err(|e| wrap(e, "Failed to read select input"))?; + let chosen = Select::with_theme(&ColorfulTheme::default()) + .with_prompt(text) + .items(options) + .default(0) + .interact() + .map_err(|e| wrap(e, "Failed to read select input"))?; - Ok(options[chosen]) + Ok(options[chosen]) } pub fn prompt_placeholder(question: &str, placeholder: &str) -> Result { - Input::with_theme(&ColorfulTheme::default()) - .with_prompt(question) - .default(placeholder.to_string()) - .interact_text() - .map_err(|e| wrap(e, "Failed to read confirm input")) + Input::with_theme(&ColorfulTheme::default()) + .with_prompt(question) + .default(placeholder.to_string()) + .interact_text() + .map_err(|e| wrap(e, "Failed to read confirm input")) } diff --git a/cli/src/util/io.rs b/cli/src/util/io.rs index ed2c41dbcc0..c55a9135e12 100644 --- a/cli/src/util/io.rs +++ b/cli/src/util/io.rs @@ -7,53 +7,53 @@ use std::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; pub trait ReportCopyProgress { - fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64); + fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64); } /// Type that doesn't emit anything for download progress. pub struct SilentCopyProgress(); impl ReportCopyProgress for SilentCopyProgress { - fn report_progress(&mut self, _bytes_so_far: u64, _total_bytes: u64) {} + fn report_progress(&mut self, _bytes_so_far: u64, _total_bytes: u64) {} } /// Copies from the reader to the writer, reporting progress to the provided /// reporter every so often. pub async fn copy_async_progress( - mut reporter: T, - reader: &mut R, - writer: &mut W, - total_bytes: u64, + mut reporter: T, + reader: &mut R, + writer: &mut W, + total_bytes: u64, ) -> io::Result where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, - T: ReportCopyProgress, + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, + T: ReportCopyProgress, { - let mut buf = vec![0; 8 * 1024]; - let mut bytes_so_far = 0; - let mut bytes_last_reported = 0; - let report_granularity = std::cmp::min(total_bytes / 10, 2 * 1024 * 1024); + let mut buf = vec![0; 8 * 1024]; + let mut bytes_so_far = 0; + let mut bytes_last_reported = 0; + let report_granularity = std::cmp::min(total_bytes / 10, 2 * 1024 * 1024); - reporter.report_progress(0, total_bytes); + reporter.report_progress(0, total_bytes); - loop { - let read_buf = match reader.read(&mut buf).await { - Ok(0) => break, - Ok(n) => &buf[..n], - Err(e) => return Err(e), - }; + loop { + let read_buf = match reader.read(&mut buf).await { + Ok(0) => break, + Ok(n) => &buf[..n], + Err(e) => return Err(e), + }; - writer.write_all(read_buf).await?; + writer.write_all(read_buf).await?; - bytes_so_far += read_buf.len() as u64; - if bytes_so_far - bytes_last_reported > report_granularity { - bytes_last_reported = bytes_so_far; - reporter.report_progress(bytes_so_far, total_bytes); - } - } + bytes_so_far += read_buf.len() as u64; + if bytes_so_far - bytes_last_reported > report_granularity { + bytes_last_reported = bytes_so_far; + reporter.report_progress(bytes_so_far, total_bytes); + } + } - reporter.report_progress(bytes_so_far, total_bytes); + reporter.report_progress(bytes_so_far, total_bytes); - Ok(bytes_so_far) + Ok(bytes_so_far) } diff --git a/cli/src/util/machine.rs b/cli/src/util/machine.rs index b64c2b8029d..c35a73f222e 100644 --- a/cli/src/util/machine.rs +++ b/cli/src/util/machine.rs @@ -7,72 +7,72 @@ use std::path::Path; use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; pub fn process_at_path_exists(pid: u32, name: &Path) -> bool { - // TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage - let mut sys = System::new_all(); - sys.refresh_processes(); + // TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage + let mut sys = System::new_all(); + sys.refresh_processes(); - let name_str = format!("{}", name.display()); - match sys.process(Pid::from_u32(pid)) { - Some(process) => { - for cmd in process.cmd() { - if cmd.contains(&name_str) { - return true; - } - } - } - None => { - return false; - } - } + let name_str = format!("{}", name.display()); + match sys.process(Pid::from_u32(pid)) { + Some(process) => { + for cmd in process.cmd() { + if cmd.contains(&name_str) { + return true; + } + } + } + None => { + return false; + } + } - false + false } pub fn process_exists(pid: u32) -> bool { - let mut sys = System::new_all(); - sys.refresh_processes(); - sys.process(Pid::from_u32(pid)).is_some() + let mut sys = System::new_all(); + sys.refresh_processes(); + sys.process(Pid::from_u32(pid)).is_some() } pub fn find_running_process(name: &Path) -> Option { - // TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage - let mut sys = System::new_all(); - sys.refresh_processes(); + // TODO https://docs.rs/sysinfo/latest/sysinfo/index.html#usage + let mut sys = System::new_all(); + sys.refresh_processes(); - let name_str = format!("{}", name.display()); + let name_str = format!("{}", name.display()); - for (pid, process) in sys.processes() { - for cmd in process.cmd() { - if cmd.contains(&name_str) { - return Some(pid.as_u32()); - } - } - } - None + for (pid, process) in sys.processes() { + for cmd in process.cmd() { + if cmd.contains(&name_str) { + return Some(pid.as_u32()); + } + } + } + None } #[cfg(not(target_family = "unix"))] pub async fn set_executable_permission>( - _file: P, + _file: P, ) -> Result<(), errors::WrappedError> { - Ok(()) + Ok(()) } #[cfg(target_family = "unix")] pub async fn set_executable_permission>( - file: P, + file: P, ) -> Result<(), errors::WrappedError> { - use std::os::unix::prelude::PermissionsExt; + use std::os::unix::prelude::PermissionsExt; - let mut permissions = tokio::fs::metadata(&file) - .await - .map_err(|e| errors::wrap(e, "failed to read executable file metadata"))? - .permissions(); + let mut permissions = tokio::fs::metadata(&file) + .await + .map_err(|e| errors::wrap(e, "failed to read executable file metadata"))? + .permissions(); - permissions.set_mode(0o750); + permissions.set_mode(0o750); - tokio::fs::set_permissions(&file, permissions) - .await - .map_err(|e| errors::wrap(e, "failed to set executable permissions"))?; + tokio::fs::set_permissions(&file, permissions) + .await + .map_err(|e| errors::wrap(e, "failed to set executable permissions"))?; - Ok(()) + Ok(()) } diff --git a/cli/src/util/prereqs.rs b/cli/src/util/prereqs.rs index 1918a3a5e0c..86d5020efa1 100644 --- a/cli/src/util/prereqs.rs +++ b/cli/src/util/prereqs.rs @@ -15,245 +15,245 @@ use tokio::fs; use super::errors::AnyError; lazy_static! { - static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap(); - static ref LDD_VERSION_RE: BinRegex = BinRegex::new(r"^ldd.*(.+)\.(.+)\s").unwrap(); - static ref LIBSTD_CXX_VERSION_RE: BinRegex = - BinRegex::new(r"GLIBCXX_([0-9]+)\.([0-9]+)(?:\.([0-9]+))?").unwrap(); - static ref MIN_CXX_VERSION: SimpleSemver = SimpleSemver::new(3, 4, 18); - static ref MIN_LDD_VERSION: SimpleSemver = SimpleSemver::new(2, 17, 0); + static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap(); + static ref LDD_VERSION_RE: BinRegex = BinRegex::new(r"^ldd.*(.+)\.(.+)\s").unwrap(); + static ref LIBSTD_CXX_VERSION_RE: BinRegex = + BinRegex::new(r"GLIBCXX_([0-9]+)\.([0-9]+)(?:\.([0-9]+))?").unwrap(); + static ref MIN_CXX_VERSION: SimpleSemver = SimpleSemver::new(3, 4, 18); + static ref MIN_LDD_VERSION: SimpleSemver = SimpleSemver::new(2, 17, 0); } pub struct PreReqChecker {} impl Default for PreReqChecker { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl PreReqChecker { - pub fn new() -> PreReqChecker { - PreReqChecker {} - } + pub fn new() -> PreReqChecker { + PreReqChecker {} + } - #[cfg(not(target_os = "linux"))] - pub async fn verify(&self) -> Result { - Platform::env_default().ok_or_else(|| { - SetupError("VS Code it not supported on this platform".to_owned()).into() - }) - } + #[cfg(not(target_os = "linux"))] + pub async fn verify(&self) -> Result { + Platform::env_default().ok_or_else(|| { + SetupError("VS Code it not supported on this platform".to_owned()).into() + }) + } - #[cfg(target_os = "linux")] - pub async fn verify(&self) -> Result { - let (gnu_a, gnu_b, or_musl) = tokio::join!( - check_glibc_version(), - check_glibcxx_version(), - check_musl_interpreter() - ); + #[cfg(target_os = "linux")] + pub async fn verify(&self) -> Result { + let (gnu_a, gnu_b, or_musl) = tokio::join!( + check_glibc_version(), + check_glibcxx_version(), + check_musl_interpreter() + ); - if gnu_a.is_ok() && gnu_b.is_ok() { - return Ok(if cfg!(target_arch = "x86_64") { - Platform::LinuxX64 - } else if cfg!(target_arch = "armhf") { - Platform::LinuxARM32 - } else { - Platform::LinuxARM64 - }); - } + if gnu_a.is_ok() && gnu_b.is_ok() { + return Ok(if cfg!(target_arch = "x86_64") { + Platform::LinuxX64 + } else if cfg!(target_arch = "armhf") { + Platform::LinuxARM32 + } else { + Platform::LinuxARM64 + }); + } - if or_musl.is_ok() { - return Ok(if cfg!(target_arch = "x86_64") { - Platform::LinuxAlpineX64 - } else { - Platform::LinuxAlpineARM64 - }); - } + if or_musl.is_ok() { + return Ok(if cfg!(target_arch = "x86_64") { + Platform::LinuxAlpineX64 + } else { + Platform::LinuxAlpineARM64 + }); + } - let mut errors: Vec = vec![]; - if let Err(e) = gnu_a { - errors.push(e); - } else if let Err(e) = gnu_b { - errors.push(e); - } + let mut errors: Vec = vec![]; + if let Err(e) = gnu_a { + errors.push(e); + } else if let Err(e) = gnu_b { + errors.push(e); + } - if let Err(e) = or_musl { - errors.push(e); - } + if let Err(e) = or_musl { + errors.push(e); + } - let bullets = errors - .iter() - .map(|e| format!(" - {}", e)) - .collect::>() - .join("\n"); + let bullets = errors + .iter() + .map(|e| format!(" - {}", e)) + .collect::>() + .join("\n"); - Err(AnyError::from(SetupError(format!( - "This machine not meet VS Code Server's prerequisites, expected either...\n{}", - bullets, - )))) - } + Err(AnyError::from(SetupError(format!( + "This machine not meet VS Code Server's prerequisites, expected either...\n{}", + bullets, + )))) + } } #[allow(dead_code)] async fn check_musl_interpreter() -> Result<(), String> { - const MUSL_PATH: &str = if cfg!(target_platform = "aarch64") { - "/lib/ld-musl-aarch64.so.1" - } else { - "/lib/ld-musl-x86_64.so.1" - }; + const MUSL_PATH: &str = if cfg!(target_platform = "aarch64") { + "/lib/ld-musl-aarch64.so.1" + } else { + "/lib/ld-musl-x86_64.so.1" + }; - if fs::metadata(MUSL_PATH).await.is_err() { - return Err(format!( - "find {}, which is required to run the VS Code Server in musl environments", - MUSL_PATH - )); - } + if fs::metadata(MUSL_PATH).await.is_err() { + return Err(format!( + "find {}, which is required to run the VS Code Server in musl environments", + MUSL_PATH + )); + } - Ok(()) + Ok(()) } #[allow(dead_code)] async fn check_glibc_version() -> Result<(), String> { - let ldd_version = capture_command("ldd", ["--version"]) - .await - .ok() - .and_then(|o| extract_ldd_version(&o.stdout)); + let ldd_version = capture_command("ldd", ["--version"]) + .await + .ok() + .and_then(|o| extract_ldd_version(&o.stdout)); - if let Some(v) = ldd_version { - return if v.gte(&MIN_LDD_VERSION) { - Ok(()) - } else { - Err(format!( - "find GLIBC >= 2.17 (but found {} instead) for GNU environments", - v - )) - }; - } + if let Some(v) = ldd_version { + return if v.gte(&MIN_LDD_VERSION) { + Ok(()) + } else { + Err(format!( + "find GLIBC >= 2.17 (but found {} instead) for GNU environments", + v + )) + }; + } - Ok(()) + Ok(()) } #[allow(dead_code)] async fn check_glibcxx_version() -> Result<(), String> { - let mut libstdc_path: Option = None; + let mut libstdc_path: Option = None; - const DEFAULT_LIB_PATH: &str = "/usr/lib64/libstdc++.so.6"; - const LDCONFIG_PATH: &str = "/sbin/ldconfig"; + const DEFAULT_LIB_PATH: &str = "/usr/lib64/libstdc++.so.6"; + const LDCONFIG_PATH: &str = "/sbin/ldconfig"; - if fs::metadata(DEFAULT_LIB_PATH).await.is_ok() { - libstdc_path = Some(DEFAULT_LIB_PATH.to_owned()); - } else if fs::metadata(LDCONFIG_PATH).await.is_ok() { - libstdc_path = capture_command(LDCONFIG_PATH, ["-p"]) - .await - .ok() - .and_then(|o| extract_libstd_from_ldconfig(&o.stdout)); - } + if fs::metadata(DEFAULT_LIB_PATH).await.is_ok() { + libstdc_path = Some(DEFAULT_LIB_PATH.to_owned()); + } else if fs::metadata(LDCONFIG_PATH).await.is_ok() { + libstdc_path = capture_command(LDCONFIG_PATH, ["-p"]) + .await + .ok() + .and_then(|o| extract_libstd_from_ldconfig(&o.stdout)); + } - match libstdc_path { - Some(path) => match fs::read(&path).await { - Ok(contents) => check_for_sufficient_glibcxx_versions(contents), - Err(e) => Err(format!( - "validate GLIBCXX version for GNU environments, but could not: {}", - e - )), - }, - None => Err("find libstdc++.so or ldconfig for GNU environments".to_owned()), - } + match libstdc_path { + Some(path) => match fs::read(&path).await { + Ok(contents) => check_for_sufficient_glibcxx_versions(contents), + Err(e) => Err(format!( + "validate GLIBCXX version for GNU environments, but could not: {}", + e + )), + }, + None => Err("find libstdc++.so or ldconfig for GNU environments".to_owned()), + } } #[allow(dead_code)] fn check_for_sufficient_glibcxx_versions(contents: Vec) -> Result<(), String> { - let all_versions: Vec = LIBSTD_CXX_VERSION_RE - .captures_iter(&contents) - .map(|m| SimpleSemver { - major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())), - minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())), - patch: m.get(3).map_or(0, |s| u32_from_bytes(s.as_bytes())), - }) - .collect(); + let all_versions: Vec = LIBSTD_CXX_VERSION_RE + .captures_iter(&contents) + .map(|m| SimpleSemver { + major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())), + minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())), + patch: m.get(3).map_or(0, |s| u32_from_bytes(s.as_bytes())), + }) + .collect(); - if !all_versions.iter().any(|v| MIN_CXX_VERSION.gte(v)) { - return Err(format!( - "find GLIBCXX >= 3.4.18 (but found {} instead) for GNU environments", - all_versions - .iter() - .map(String::from) - .collect::>() - .join(", ") - )); - } + if !all_versions.iter().any(|v| MIN_CXX_VERSION.gte(v)) { + return Err(format!( + "find GLIBCXX >= 3.4.18 (but found {} instead) for GNU environments", + all_versions + .iter() + .map(String::from) + .collect::>() + .join(", ") + )); + } - Ok(()) + Ok(()) } fn extract_ldd_version(output: &[u8]) -> Option { - LDD_VERSION_RE.captures(output).map(|m| SimpleSemver { - major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())), - minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())), - patch: 0, - }) + LDD_VERSION_RE.captures(output).map(|m| SimpleSemver { + major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())), + minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())), + patch: 0, + }) } fn extract_libstd_from_ldconfig(output: &[u8]) -> Option { - String::from_utf8_lossy(output) - .lines() - .find_map(|l| LDCONFIG_STDC_RE.captures(l)) - .and_then(|cap| cap.get(1)) - .map(|cap| cap.as_str().to_owned()) + String::from_utf8_lossy(output) + .lines() + .find_map(|l| LDCONFIG_STDC_RE.captures(l)) + .and_then(|cap| cap.get(1)) + .map(|cap| cap.as_str().to_owned()) } fn u32_from_bytes(b: &[u8]) -> u32 { - String::from_utf8_lossy(b).parse::().unwrap_or(0) + String::from_utf8_lossy(b).parse::().unwrap_or(0) } #[derive(Debug, PartialEq)] struct SimpleSemver { - major: u32, - minor: u32, - patch: u32, + major: u32, + minor: u32, + patch: u32, } impl From<&SimpleSemver> for String { - fn from(s: &SimpleSemver) -> Self { - format!("v{}.{}.{}", s.major, s.minor, s.patch) - } + fn from(s: &SimpleSemver) -> Self { + format!("v{}.{}.{}", s.major, s.minor, s.patch) + } } impl std::fmt::Display for SimpleSemver { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", String::from(self)) - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", String::from(self)) + } } #[allow(dead_code)] impl SimpleSemver { - fn new(major: u32, minor: u32, patch: u32) -> SimpleSemver { - SimpleSemver { - major, - minor, - patch, - } - } + fn new(major: u32, minor: u32, patch: u32) -> SimpleSemver { + SimpleSemver { + major, + minor, + patch, + } + } - fn gte(&self, other: &SimpleSemver) -> bool { - match self.major.cmp(&other.major) { - Ordering::Greater => true, - Ordering::Less => false, - Ordering::Equal => match self.minor.cmp(&other.minor) { - Ordering::Greater => true, - Ordering::Less => false, - Ordering::Equal => self.patch >= other.patch, - }, - } - } + fn gte(&self, other: &SimpleSemver) -> bool { + match self.major.cmp(&other.major) { + Ordering::Greater => true, + Ordering::Less => false, + Ordering::Equal => match self.minor.cmp(&other.minor) { + Ordering::Greater => true, + Ordering::Less => false, + Ordering::Equal => self.patch >= other.patch, + }, + } + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_extract_libstd_from_ldconfig() { - let actual = " + #[test] + fn test_extract_libstd_from_ldconfig() { + let actual = " libstoken.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstoken.so.1 libstemmer.so.0d (libc6,x86-64) => /lib/x86_64-linux-gnu/libstemmer.so.0d libstdc++.so.6 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstdc++.so.6 @@ -261,41 +261,41 @@ mod tests { libssl3.so (libc6,x86-64) => /lib/x86_64-linux-gnu/libssl3.so ".to_owned().into_bytes(); - assert_eq!( - extract_libstd_from_ldconfig(&actual), - Some("/lib/x86_64-linux-gnu/libstdc++.so.6".to_owned()), - ); + assert_eq!( + extract_libstd_from_ldconfig(&actual), + Some("/lib/x86_64-linux-gnu/libstdc++.so.6".to_owned()), + ); - assert_eq!( - extract_libstd_from_ldconfig(&"nothing here!".to_owned().into_bytes()), - None, - ); - } + assert_eq!( + extract_libstd_from_ldconfig(&"nothing here!".to_owned().into_bytes()), + None, + ); + } - #[test] - fn test_gte() { - assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 3))); - assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(0, 10, 10))); - assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 1, 10))); + #[test] + fn test_gte() { + assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 3))); + assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(0, 10, 10))); + assert!(SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 1, 10))); - assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 10))); - assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 3, 1))); - assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(2, 2, 1))); - } + assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 2, 10))); + assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(1, 3, 1))); + assert!(!SimpleSemver::new(1, 2, 3).gte(&SimpleSemver::new(2, 2, 1))); + } - #[test] - fn check_for_sufficient_glibcxx_versions() { - let actual = "ldd (Ubuntu GLIBC 2.31-0ubuntu9.7) 2.31 + #[test] + fn check_for_sufficient_glibcxx_versions() { + let actual = "ldd (Ubuntu GLIBC 2.31-0ubuntu9.7) 2.31 Copyright (C) 2020 Free Software Foundation, Inc. This is free software; see the source for copying conditions. There is NO warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. Written by Roland McGrath and Ulrich Drepper." - .to_owned() - .into_bytes(); + .to_owned() + .into_bytes(); - assert_eq!( - extract_ldd_version(&actual), - Some(SimpleSemver::new(2, 31, 0)), - ); - } + assert_eq!( + extract_ldd_version(&actual), + Some(SimpleSemver::new(2, 31, 0)), + ); + } } diff --git a/cli/src/util/sync.rs b/cli/src/util/sync.rs index 9dc521f44c5..6e069215e74 100644 --- a/cli/src/util/sync.rs +++ b/cli/src/util/sync.rs @@ -3,87 +3,87 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ use tokio::sync::watch::{ - self, - error::{RecvError, SendError}, + self, + error::{RecvError, SendError}, }; #[derive(Clone)] pub struct Barrier(watch::Receiver>) where - T: Copy; + T: Copy; impl Barrier where - T: Copy, + T: Copy, { - /// Waits for the barrier to be closed, returning a value if one was sent. - pub async fn wait(&mut self) -> Result { - loop { - if let Err(e) = self.0.changed().await { - return Err(e); - } + /// Waits for the barrier to be closed, returning a value if one was sent. + pub async fn wait(&mut self) -> Result { + loop { + if let Err(e) = self.0.changed().await { + return Err(e); + } - if let Some(v) = *(self.0.borrow()) { - return Ok(v); - } - } - } + if let Some(v) = *(self.0.borrow()) { + return Ok(v); + } + } + } } pub struct BarrierOpener(watch::Sender>); impl BarrierOpener { - /// Closes the barrier. - pub fn open(self, value: T) -> Result<(), SendError>> { - self.0.send(Some(value)) - } + /// Closes the barrier. + pub fn open(self, value: T) -> Result<(), SendError>> { + self.0.send(Some(value)) + } } /// The Barrier is something that can be opened once from one side, /// and is thereafter permanently closed. It can contain a value. pub fn new_barrier() -> (Barrier, BarrierOpener) where - T: Copy, + T: Copy, { - let (closed_tx, closed_rx) = watch::channel(None); - (Barrier(closed_rx), BarrierOpener(closed_tx)) + let (closed_tx, closed_rx) = watch::channel(None); + (Barrier(closed_rx), BarrierOpener(closed_tx)) } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[tokio::test] - async fn test_barrier_close_after_spawn() { - let (mut barrier, opener) = new_barrier::(); - let (tx, rx) = tokio::sync::oneshot::channel::(); + #[tokio::test] + async fn test_barrier_close_after_spawn() { + let (mut barrier, opener) = new_barrier::(); + let (tx, rx) = tokio::sync::oneshot::channel::(); - tokio::spawn(async move { - tx.send(barrier.wait().await.unwrap()).unwrap(); - }); + tokio::spawn(async move { + tx.send(barrier.wait().await.unwrap()).unwrap(); + }); - opener.open(42).unwrap(); + opener.open(42).unwrap(); - assert!(rx.await.unwrap() == 42); - } + assert!(rx.await.unwrap() == 42); + } - #[tokio::test] - async fn test_barrier_close_before_spawn() { - let (barrier, opener) = new_barrier::(); - let (tx1, rx1) = tokio::sync::oneshot::channel::(); - let (tx2, rx2) = tokio::sync::oneshot::channel::(); + #[tokio::test] + async fn test_barrier_close_before_spawn() { + let (barrier, opener) = new_barrier::(); + let (tx1, rx1) = tokio::sync::oneshot::channel::(); + let (tx2, rx2) = tokio::sync::oneshot::channel::(); - opener.open(42).unwrap(); - let mut b1 = barrier.clone(); - tokio::spawn(async move { - tx1.send(b1.wait().await.unwrap()).unwrap(); - }); - let mut b2 = barrier.clone(); - tokio::spawn(async move { - tx2.send(b2.wait().await.unwrap()).unwrap(); - }); + opener.open(42).unwrap(); + let mut b1 = barrier.clone(); + tokio::spawn(async move { + tx1.send(b1.wait().await.unwrap()).unwrap(); + }); + let mut b2 = barrier.clone(); + tokio::spawn(async move { + tx2.send(b2.wait().await.unwrap()).unwrap(); + }); - assert!(rx1.await.unwrap() == 42); - assert!(rx2.await.unwrap() == 42); - } + assert!(rx1.await.unwrap() == 42); + assert!(rx2.await.unwrap() == 42); + } } diff --git a/cli/src/util/tar.rs b/cli/src/util/tar.rs index c55f3555e1b..e18927b16c3 100644 --- a/cli/src/util/tar.rs +++ b/cli/src/util/tar.rs @@ -12,41 +12,41 @@ use tar::Archive; use super::io::ReportCopyProgress; pub fn decompress_tarball( - path: &Path, - parent_path: &Path, - mut reporter: T, + path: &Path, + parent_path: &Path, + mut reporter: T, ) -> Result<(), WrappedError> where - T: ReportCopyProgress, + T: ReportCopyProgress, { - let tar_gz = File::open(path).map_err(|e| { - wrap( - Box::new(e), - format!("error opening file {}", path.display()), - ) - })?; - let tar = GzDecoder::new(tar_gz); - let mut archive = Archive::new(tar); + let tar_gz = File::open(path).map_err(|e| { + wrap( + Box::new(e), + format!("error opening file {}", path.display()), + ) + })?; + let tar = GzDecoder::new(tar_gz); + let mut archive = Archive::new(tar); - let results = archive - .entries() - .map_err(|e| wrap(e, format!("error opening archive {}", path.display())))? - .filter_map(|e| e.ok()) - .map(|mut entry| { - let entry_path = entry - .path() - .map_err(|e| wrap(e, "error reading entry path"))?; + let results = archive + .entries() + .map_err(|e| wrap(e, format!("error opening archive {}", path.display())))? + .filter_map(|e| e.ok()) + .map(|mut entry| { + let entry_path = entry + .path() + .map_err(|e| wrap(e, "error reading entry path"))?; - let path = parent_path.join(entry_path.iter().skip(1).collect::()); - entry - .unpack(&path) - .map_err(|e| wrap(e, format!("error unpacking {}", path.display())))?; - Ok(path) - }) - .collect::, WrappedError>>()?; + let path = parent_path.join(entry_path.iter().skip(1).collect::()); + entry + .unpack(&path) + .map_err(|e| wrap(e, format!("error unpacking {}", path.display())))?; + Ok(path) + }) + .collect::, WrappedError>>()?; - // Tarballs don't have a way to get the number of entries ahead of time - reporter.report_progress(results.len() as u64, results.len() as u64); + // Tarballs don't have a way to get the number of entries ahead of time + reporter.report_progress(results.len() as u64, results.len() as u64); - Ok(()) + Ok(()) } diff --git a/cli/src/util/zipper.rs b/cli/src/util/zipper.rs index 15ac0e60c86..a9106fd6b6c 100644 --- a/cli/src/util/zipper.rs +++ b/cli/src/util/zipper.rs @@ -16,140 +16,140 @@ use zip::{self, ZipArchive}; /// Returns whether all files in the archive start with the same path segment. /// If so, it's an indication we should skip that segment when extracting. fn should_skip_first_segment(archive: &mut ZipArchive) -> bool { - let first_name = { - let file = archive - .by_index_raw(0) - .expect("expected not to have an empty archive"); + let first_name = { + let file = archive + .by_index_raw(0) + .expect("expected not to have an empty archive"); - let path = file - .enclosed_name() - .expect("expected to have path") - .iter() - .next() - .expect("expected to have non-empty name"); + let path = file + .enclosed_name() + .expect("expected to have path") + .iter() + .next() + .expect("expected to have non-empty name"); - path.to_owned() - }; + path.to_owned() + }; - for i in 1..archive.len() { - if let Ok(file) = archive.by_index_raw(i) { - if let Some(name) = file.enclosed_name() { - if name.iter().next() != Some(&first_name) { - return false; - } - } - } - } + for i in 1..archive.len() { + if let Ok(file) = archive.by_index_raw(i) { + if let Some(name) = file.enclosed_name() { + if name.iter().next() != Some(&first_name) { + return false; + } + } + } + } - true + true } pub fn unzip_file(path: &Path, parent_path: &Path, mut reporter: T) -> Result<(), WrappedError> where - T: ReportCopyProgress, + T: ReportCopyProgress, { - let file = fs::File::open(path) - .map_err(|e| wrap(e, format!("unable to open file {}", path.display())))?; + let file = fs::File::open(path) + .map_err(|e| wrap(e, format!("unable to open file {}", path.display())))?; - let mut archive = zip::ZipArchive::new(file) - .map_err(|e| wrap(e, format!("failed to open zip archive {}", path.display())))?; + let mut archive = zip::ZipArchive::new(file) + .map_err(|e| wrap(e, format!("failed to open zip archive {}", path.display())))?; - let skip_segments_no = if should_skip_first_segment(&mut archive) { - 1 - } else { - 0 - }; + let skip_segments_no = if should_skip_first_segment(&mut archive) { + 1 + } else { + 0 + }; - for i in 0..archive.len() { - reporter.report_progress(i as u64, archive.len() as u64); - let mut file = archive - .by_index(i) - .map_err(|e| wrap(e, format!("could not open zip entry {}", i)))?; + for i in 0..archive.len() { + reporter.report_progress(i as u64, archive.len() as u64); + let mut file = archive + .by_index(i) + .map_err(|e| wrap(e, format!("could not open zip entry {}", i)))?; - let outpath: PathBuf = match file.enclosed_name() { - Some(path) => { - let mut full_path = PathBuf::from(parent_path); - full_path.push(PathBuf::from_iter(path.iter().skip(skip_segments_no))); - full_path - } - None => continue, - }; + let outpath: PathBuf = match file.enclosed_name() { + Some(path) => { + let mut full_path = PathBuf::from(parent_path); + full_path.push(PathBuf::from_iter(path.iter().skip(skip_segments_no))); + full_path + } + None => continue, + }; - if file.is_dir() || file.name().ends_with('/') { - fs::create_dir_all(&outpath) - .map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?; - apply_permissions(&file, &outpath)?; - continue; - } + if file.is_dir() || file.name().ends_with('/') { + fs::create_dir_all(&outpath) + .map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?; + apply_permissions(&file, &outpath)?; + continue; + } - if let Some(p) = outpath.parent() { - fs::create_dir_all(&p) - .map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?; - } + if let Some(p) = outpath.parent() { + fs::create_dir_all(&p) + .map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?; + } - #[cfg(unix)] - { - use libc::S_IFLNK; - use std::io::Read; - use std::os::unix::ffi::OsStringExt; + #[cfg(unix)] + { + use libc::S_IFLNK; + use std::io::Read; + use std::os::unix::ffi::OsStringExt; - if matches!(file.unix_mode(), Some(mode) if mode & (S_IFLNK as u32) == (S_IFLNK as u32)) - { - let mut link_to = Vec::new(); - file.read_to_end(&mut link_to).map_err(|e| { - wrap( - e, - format!("could not read symlink linkpath {}", outpath.display()), - ) - })?; + if matches!(file.unix_mode(), Some(mode) if mode & (S_IFLNK as u32) == (S_IFLNK as u32)) + { + let mut link_to = Vec::new(); + file.read_to_end(&mut link_to).map_err(|e| { + wrap( + e, + format!("could not read symlink linkpath {}", outpath.display()), + ) + })?; - let link_path = PathBuf::from(std::ffi::OsString::from_vec(link_to)); - std::os::unix::fs::symlink(link_path, &outpath).map_err(|e| { - wrap(e, format!("could not create symlink {}", outpath.display())) - })?; - continue; - } - } + let link_path = PathBuf::from(std::ffi::OsString::from_vec(link_to)); + std::os::unix::fs::symlink(link_path, &outpath).map_err(|e| { + wrap(e, format!("could not create symlink {}", outpath.display())) + })?; + continue; + } + } - let mut outfile = fs::File::create(&outpath).map_err(|e| { - wrap( - e, - format!( - "unable to open file to write {} (from {:?})", - outpath.display(), - file.enclosed_name().map(|p| p.to_string_lossy()), - ), - ) - })?; + let mut outfile = fs::File::create(&outpath).map_err(|e| { + wrap( + e, + format!( + "unable to open file to write {} (from {:?})", + outpath.display(), + file.enclosed_name().map(|p| p.to_string_lossy()), + ), + ) + })?; - io::copy(&mut file, &mut outfile) - .map_err(|e| wrap(e, format!("error copying file {}", outpath.display())))?; + io::copy(&mut file, &mut outfile) + .map_err(|e| wrap(e, format!("error copying file {}", outpath.display())))?; - apply_permissions(&file, &outpath)?; - } + apply_permissions(&file, &outpath)?; + } - reporter.report_progress(archive.len() as u64, archive.len() as u64); + reporter.report_progress(archive.len() as u64, archive.len() as u64); - Ok(()) + Ok(()) } #[cfg(unix)] fn apply_permissions(file: &ZipFile, outpath: &Path) -> Result<(), WrappedError> { - use std::os::unix::fs::PermissionsExt; + use std::os::unix::fs::PermissionsExt; - if let Some(mode) = file.unix_mode() { - fs::set_permissions(&outpath, fs::Permissions::from_mode(mode)).map_err(|e| { - wrap( - e, - format!("error setting permissions on {}", outpath.display()), - ) - })?; - } + if let Some(mode) = file.unix_mode() { + fs::set_permissions(&outpath, fs::Permissions::from_mode(mode)).map_err(|e| { + wrap( + e, + format!("error setting permissions on {}", outpath.display()), + ) + })?; + } - Ok(()) + Ok(()) } #[cfg(windows)] fn apply_permissions(_file: &ZipFile, _outpath: &Path) -> Result<(), WrappedError> { - Ok(()) + Ok(()) }