diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index ef1a8d03284..67a0bcf64ac 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -21,6 +21,7 @@ use crate::util::http::{ }; use crate::util::io::SilentCopyProgress; use crate::util::is_integrated_cli; +use crate::util::os::os_release; use crate::util::sync::{new_barrier, Barrier}; use futures::stream::FuturesUnordered; @@ -47,9 +48,10 @@ use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, - ForwardParams, ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, - ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, - UnforwardParams, UpdateParams, UpdateResult, VersionParams, + ForwardParams, ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, + GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, + ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, + UpdateResult, VersionParams, }; use super::server_bridge::ServerBridge; use super::server_multiplexer::ServerMultiplexer; @@ -264,6 +266,8 @@ async fn process_socket( rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {})); rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname()); + rpc.register_sync("fs_stat", |p: FsStatRequest, _| handle_stat(p.path)); + rpc.register_sync("get_env", |_: EmptyObject, _| handle_get_env()); rpc.register_async("serve", move |params: ServeParams, c| async move { handle_serve(c, params).await }); @@ -672,6 +676,34 @@ fn handle_get_hostname() -> Result { }) } +fn handle_stat(path: String) -> Result { + Ok(std::fs::metadata(path) + .map(|m| FsStatResponse { + exists: true, + size: Some(m.len()), + kind: Some(match m.file_type() { + t if t.is_dir() => "dir", + t if t.is_file() => "file", + t if t.is_symlink() => "link", + _ => "unknown", + }), + }) + .unwrap_or_default()) +} + +fn handle_get_env() -> Result { + Ok(GetEnvResponse { + env: std::env::vars().collect(), + os_release: os_release().unwrap_or_else(|_| "unknown".to_string()), + #[cfg(windows)] + os_platform: "win32", + #[cfg(target_os = "linux")] + os_platform: "linux", + #[cfg(target_os = "macos")] + os_platform: "darwin", + }) +} + async fn handle_forward( log: &log::Logger, port_forwarding: &PortForwarding, @@ -804,14 +836,17 @@ where }; } - let mut p = tokio::process::Command::new(¶ms.command) - .args(¶ms.args) - .envs(¶ms.env) - .stdin(pipe_if_some!(stdin)) - .stdout(pipe_if_some!(stdout)) - .stderr(pipe_if_some!(stderr)) - .spawn() - .map_err(CodeError::ProcessSpawnFailed)?; + let mut p = tokio::process::Command::new(¶ms.command); + p.args(¶ms.args); + p.envs(¶ms.env); + p.stdin(pipe_if_some!(stdin)); + p.stdout(pipe_if_some!(stdout)); + p.stderr(pipe_if_some!(stderr)); + if let Some(cwd) = ¶ms.cwd { + p.current_dir(cwd); + } + + let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?; let futs = FuturesUnordered::new(); if let (Some(mut a), Some(mut b)) = (p.stdout.take(), stdout) { diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index 2093b1d1896..17282381c55 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -128,6 +128,26 @@ pub struct GetHostnameResponse { pub value: String, } +#[derive(Serialize)] +pub struct GetEnvResponse { + pub env: HashMap, + pub os_platform: &'static str, + pub os_release: String, +} + +#[derive(Deserialize)] +pub struct FsStatRequest { + pub path: String, +} + +#[derive(Serialize, Default)] +pub struct FsStatResponse { + pub exists: bool, + pub size: Option, + #[serde(rename = "type")] + pub kind: Option<&'static str>, +} + #[derive(Deserialize, Debug)] pub struct CallServerHttpParams { pub path: String, @@ -164,6 +184,8 @@ pub struct SpawnParams { pub command: String, pub args: Vec, #[serde(default)] + pub cwd: Option, + #[serde(default)] pub env: HashMap, } diff --git a/cli/src/util.rs b/cli/src/util.rs index f45f74de13c..3acd046f5f9 100644 --- a/cli/src/util.rs +++ b/cli/src/util.rs @@ -17,5 +17,6 @@ pub mod sync; pub use is_integrated::*; pub mod app_lock; pub mod file_lock; +pub mod os; pub mod tar; pub mod zipper; diff --git a/cli/src/util/os.rs b/cli/src/util/os.rs new file mode 100644 index 00000000000..d8105baf65a --- /dev/null +++ b/cli/src/util/os.rs @@ -0,0 +1,39 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +#[cfg(windows)] +pub fn os_release() -> Result { + // The windows API *had* nice GetVersionEx/A APIs, but these were deprecated + // in Winodws 8 and there's no newer win API to get version numbers. So + // instead read the registry. + + use winreg::{enums::HKEY_LOCAL_MACHINE, RegKey}; + + let key = RegKey::predef(HKEY_LOCAL_MACHINE) + .open_subkey(r"SOFTWARE\Microsoft\Windows NT\CurrentVersion")?; + + let major: u32 = key.get_value("CurrentMajorVersionNumber")?; + let minor: u32 = key.get_value("CurrentMinorVersionNumber")?; + let build: String = key.get_value("CurrentBuild")?; + + Ok(format!("{}.{}.{}", major, minor, build)) +} + +#[cfg(unix)] +pub fn os_release() -> Result { + use std::{ffi::CStr, mem}; + + unsafe { + let mut ret = mem::MaybeUninit::zeroed(); + + if libc::uname(ret.as_mut_ptr()) != 0 { + return Err(std::io::Error::last_os_error()); + } + + let ret = ret.assume_init(); + let c_str: &CStr = CStr::from_ptr(ret.release.as_ptr()); + Ok(c_str.to_string_lossy().into_owned()) + } +}