From f99981ed56f3adca0d446e9e2ec69fdff563d2c0 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Wed, 28 Feb 2024 11:47:31 -0800 Subject: [PATCH] cli: fix compressor not draining and leading to truncated responses (#206464) * cli: fix compressor not draining and leading to truncated responses Fixes https://github.com/microsoft/vscode-remote-release/issues/9594 * fix lint --- cli/src/tunnels/socket_signal.rs | 115 +++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 36 deletions(-) diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs index 9036c6ae3f9..2227f323852 100644 --- a/cli/src/tunnels/socket_signal.rs +++ b/cli/src/tunnels/socket_signal.rs @@ -94,41 +94,42 @@ impl ServerMessageSink { async fn server_message_or_closed( &mut self, - body: Option<&[u8]>, + body_or_end: Option<&[u8]>, ) -> Result<(), mpsc::error::SendError> { let i = self.id; let mut tx = self.tx.take().unwrap(); - let msg = body - .map(|b| self.get_server_msg_content(b)) - .map(|body| RefServerMessageParams { i, body }); - let r = match &mut tx { - ServerMessageDestination::Channel(tx) => { - tx.send(SocketSignal::from_message(&ToClientRequest { - id: None, - params: match msg { - Some(msg) => ClientRequestMethod::servermsg(msg), - None => ClientRequestMethod::serverclose(ServerClosedParams { i }), - }, - })) - .await - } - ServerMessageDestination::Rpc(caller) => { - match msg { - Some(msg) => caller.notify("servermsg", msg), - None => caller.notify("serverclose", ServerClosedParams { i }), - }; - Ok(()) - } - }; + if let Some(b) = body_or_end { + let body = self.get_server_msg_content(b, false); + let r = + send_data_or_close_if_none(i, &mut tx, Some(RefServerMessageParams { i, body })) + .await; + self.tx = Some(tx); + return r; + } + let tail = self.get_server_msg_content(&[], true); + if !tail.is_empty() { + let _ = send_data_or_close_if_none( + i, + &mut tx, + Some(RefServerMessageParams { i, body: tail }), + ) + .await; + } + + let r = send_data_or_close_if_none(i, &mut tx, None).await; self.tx = Some(tx); r } - pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] { + pub(crate) fn get_server_msg_content<'a: 'b, 'b>( + &'a mut self, + body: &'b [u8], + finish: bool, + ) -> &'b [u8] { if let Some(flate) = &mut self.flate { - if let Ok(compressed) = flate.process(body) { + if let Ok(compressed) = flate.process(body, finish) { return compressed; } } @@ -137,6 +138,32 @@ impl ServerMessageSink { } } +async fn send_data_or_close_if_none( + i: u16, + tx: &mut ServerMessageDestination, + msg: Option>, +) -> Result<(), mpsc::error::SendError> { + match tx { + ServerMessageDestination::Channel(tx) => { + tx.send(SocketSignal::from_message(&ToClientRequest { + id: None, + params: match msg { + Some(msg) => ClientRequestMethod::servermsg(msg), + None => ClientRequestMethod::serverclose(ServerClosedParams { i }), + }, + })) + .await + } + ServerMessageDestination::Rpc(caller) => { + match msg { + Some(msg) => caller.notify("servermsg", msg), + None => caller.notify("serverclose", ServerClosedParams { i }), + }; + Ok(()) + } + } +} + impl Drop for ServerMessageSink { fn drop(&mut self) { self.multiplexer.remove(self.id); @@ -162,7 +189,8 @@ impl ClientMessageDecoder { pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> { match &mut self.dec { - Some(d) => d.process(message), + // todo@connor4312 do we ever need to actually 'finish' the client message stream? + Some(d) => d.process(message, false), None => Ok(message), } } @@ -175,6 +203,7 @@ trait FlateAlgorithm { &mut self, contents: &[u8], output: &mut [u8], + finish: bool, ) -> Result; } @@ -193,9 +222,15 @@ impl FlateAlgorithm for DecompressFlateAlgorithm { &mut self, contents: &[u8], output: &mut [u8], + finish: bool, ) -> Result { + let mode = match finish { + true => flate2::FlushDecompress::Finish, + false => flate2::FlushDecompress::None, + }; + self.0 - .decompress(contents, output, flate2::FlushDecompress::None) + .decompress(contents, output, mode) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) } } @@ -215,9 +250,15 @@ impl FlateAlgorithm for CompressFlateAlgorithm { &mut self, contents: &[u8], output: &mut [u8], + finish: bool, ) -> Result { + let mode = match finish { + true => flate2::FlushCompress::Finish, + false => flate2::FlushCompress::Sync, + }; + self.0 - .compress(contents, output, flate2::FlushCompress::Sync) + .compress(contents, output, mode) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) } } @@ -241,23 +282,25 @@ where } } - pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> { + pub fn process(&mut self, contents: &[u8], finish: bool) -> std::io::Result<&[u8]> { let mut out_offset = 0; let mut in_offset = 0; loop { let in_before = self.flate.total_in(); let out_before = self.flate.total_out(); - match self - .flate - .process(&contents[in_offset..], &mut self.output[out_offset..]) - { + match self.flate.process( + &contents[in_offset..], + &mut self.output[out_offset..], + finish, + ) { Ok(flate2::Status::Ok | flate2::Status::BufError) => { let processed_len = in_offset + (self.flate.total_in() - in_before) as usize; let output_len = out_offset + (self.flate.total_out() - out_before) as usize; - if processed_len < contents.len() { + if processed_len < contents.len() || output_len == self.output.len() { // If we filled the output buffer but there's more data to compress, - // extend the output buffer and keep compressing. + // or the output got filled after processing all input, extend + // the output buffer and keep compressing. out_offset = output_len; in_offset = processed_len; if output_len == self.output.len() { @@ -298,7 +341,7 @@ mod tests { // 3000 and 30000 test resizing the buffer for msg_len in [3, 30, 300, 3000, 30000] { let vals = (0..msg_len).map(|v| v as u8).collect::>(); - let compressed = sink.get_server_msg_content(&vals); + let compressed = sink.get_server_msg_content(&vals, false); assert_ne!(compressed, vals); let decompressed = decompress.decode(compressed).unwrap(); assert_eq!(decompressed.len(), vals.len());