From a5ff3ce6c405e644207ef16fc8e7d3fccbcd95fb Mon Sep 17 00:00:00 2001 From: Alex Dima Date: Tue, 4 Sep 2018 14:54:10 +0200 Subject: [PATCH] Add support for CancellationToken as last argument in RPC communication --- src/vs/base/common/async.ts | 7 ++ .../mainThreadLanguageFeatures.ts | 2 +- src/vs/workbench/api/node/extHost.protocol.ts | 3 +- .../api/node/extHostLanguageFeatures.ts | 19 +++- .../services/extensions/node/rpcProtocol.ts | 96 +++++++++++++------ .../extensions/test/node/rpcProtocol.test.ts | 41 ++++++++ 6 files changed, 131 insertions(+), 37 deletions(-) diff --git a/src/vs/base/common/async.ts b/src/vs/base/common/async.ts index b2cdcca32e8..51ef142cc04 100644 --- a/src/vs/base/common/async.ts +++ b/src/vs/base/common/async.ts @@ -87,6 +87,13 @@ export function asWinJsPromise(callback: (token: CancellationToken) => T | TP }); } +export function asThenable(item: T | TPromise | Thenable): Thenable { + if (item instanceof TPromise || isThenable(item)) { + return item; + } + return TPromise.wrap(item); +} + /** * Hook a cancellation token to a WinJS Promise */ diff --git a/src/vs/workbench/api/electron-browser/mainThreadLanguageFeatures.ts b/src/vs/workbench/api/electron-browser/mainThreadLanguageFeatures.ts index fe6a497da5c..b0c1fcaa4ea 100644 --- a/src/vs/workbench/api/electron-browser/mainThreadLanguageFeatures.ts +++ b/src/vs/workbench/api/electron-browser/mainThreadLanguageFeatures.ts @@ -179,7 +179,7 @@ export class MainThreadLanguageFeatures implements MainThreadLanguageFeaturesSha $registerHoverProvider(handle: number, selector: ISerializedDocumentFilter[]): void { this._registrations[handle] = modes.HoverProviderRegistry.register(typeConverters.LanguageSelector.from(selector), { provideHover: (model: ITextModel, position: EditorPosition, token: CancellationToken): Thenable => { - return wireCancellationToken(token, this._proxy.$provideHover(handle, model.uri, position)); + return this._proxy.$provideHover(handle, model.uri, position, token); } }); } diff --git a/src/vs/workbench/api/node/extHost.protocol.ts b/src/vs/workbench/api/node/extHost.protocol.ts index 1e847f934e7..84870e5e93a 100644 --- a/src/vs/workbench/api/node/extHost.protocol.ts +++ b/src/vs/workbench/api/node/extHost.protocol.ts @@ -42,6 +42,7 @@ import { createExtHostContextProxyIdentifier as createExtId, createMainContextPr import { IProgressOptions, IProgressStep } from 'vs/workbench/services/progress/common/progress'; import { SaveReason } from 'vs/workbench/services/textfile/common/textfiles'; import * as vscode from 'vscode'; +import { CancellationToken } from 'vs/base/common/cancellation'; export interface IEnvironment { isExtensionDevelopmentDebug: boolean; @@ -825,7 +826,7 @@ export interface ExtHostLanguageFeaturesShape { $provideDefinition(handle: number, resource: UriComponents, position: IPosition): TPromise; $provideImplementation(handle: number, resource: UriComponents, position: IPosition): TPromise; $provideTypeDefinition(handle: number, resource: UriComponents, position: IPosition): TPromise; - $provideHover(handle: number, resource: UriComponents, position: IPosition): TPromise; + $provideHover(handle: number, resource: UriComponents, position: IPosition, token: CancellationToken): Thenable; $provideDocumentHighlights(handle: number, resource: UriComponents, position: IPosition): TPromise; $provideReferences(handle: number, resource: UriComponents, position: IPosition, context: modes.ReferenceContext): TPromise; $provideCodeActions(handle: number, resource: UriComponents, rangeOrSelection: IRange | ISelection, context: modes.CodeActionContext): TPromise; diff --git a/src/vs/workbench/api/node/extHostLanguageFeatures.ts b/src/vs/workbench/api/node/extHostLanguageFeatures.ts index 52f19e427a9..ba5acc57ca2 100644 --- a/src/vs/workbench/api/node/extHostLanguageFeatures.ts +++ b/src/vs/workbench/api/node/extHostLanguageFeatures.ts @@ -16,7 +16,7 @@ import { ExtHostHeapService } from 'vs/workbench/api/node/extHostHeapService'; import { ExtHostDocuments } from 'vs/workbench/api/node/extHostDocuments'; import { ExtHostCommands, CommandsConverter } from 'vs/workbench/api/node/extHostCommands'; import { ExtHostDiagnostics } from 'vs/workbench/api/node/extHostDiagnostics'; -import { asWinJsPromise } from 'vs/base/common/async'; +import { asWinJsPromise, asThenable } from 'vs/base/common/async'; import { MainContext, MainThreadLanguageFeaturesShape, ExtHostLanguageFeaturesShape, ObjectIdentifier, IRawColorInfo, IMainContext, IdObject, ISerializedRegExp, ISerializedIndentationRule, ISerializedOnEnterRule, ISerializedLanguageConfiguration, WorkspaceSymbolDto, SuggestResultDto, WorkspaceSymbolsDto, SuggestionDto, CodeActionDto, ISerializedDocumentFilter, WorkspaceEditDto } from './extHost.protocol'; import { regExpLeadsToEndlessLoop } from 'vs/base/common/strings'; import { IPosition } from 'vs/editor/common/core/position'; @@ -26,6 +26,7 @@ import { isObject } from 'vs/base/common/types'; import { ISelection, Selection } from 'vs/editor/common/core/selection'; import { IExtensionDescription } from 'vs/workbench/services/extensions/common/extensions'; import { ILogService } from 'vs/platform/log/common/log'; +import { CancellationToken } from 'vs/base/common/cancellation'; // --- adapter @@ -202,12 +203,12 @@ class HoverAdapter { private readonly _provider: vscode.HoverProvider, ) { } - public provideHover(resource: URI, position: IPosition): TPromise { + public provideHover(resource: URI, position: IPosition, token: CancellationToken): Thenable { let doc = this._documents.getDocumentData(resource).document; let pos = typeConvert.Position.to(position); - return asWinJsPromise(token => this._provider.provideHover(doc, pos, token)).then(value => { + return asThenable(this._provider.provideHover(doc, pos, token)).then(value => { if (!value || isFalsyOrEmpty(value.contents)) { return undefined; } @@ -924,6 +925,14 @@ export class ExtHostLanguageFeatures implements ExtHostLanguageFeaturesShape { return callback(adapter); } + private _withAdapter2(handle: number, ctor: { new(...args: any[]): A }, callback: (adapter: A) => Thenable): Thenable { + let adapter = this._adapter.get(handle); + if (!(adapter instanceof ctor)) { + return TPromise.wrapError(new Error('no adapter found')); + } + return callback(adapter); + } + private _addNewAdapter(adapter: Adapter): number { const handle = this._nextHandle(); this._adapter.set(handle, adapter); @@ -1008,8 +1017,8 @@ export class ExtHostLanguageFeatures implements ExtHostLanguageFeaturesShape { return this._createDisposable(handle); } - $provideHover(handle: number, resource: UriComponents, position: IPosition): TPromise { - return this._withAdapter(handle, HoverAdapter, adpater => adpater.provideHover(URI.revive(resource), position)); + $provideHover(handle: number, resource: UriComponents, position: IPosition, token: CancellationToken): Thenable { + return this._withAdapter2(handle, HoverAdapter, adapter => adapter.provideHover(URI.revive(resource), position, token)); } // --- occurrences diff --git a/src/vs/workbench/services/extensions/node/rpcProtocol.ts b/src/vs/workbench/services/extensions/node/rpcProtocol.ts index 9a2ea7fe9a3..7327026dd0c 100644 --- a/src/vs/workbench/services/extensions/node/rpcProtocol.ts +++ b/src/vs/workbench/services/extensions/node/rpcProtocol.ts @@ -13,6 +13,7 @@ import { CharCode } from 'vs/base/common/charCode'; import { URI } from 'vs/base/common/uri'; import { MarshalledObject } from 'vs/base/common/marshalling'; import { IURITransformer } from 'vs/base/common/uriIpc'; +import { CancellationToken, CancellationTokenSource } from 'vs/base/common/cancellation'; declare var Proxy: any; // TODO@TypeScript @@ -104,7 +105,7 @@ export class RPCProtocol implements IRPCProtocol { private readonly _locals: any[]; private readonly _proxies: any[]; private _lastMessageId: number; - private readonly _invokedHandlers: { [req: string]: TPromise; }; + private readonly _cancelInvokedHandlers: { [req: string]: () => void; }; private readonly _pendingRPCReplies: { [msgId: string]: LazyPromise; }; constructor(protocol: IMessagePassingProtocol, logger: IRPCProtocolLogger = null, transformer: IURITransformer = null) { @@ -119,7 +120,7 @@ export class RPCProtocol implements IRPCProtocol { this._proxies[i] = null; } this._lastMessageId = 0; - this._invokedHandlers = Object.create(null); + this._cancelInvokedHandlers = Object.create(null); this._pendingRPCReplies = {}; this._protocol.onMessage((msg) => this._receiveOneMessage(msg)); } @@ -188,20 +189,22 @@ export class RPCProtocol implements IRPCProtocol { const req = buff.readUInt32(); switch (messageType) { - case MessageType.RequestJSONArgs: { + case MessageType.RequestJSONArgs: + case MessageType.RequestJSONArgsWithCancellation: { let { rpcId, method, args } = MessageIO.deserializeRequestJSONArgs(buff); if (this._uriTransformer) { args = transformIncomingURIs(args, this._uriTransformer); } - this._receiveRequest(msgLength, req, rpcId, method, args); + this._receiveRequest(msgLength, req, rpcId, method, args, (messageType === MessageType.RequestJSONArgsWithCancellation)); break; } - case MessageType.RequestMixedArgs: { + case MessageType.RequestMixedArgs: + case MessageType.RequestMixedArgsWithCancellation: { let { rpcId, method, args } = MessageIO.deserializeRequestMixedArgs(buff); if (this._uriTransformer) { args = transformIncomingURIs(args, this._uriTransformer); } - this._receiveRequest(msgLength, req, rpcId, method, args); + this._receiveRequest(msgLength, req, rpcId, method, args, (messageType === MessageType.RequestMixedArgsWithCancellation)); break; } case MessageType.Cancel: { @@ -240,16 +243,28 @@ export class RPCProtocol implements IRPCProtocol { } } - private _receiveRequest(msgLength: number, req: number, rpcId: number, method: string, args: any[]): void { + private _receiveRequest(msgLength: number, req: number, rpcId: number, method: string, args: any[], usesCancellationToken: boolean): void { if (this._logger) { this._logger.logIncoming(msgLength, req, RequestInitiator.OtherSide, `receiveRequest ${getStringIdentifierForProxy(rpcId)}.${method}(`, args); } const callId = String(req); - this._invokedHandlers[callId] = this._invokeHandler(rpcId, method, args); + let promise: TPromise; + let cancel: () => void; + if (usesCancellationToken) { + const cancellationTokenSource = new CancellationTokenSource(); + args.push(cancellationTokenSource.token); + promise = this._invokeHandler(rpcId, method, args); + cancel = () => cancellationTokenSource.cancel(); + } else { + promise = this._invokeHandler(rpcId, method, args); + cancel = () => promise.cancel(); + } - this._invokedHandlers[callId].then((r) => { - delete this._invokedHandlers[callId]; + this._cancelInvokedHandlers[callId] = cancel; + + promise.then((r) => { + delete this._cancelInvokedHandlers[callId]; if (this._uriTransformer) { r = transformOutgoingURIs(r, this._uriTransformer); } @@ -259,7 +274,7 @@ export class RPCProtocol implements IRPCProtocol { } this._protocol.send(msg); }, (err) => { - delete this._invokedHandlers[callId]; + delete this._cancelInvokedHandlers[callId]; const msg = MessageIO.serializeReplyErr(req, err); if (this._logger) { this._logger.logOutgoing(msg.byteLength, req, RequestInitiator.OtherSide, `replyErr:`, err); @@ -273,8 +288,8 @@ export class RPCProtocol implements IRPCProtocol { this._logger.logIncoming(msgLength, req, RequestInitiator.OtherSide, `receiveCancel`); } const callId = String(req); - if (this._invokedHandlers[callId]) { - this._invokedHandlers[callId].cancel(); + if (this._cancelInvokedHandlers[callId]) { + this._cancelInvokedHandlers[callId](); } } @@ -340,22 +355,41 @@ export class RPCProtocol implements IRPCProtocol { if (this._isDisposed) { return TPromise.wrapError(errors.canceled()); } + let cancellationToken: CancellationToken = null; + if (args.length > 0 && CancellationToken.isCancellationToken(args[args.length - 1])) { + cancellationToken = args.pop(); + } + + if (cancellationToken && cancellationToken.isCancellationRequested) { + // No need to do anything... + return TPromise.wrapError(errors.canceled()); + } + + if (cancellationToken && cancellationToken === CancellationToken.None) { + // This can never be canceled, so pretend we never even saw a cancelation token + cancellationToken = null; + } const req = ++this._lastMessageId; const callId = String(req); - const result = new LazyPromise(() => { + const sendCancel = () => { const msg = MessageIO.serializeCancel(req); if (this._logger) { this._logger.logOutgoing(msg.byteLength, req, RequestInitiator.LocalSide, `cancel`); } this._protocol.send(MessageIO.serializeCancel(req)); - }); + }; + const result = new LazyPromise(sendCancel); + + if (cancellationToken) { + cancellationToken.onCancellationRequested(sendCancel); + } this._pendingRPCReplies[callId] = result; if (this._uriTransformer) { args = transformOutgoingURIs(args, this._uriTransformer); } - const msg = MessageIO.serializeRequest(req, rpcId, methodName, args); + const msg = MessageIO.serializeRequest(req, rpcId, methodName, args, !!cancellationToken); if (this._logger) { this._logger.logOutgoing(msg.byteLength, req, RequestInitiator.LocalSide, `request: ${getStringIdentifierForProxy(rpcId)}.${methodName}(`, args); } @@ -513,7 +547,7 @@ class MessageIO { return false; } - public static serializeRequest(req: number, rpcId: number, method: string, args: any[]): Buffer { + public static serializeRequest(req: number, rpcId: number, method: string, args: any[], usesCancellationToken: boolean): Buffer { if (this._arrayContainsBuffer(args)) { let massagedArgs: (string | Buffer)[] = new Array(args.length); let argsLengths: number[] = new Array(args.length); @@ -527,12 +561,12 @@ class MessageIO { argsLengths[i] = Buffer.byteLength(massagedArgs[i], 'utf8'); } } - return this._requestMixedArgs(req, rpcId, method, massagedArgs, argsLengths); + return this._requestMixedArgs(req, rpcId, method, massagedArgs, argsLengths, usesCancellationToken); } - return this._requestJSONArgs(req, rpcId, method, JSON.stringify(args)); + return this._requestJSONArgs(req, rpcId, method, JSON.stringify(args), usesCancellationToken); } - private static _requestJSONArgs(req: number, rpcId: number, method: string, args: string): Buffer { + private static _requestJSONArgs(req: number, rpcId: number, method: string, args: string, usesCancellationToken: boolean): Buffer { const methodByteLength = Buffer.byteLength(method, 'utf8'); const argsByteLength = Buffer.byteLength(args, 'utf8'); @@ -541,7 +575,7 @@ class MessageIO { len += MessageBuffer.sizeShortString(method, methodByteLength); len += MessageBuffer.sizeLongString(args, argsByteLength); - let result = MessageBuffer.alloc(MessageType.RequestJSONArgs, req, len); + let result = MessageBuffer.alloc(usesCancellationToken ? MessageType.RequestJSONArgsWithCancellation : MessageType.RequestJSONArgs, req, len); result.writeUInt8(rpcId); result.writeShortString(method, methodByteLength); result.writeLongString(args, argsByteLength); @@ -559,7 +593,7 @@ class MessageIO { }; } - private static _requestMixedArgs(req: number, rpcId: number, method: string, args: (string | Buffer)[], argsLengths: number[]): Buffer { + private static _requestMixedArgs(req: number, rpcId: number, method: string, args: (string | Buffer)[], argsLengths: number[], usesCancellationToken: boolean): Buffer { const methodByteLength = Buffer.byteLength(method, 'utf8'); let len = 0; @@ -567,7 +601,7 @@ class MessageIO { len += MessageBuffer.sizeShortString(method, methodByteLength); len += MessageBuffer.sizeMixedArray(args, argsLengths); - let result = MessageBuffer.alloc(MessageType.RequestMixedArgs, req, len); + let result = MessageBuffer.alloc(usesCancellationToken ? MessageType.RequestMixedArgsWithCancellation : MessageType.RequestMixedArgs, req, len); result.writeUInt8(rpcId); result.writeShortString(method, methodByteLength); result.writeMixedArray(args, argsLengths); @@ -674,13 +708,15 @@ class MessageIO { const enum MessageType { RequestJSONArgs = 1, - RequestMixedArgs = 2, - Cancel = 3, - ReplyOKEmpty = 4, - ReplyOKBuffer = 5, - ReplyOKJSON = 6, - ReplyErrError = 7, - ReplyErrEmpty = 8, + RequestJSONArgsWithCancellation = 2, + RequestMixedArgs = 3, + RequestMixedArgsWithCancellation = 4, + Cancel = 5, + ReplyOKEmpty = 6, + ReplyOKBuffer = 7, + ReplyOKJSON = 8, + ReplyErrError = 9, + ReplyErrEmpty = 10, } const enum ArgType { diff --git a/src/vs/workbench/services/extensions/test/node/rpcProtocol.test.ts b/src/vs/workbench/services/extensions/test/node/rpcProtocol.test.ts index 602a7d37d62..7e5e49d8dc7 100644 --- a/src/vs/workbench/services/extensions/test/node/rpcProtocol.test.ts +++ b/src/vs/workbench/services/extensions/test/node/rpcProtocol.test.ts @@ -11,6 +11,7 @@ import { IMessagePassingProtocol } from 'vs/base/parts/ipc/node/ipc'; import { Event, Emitter } from 'vs/base/common/event'; import { ProxyIdentifier } from 'vs/workbench/services/extensions/node/proxyIdentifier'; import { TPromise } from 'vs/base/common/winjs.base'; +import { CancellationToken, CancellationTokenSource } from 'vs/base/common/cancellation'; suite('RPCProtocol', () => { @@ -116,6 +117,46 @@ suite('RPCProtocol', () => { p.cancel(); }); + test('cancelling a call via CancellationToken before', function (done) { + delegate = (a1: number, a2: number) => a1 + a2; + let p = bProxy.$m(4, CancellationToken.Cancelled); + p.then((res: number) => { + assert.fail('should not receive result'); + }, (err) => { + assert.ok(true); + done(null); + }); + }); + + test('passing CancellationToken.None', function (done) { + delegate = (a1: number, a2: number) => a1 + 1; + bProxy.$m(4, CancellationToken.None).then((res: number) => { + assert.equal(res, 5); + done(null); + }, done); + }); + + test('cancelling a call via CancellationToken quickly', function (done) { + // this is an implementation which, when cancellation is triggered, will return 7 + delegate = (a1: number, token: CancellationToken) => { + return new TPromise((resolve, reject) => { + token.onCancellationRequested((e) => { + resolve(7); + }); + }); + }; + let tokenSource = new CancellationTokenSource(); + let p = bProxy.$m(4, tokenSource.token); + p.then((res: number) => { + assert.equal(res, 7); + done(null); + }, (err) => { + assert.fail('should not receive error'); + done(); + }); + tokenSource.cancel(); + }); + test('throwing an error', function (done) { delegate = (a1: number, a2: number) => { throw new Error(`nope`);