diff --git a/src/vs/base/common/jsonRpcProtocol.ts b/src/vs/base/common/jsonRpcProtocol.ts index 67c4ed4fc4d..35d7144ba82 100644 --- a/src/vs/base/common/jsonRpcProtocol.ts +++ b/src/vs/base/common/jsonRpcProtocol.ts @@ -43,6 +43,7 @@ export interface IJsonRpcErrorResponse { } export type JsonRpcMessage = IJsonRpcRequest | IJsonRpcNotification | IJsonRpcSuccessResponse | IJsonRpcErrorResponse; +export type JsonRpcResponse = IJsonRpcSuccessResponse | IJsonRpcErrorResponse; interface IPendingRequest { promise: DeferredPromise; @@ -122,15 +123,31 @@ export class JsonRpcProtocol extends Disposable { }) as Promise; } - public async handleMessage(message: JsonRpcMessage | JsonRpcMessage[]): Promise { + /** + * Handles one or more incoming JSON-RPC messages. + * + * Returns an array of JSON-RPC response objects generated for any incoming + * requests in the message(s). Notifications and responses to our own + * outgoing requests do not produce return values. For batch inputs, the + * returned responses are in the same order as the corresponding requests. + * + * Note: responses are also emitted via the `_send` callback, so callers + * that rely on the return value should not re-send them. + */ + public async handleMessage(message: JsonRpcMessage | JsonRpcMessage[]): Promise { if (Array.isArray(message)) { + const replies: JsonRpcResponse[] = []; for (const single of message) { - await this._handleMessage(single); + const reply = await this._handleMessage(single); + if (reply) { + replies.push(reply); + } } - return; + return replies; } - await this._handleMessage(message); + const reply = await this._handleMessage(message); + return reply ? [reply] : []; } public cancelPendingRequest(id: JsonRpcId): void { @@ -152,22 +169,25 @@ export class JsonRpcProtocol extends Disposable { } } - private async _handleMessage(message: JsonRpcMessage): Promise { + private async _handleMessage(message: JsonRpcMessage): Promise { if (isJsonRpcResponse(message)) { if (hasKey(message, { result: true })) { this._handleResult(message); } else { this._handleError(message); } + return undefined; } if (isJsonRpcRequest(message)) { - await this._handleRequest(message); + return this._handleRequest(message); } if (isJsonRpcNotification(message)) { this._handlers.handleNotification?.(message); } + + return undefined; } private _handleResult(response: IJsonRpcSuccessResponse): void { @@ -192,17 +212,18 @@ export class JsonRpcProtocol extends Disposable { } } - private async _handleRequest(request: IJsonRpcRequest): Promise { + private async _handleRequest(request: IJsonRpcRequest): Promise { if (!this._handlers.handleRequest) { - this._send({ + const response: IJsonRpcErrorResponse = { jsonrpc: '2.0', id: request.id, error: { code: JsonRpcProtocol.MethodNotFound, message: `Method not found: ${request.method}`, } - }); - return; + }; + this._send(response); + return response; } const cts = new CancellationTokenSource(); @@ -211,14 +232,17 @@ export class JsonRpcProtocol extends Disposable { try { const resultOrThenable = this._handlers.handleRequest(request, cts.token); const result = isThenable(resultOrThenable) ? await resultOrThenable : resultOrThenable; - this._send({ + const response: IJsonRpcSuccessResponse = { jsonrpc: '2.0', id: request.id, result, - }); + }; + this._send(response); + return response; } catch (error) { + let response: IJsonRpcErrorResponse; if (error instanceof JsonRpcError) { - this._send({ + response = { jsonrpc: '2.0', id: request.id, error: { @@ -226,17 +250,19 @@ export class JsonRpcProtocol extends Disposable { message: error.message, data: error.data, } - }); + }; } else { - this._send({ + response = { jsonrpc: '2.0', id: request.id, error: { code: JsonRpcProtocol.InternalError, message: error instanceof Error ? error.message : 'Internal error', } - }); + }; } + this._send(response); + return response; } finally { cts.dispose(true); } diff --git a/src/vs/base/test/common/jsonRpcProtocol.test.ts b/src/vs/base/test/common/jsonRpcProtocol.test.ts index 4a167d2cc8a..9a000e35f48 100644 --- a/src/vs/base/test/common/jsonRpcProtocol.test.ts +++ b/src/vs/base/test/common/jsonRpcProtocol.test.ts @@ -39,7 +39,7 @@ suite('JsonRpcProtocol', () => { const requestPromise = protocol.sendRequest({ method: 'echo', params: { value: 'ok' } }); const outgoingRequest = sentMessages[0] as IJsonRpcRequest; - await protocol.handleMessage({ + const replies = await protocol.handleMessage({ jsonrpc: '2.0', id: outgoingRequest.id, result: 'done' @@ -47,6 +47,7 @@ suite('JsonRpcProtocol', () => { const result = await requestPromise; assert.strictEqual(result, 'done'); + assert.deepStrictEqual(replies, []); }); test('sendRequest rejects on error response', async () => { @@ -107,20 +108,22 @@ suite('JsonRpcProtocol', () => { test('handleRequest responds with method not found without handler', async () => { const { protocol, sentMessages } = createProtocol(); - await protocol.handleMessage({ + const replies = await protocol.handleMessage({ jsonrpc: '2.0', id: 7, method: 'unknown' }); - assert.deepStrictEqual(sentMessages, [{ + const expected = [{ jsonrpc: '2.0', id: 7, error: { code: -32601, message: 'Method not found: unknown' } - }]); + }]; + assert.deepStrictEqual(sentMessages, expected); + assert.deepStrictEqual(replies, expected); }); test('handleRequest responds with result and passes cancellation token', async () => { @@ -134,7 +137,7 @@ suite('JsonRpcProtocol', () => { } }); - await protocol.handleMessage({ + const replies = await protocol.handleMessage({ jsonrpc: '2.0', id: 9, method: 'compute' @@ -142,27 +145,29 @@ suite('JsonRpcProtocol', () => { assert.ok(receivedToken); assert.strictEqual(wasCanceledDuringHandler, false); - assert.deepStrictEqual(sentMessages, [{ + const expected = [{ jsonrpc: '2.0', id: 9, result: 'compute:ok' - }]); + }]; + assert.deepStrictEqual(sentMessages, expected); + assert.deepStrictEqual(replies, expected); }); - test('handleRequest serializes JsonRpcError', async () => { + test('handleRequest serializes JsonRpcError and returns it', async () => { const { protocol, sentMessages } = createProtocol({ handleRequest: () => { throw new JsonRpcError(88, 'bad request', { detail: true }); } }); - await protocol.handleMessage({ + const replies = await protocol.handleMessage({ jsonrpc: '2.0', id: 'a', method: 'boom' }); - assert.deepStrictEqual(sentMessages, [{ + const expected = [{ jsonrpc: '2.0', id: 'a', error: { @@ -170,30 +175,34 @@ suite('JsonRpcProtocol', () => { message: 'bad request', data: { detail: true } } - }]); + }]; + assert.deepStrictEqual(sentMessages, expected); + assert.deepStrictEqual(replies, expected); }); - test('handleRequest maps unknown errors to internal error', async () => { + test('handleRequest maps unknown errors to internal error and returns it', async () => { const { protocol, sentMessages } = createProtocol({ handleRequest: () => { throw new Error('unexpected'); } }); - await protocol.handleMessage({ + const replies = await protocol.handleMessage({ jsonrpc: '2.0', id: 'b', method: 'explode' }); - assert.deepStrictEqual(sentMessages, [{ + const expected = [{ jsonrpc: '2.0', id: 'b', error: { code: -32603, message: 'unexpected' } - }]); + }]; + assert.deepStrictEqual(sentMessages, expected); + assert.deepStrictEqual(replies, expected); }); test('handleMessage processes batch sequentially', async () => { @@ -225,8 +234,9 @@ suite('JsonRpcProtocol', () => { assert.deepStrictEqual(sequence, ['request:start']); gate.complete(); - await handlingPromise; + const replies = await handlingPromise; assert.deepStrictEqual(sequence, ['request:start', 'request:end', 'notification']); + assert.deepStrictEqual(replies, [{ jsonrpc: '2.0', id: 1, result: true }]); }); }); diff --git a/src/vs/platform/mcp/node/mcpGatewaySession.ts b/src/vs/platform/mcp/node/mcpGatewaySession.ts index 579b0184495..20f6d23dc71 100644 --- a/src/vs/platform/mcp/node/mcpGatewaySession.ts +++ b/src/vs/platform/mcp/node/mcpGatewaySession.ts @@ -6,7 +6,7 @@ import type * as http from 'http'; import { IJsonRpcNotification, IJsonRpcRequest, - isJsonRpcNotification, isJsonRpcResponse, JsonRpcError, JsonRpcMessage, JsonRpcProtocol + isJsonRpcNotification, isJsonRpcResponse, JsonRpcError, JsonRpcMessage, JsonRpcProtocol, JsonRpcResponse } from '../../../base/common/jsonRpcProtocol.js'; import { Disposable } from '../../../base/common/lifecycle.js'; import { hasKey } from '../../../base/common/types.js'; @@ -79,8 +79,6 @@ function encodeResourceUrisInContent(content: MCP.ContentBlock[], serverIndex: n export class McpGatewaySession extends Disposable { private readonly _rpc: JsonRpcProtocol; private readonly _sseClients = new Set(); - private readonly _pendingResponses: JsonRpcMessage[] = []; - private _isCollectingPostResponses = false; private _lastEventId = 0; private _isInitialized = false; @@ -136,16 +134,8 @@ export class McpGatewaySession extends Disposable { }); } - public async handleIncoming(message: JsonRpcMessage | JsonRpcMessage[]): Promise { - this._pendingResponses.length = 0; - this._isCollectingPostResponses = true; - try { - await this._rpc.handleMessage(message); - return [...this._pendingResponses]; - } finally { - this._isCollectingPostResponses = false; - this._pendingResponses.length = 0; - } + public async handleIncoming(message: JsonRpcMessage | JsonRpcMessage[]): Promise { + return this._rpc.handleMessage(message); } public override dispose(): void { @@ -162,9 +152,6 @@ export class McpGatewaySession extends Disposable { private _handleOutgoingMessage(message: JsonRpcMessage): void { if (isJsonRpcResponse(message)) { - if (this._isCollectingPostResponses) { - this._pendingResponses.push(message); - } this._logService.debug(`[McpGateway][session ${this.id}] --> response: ${JSON.stringify(message)}`); return; }