mcp: fix concurrent request response collection race (#299628)

* mcp: fix concurrent request response collection race

- JsonRpcProtocol.handleMessage now returns JsonRpcMessage[] containing
  responses generated by incoming requests, rather than delegating
  response collection to callers via side-channel state
- McpGatewaySession simplified by removing _pendingResponses and
  _isCollectingPostResponses fields, which were susceptible to racing
  under concurrent HTTP POSTs. Now directly uses handleMessage's
  return value for the response body
- _send callback still invoked for all messages (backward compatible
  with McpServerRequestHandler and SSE notification broadcast)
- Updated tests to assert on handleMessage return values

Fixes #297780

(Commit message generated by Copilot)

* mcp: address review comments on jsonRpcProtocol changes

- Adds JSDoc to handleMessage clarifying what is returned (only responses
  for incoming requests), ordering guarantees for batch inputs, and that
  responses are still emitted via _send callback to avoid double-sending
- Tightens _handleRequest return type from Promise<JsonRpcMessage> to
  Promise<JsonRpcResponse>, enforcing that only valid responses are
  returned. Introduces JsonRpcResponse type alias for better type safety
- Expands error handling tests to assert that returned replies match
  what is emitted via _send for both JsonRpcError and generic error paths

Fixes #297780

(Commit message generated by Copilot)
This commit is contained in:
Connor Peet
2026-03-06 14:52:44 -08:00
committed by GitHub
parent 9fc3d42d22
commit 57479c0e8a
3 changed files with 71 additions and 48 deletions

View File

@@ -43,6 +43,7 @@ export interface IJsonRpcErrorResponse {
}
export type JsonRpcMessage = IJsonRpcRequest | IJsonRpcNotification | IJsonRpcSuccessResponse | IJsonRpcErrorResponse;
export type JsonRpcResponse = IJsonRpcSuccessResponse | IJsonRpcErrorResponse;
interface IPendingRequest {
promise: DeferredPromise<unknown>;
@@ -122,15 +123,31 @@ export class JsonRpcProtocol extends Disposable {
}) as Promise<T>;
}
public async handleMessage(message: JsonRpcMessage | JsonRpcMessage[]): Promise<void> {
/**
* Handles one or more incoming JSON-RPC messages.
*
* Returns an array of JSON-RPC response objects generated for any incoming
* requests in the message(s). Notifications and responses to our own
* outgoing requests do not produce return values. For batch inputs, the
* returned responses are in the same order as the corresponding requests.
*
* Note: responses are also emitted via the `_send` callback, so callers
* that rely on the return value should not re-send them.
*/
public async handleMessage(message: JsonRpcMessage | JsonRpcMessage[]): Promise<JsonRpcResponse[]> {
if (Array.isArray(message)) {
const replies: JsonRpcResponse[] = [];
for (const single of message) {
await this._handleMessage(single);
const reply = await this._handleMessage(single);
if (reply) {
replies.push(reply);
}
}
return;
return replies;
}
await this._handleMessage(message);
const reply = await this._handleMessage(message);
return reply ? [reply] : [];
}
public cancelPendingRequest(id: JsonRpcId): void {
@@ -152,22 +169,25 @@ export class JsonRpcProtocol extends Disposable {
}
}
private async _handleMessage(message: JsonRpcMessage): Promise<void> {
private async _handleMessage(message: JsonRpcMessage): Promise<JsonRpcResponse | undefined> {
if (isJsonRpcResponse(message)) {
if (hasKey(message, { result: true })) {
this._handleResult(message);
} else {
this._handleError(message);
}
return undefined;
}
if (isJsonRpcRequest(message)) {
await this._handleRequest(message);
return this._handleRequest(message);
}
if (isJsonRpcNotification(message)) {
this._handlers.handleNotification?.(message);
}
return undefined;
}
private _handleResult(response: IJsonRpcSuccessResponse): void {
@@ -192,17 +212,18 @@ export class JsonRpcProtocol extends Disposable {
}
}
private async _handleRequest(request: IJsonRpcRequest): Promise<void> {
private async _handleRequest(request: IJsonRpcRequest): Promise<JsonRpcResponse> {
if (!this._handlers.handleRequest) {
this._send({
const response: IJsonRpcErrorResponse = {
jsonrpc: '2.0',
id: request.id,
error: {
code: JsonRpcProtocol.MethodNotFound,
message: `Method not found: ${request.method}`,
}
});
return;
};
this._send(response);
return response;
}
const cts = new CancellationTokenSource();
@@ -211,14 +232,17 @@ export class JsonRpcProtocol extends Disposable {
try {
const resultOrThenable = this._handlers.handleRequest(request, cts.token);
const result = isThenable(resultOrThenable) ? await resultOrThenable : resultOrThenable;
this._send({
const response: IJsonRpcSuccessResponse = {
jsonrpc: '2.0',
id: request.id,
result,
});
};
this._send(response);
return response;
} catch (error) {
let response: IJsonRpcErrorResponse;
if (error instanceof JsonRpcError) {
this._send({
response = {
jsonrpc: '2.0',
id: request.id,
error: {
@@ -226,17 +250,19 @@ export class JsonRpcProtocol extends Disposable {
message: error.message,
data: error.data,
}
});
};
} else {
this._send({
response = {
jsonrpc: '2.0',
id: request.id,
error: {
code: JsonRpcProtocol.InternalError,
message: error instanceof Error ? error.message : 'Internal error',
}
});
};
}
this._send(response);
return response;
} finally {
cts.dispose(true);
}

View File

@@ -39,7 +39,7 @@ suite('JsonRpcProtocol', () => {
const requestPromise = protocol.sendRequest<string>({ method: 'echo', params: { value: 'ok' } });
const outgoingRequest = sentMessages[0] as IJsonRpcRequest;
await protocol.handleMessage({
const replies = await protocol.handleMessage({
jsonrpc: '2.0',
id: outgoingRequest.id,
result: 'done'
@@ -47,6 +47,7 @@ suite('JsonRpcProtocol', () => {
const result = await requestPromise;
assert.strictEqual(result, 'done');
assert.deepStrictEqual(replies, []);
});
test('sendRequest rejects on error response', async () => {
@@ -107,20 +108,22 @@ suite('JsonRpcProtocol', () => {
test('handleRequest responds with method not found without handler', async () => {
const { protocol, sentMessages } = createProtocol();
await protocol.handleMessage({
const replies = await protocol.handleMessage({
jsonrpc: '2.0',
id: 7,
method: 'unknown'
});
assert.deepStrictEqual(sentMessages, [{
const expected = [{
jsonrpc: '2.0',
id: 7,
error: {
code: -32601,
message: 'Method not found: unknown'
}
}]);
}];
assert.deepStrictEqual(sentMessages, expected);
assert.deepStrictEqual(replies, expected);
});
test('handleRequest responds with result and passes cancellation token', async () => {
@@ -134,7 +137,7 @@ suite('JsonRpcProtocol', () => {
}
});
await protocol.handleMessage({
const replies = await protocol.handleMessage({
jsonrpc: '2.0',
id: 9,
method: 'compute'
@@ -142,27 +145,29 @@ suite('JsonRpcProtocol', () => {
assert.ok(receivedToken);
assert.strictEqual(wasCanceledDuringHandler, false);
assert.deepStrictEqual(sentMessages, [{
const expected = [{
jsonrpc: '2.0',
id: 9,
result: 'compute:ok'
}]);
}];
assert.deepStrictEqual(sentMessages, expected);
assert.deepStrictEqual(replies, expected);
});
test('handleRequest serializes JsonRpcError', async () => {
test('handleRequest serializes JsonRpcError and returns it', async () => {
const { protocol, sentMessages } = createProtocol({
handleRequest: () => {
throw new JsonRpcError(88, 'bad request', { detail: true });
}
});
await protocol.handleMessage({
const replies = await protocol.handleMessage({
jsonrpc: '2.0',
id: 'a',
method: 'boom'
});
assert.deepStrictEqual(sentMessages, [{
const expected = [{
jsonrpc: '2.0',
id: 'a',
error: {
@@ -170,30 +175,34 @@ suite('JsonRpcProtocol', () => {
message: 'bad request',
data: { detail: true }
}
}]);
}];
assert.deepStrictEqual(sentMessages, expected);
assert.deepStrictEqual(replies, expected);
});
test('handleRequest maps unknown errors to internal error', async () => {
test('handleRequest maps unknown errors to internal error and returns it', async () => {
const { protocol, sentMessages } = createProtocol({
handleRequest: () => {
throw new Error('unexpected');
}
});
await protocol.handleMessage({
const replies = await protocol.handleMessage({
jsonrpc: '2.0',
id: 'b',
method: 'explode'
});
assert.deepStrictEqual(sentMessages, [{
const expected = [{
jsonrpc: '2.0',
id: 'b',
error: {
code: -32603,
message: 'unexpected'
}
}]);
}];
assert.deepStrictEqual(sentMessages, expected);
assert.deepStrictEqual(replies, expected);
});
test('handleMessage processes batch sequentially', async () => {
@@ -225,8 +234,9 @@ suite('JsonRpcProtocol', () => {
assert.deepStrictEqual(sequence, ['request:start']);
gate.complete();
await handlingPromise;
const replies = await handlingPromise;
assert.deepStrictEqual(sequence, ['request:start', 'request:end', 'notification']);
assert.deepStrictEqual(replies, [{ jsonrpc: '2.0', id: 1, result: true }]);
});
});

View File

@@ -6,7 +6,7 @@
import type * as http from 'http';
import {
IJsonRpcNotification, IJsonRpcRequest,
isJsonRpcNotification, isJsonRpcResponse, JsonRpcError, JsonRpcMessage, JsonRpcProtocol
isJsonRpcNotification, isJsonRpcResponse, JsonRpcError, JsonRpcMessage, JsonRpcProtocol, JsonRpcResponse
} from '../../../base/common/jsonRpcProtocol.js';
import { Disposable } from '../../../base/common/lifecycle.js';
import { hasKey } from '../../../base/common/types.js';
@@ -79,8 +79,6 @@ function encodeResourceUrisInContent(content: MCP.ContentBlock[], serverIndex: n
export class McpGatewaySession extends Disposable {
private readonly _rpc: JsonRpcProtocol;
private readonly _sseClients = new Set<http.ServerResponse>();
private readonly _pendingResponses: JsonRpcMessage[] = [];
private _isCollectingPostResponses = false;
private _lastEventId = 0;
private _isInitialized = false;
@@ -136,16 +134,8 @@ export class McpGatewaySession extends Disposable {
});
}
public async handleIncoming(message: JsonRpcMessage | JsonRpcMessage[]): Promise<JsonRpcMessage[]> {
this._pendingResponses.length = 0;
this._isCollectingPostResponses = true;
try {
await this._rpc.handleMessage(message);
return [...this._pendingResponses];
} finally {
this._isCollectingPostResponses = false;
this._pendingResponses.length = 0;
}
public async handleIncoming(message: JsonRpcMessage | JsonRpcMessage[]): Promise<JsonRpcResponse[]> {
return this._rpc.handleMessage(message);
}
public override dispose(): void {
@@ -162,9 +152,6 @@ export class McpGatewaySession extends Disposable {
private _handleOutgoingMessage(message: JsonRpcMessage): void {
if (isJsonRpcResponse(message)) {
if (this._isCollectingPostResponses) {
this._pendingResponses.push(message);
}
this._logService.debug(`[McpGateway][session ${this.id}] --> response: ${JSON.stringify(message)}`);
return;
}