mirror of
https://github.com/microsoft/vscode.git
synced 2026-04-28 12:33:35 +01:00
cli: add stdio control server
* signing: implement signing service on the web * wip * cli: implement stdio service This is used to implement the exec server for WSL. Guarded behind a signed handshake. * update distro * rm debug * address pr comments
This commit is contained in:
@@ -5,16 +5,15 @@
|
||||
use crate::async_pipe::get_socket_rw_stream;
|
||||
use crate::constants::{CONTROL_PORT, PRODUCT_NAME_LONG};
|
||||
use crate::log;
|
||||
use crate::msgpack_rpc::U32PrefixedCodec;
|
||||
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
|
||||
use crate::msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCodec, MsgPackSerializer};
|
||||
use crate::rpc::{MaybeSync, RpcBuilder, RpcCaller, RpcDispatcher};
|
||||
use crate::self_update::SelfUpdate;
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::tunnels::protocol::HttpRequestParams;
|
||||
use crate::tunnels::protocol::{HttpRequestParams, METHOD_CHALLENGE_ISSUE};
|
||||
use crate::tunnels::socket_signal::CloseReason;
|
||||
use crate::update_service::{Platform, Release, TargetKind, UpdateService};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError,
|
||||
NoAttachedServerError,
|
||||
wrap, AnyError, CodeError, MismatchedLaunchModeError, NoAttachedServerError,
|
||||
};
|
||||
use crate::util::http::{
|
||||
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
|
||||
@@ -22,7 +21,7 @@ use crate::util::http::{
|
||||
use crate::util::io::SilentCopyProgress;
|
||||
use crate::util::is_integrated_cli;
|
||||
use crate::util::os::os_release;
|
||||
use crate::util::sync::{new_barrier, Barrier};
|
||||
use crate::util::sync::{new_barrier, Barrier, BarrierOpener};
|
||||
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::FutureExt;
|
||||
@@ -31,6 +30,7 @@ use opentelemetry::KeyValue;
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use tokio::pin;
|
||||
use tokio::process::{ChildStderr, ChildStdin};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
|
||||
@@ -39,6 +39,7 @@ use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
use super::challenge::{create_challenge, sign_challenge, verify_challenge};
|
||||
use super::code_server::{
|
||||
download_cli_into_cache, AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw,
|
||||
SocketCodeServer,
|
||||
@@ -47,11 +48,12 @@ use super::dev_tunnels::ActiveTunnel;
|
||||
use super::paths::prune_stopped_servers;
|
||||
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
||||
use super::protocol::{
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject,
|
||||
ForwardParams, ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse,
|
||||
GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog,
|
||||
ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams,
|
||||
UpdateResult, VersionParams,
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
|
||||
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
|
||||
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
|
||||
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
|
||||
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
};
|
||||
use super::server_bridge::ServerBridge;
|
||||
use super::server_multiplexer::ServerMultiplexer;
|
||||
@@ -68,6 +70,8 @@ struct HandlerContext {
|
||||
log: log::Logger,
|
||||
/// Whether the server update during the handler session.
|
||||
did_update: Arc<AtomicBool>,
|
||||
/// Whether authentication is still required on the socket.
|
||||
auth_state: Arc<std::sync::Mutex<AuthState>>,
|
||||
/// A loopback channel to talk to the socket server task.
|
||||
socket_tx: mpsc::Sender<SocketSignal>,
|
||||
/// Configured launcher paths.
|
||||
@@ -79,7 +83,7 @@ struct HandlerContext {
|
||||
// the cli arguments used to start the code server
|
||||
code_server_args: CodeServerArgs,
|
||||
/// port forwarding functionality
|
||||
port_forwarding: PortForwarding,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
/// install platform for the VS Code server
|
||||
platform: Platform,
|
||||
/// http client to make download/update requests
|
||||
@@ -88,6 +92,16 @@ struct HandlerContext {
|
||||
http_requests: HttpRequestsMap,
|
||||
}
|
||||
|
||||
/// Handler auth state.
|
||||
enum AuthState {
|
||||
/// Auth is required, we're waiting for the client to send its challenge.
|
||||
WaitingForChallenge,
|
||||
/// A challenge has been issued. Waiting for a verification.
|
||||
ChallengeIssued(String),
|
||||
/// Auth is no longer required.
|
||||
Authenticated,
|
||||
}
|
||||
|
||||
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
// Gets a next incrementing number that can be used in logs
|
||||
@@ -195,7 +209,14 @@ pub async fn serve(
|
||||
debug!(own_log, "Serving new connection");
|
||||
|
||||
let (writehalf, readhalf) = socket.into_split();
|
||||
let stats = process_socket(own_exit, readhalf, writehalf, own_log, own_tx, own_paths, own_code_server_args, own_forwarding, platform).with_context(cx.clone()).await;
|
||||
let stats = process_socket(readhalf, writehalf, own_tx, Some(own_forwarding), ServeStreamParams {
|
||||
log: own_log,
|
||||
launcher_paths: own_paths,
|
||||
code_server_args: own_code_server_args,
|
||||
platform,
|
||||
exit_barrier: own_exit,
|
||||
requires_auth: false,
|
||||
}).with_context(cx.clone()).await;
|
||||
|
||||
cx.span().add_event(
|
||||
"socket.bandwidth",
|
||||
@@ -206,69 +227,91 @@ pub async fn serve(
|
||||
],
|
||||
);
|
||||
cx.span().end();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SocketStats {
|
||||
pub struct ServeStreamParams {
|
||||
pub log: log::Logger,
|
||||
pub launcher_paths: LauncherPaths,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub requires_auth: bool,
|
||||
pub exit_barrier: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
pub async fn serve_stream(
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
writehalf: impl AsyncWrite + Unpin,
|
||||
params: ServeStreamParams,
|
||||
) -> SocketStats {
|
||||
// Currently the only server signal is respawn, that doesn't have much meaning
|
||||
// when serving a stream, so make an ignored channel.
|
||||
let (server_rx, server_tx) = mpsc::channel(1);
|
||||
drop(server_tx);
|
||||
|
||||
process_socket(readhalf, writehalf, server_rx, None, params).await
|
||||
}
|
||||
|
||||
pub struct SocketStats {
|
||||
rx: usize,
|
||||
tx: usize,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct MsgPackSerializer {}
|
||||
|
||||
impl Serialization for MsgPackSerializer {
|
||||
fn serialize(&self, value: impl serde::Serialize) -> Vec<u8> {
|
||||
rmp_serde::to_vec_named(&value).expect("expected to serialize")
|
||||
}
|
||||
|
||||
fn deserialize<P: serde::de::DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError> {
|
||||
rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)] // necessary here
|
||||
async fn process_socket(
|
||||
mut exit_barrier: Barrier<()>,
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
mut writehalf: impl AsyncWrite + Unpin,
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn make_socket_rpc(
|
||||
log: log::Logger,
|
||||
server_tx: mpsc::Sender<ServerSignal>,
|
||||
socket_tx: mpsc::Sender<SocketSignal>,
|
||||
http_delegated: DelegatedSimpleHttp,
|
||||
launcher_paths: LauncherPaths,
|
||||
code_server_args: CodeServerArgs,
|
||||
port_forwarding: PortForwarding,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
requires_auth: bool,
|
||||
platform: Platform,
|
||||
) -> SocketStats {
|
||||
let (socket_tx, mut socket_rx) = mpsc::channel(4);
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
let server_bridges = ServerMultiplexer::new();
|
||||
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
|
||||
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
|
||||
did_update: Arc::new(AtomicBool::new(false)),
|
||||
socket_tx: socket_tx.clone(),
|
||||
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
|
||||
true => AuthState::WaitingForChallenge,
|
||||
false => AuthState::Authenticated,
|
||||
})),
|
||||
socket_tx,
|
||||
log: log.clone(),
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
code_server: Arc::new(Mutex::new(None)),
|
||||
server_bridges: server_bridges.clone(),
|
||||
server_bridges,
|
||||
port_forwarding,
|
||||
platform,
|
||||
http: Arc::new(FallbackSimpleHttp::new(
|
||||
ReqwestSimpleHttp::new(),
|
||||
http_delegated,
|
||||
)),
|
||||
http_requests: http_requests.clone(),
|
||||
http_requests,
|
||||
});
|
||||
|
||||
rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {}));
|
||||
rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname());
|
||||
rpc.register_sync("fs_stat", |p: FsStatRequest, _| handle_stat(p.path));
|
||||
rpc.register_sync("get_env", |_: EmptyObject, _| handle_get_env());
|
||||
rpc.register_sync("fs_stat", |p: FsStatRequest, c| {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_stat(p.path)
|
||||
});
|
||||
rpc.register_sync("get_env", |_: EmptyObject, c| {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_get_env()
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
|
||||
handle_challenge_issue(&c.auth_state)
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
|
||||
handle_challenge_verify(p.response, &c.auth_state)
|
||||
});
|
||||
rpc.register_async("serve", move |params: ServeParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_serve(c, params).await
|
||||
});
|
||||
rpc.register_async("update", |p: UpdateParams, c| async move {
|
||||
@@ -286,15 +329,19 @@ async fn process_socket(
|
||||
handle_call_server_http(code_server, p).await
|
||||
});
|
||||
rpc.register_async("forward", |p: ForwardParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_forward(&c.log, &c.port_forwarding, p).await
|
||||
});
|
||||
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_unforward(&c.log, &c.port_forwarding, p).await
|
||||
});
|
||||
rpc.register_async("acquire_cli", |p: AcquireCliParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_acquire_cli(&c.launcher_paths, &c.http, &c.log, p).await
|
||||
});
|
||||
rpc.register_duplex("spawn", 3, |mut streams, p: SpawnParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_spawn(
|
||||
&c.log,
|
||||
p,
|
||||
@@ -304,13 +351,28 @@ async fn process_socket(
|
||||
)
|
||||
.await
|
||||
});
|
||||
rpc.register_duplex(
|
||||
"spawn_cli",
|
||||
3,
|
||||
|mut streams, p: SpawnParams, c| async move {
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_spawn_cli(
|
||||
&c.log,
|
||||
p,
|
||||
streams.remove(0),
|
||||
streams.remove(0),
|
||||
streams.remove(0),
|
||||
)
|
||||
.await
|
||||
},
|
||||
);
|
||||
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
|
||||
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
|
||||
req.initial_response(p.status_code, p.headers);
|
||||
}
|
||||
Ok(EmptyObject {})
|
||||
});
|
||||
rpc.register_sync("unforward", move |p: HttpBodyParams, c| {
|
||||
rpc.register_sync("httpbody", move |p: HttpBodyParams, c| {
|
||||
let mut reqs = c.http_requests.lock().unwrap();
|
||||
if let Some(req) = reqs.get(&p.req_id) {
|
||||
if !p.segment.is_empty() {
|
||||
@@ -322,15 +384,64 @@ async fn process_socket(
|
||||
}
|
||||
Ok(EmptyObject {})
|
||||
});
|
||||
rpc.register_sync(
|
||||
"version",
|
||||
|_: EmptyObject, _| Ok(VersionResponse::default()),
|
||||
);
|
||||
|
||||
rpc.build(log)
|
||||
}
|
||||
|
||||
fn ensure_auth(is_authed: &Arc<std::sync::Mutex<AuthState>>) -> Result<(), AnyError> {
|
||||
if let AuthState::Authenticated = &*is_authed.lock().unwrap() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CodeError::ServerAuthRequired.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)] // necessary here
|
||||
async fn process_socket(
|
||||
readhalf: impl AsyncRead + Send + Unpin + 'static,
|
||||
mut writehalf: impl AsyncWrite + Unpin,
|
||||
server_tx: mpsc::Sender<ServerSignal>,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
params: ServeStreamParams,
|
||||
) -> SocketStats {
|
||||
let ServeStreamParams {
|
||||
mut exit_barrier,
|
||||
log,
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
platform,
|
||||
requires_auth,
|
||||
} = params;
|
||||
|
||||
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
|
||||
let (socket_tx, mut socket_rx) = mpsc::channel(4);
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let rpc = make_socket_rpc(
|
||||
log.clone(),
|
||||
socket_tx.clone(),
|
||||
http_delegated,
|
||||
launcher_paths,
|
||||
code_server_args,
|
||||
port_forwarding,
|
||||
requires_auth,
|
||||
platform,
|
||||
);
|
||||
|
||||
{
|
||||
let log = log.clone();
|
||||
let rx_counter = rx_counter.clone();
|
||||
let socket_tx = socket_tx.clone();
|
||||
let exit_barrier = exit_barrier.clone();
|
||||
let rpc = rpc.build(log.clone());
|
||||
tokio::spawn(async move {
|
||||
send_version(&socket_tx).await;
|
||||
if !requires_auth {
|
||||
send_version(&socket_tx).await;
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
handle_socket_read(&log, readhalf, exit_barrier, &socket_tx, rx_counter, &rpc).await
|
||||
@@ -350,6 +461,10 @@ async fn process_socket(
|
||||
}
|
||||
|
||||
ctx.dispose().await;
|
||||
|
||||
let _ = socket_tx
|
||||
.send(SocketSignal::CloseWith(CloseReason("eof".to_string())))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -408,7 +523,7 @@ async fn process_socket(
|
||||
async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
|
||||
tx.send(SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::version(VersionParams::default()),
|
||||
params: ClientRequestMethod::version(VersionResponse::default()),
|
||||
}))
|
||||
.await
|
||||
.ok();
|
||||
@@ -416,13 +531,13 @@ async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
|
||||
async fn handle_socket_read(
|
||||
_log: &log::Logger,
|
||||
readhalf: impl AsyncRead + Unpin,
|
||||
mut closer: Barrier<()>,
|
||||
mut closer: Barrier<ShutdownSignal>,
|
||||
socket_tx: &mpsc::Sender<SocketSignal>,
|
||||
rx_counter: Arc<AtomicUsize>,
|
||||
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let mut readhalf = BufReader::new(readhalf);
|
||||
let mut decoder = U32PrefixedCodec {};
|
||||
let mut decoder = MsgPackCodec::new();
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
loop {
|
||||
@@ -431,10 +546,14 @@ async fn handle_socket_read(
|
||||
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
|
||||
}?;
|
||||
|
||||
if read_len == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
rx_counter.fetch_add(read_len, Ordering::Relaxed);
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match rpc.dispatch(&frame) {
|
||||
match rpc.dispatch_with_partial(&frame.vec, frame.obj) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
|
||||
return Ok(());
|
||||
@@ -704,11 +823,44 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_challenge_issue(
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<ChallengeIssueResponse, AnyError> {
|
||||
let challenge = create_challenge();
|
||||
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
|
||||
Ok(ChallengeIssueResponse { challenge })
|
||||
}
|
||||
|
||||
fn handle_challenge_verify(
|
||||
response: String,
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<EmptyObject, AnyError> {
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
|
||||
match &*auth_state {
|
||||
AuthState::Authenticated => Ok(EmptyObject {}),
|
||||
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
|
||||
false => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
true => {
|
||||
*auth_state = AuthState::Authenticated;
|
||||
Ok(EmptyObject {})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_forward(
|
||||
log: &log::Logger,
|
||||
port_forwarding: &PortForwarding,
|
||||
port_forwarding: &Option<PortForwarding>,
|
||||
params: ForwardParams,
|
||||
) -> Result<ForwardResult, AnyError> {
|
||||
let port_forwarding = port_forwarding
|
||||
.as_ref()
|
||||
.ok_or(CodeError::PortForwardingNotAvailable)?;
|
||||
info!(log, "Forwarding port {}", params.port);
|
||||
let uri = port_forwarding.forward(params.port).await?;
|
||||
Ok(ForwardResult { uri })
|
||||
@@ -716,9 +868,12 @@ async fn handle_forward(
|
||||
|
||||
async fn handle_unforward(
|
||||
log: &log::Logger,
|
||||
port_forwarding: &PortForwarding,
|
||||
port_forwarding: &Option<PortForwarding>,
|
||||
params: UnforwardParams,
|
||||
) -> Result<EmptyObject, AnyError> {
|
||||
let port_forwarding = port_forwarding
|
||||
.as_ref()
|
||||
.ok_or(CodeError::PortForwardingNotAvailable)?;
|
||||
info!(log, "Unforwarding port {}", params.port);
|
||||
port_forwarding.unforward(params.port).await?;
|
||||
Ok(EmptyObject {})
|
||||
@@ -818,17 +973,17 @@ async fn handle_spawn<Stdin, StdoutAndErr>(
|
||||
stderr: Option<StdoutAndErr>,
|
||||
) -> Result<SpawnResult, AnyError>
|
||||
where
|
||||
Stdin: AsyncRead + Unpin + Send,
|
||||
StdoutAndErr: AsyncWrite + Unpin + Send,
|
||||
Stdin: AsyncRead + Unpin + Send + 'static,
|
||||
StdoutAndErr: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
debug!(
|
||||
log,
|
||||
"requested to spawn {} with args {:?}", params.command, params.args
|
||||
);
|
||||
|
||||
macro_rules! pipe_if_some {
|
||||
macro_rules! pipe_if {
|
||||
($e: expr) => {
|
||||
if $e.is_some() {
|
||||
if $e {
|
||||
Stdio::piped()
|
||||
} else {
|
||||
Stdio::null()
|
||||
@@ -839,9 +994,9 @@ where
|
||||
let mut p = tokio::process::Command::new(¶ms.command);
|
||||
p.args(¶ms.args);
|
||||
p.envs(¶ms.env);
|
||||
p.stdin(pipe_if_some!(stdin));
|
||||
p.stdout(pipe_if_some!(stdout));
|
||||
p.stderr(pipe_if_some!(stderr));
|
||||
p.stdin(pipe_if!(stdin.is_some()));
|
||||
p.stdout(pipe_if!(stdin.is_some()));
|
||||
p.stderr(pipe_if!(stderr.is_some()));
|
||||
if let Some(cwd) = ¶ms.cwd {
|
||||
p.current_dir(cwd);
|
||||
}
|
||||
@@ -859,7 +1014,72 @@ where
|
||||
futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed());
|
||||
}
|
||||
|
||||
let closed = p.wait();
|
||||
wait_for_process_exit(log, ¶ms.command, p, futs).await
|
||||
}
|
||||
|
||||
async fn handle_spawn_cli(
|
||||
log: &log::Logger,
|
||||
params: SpawnParams,
|
||||
mut protocol_in: DuplexStream,
|
||||
mut protocol_out: DuplexStream,
|
||||
mut log_out: DuplexStream,
|
||||
) -> Result<SpawnResult, AnyError> {
|
||||
debug!(
|
||||
log,
|
||||
"requested to spawn cli {} with args {:?}", params.command, params.args
|
||||
);
|
||||
|
||||
let mut p = tokio::process::Command::new(¶ms.command);
|
||||
p.args(¶ms.args);
|
||||
|
||||
// CLI args to spawn a server; contracted with clients that they should _not_ provide these.
|
||||
p.arg("--verbose");
|
||||
p.arg("tunnel");
|
||||
p.arg("stdio");
|
||||
|
||||
p.envs(¶ms.env);
|
||||
p.stdin(Stdio::piped());
|
||||
p.stdout(Stdio::piped());
|
||||
p.stderr(Stdio::piped());
|
||||
if let Some(cwd) = ¶ms.cwd {
|
||||
p.current_dir(cwd);
|
||||
}
|
||||
|
||||
let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?;
|
||||
|
||||
let mut stdin = p.stdin.take().unwrap();
|
||||
let mut stdout = p.stdout.take().unwrap();
|
||||
let mut stderr = p.stderr.take().unwrap();
|
||||
|
||||
// Start handling logs while doing the handshake in case there's some kind of error
|
||||
let log_pump = tokio::spawn(async move { tokio::io::copy(&mut stdout, &mut log_out).await });
|
||||
|
||||
// note: intentionally do not wrap stdin in a bufreader, since we don't
|
||||
// want to read anything other than our handshake messages.
|
||||
if let Err(e) = spawn_do_child_authentication(log, &mut stdin, &mut stderr).await {
|
||||
warning!(log, "failed to authenticate with child process {}", e);
|
||||
let _ = p.kill().await;
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
debug!(log, "cli authenticated, attaching stdio");
|
||||
let futs = FuturesUnordered::new();
|
||||
futs.push(async move { tokio::io::copy(&mut protocol_in, &mut stdin).await }.boxed());
|
||||
futs.push(async move { tokio::io::copy(&mut stderr, &mut protocol_out).await }.boxed());
|
||||
futs.push(async move { log_pump.await.unwrap() }.boxed());
|
||||
|
||||
wait_for_process_exit(log, ¶ms.command, p, futs).await
|
||||
}
|
||||
|
||||
type TokioCopyFuture = dyn futures::Future<Output = Result<u64, std::io::Error>> + Send;
|
||||
|
||||
async fn wait_for_process_exit(
|
||||
log: &log::Logger,
|
||||
command: &str,
|
||||
mut process: tokio::process::Child,
|
||||
futs: FuturesUnordered<std::pin::Pin<Box<TokioCopyFuture>>>,
|
||||
) -> Result<SpawnResult, AnyError> {
|
||||
let closed = process.wait();
|
||||
pin!(closed);
|
||||
|
||||
let r = tokio::select! {
|
||||
@@ -880,8 +1100,69 @@ where
|
||||
|
||||
debug!(
|
||||
log,
|
||||
"spawned command {} exited with code {}", params.command, r.exit_code
|
||||
"spawned cli {} exited with code {}", command, r.exit_code
|
||||
);
|
||||
|
||||
Ok(r)
|
||||
}
|
||||
|
||||
async fn spawn_do_child_authentication(
|
||||
log: &log::Logger,
|
||||
stdin: &mut ChildStdin,
|
||||
stdout: &mut ChildStderr,
|
||||
) -> Result<(), CodeError> {
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let (shutdown_rx, shutdown) = new_barrier();
|
||||
let mut rpc = new_msgpack_rpc();
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
|
||||
let challenge_response = do_challenge_response_flow(caller, shutdown);
|
||||
let rpc = start_msgpack_rpc(
|
||||
rpc.methods(()).build(log.prefixed("client-auth")),
|
||||
stdout,
|
||||
stdin,
|
||||
msg_rx,
|
||||
shutdown_rx,
|
||||
);
|
||||
pin!(rpc);
|
||||
|
||||
tokio::select! {
|
||||
r = &mut rpc => {
|
||||
match r {
|
||||
// means shutdown happened cleanly already, we're good
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(CodeError::ProcessSpawnHandshakeFailed(e))
|
||||
}
|
||||
},
|
||||
r = challenge_response => {
|
||||
r?;
|
||||
rpc.await.map(|_| ()).map_err(CodeError::ProcessSpawnFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_challenge_response_flow(
|
||||
caller: RpcCaller<MsgPackSerializer>,
|
||||
shutdown: BarrierOpener<()>,
|
||||
) -> Result<(), CodeError> {
|
||||
let challenge: ChallengeIssueResponse = caller
|
||||
.call(METHOD_CHALLENGE_ISSUE, EmptyObject {})
|
||||
.await
|
||||
.unwrap()
|
||||
.map_err(CodeError::TunnelRpcCallFailed)?;
|
||||
|
||||
let _: EmptyObject = caller
|
||||
.call(
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
ChallengeVerifyParams {
|
||||
response: sign_challenge(&challenge.challenge),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.map_err(CodeError::TunnelRpcCallFailed)?;
|
||||
|
||||
shutdown.open(());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user