From b5038f81d146ab9b49a58a7dace7dfd9afbbd136 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Fri, 21 Jul 2023 09:32:20 -0700 Subject: [PATCH] cli: allow exec server to listen on a port and require token authentication (#188434) * cli: allow exec server to listen on a port and require token authentication For remote ssh on Windows where pipe forwarding doesn't work * fix linux build --- cli/build.rs | 2 +- cli/src/async_pipe.rs | 36 ++++++++++++++++++++- cli/src/commands/args.rs | 6 ++++ cli/src/commands/tunnels.rs | 54 ++++++++++++++++++++----------- cli/src/msgpack_rpc.rs | 2 +- cli/src/tunnels.rs | 2 +- cli/src/tunnels/control_server.rs | 52 +++++++++++++++++++---------- cli/src/tunnels/protocol.rs | 5 +++ cli/src/util/errors.rs | 2 ++ 9 files changed, 121 insertions(+), 40 deletions(-) diff --git a/cli/build.rs b/cli/build.rs index bcf1bf27e0a..41e289774e9 100644 --- a/cli/build.rs +++ b/cli/build.rs @@ -25,7 +25,7 @@ fn apply_build_environment_variables() { } let pkg_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let mut cmd = Command::new("node"); + let mut cmd = Command::new(env::var("NODE_PATH").unwrap_or_else(|_| "node".to_string())); cmd.arg("../build/azure-pipelines/cli/prepare.js"); cmd.current_dir(&pkg_dir); cmd.env("VSCODE_CLI_PREPARE_OUTPUT", "json"); diff --git a/cli/src/async_pipe.rs b/cli/src/async_pipe.rs index dcbe0d16017..6c7c918967a 100644 --- a/cli/src/async_pipe.rs +++ b/cli/src/async_pipe.rs @@ -4,7 +4,10 @@ *--------------------------------------------------------------------------------------------*/ use crate::{constants::APPLICATION_NAME, util::errors::CodeError}; +use async_trait::async_trait; use std::path::{Path, PathBuf}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; use uuid::Uuid; // todo: we could probably abstract this into some crate, if one doesn't already exist @@ -39,7 +42,7 @@ cfg_if::cfg_if! { pipe.into_split() } } else { - use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}}; + use tokio::{time::sleep, io::ReadBuf}; use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer}; use std::{time::Duration, pin::Pin, task::{Context, Poll}, io}; use pin_project::pin_project; @@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf { } } } + +pub type AcceptedRW = ( + Box, + Box, +); + +#[async_trait] +pub trait AsyncRWAccepter { + async fn accept_rw(&mut self) -> Result; +} + +#[async_trait] +impl AsyncRWAccepter for AsyncPipeListener { + async fn accept_rw(&mut self) -> Result { + let pipe = self.accept().await?; + let (read, write) = socket_stream_split(pipe); + Ok((Box::new(read), Box::new(write))) + } +} + +#[async_trait] +impl AsyncRWAccepter for TcpListener { + async fn accept_rw(&mut self) -> Result { + let (stream, _) = self + .accept() + .await + .map_err(CodeError::AsyncPipeListenerFailed)?; + let (read, write) = tokio::io::split(stream); + Ok((Box::new(read), Box::new(write))) + } +} diff --git a/cli/src/commands/args.rs b/cli/src/commands/args.rs index d34519d6810..e253130573b 100644 --- a/cli/src/commands/args.rs +++ b/cli/src/commands/args.rs @@ -182,6 +182,12 @@ pub struct CommandShellArgs { /// Listen on a socket instead of stdin/stdout. #[clap(long)] pub on_socket: bool, + /// Listen on a port instead of stdin/stdout. + #[clap(long)] + pub on_port: bool, + /// Require the given token string to be given in the handshake. + #[clap(long)] + pub require_token: Option, } #[derive(Args, Debug, Clone)] diff --git a/cli/src/commands/tunnels.rs b/cli/src/commands/tunnels.rs index 9831de6e426..578114611c0 100644 --- a/cli/src/commands/tunnels.rs +++ b/cli/src/commands/tunnels.rs @@ -20,7 +20,7 @@ use super::{ }; use crate::{ - async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split}, + async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter}, auth::Auth, constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME}, log, @@ -35,7 +35,7 @@ use crate::{ singleton_server::{ make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs, }, - Next, ServeStreamParams, ServiceContainer, ServiceManager, + AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager, }, util::{ app_lock::AppMutex, @@ -128,36 +128,52 @@ pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Resul log: ctx.log, launcher_paths: ctx.paths, platform, - requires_auth: true, + requires_auth: args + .require_token + .map(AuthRequired::VSDAWithToken) + .unwrap_or(AuthRequired::VSDA), exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), code_server_args: (&ctx.args).into(), }; - if !args.on_socket { - serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await; - return Ok(0); - } + let mut listener: Box = match (args.on_port, args.on_socket) { + (_, true) => { + let socket = get_socket_name(); + let listener = listen_socket_rw_stream(&socket) + .await + .map_err(|e| wrap(e, "error listening on socket"))?; - let socket = get_socket_name(); - let mut listener = listen_socket_rw_stream(&socket) - .await - .map_err(|e| wrap(e, "error listening on socket"))?; + params + .log + .result(format!("Listening on {}", socket.display())); - params - .log - .result(format!("Listening on {}", socket.display())); + Box::new(listener) + } + (true, _) => { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .map_err(|e| wrap(e, "error listening on port"))?; + + params + .log + .result(format!("Listening on {}", listener.local_addr().unwrap())); + + Box::new(listener) + } + _ => { + serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await; + return Ok(0); + } + }; let mut servers = FuturesUnordered::new(); loop { tokio::select! { Some(_) = servers.next() => {}, - socket = listener.accept() => { + socket = listener.accept_rw() => { match socket { - Ok(s) => { - let (read, write) = socket_stream_split(s); - servers.push(serve_stream(read, write, params.clone())); - }, + Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())), Err(e) => { error!(params.log, &format!("Error accepting connection: {}", e)); return Ok(1); diff --git a/cli/src/msgpack_rpc.rs b/cli/src/msgpack_rpc.rs index 219c923cdf2..ef6b7782074 100644 --- a/cli/src/msgpack_rpc.rs +++ b/cli/src/msgpack_rpc.rs @@ -122,7 +122,7 @@ pub struct MsgPackCodec { impl MsgPackCodec { pub fn new() -> Self { Self { - _marker: std::marker::PhantomData::default(), + _marker: std::marker::PhantomData, } } } diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index 5d97b757afc..63ccad22382 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -34,7 +34,7 @@ mod service_macos; mod service_windows; mod socket_signal; -pub use control_server::{serve, serve_stream, Next, ServeStreamParams}; +pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired}; pub use nosleep::SleepInhibitor; pub use service::{ create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME, diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index 8577f9668e9..6f8c1060e1f 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel; use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ - AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse, - ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult, - FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams, - HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult, - ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse, + AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams, + ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, + ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, + HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, + SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse, METHOD_CHALLENGE_VERIFY, }; use super::server_bridge::ServerBridge; @@ -94,8 +94,8 @@ struct HandlerContext { /// Handler auth state. enum AuthState { - /// Auth is required, we're waiting for the client to send its challenge. - WaitingForChallenge, + /// Auth is required, we're waiting for the client to send its challenge optionally bearing a token. + WaitingForChallenge(Option), /// A challenge has been issued. Waiting for a verification. ChallengeIssued(String), /// Auth is no longer required. @@ -215,7 +215,7 @@ pub async fn serve( code_server_args: own_code_server_args, platform, exit_barrier: own_exit, - requires_auth: false, + requires_auth: AuthRequired::None, }).with_context(cx.clone()).await; cx.span().add_event( @@ -233,13 +233,20 @@ pub async fn serve( } } +#[derive(Clone)] +pub enum AuthRequired { + None, + VSDA, + VSDAWithToken(String), +} + #[derive(Clone)] pub struct ServeStreamParams { pub log: log::Logger, pub launcher_paths: LauncherPaths, pub code_server_args: CodeServerArgs, pub platform: Platform, - pub requires_auth: bool, + pub requires_auth: AuthRequired, pub exit_barrier: Barrier, } @@ -269,7 +276,7 @@ fn make_socket_rpc( launcher_paths: LauncherPaths, code_server_args: CodeServerArgs, port_forwarding: Option, - requires_auth: bool, + requires_auth: AuthRequired, platform: Platform, ) -> RpcDispatcher { let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); @@ -277,8 +284,9 @@ fn make_socket_rpc( let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext { did_update: Arc::new(AtomicBool::new(false)), auth_state: Arc::new(std::sync::Mutex::new(match requires_auth { - true => AuthState::WaitingForChallenge, - false => AuthState::Authenticated, + AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)), + AuthRequired::VSDA => AuthState::WaitingForChallenge(None), + AuthRequired::None => AuthState::Authenticated, })), socket_tx, log: log.clone(), @@ -305,8 +313,8 @@ fn make_socket_rpc( ensure_auth(&c.auth_state)?; handle_get_env() }); - rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| { - handle_challenge_issue(&c.auth_state) + rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| { + handle_challenge_issue(p, &c.auth_state) }); rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| { handle_challenge_verify(p.response, &c.auth_state) @@ -423,6 +431,7 @@ async fn process_socket( let rx_counter = Arc::new(AtomicUsize::new(0)); let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new())); + let already_authed = matches!(requires_auth, AuthRequired::None); let rpc = make_socket_rpc( log.clone(), socket_tx.clone(), @@ -440,7 +449,7 @@ async fn process_socket( let socket_tx = socket_tx.clone(); let exit_barrier = exit_barrier.clone(); tokio::spawn(async move { - if !requires_auth { + if already_authed { send_version(&socket_tx).await; } @@ -826,13 +835,22 @@ fn handle_get_env() -> Result { } fn handle_challenge_issue( + params: ChallengeIssueParams, auth_state: &Arc>, ) -> Result { let challenge = create_challenge(); let mut auth_state = auth_state.lock().unwrap(); - *auth_state = AuthState::ChallengeIssued(challenge.clone()); + if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state { + println!("looking for token {}, got {:?}", s, params.token); + match ¶ms.token { + Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()), + None => return Err(CodeError::AuthChallengeBadToken.into()), + _ => {} + } + } + *auth_state = AuthState::ChallengeIssued(challenge.clone()); Ok(ChallengeIssueResponse { challenge }) } @@ -844,7 +862,7 @@ fn handle_challenge_verify( match &*auth_state { AuthState::Authenticated => Ok(EmptyObject {}), - AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()), + AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()), AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) { false => Err(CodeError::AuthChallengeNotIssued.into()), true => { diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index eb20afe0ce5..b9d93761364 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -199,6 +199,11 @@ pub struct SpawnResult { pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue"; pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify"; +#[derive(Serialize, Deserialize)] +pub struct ChallengeIssueParams { + pub token: Option, +} + #[derive(Serialize, Deserialize)] pub struct ChallengeIssueResponse { pub challenge: String, diff --git a/cli/src/util/errors.rs b/cli/src/util/errors.rs index ca6d4bf3d8a..abd4ef24193 100644 --- a/cli/src/util/errors.rs +++ b/cli/src/util/errors.rs @@ -509,6 +509,8 @@ pub enum CodeError { ServerAuthRequired, #[error("challenge not yet issued")] AuthChallengeNotIssued, + #[error("challenge token is invalid")] + AuthChallengeBadToken, #[error("unauthorized client refused")] AuthMismatch, #[error("keyring communication timed out after 5s")]