cli: implement local download fallback

Implements an automatic local download fallback, similar to SSH
(cc @roblourens). If the initial download results in an error, either
in making the request or a 5xx, it'll try to fall back to making the
request locally and streaming it over the tunnel.

This abstracts the request client behing a "SimpleHttp" trait which
either uses to the native reqwest or uses the 'delegated' mode over the
socket.
This commit is contained in:
Connor Peet
2022-11-09 15:27:47 -08:00
parent 489d16dff3
commit d31573550f
8 changed files with 577 additions and 84 deletions

View File

@@ -11,15 +11,19 @@ use crate::update_service::{Platform, UpdateService};
use crate::util::errors::{
wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError,
};
use crate::util::http::{
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
};
use crate::util::io::SilentCopyProgress;
use crate::util::sync::{new_barrier, Barrier};
use opentelemetry::trace::SpanKind;
use opentelemetry::KeyValue;
use serde::Serialize;
use std::collections::HashMap;
use std::convert::Infallible;
use std::env;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
@@ -34,14 +38,16 @@ use super::paths::prune_stopped_servers;
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse,
ForwardParams, ForwardResult, GetHostnameResponse, RefServerMessageParams, ResponseError,
ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod, SuccessResponse,
ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult, VersionParams,
ForwardParams, ForwardResult, GetHostnameResponse, HttpRequestParams, RefServerMessageParams,
ResponseError, ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod,
SuccessResponse, ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult,
VersionParams,
};
use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge};
type ServerBridgeList = Option<Vec<(u16, ServerBridge)>>;
type ServerBridgeListLock = Arc<Mutex<ServerBridgeList>>;
type HttpRequestsMap = Arc<std::sync::Mutex<HashMap<u32, DelegatedHttpRequest>>>;
type CodeServerCell = Arc<Mutex<Option<SocketCodeServer>>>;
struct HandlerContext {
@@ -67,6 +73,17 @@ struct HandlerContext {
port_forwarding: PortForwarding,
/// install platform for the VS Code server
platform: Platform,
/// http client to make download/update requests
http: FallbackSimpleHttp,
/// requests being served by the client
http_requests: HttpRequestsMap,
}
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
// Gets a next incrementing number that can be used in logs
pub fn next_message_id() -> u32 {
MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
}
impl HandlerContext {
@@ -279,7 +296,7 @@ async fn process_socket(
platform: Platform,
) -> SocketStats {
let (socket_tx, mut socket_rx) = mpsc::channel(4);
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
let rx_counter = Arc::new(AtomicUsize::new(0));
let server_bridges: ServerBridgeListLock = Arc::new(Mutex::new(Some(vec![])));
@@ -287,6 +304,8 @@ async fn process_socket(
let barrier_ctx = exit_barrier.clone();
let log_ctx = log.clone();
let rx_counter_ctx = rx_counter.clone();
let http_requests_ctx = http_requests.clone();
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log_ctx.clone());
tokio::spawn(async move {
let mut ctx = HandlerContext {
@@ -301,6 +320,8 @@ async fn process_socket(
server_bridges: server_bridges_lock,
port_forwarding,
platform,
http: FallbackSimpleHttp::new(ReqwestSimpleHttp::new(), http_delegated),
http_requests: http_requests_ctx,
};
send_version(&ctx.socket_tx).await;
@@ -324,6 +345,25 @@ async fn process_socket(
writehalf.shutdown().await.ok();
break;
},
Some(r) = http_rx.recv() => {
let id = next_message_id();
let serialized = rmp_serde::to_vec_named(&ToClientRequest {
id: None,
params: ClientRequestMethod::makehttpreq(HttpRequestParams {
url: &r.url,
method: r.method,
req_id: id,
}),
})
.unwrap();
http_requests.lock().unwrap().insert(id, r);
tx_counter += serialized.len();
if let Err(e) = writehalf.write_all(&serialized).await {
debug!(log, "Closing connection: {}", e);
break;
}
}
recv = socket_rx.recv() => match recv {
None => break,
Some(message) => match message {
@@ -509,6 +549,7 @@ async fn dispatch_next(req: ToServerRequest, ctx: &mut HandlerContext, did_updat
}
ServerRequestMethod::serve(params) => {
let log = ctx.log.clone();
let http = ctx.http.clone();
let server_bridges = ctx.server_bridges.clone();
let code_server_args = ctx.code_server_args.clone();
let code_server = ctx.code_server.clone();
@@ -519,6 +560,7 @@ async fn dispatch_next(req: ToServerRequest, ctx: &mut HandlerContext, did_updat
"serve",
handle_serve(
log,
http,
server_bridges,
code_server_args,
platform,
@@ -538,7 +580,7 @@ async fn dispatch_next(req: ToServerRequest, ctx: &mut HandlerContext, did_updat
}
ServerRequestMethod::update(p) => {
dispatch_blocking!("update", async {
let r = handle_update(&ctx.log, &p).await;
let r = handle_update(&ctx.http, &ctx.log, &p).await;
if matches!(&r, Ok(u) if u.did_update) {
*did_update = true;
}
@@ -567,6 +609,26 @@ async fn dispatch_next(req: ToServerRequest, ctx: &mut HandlerContext, did_updat
let port_forwarding = ctx.port_forwarding.clone();
dispatch_async!("unforward", handle_unforward(log, port_forwarding, p));
}
ServerRequestMethod::httpheaders(p) => {
if let Some(req) = ctx.http_requests.lock().unwrap().get(&p.req_id) {
req.initial_response(p.status_code, p.headers);
}
success!(ctx.socket_tx, EmptyResult {});
}
ServerRequestMethod::httpbody(p) => {
{
let mut reqs = ctx.http_requests.lock().unwrap();
if let Some(req) = reqs.get(&p.req_id) {
if !p.segment.is_empty() {
req.body(p.segment);
}
if p.complete {
reqs.remove(&p.req_id);
}
}
}
success!(ctx.socket_tx, EmptyResult {});
}
};
}
@@ -594,6 +656,7 @@ impl log::LogSink for ServerOutputSink {
#[allow(clippy::too_many_arguments)]
async fn handle_serve(
log: log::Logger,
http: FallbackSimpleHttp,
server_bridges: ServerBridgeListLock,
mut code_server_args: CodeServerArgs,
platform: Platform,
@@ -607,15 +670,19 @@ async fn handle_serve(
.install_extensions
.extend(params.extensions.into_iter());
let resolved = ServerParamsRaw {
let params_raw = ServerParamsRaw {
commit_id: params.commit_id,
quality: params.quality,
code_server_args,
headless: true,
platform,
}
.resolve(&log)
.await?;
};
let resolved = if params.use_local_download {
params_raw.resolve(&log, http.delegated()).await
} else {
params_raw.resolve(&log, http.clone()).await
}?;
let mut server_ref = code_server.lock().await;
let server = match &*server_ref {
@@ -624,15 +691,27 @@ async fn handle_serve(
let install_log = log.tee(ServerOutputSink {
tx: socket_tx.clone(),
});
let sb = ServerBuilder::new(&install_log, &resolved, &launcher_paths);
let server = match sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
None => {
sb.setup().await?;
sb.listen_on_default_socket().await?
}
macro_rules! do_setup {
($sb:expr) => {
match $sb.get_running().await? {
Some(AnyCodeServer::Socket(s)) => s,
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
None => {
$sb.setup().await?;
$sb.listen_on_default_socket().await?
}
}
};
}
let server = if params.use_local_download {
let sb =
ServerBuilder::new(&install_log, &resolved, &launcher_paths, http.delegated());
do_setup!(sb)
} else {
let sb = ServerBuilder::new(&install_log, &resolved, &launcher_paths, http);
do_setup!(sb)
};
server_ref.replace(server.clone());
@@ -699,8 +778,12 @@ async fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
})
}
async fn handle_update(log: &log::Logger, params: &UpdateParams) -> Result<UpdateResult, AnyError> {
let update_service = UpdateService::new(log.clone(), reqwest::Client::new());
async fn handle_update(
http: &FallbackSimpleHttp,
log: &log::Logger,
params: &UpdateParams,
) -> Result<UpdateResult, AnyError> {
let update_service = UpdateService::new(log.clone(), http.clone());
let updater = SelfUpdate::new(&update_service)?;
let latest_release = updater.get_current_release().await?;
let up_to_date = updater.is_up_to_date_with(&latest_release);