cli: fix compressor not draining and leading to truncated responses (#206464)

* cli: fix compressor not draining and leading to truncated responses

Fixes https://github.com/microsoft/vscode-remote-release/issues/9594

* fix lint
This commit is contained in:
Connor Peet
2024-02-28 11:47:31 -08:00
committed by GitHub
parent eb4e516a8a
commit f99981ed56

View File

@@ -94,41 +94,42 @@ impl ServerMessageSink {
async fn server_message_or_closed(
&mut self,
body: Option<&[u8]>,
body_or_end: Option<&[u8]>,
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
let i = self.id;
let mut tx = self.tx.take().unwrap();
let msg = body
.map(|b| self.get_server_msg_content(b))
.map(|body| RefServerMessageParams { i, body });
let r = match &mut tx {
ServerMessageDestination::Channel(tx) => {
tx.send(SocketSignal::from_message(&ToClientRequest {
id: None,
params: match msg {
Some(msg) => ClientRequestMethod::servermsg(msg),
None => ClientRequestMethod::serverclose(ServerClosedParams { i }),
},
}))
.await
}
ServerMessageDestination::Rpc(caller) => {
match msg {
Some(msg) => caller.notify("servermsg", msg),
None => caller.notify("serverclose", ServerClosedParams { i }),
};
Ok(())
}
};
if let Some(b) = body_or_end {
let body = self.get_server_msg_content(b, false);
let r =
send_data_or_close_if_none(i, &mut tx, Some(RefServerMessageParams { i, body }))
.await;
self.tx = Some(tx);
return r;
}
let tail = self.get_server_msg_content(&[], true);
if !tail.is_empty() {
let _ = send_data_or_close_if_none(
i,
&mut tx,
Some(RefServerMessageParams { i, body: tail }),
)
.await;
}
let r = send_data_or_close_if_none(i, &mut tx, None).await;
self.tx = Some(tx);
r
}
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] {
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(
&'a mut self,
body: &'b [u8],
finish: bool,
) -> &'b [u8] {
if let Some(flate) = &mut self.flate {
if let Ok(compressed) = flate.process(body) {
if let Ok(compressed) = flate.process(body, finish) {
return compressed;
}
}
@@ -137,6 +138,32 @@ impl ServerMessageSink {
}
}
async fn send_data_or_close_if_none(
i: u16,
tx: &mut ServerMessageDestination,
msg: Option<RefServerMessageParams<'_>>,
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
match tx {
ServerMessageDestination::Channel(tx) => {
tx.send(SocketSignal::from_message(&ToClientRequest {
id: None,
params: match msg {
Some(msg) => ClientRequestMethod::servermsg(msg),
None => ClientRequestMethod::serverclose(ServerClosedParams { i }),
},
}))
.await
}
ServerMessageDestination::Rpc(caller) => {
match msg {
Some(msg) => caller.notify("servermsg", msg),
None => caller.notify("serverclose", ServerClosedParams { i }),
};
Ok(())
}
}
}
impl Drop for ServerMessageSink {
fn drop(&mut self) {
self.multiplexer.remove(self.id);
@@ -162,7 +189,8 @@ impl ClientMessageDecoder {
pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
match &mut self.dec {
Some(d) => d.process(message),
// todo@connor4312 do we ever need to actually 'finish' the client message stream?
Some(d) => d.process(message, false),
None => Ok(message),
}
}
@@ -175,6 +203,7 @@ trait FlateAlgorithm {
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error>;
}
@@ -193,9 +222,15 @@ impl FlateAlgorithm for DecompressFlateAlgorithm {
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error> {
let mode = match finish {
true => flate2::FlushDecompress::Finish,
false => flate2::FlushDecompress::None,
};
self.0
.decompress(contents, output, flate2::FlushDecompress::None)
.decompress(contents, output, mode)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
}
@@ -215,9 +250,15 @@ impl FlateAlgorithm for CompressFlateAlgorithm {
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error> {
let mode = match finish {
true => flate2::FlushCompress::Finish,
false => flate2::FlushCompress::Sync,
};
self.0
.compress(contents, output, flate2::FlushCompress::Sync)
.compress(contents, output, mode)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
}
@@ -241,23 +282,25 @@ where
}
}
pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> {
pub fn process(&mut self, contents: &[u8], finish: bool) -> std::io::Result<&[u8]> {
let mut out_offset = 0;
let mut in_offset = 0;
loop {
let in_before = self.flate.total_in();
let out_before = self.flate.total_out();
match self
.flate
.process(&contents[in_offset..], &mut self.output[out_offset..])
{
match self.flate.process(
&contents[in_offset..],
&mut self.output[out_offset..],
finish,
) {
Ok(flate2::Status::Ok | flate2::Status::BufError) => {
let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
if processed_len < contents.len() {
if processed_len < contents.len() || output_len == self.output.len() {
// If we filled the output buffer but there's more data to compress,
// extend the output buffer and keep compressing.
// or the output got filled after processing all input, extend
// the output buffer and keep compressing.
out_offset = output_len;
in_offset = processed_len;
if output_len == self.output.len() {
@@ -298,7 +341,7 @@ mod tests {
// 3000 and 30000 test resizing the buffer
for msg_len in [3, 30, 300, 3000, 30000] {
let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
let compressed = sink.get_server_msg_content(&vals);
let compressed = sink.get_server_msg_content(&vals, false);
assert_ne!(compressed, vals);
let decompressed = decompress.decode(compressed).unwrap();
assert_eq!(decompressed.len(), vals.len());