cli: add streams to rpc, generic 'spawn' command (#179732)

* cli: apply improvements from integrated wsl branch

* cli: add streams to rpc, generic 'spawn' command

For the "exec server" concept, fyi @aeschli.

* update clippy and apply fixes

* fix unused imports :(
This commit is contained in:
Connor Peet
2023-04-12 08:51:29 -07:00
committed by GitHub
parent bb7570f4f8
commit 2d8ff25c85
23 changed files with 572 additions and 184 deletions

View File

@@ -5,6 +5,7 @@
use crate::async_pipe::get_socket_rw_stream;
use crate::constants::CONTROL_PORT;
use crate::log;
use crate::msgpack_rpc::U32PrefixedCodec;
use crate::rpc::{MaybeSync, RpcBuilder, RpcDispatcher, Serialization};
use crate::self_update::SelfUpdate;
use crate::state::LauncherPaths;
@@ -12,7 +13,8 @@ use crate::tunnels::protocol::HttpRequestParams;
use crate::tunnels::socket_signal::CloseReason;
use crate::update_service::{Platform, UpdateService};
use crate::util::errors::{
wrap, AnyError, InvalidRpcDataError, MismatchedLaunchModeError, NoAttachedServerError,
wrap, AnyError, CodeError, InvalidRpcDataError, MismatchedLaunchModeError,
NoAttachedServerError,
};
use crate::util::http::{
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
@@ -24,11 +26,14 @@ use crate::util::sync::{new_barrier, Barrier};
use opentelemetry::trace::SpanKind;
use opentelemetry::KeyValue;
use std::collections::HashMap;
use std::process::Stdio;
use tokio::pin;
use tokio_util::codec::Decoder;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
use tokio::sync::{mpsc, Mutex};
use super::code_server::{
@@ -40,8 +45,8 @@ use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyObject, ForwardParams,
ForwardResult, GetHostnameResponse, HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog,
ServerMessageParams, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult,
VersionParams,
ServerMessageParams, SpawnParams, SpawnResult, ToClientRequest, UnforwardParams, UpdateParams,
UpdateResult, VersionParams,
};
use super::server_bridge::ServerBridge;
use super::server_multiplexer::ServerMultiplexer;
@@ -73,7 +78,7 @@ struct HandlerContext {
/// install platform for the VS Code server
platform: Platform,
/// http client to make download/update requests
http: FallbackSimpleHttp,
http: Arc<FallbackSimpleHttp>,
/// requests being served by the client
http_requests: HttpRequestsMap,
}
@@ -196,7 +201,7 @@ pub async fn serve(
],
);
cx.span().end();
});
});
}
}
}
@@ -247,7 +252,10 @@ async fn process_socket(
server_bridges: server_bridges.clone(),
port_forwarding,
platform,
http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated),
http: Arc::new(FallbackSimpleHttp::new(
ReqwestSimpleHttp::new(),
http_delegated,
)),
http_requests: http_requests.clone(),
});
@@ -276,6 +284,9 @@ async fn process_socket(
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
handle_unforward(&c.log, &c.port_forwarding, p).await
});
rpc.register_duplex("spawn", |stream, p: SpawnParams, c| async move {
handle_spawn(&c.log, stream, p).await
});
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
req.initial_response(p.status_code, p.headers);
@@ -393,20 +404,20 @@ async fn handle_socket_read(
rx_counter: Arc<AtomicUsize>,
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
) -> Result<(), std::io::Error> {
let mut socket_reader = BufReader::new(readhalf);
let mut decode_buf = vec![];
let mut readhalf = BufReader::new(readhalf);
let mut decoder = U32PrefixedCodec {};
let mut decoder_buf = bytes::BytesMut::new();
loop {
let read = read_next(
&mut socket_reader,
&rx_counter,
&mut closer,
&mut decode_buf,
)
.await;
let read_len = tokio::select! {
r = readhalf.read_buf(&mut decoder_buf) => r,
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
}?;
match read {
Ok(len) => match rpc.dispatch(&decode_buf[..len]) {
rx_counter.fetch_add(read_len, Ordering::Relaxed);
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
match rpc.dispatch(&frame) {
MaybeSync::Sync(Some(v)) => {
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
return Ok(());
@@ -421,34 +432,22 @@ async fn handle_socket_read(
}
});
}
},
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
Err(e) => return Err(e),
MaybeSync::Stream((stream, fut)) => {
if let Some(stream) = stream {
rpc.register_stream(socket_tx.clone(), stream).await;
}
let socket_tx = socket_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
socket_tx.send(SocketSignal::Send(v)).await.ok();
}
});
}
}
}
}
}
/// Reads and handles the next data packet. Returns the next packet to dispatch,
/// or an error (including EOF).
async fn read_next(
socket_reader: &mut BufReader<impl AsyncRead + Unpin>,
rx_counter: &Arc<AtomicUsize>,
closer: &mut Barrier<()>,
decode_buf: &mut Vec<u8>,
) -> Result<usize, std::io::Error> {
let msg_length = tokio::select! {
u = socket_reader.read_u32() => u? as usize,
_ = closer.wait() => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
};
decode_buf.resize(msg_length, 0);
rx_counter.fetch_add(msg_length + 4 /* u32 */, Ordering::Relaxed);
tokio::select! {
r = socket_reader.read_exact(decode_buf) => r,
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
}
}
#[derive(Clone)]
struct ServerOutputSink {
tx: mpsc::Sender<SocketSignal>,
@@ -487,7 +486,9 @@ async fn handle_serve(
};
let resolved = if params.use_local_download {
params_raw.resolve(&c.log, c.http.delegated()).await
params_raw
.resolve(&c.log, Arc::new(c.http.delegated()))
.await
} else {
params_raw.resolve(&c.log, c.http.clone()).await
}?;
@@ -518,7 +519,7 @@ async fn handle_serve(
&install_log,
&resolved,
&c.launcher_paths,
c.http.delegated(),
Arc::new(c.http.delegated()),
);
do_setup!(sb)
} else {
@@ -606,7 +607,7 @@ fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
}
async fn handle_update(
http: &FallbackSimpleHttp,
http: &Arc<FallbackSimpleHttp>,
log: &log::Logger,
did_update: &AtomicBool,
params: &UpdateParams,
@@ -732,3 +733,83 @@ async fn handle_call_server_http(
.to_vec(),
})
}
async fn handle_spawn(
log: &log::Logger,
mut duplex: DuplexStream,
params: SpawnParams,
) -> Result<SpawnResult, AnyError> {
debug!(
log,
"requested to spawn {} with args {:?}", params.command, params.args
);
let mut p = tokio::process::Command::new(&params.command)
.args(&params.args)
.envs(&params.env)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(CodeError::ProcessSpawnFailed)?;
let mut stdout = p.stdout.take().unwrap();
let mut stderr = p.stderr.take().unwrap();
let mut stdin = p.stdin.take().unwrap();
let (tx, mut rx) = mpsc::channel(4);
macro_rules! copy_stream_to {
($target:expr) => {
let tx = tx.clone();
tokio::spawn(async move {
let mut buf = vec![0; 4096];
loop {
let n = match $target.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
if !tx.send(buf[..n].to_vec()).await.is_ok() {
return;
}
}
});
};
}
copy_stream_to!(stdout);
copy_stream_to!(stderr);
let mut stdin_buf = vec![0; 4096];
let closed = p.wait();
pin!(closed);
loop {
tokio::select! {
Ok(n) = duplex.read(&mut stdin_buf) => {
let _ = stdin.write_all(&stdin_buf[..n]).await;
},
Some(m) = rx.recv() => {
let _ = duplex.write_all(&m).await;
},
r = &mut closed => {
let r = match r {
Ok(e) => SpawnResult {
message: e.to_string(),
exit_code: e.code().unwrap_or(-1),
},
Err(e) => SpawnResult {
message: e.to_string(),
exit_code: -1,
},
};
debug!(
log,
"spawned command {} exited with code {}", params.command, r.exit_code
);
return Ok(r)
},
}
}
}