diff --git a/src/vs/base/parts/ipc/node/ipc.net.ts b/src/vs/base/parts/ipc/node/ipc.net.ts index 9a508286e23..5ada7242262 100644 --- a/src/vs/base/parts/ipc/node/ipc.net.ts +++ b/src/vs/base/parts/ipc/node/ipc.net.ts @@ -318,6 +318,10 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT return this._flowManager.recordedInflateBytes; } + public setRecordInflateBytes(record: boolean): void { + this._flowManager.setRecordInflateBytes(record); + } + public traceSocketEvent(type: SocketDiagnosticsEventType, data?: VSBuffer | Uint8Array | ArrayBuffer | ArrayBufferView | unknown): void { this.socket.traceSocketEvent(type, data); } @@ -598,6 +602,10 @@ class WebSocketFlowManager extends Disposable { return VSBuffer.alloc(0); } + public setRecordInflateBytes(record: boolean): void { + this._zlibInflateStream?.setRecordInflateBytes(record); + } + constructor( private readonly _tracer: ISocketTracer, permessageDeflate: boolean, @@ -714,6 +722,7 @@ class ZlibInflateStream extends Disposable { private readonly _zlibInflate: InflateRaw; private readonly _recordedInflateBytes: VSBuffer[] = []; private readonly _pendingInflateData: VSBuffer[] = []; + private _recordInflateBytes: boolean; public get recordedInflateBytes(): VSBuffer { if (this._recordInflateBytes) { @@ -724,11 +733,12 @@ class ZlibInflateStream extends Disposable { constructor( private readonly _tracer: ISocketTracer, - private readonly _recordInflateBytes: boolean, + recordInflateBytes: boolean, inflateBytes: VSBuffer | null, options: ZlibOptions ) { super(); + this._recordInflateBytes = recordInflateBytes; this._zlibInflate = createInflateRaw(options); this._zlibInflate.on('error', (err: Error) => { this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibInflateError, { message: err?.message, code: (err as NodeJS.ErrnoException)?.code }); @@ -756,6 +766,13 @@ class ZlibInflateStream extends Disposable { this._zlibInflate.write(buffer.buffer); } + public setRecordInflateBytes(record: boolean): void { + this._recordInflateBytes = record; + if (!record) { + this._recordedInflateBytes.length = 0; + } + } + public flush(callback: (data: VSBuffer) => void): void { this._zlibInflate.flush(() => { this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibInflateFlushFired); @@ -764,6 +781,17 @@ class ZlibInflateStream extends Disposable { callback(data); }); } + + public override dispose(): void { + this._recordedInflateBytes.length = 0; + this._pendingInflateData.length = 0; + try { + this._zlibInflate.close(); + } catch { + // ignore errors while disposing + } + super.dispose(); + } } class ZlibDeflateStream extends Disposable { @@ -812,6 +840,16 @@ class ZlibDeflateStream extends Disposable { callback(data); }); } + + public override dispose(): void { + this._pendingDeflateData.length = 0; + try { + this._zlibDeflate.close(); + } catch { + // ignore errors while disposing + } + super.dispose(); + } } function unmask(buffer: VSBuffer, mask: number): void { diff --git a/src/vs/base/parts/ipc/test/node/ipc.net.test.ts b/src/vs/base/parts/ipc/test/node/ipc.net.test.ts index a1bc9a5749a..6c96decef45 100644 --- a/src/vs/base/parts/ipc/test/node/ipc.net.test.ts +++ b/src/vs/base/parts/ipc/test/node/ipc.net.test.ts @@ -711,6 +711,47 @@ suite('WebSocketNodeSocket', () => { assert.deepStrictEqual(actual, 'Hello'); }); + test('setRecordInflateBytes(false) clears and stops recording', async () => { + const disposables = new DisposableStore(); + const socket = disposables.add(new FakeNodeSocket()); + // eslint-disable-next-line local/code-no-any-casts + const webSocket = disposables.add(new WebSocketNodeSocket(socket, true, null, true)); + + const compressedHelloFrame = [0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00]; + const waitForOneData = () => new Promise(resolve => { + const d = webSocket.onData(data => { + d.dispose(); + resolve(data); + }); + }); + + const firstPromise = waitForOneData(); + socket.fireData(compressedHelloFrame); + const first = await firstPromise; + assert.strictEqual(fromCharCodeArray(fromUint8Array(first.buffer)), 'Hello'); + assert.ok(webSocket.recordedInflateBytes.byteLength > 0); + + webSocket.setRecordInflateBytes(false); + assert.strictEqual(webSocket.recordedInflateBytes.byteLength, 0); + + const secondPromise = waitForOneData(); + socket.fireData(compressedHelloFrame); + const second = await secondPromise; + assert.strictEqual(fromCharCodeArray(fromUint8Array(second.buffer)), 'Hello'); + assert.strictEqual(webSocket.recordedInflateBytes.byteLength, 0); + + webSocket.setRecordInflateBytes(true); + assert.strictEqual(webSocket.recordedInflateBytes.byteLength, 0); + + const thirdPromise = waitForOneData(); + socket.fireData(compressedHelloFrame); + const third = await thirdPromise; + assert.strictEqual(fromCharCodeArray(fromUint8Array(third.buffer)), 'Hello'); + assert.ok(webSocket.recordedInflateBytes.byteLength > 0); + + disposables.dispose(); + }); + test('A fragmented compressed text message', async () => { // contains "Hello" const frames = [ // contains "Hello" diff --git a/src/vs/server/node/extensionHostConnection.ts b/src/vs/server/node/extensionHostConnection.ts index 6ae4edd84b9..0daf9ee7031 100644 --- a/src/vs/server/node/extensionHostConnection.ts +++ b/src/vs/server/node/extensionHostConnection.ts @@ -94,6 +94,7 @@ class ConnectionData { skipWebSocketFrames = false; permessageDeflate = this.socket.permessageDeflate; inflateBytes = this.socket.recordedInflateBytes; + this.socket.setRecordInflateBytes(false); } return { @@ -133,6 +134,9 @@ export class ExtensionHostConnection extends Disposable { this._remoteAddress = remoteAddress; this._extensionHostProcess = null; this._connectionData = new ConnectionData(socket, initialDataChunk); + if (!this._canSendSocket && socket instanceof WebSocketNodeSocket) { + socket.setRecordInflateBytes(false); + } this._log(`New connection established.`); } @@ -209,6 +213,9 @@ export class ExtensionHostConnection extends Disposable { public acceptReconnection(remoteAddress: string, _socket: NodeSocket | WebSocketNodeSocket, initialDataChunk: VSBuffer): void { this._remoteAddress = remoteAddress; this._log(`The client has reconnected.`); + if (!this._canSendSocket && _socket instanceof WebSocketNodeSocket) { + _socket.setRecordInflateBytes(false); + } const connectionData = new ConnectionData(_socket, initialDataChunk); if (!this._extensionHostProcess) { diff --git a/src/vs/server/node/remoteExtensionHostAgentServer.ts b/src/vs/server/node/remoteExtensionHostAgentServer.ts index 269cc3878eb..da7e417cd5c 100644 --- a/src/vs/server/node/remoteExtensionHostAgentServer.ts +++ b/src/vs/server/node/remoteExtensionHostAgentServer.ts @@ -395,6 +395,9 @@ class RemoteExtensionHostAgentServer extends Disposable implements IServerAPI { if (msg.desiredConnectionType === ConnectionType.Management) { // This should become a management connection + if (socket instanceof WebSocketNodeSocket) { + socket.setRecordInflateBytes(false); + } if (isReconnection) { // This is a reconnection @@ -484,6 +487,9 @@ class RemoteExtensionHostAgentServer extends Disposable implements IServerAPI { } } else if (msg.desiredConnectionType === ConnectionType.Tunnel) { + if (socket instanceof WebSocketNodeSocket) { + socket.setRecordInflateBytes(false); + } const tunnelStartParams = msg.args; this._createTunnel(protocol, tunnelStartParams);