This commit is contained in:
Connor Peet
2023-01-18 16:15:43 -08:00
parent d89738a0e2
commit 8965c48d30
11 changed files with 377 additions and 182 deletions

View File

@@ -48,20 +48,15 @@ use super::protocol::{
VersionParams,
};
use super::server_bridge::{get_socket_rw_stream, ServerBridge};
use super::server_multiplexer::ServerMultiplexer;
use super::shutdown_signal::ShutdownSignal;
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal};
use super::socket_signal::{
ClientMessageDecoder, ServerMessageDestination, ServerMessageSink, SocketSignal,
};
type ServerBridgeListLock = Arc<std::sync::Mutex<Option<Vec<ServerBridgeRec>>>>;
type HttpRequestsMap = Arc<std::sync::Mutex<HashMap<u32, DelegatedHttpRequest>>>;
type CodeServerCell = Arc<Mutex<Option<SocketCodeServer>>>;
struct ServerBridgeRec {
id: u16,
// bridge is removed when there's a write loop currently active
bridge: Option<ServerBridge>,
write_queue: Vec<Vec<u8>>,
}
struct HandlerContext {
/// Log handle for the server
log: log::Logger,
@@ -74,7 +69,7 @@ struct HandlerContext {
/// Connected VS Code Server
code_server: CodeServerCell,
/// Potentially many "websocket" connections to client
server_bridges: ServerBridgeListLock,
server_bridges: ServerMultiplexer,
// the cli arguments used to start the code server
code_server_args: CodeServerArgs,
/// port forwarding functionality
@@ -96,28 +91,7 @@ pub fn next_message_id() -> u32 {
impl HandlerContext {
async fn dispose(&self) {
let bridges = {
let mut lock = self.server_bridges.lock().unwrap();
lock.take()
};
let bridges = match bridges {
Some(b) => b,
None => return,
};
for rec in bridges {
if let Some(b) = rec.bridge {
if let Err(e) = b.close().await {
warning!(
self.log,
"Could not properly dispose of connection context: {}",
e
)
}
}
}
self.server_bridges.dispose().await;
info!(self.log, "Disposed of connection to running server.");
}
}
@@ -295,7 +269,7 @@ async fn process_socket(
let (socket_tx, mut socket_rx) = mpsc::channel(4);
let rx_counter = Arc::new(AtomicUsize::new(0));
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
let server_bridges = Arc::new(std::sync::Mutex::new(Some(vec![])));
let server_bridges = ServerMultiplexer::new();
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
did_update: Arc::new(AtomicBool::new(false)),
@@ -426,13 +400,6 @@ async fn process_socket(
debug!(log, "Closing connection: {}", reason.0);
break;
}
SocketSignal::CloseServerBridge(id) => {
let mut lock = server_bridges.lock().unwrap();
match &mut *lock {
Some(bridges) => bridges.retain(|sb| sb.id != id),
None => {}
}
}
}
}
}
@@ -618,37 +585,34 @@ async fn attach_server_bridge(
log: &log::Logger,
code_server: SocketCodeServer,
socket_tx: mpsc::Sender<SocketSignal>,
server_bridges: ServerBridgeListLock,
multiplexer: ServerMultiplexer,
socket_id: u16,
compress: bool,
) -> Result<u16, AnyError> {
let (server_messages, decoder) = if compress {
(
ServerMessageSink::new_compressed(socket_tx),
ServerMessageSink::new_compressed(
multiplexer.clone(),
socket_id,
ServerMessageDestination::Channel(socket_tx),
),
ClientMessageDecoder::new_compressed(),
)
} else {
(
ServerMessageSink::new_plain(socket_tx),
ServerMessageSink::new_plain(
multiplexer.clone(),
socket_id,
ServerMessageDestination::Channel(socket_tx),
),
ClientMessageDecoder::new_plain(),
)
};
let attached_fut =
ServerBridge::new(&code_server.socket, socket_id, server_messages, decoder).await;
let attached_fut = ServerBridge::new(&code_server.socket, server_messages, decoder).await;
match attached_fut {
Ok(a) => {
let mut lock = server_bridges.lock().unwrap();
let bridge_rec = ServerBridgeRec {
id: socket_id,
bridge: Some(a),
write_queue: vec![],
};
match &mut *lock {
Some(server_bridges) => (*server_bridges).push(bridge_rec),
None => *lock = Some(vec![bridge_rec]),
}
multiplexer.register(socket_id, a);
trace!(log, "Attached to server");
Ok(socket_id)
}
@@ -660,71 +624,14 @@ async fn attach_server_bridge(
/// to ensure message order is preserved exactly, which is necessary for compression.
fn handle_server_message(
log: &log::Logger,
bridges_lock: &ServerBridgeListLock,
multiplexer: &ServerMultiplexer,
params: ServerMessageParams,
) -> Result<EmptyObject, AnyError> {
let mut lock = bridges_lock.lock().unwrap();
match &mut *lock {
Some(server_bridges) => match server_bridges.iter_mut().find(|b| b.id == params.i) {
Some(sb) => {
sb.write_queue.push(params.body);
if let Some(bridge) = sb.bridge.take() {
let bridges_lock = bridges_lock.clone();
let log = log.clone();
tokio::spawn(start_bridge_write_loop(log, sb.id, bridge, bridges_lock));
}
}
None => return Err(AnyError::from(NoAttachedServerError())),
},
None => return Err(AnyError::from(NoAttachedServerError())),
if multiplexer.write_message(log, params.i, params.body) {
Ok(EmptyObject {})
} else {
Err(AnyError::from(NoAttachedServerError()))
}
Ok(EmptyObject {})
}
/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll
/// return the ServerBridge so that the next handle_server_message call starts
/// the loop again. Otherwise, it'll close the bridge.
async fn start_bridge_write_loop(
log: log::Logger,
id: u16,
mut bridge: ServerBridge,
bridges_lock: ServerBridgeListLock,
) {
let mut items_vec = vec![];
loop {
{
let mut lock = bridges_lock.lock().unwrap();
let server_bridges = match &mut *lock {
Some(sb) => sb,
None => break,
};
let bridge_rec = match server_bridges.iter_mut().find(|b| id == b.id) {
Some(b) => b,
None => break,
};
if bridge_rec.write_queue.is_empty() {
bridge_rec.bridge = Some(bridge);
return;
}
std::mem::swap(&mut bridge_rec.write_queue, &mut items_vec);
}
for item in items_vec.drain(..) {
if let Err(e) = bridge.write(item).await {
warning!(log, "Error writing to server: {:?}", e);
break;
}
}
}
bridge.close().await.ok(); // got here from `break` above, meaning our record got cleared. Close the bridge if so
}
fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {