From 2bf25ee2fdabcbf36caf71ffe1b4d0d360ea4505 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Thu, 22 Aug 2024 09:41:31 -0700 Subject: [PATCH] lm: a second rendition of returning data from LM tools (#225634) * lm: a second rendition of returning data from LM tools This is an alternative to #225454. It allows the tool caller to pass through token budget and counting information to the tool, and the tool can then 'do its thing.' Most of the actual implementation is in prompt-tsx with a new method to render elements into a JSON-serializable form, and then splice them back into the tree by the consumer. The implementation can be found here: https://github.com/microsoft/vscode-prompt-tsx/tree/connor4312/tools-api-v2 On the tool side, this looks like: ```ts vscode.lm.registerTool('myTestTool', { async invoke(context, token): Promise { return { // context includes the token info: 'mytype': await renderElementJSON(MyCustomPrompt, {}, context, token), toString() { return 'hello world!' } }; }, }); ``` I didn't make any nice wrappers yet, but the MVP consumer side looks like: ``` export class TestPrompt extends PromptElement { async render(_state: void, sizing: PromptSizing) { const result = await vscode.lm.invokeTool('myTestTool', { parameters: {}, tokenBudget: sizing.tokenBudget, countTokens: (v, token) => tokenizer.countTokens(v, token), }, new vscode.CancellationTokenSource().token); return ( <> ); } } ``` I like this approach better. It avoids bleeding knowledge of TSX into the extension host and comparatively simple. * address comments * address comments --- .../common/extensionsApiProposals.ts | 2 +- .../browser/mainThreadLanguageModelTools.ts | 30 ++++++++-- .../workbench/api/common/extHost.api.impl.ts | 2 +- .../workbench/api/common/extHost.protocol.ts | 8 ++- .../api/common/extHostLanguageModelTools.ts | 58 ++++++++++++++++--- .../chat/common/languageModelToolsService.ts | 28 ++++++--- .../common/mockLanguageModelToolsService.ts | 4 +- src/vscode-dts/vscode.proposed.lmTools.d.ts | 31 +++++++++- 8 files changed, 130 insertions(+), 33 deletions(-) diff --git a/src/vs/platform/extensions/common/extensionsApiProposals.ts b/src/vs/platform/extensions/common/extensionsApiProposals.ts index 3fddef41a72..4eb66e7ea6e 100644 --- a/src/vs/platform/extensions/common/extensionsApiProposals.ts +++ b/src/vs/platform/extensions/common/extensionsApiProposals.ts @@ -246,7 +246,7 @@ const _allApiProposals = { }, lmTools: { proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.lmTools.d.ts', - version: 4 + version: 5 }, mappedEditsProvider: { proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.mappedEditsProvider.d.ts', diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts index 0d72268dd51..4925ab39f7e 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts @@ -6,7 +6,8 @@ import { CancellationToken } from 'vs/base/common/cancellation'; import { Disposable, DisposableMap } from 'vs/base/common/lifecycle'; import { ExtHostLanguageModelToolsShape, ExtHostContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol'; -import { IToolData, ILanguageModelToolsService, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; +import { IToolData, ILanguageModelToolsService, IToolResult, IToolInvokation, CountTokensCallback } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import { IExtHostContext, extHostNamedCustomer } from 'vs/workbench/services/extensions/common/extHostCustomers'; @extHostNamedCustomer(MainContext.MainThreadLanguageModelTools) @@ -14,6 +15,7 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre private readonly _proxy: ExtHostLanguageModelToolsShape; private readonly _tools = this._register(new DisposableMap()); + private readonly _countTokenCallbacks = new Map(); constructor( extHostContext: IExtHostContext, @@ -29,16 +31,34 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre return Array.from(this._languageModelToolsService.getTools()); } - $invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - return this._languageModelToolsService.invokeTool(id, parameters, token); + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise { + return this._languageModelToolsService.invokeTool( + dto, + (input, token) => this._proxy.$countTokensForInvokation(dto.callId, input, token), + token, + ); + } + + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise { + const fn = this._countTokenCallbacks.get(callId); + if (!fn) { + throw new Error(`Tool invokation call ${callId} not found`); + } + + return fn(input, token); } $registerTool(name: string): void { const disposable = this._languageModelToolsService.registerToolImplementation( name, { - invoke: async (parameters, token) => { - return await this._proxy.$invokeTool(name, parameters, token); + invoke: async (dto, countTokens, token) => { + try { + this._countTokenCallbacks.set(dto.callId, countTokens); + return await this._proxy.$invokeTool(dto, token); + } finally { + this._countTokenCallbacks.delete(dto.callId); + } }, }); this._tools.set(name, disposable); diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index cfafa785cc5..f8a167ebdf0 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -1489,7 +1489,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I checkProposedApiEnabled(extension, 'lmTools'); return extHostLanguageModelTools.registerTool(extension, toolId, tool); }, - invokeTool(toolId: string, parameters: Object, token: vscode.CancellationToken) { + invokeTool(toolId: string, parameters: vscode.LanguageModelToolInvokationOptions, token: vscode.CancellationToken) { checkProposedApiEnabled(extension, 'lmTools'); return extHostLanguageModelTools.invokeTool(toolId, parameters, token); }, diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index 9187129951c..d7e7e3c5930 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -56,7 +56,7 @@ import { IChatProgressResponseContent } from 'vs/workbench/contrib/chat/common/c import { ChatAgentVoteDirection, IChatFollowup, IChatProgress, IChatResponseErrorDetails, IChatTask, IChatTaskDto, IChatUserActionEvent } from 'vs/workbench/contrib/chat/common/chatService'; import { IChatRequestVariableValue, IChatVariableData, IChatVariableResolverProgress } from 'vs/workbench/contrib/chat/common/chatVariables'; import { IChatMessage, IChatResponseFragment, ILanguageModelChatMetadata, ILanguageModelChatSelector, ILanguageModelsChangeEvent } from 'vs/workbench/contrib/chat/common/languageModels'; -import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import { DebugConfigurationProviderTriggerKind, IAdapterDescriptor, IConfig, IDebugSessionReplMode, IDebugTestRunReference, IDebugVisualization, IDebugVisualizationContext, IDebugVisualizationTreeItem, MainThreadDebugVisualization } from 'vs/workbench/contrib/debug/common/debug'; import * as notebookCommon from 'vs/workbench/contrib/notebook/common/notebookCommon'; import { CellExecutionUpdateType } from 'vs/workbench/contrib/notebook/common/notebookExecutionService'; @@ -1313,7 +1313,8 @@ export interface MainThreadChatVariablesShape extends IDisposable { export interface MainThreadLanguageModelToolsShape extends IDisposable { $getTools(): Promise[]>; - $invokeTool(name: string, parameters: any, token: CancellationToken): Promise; + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise; + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise; $registerTool(id: string): void; $unregisterTool(name: string): void; } @@ -1326,7 +1327,8 @@ export interface ExtHostChatVariablesShape { export interface ExtHostLanguageModelToolsShape { $acceptToolDelta(delta: IToolDelta): Promise; - $invokeTool(id: string, parameters: any, token: CancellationToken): Promise; + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise; + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise; } export interface MainThreadUrlsShape extends IDisposable { diff --git a/src/vs/workbench/api/common/extHostLanguageModelTools.ts b/src/vs/workbench/api/common/extHostLanguageModelTools.ts index d65eefe4ec9..f46b950a74a 100644 --- a/src/vs/workbench/api/common/extHostLanguageModelTools.ts +++ b/src/vs/workbench/api/common/extHostLanguageModelTools.ts @@ -3,19 +3,24 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import { raceCancellation } from 'vs/base/common/async'; import { CancellationToken } from 'vs/base/common/cancellation'; +import { CancellationError } from 'vs/base/common/errors'; import { IDisposable, toDisposable } from 'vs/base/common/lifecycle'; import { revive } from 'vs/base/common/marshalling'; +import { generateUuid } from 'vs/base/common/uuid'; import { IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { ExtHostLanguageModelToolsShape, IMainContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol'; import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; -import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; +import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import type * as vscode from 'vscode'; export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape { /** A map of tools that were registered in this EH */ private readonly _registeredTools = new Map(); private readonly _proxy: MainThreadLanguageModelToolsShape; + private readonly _tokenCountFuncs = new Map Thenable>(); /** A map of all known tools, from other EHs or registered in vscode core */ private readonly _allTools = new Map(); @@ -30,10 +35,32 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape }); } - async invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - // Making the round trip here because not all tools were necessarily registered in this EH - const result = await this._proxy.$invokeTool(id, parameters, token); - return typeConvert.LanguageModelToolResult.to(result); + async $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise { + const fn = this._tokenCountFuncs.get(callId); + if (!fn) { + throw new Error(`Tool invokation call ${callId} not found`); + } + + return await fn(typeof input === 'string' ? input : typeConvert.LanguageModelChatMessage.to(input), token); + } + + async invokeTool(toolId: string, options: vscode.LanguageModelToolInvokationOptions, token: CancellationToken): Promise { + const callId = generateUuid(); + if (options.tokenOptions) { + this._tokenCountFuncs.set(callId, options.tokenOptions.countTokens); + } + try { + // Making the round trip here because not all tools were necessarily registered in this EH + const result = await this._proxy.$invokeTool({ + toolId, + callId, + parameters: options.parameters, + tokenBudget: options.tokenOptions?.tokenBudget, + }, token); + return typeConvert.LanguageModelToolResult.to(result); + } finally { + this._tokenCountFuncs.delete(callId); + } } async $acceptToolDelta(delta: IToolDelta): Promise { @@ -51,13 +78,26 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape .map(tool => typeConvert.LanguageModelToolDescription.to(tool)); } - async $invokeTool(name: string, parameters: any, token: CancellationToken): Promise { - const item = this._registeredTools.get(name); + async $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise { + const item = this._registeredTools.get(dto.toolId); if (!item) { - throw new Error(`Unknown tool ${name}`); + throw new Error(`Unknown tool ${dto.toolId}`); + } + + const options: vscode.LanguageModelToolInvokationOptions = { parameters: dto.parameters }; + if (dto.tokenBudget !== undefined) { + options.tokenOptions = { + tokenBudget: dto.tokenBudget, + countTokens: this._tokenCountFuncs.get(dto.callId) || ((value, token = CancellationToken.None) => + this._proxy.$countTokensForInvokation(dto.callId, typeof value === 'string' ? value : typeConvert.LanguageModelChatMessage.from(value), token)) + }; + } + + const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token)), token); + if (!extensionResult) { + throw new CancellationError(); } - const extensionResult = await item.tool.invoke(parameters, token); return typeConvert.LanguageModelToolResult.from(extensionResult); } diff --git a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts index 70a0d44d930..9e97ebb5eb6 100644 --- a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts @@ -11,6 +11,7 @@ import { IDisposable, toDisposable } from 'vs/base/common/lifecycle'; import { ThemeIcon } from 'vs/base/common/themables'; import { URI } from 'vs/base/common/uri'; import { createDecorator } from 'vs/platform/instantiation/common/instantiation'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; export interface IToolData { @@ -29,13 +30,20 @@ interface IToolEntry { impl?: IToolImpl; } +export interface IToolInvokation { + callId: string; + toolId: string; + parameters: any; + tokenBudget?: number; +} + export interface IToolResult { [contentType: string]: any; string: string; } export interface IToolImpl { - invoke(parameters: any, token: CancellationToken): Promise; + invoke(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise; } export const ILanguageModelToolsService = createDecorator('ILanguageModelToolsService'); @@ -45,6 +53,8 @@ export interface IToolDelta { removed?: string; } +export type CountTokensCallback = (input: string | IChatMessage, token: CancellationToken) => Promise; + export interface ILanguageModelToolsService { _serviceBrand: undefined; onDidChangeTools: Event; @@ -53,7 +63,7 @@ export interface ILanguageModelToolsService { getTools(): Iterable>; getTool(id: string): IToolData | undefined; getToolByName(name: string): IToolData | undefined; - invokeTool(name: string, parameters: any, token: CancellationToken): Promise; + invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise; } export class LanguageModelToolsService implements ILanguageModelToolsService { @@ -117,22 +127,22 @@ export class LanguageModelToolsService implements ILanguageModelToolsService { return undefined; } - async invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - let tool = this._tools.get(id); + async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise { + let tool = this._tools.get(dto.toolId); if (!tool) { - throw new Error(`Tool ${id} was not contributed`); + throw new Error(`Tool ${dto.toolId} was not contributed`); } if (!tool.impl) { - await this._extensionService.activateByEvent(`onLanguageModelTool:${id}`); + await this._extensionService.activateByEvent(`onLanguageModelTool:${dto.toolId}`); // Extension should activate and register the tool implementation - tool = this._tools.get(id); + tool = this._tools.get(dto.toolId); if (!tool?.impl) { - throw new Error(`Tool ${id} does not have an implementation registered.`); + throw new Error(`Tool ${dto.toolId} does not have an implementation registered.`); } } - return tool.impl.invoke(parameters, token); + return tool.impl.invoke(dto, countTokens, token); } } diff --git a/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts b/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts index a764d8e88b9..5f2847ef579 100644 --- a/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts @@ -6,7 +6,7 @@ import { CancellationToken } from 'vs/base/common/cancellation'; import { Event } from 'vs/base/common/event'; import { Disposable, IDisposable } from 'vs/base/common/lifecycle'; -import { ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { CountTokensCallback, ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; export class MockLanguageModelToolsService implements ILanguageModelToolsService { _serviceBrand: undefined; @@ -35,7 +35,7 @@ export class MockLanguageModelToolsService implements ILanguageModelToolsService return undefined; } - async invokeTool(name: string, parameters: any, token: CancellationToken): Promise { + async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise { return { string: '' }; diff --git a/src/vscode-dts/vscode.proposed.lmTools.d.ts b/src/vscode-dts/vscode.proposed.lmTools.d.ts index 663aba13218..852e9328919 100644 --- a/src/vscode-dts/vscode.proposed.lmTools.d.ts +++ b/src/vscode-dts/vscode.proposed.lmTools.d.ts @@ -3,7 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -// version: 4 +// version: 5 // https://github.com/microsoft/vscode/issues/213274 declare module 'vscode' { @@ -91,7 +91,32 @@ declare module 'vscode' { * Invoke a tool with the given parameters. * TODO@API Could request a set of contentTypes to be returned so they don't all need to be computed? */ - export function invokeTool(id: string, parameters: Object, token: CancellationToken): Thenable; + export function invokeTool(id: string, options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable; + } + + export interface LanguageModelToolInvokationOptions { + /** + * Parameters with which to invoke the tool. + */ + parameters: Object; + + /** + * Options to hint at how many tokens the tool should return in its response. + */ + tokenOptions?: { + /** + * If known, the maximum number of tokens the tool should emit in its result. + */ + tokenBudget: number; + + /** + * Count the number of tokens in a message using the model specific tokenizer-logic. + * @param text A string or a message instance. + * @param token Optional cancellation token. See {@link CancellationTokenSource} for how to create one. + * @returns A thenable that resolves to the number of tokens. + */ + countTokens(text: string | LanguageModelChatMessage, token?: CancellationToken): Thenable; + }; } export type JSONSchema = object; @@ -120,7 +145,7 @@ declare module 'vscode' { export interface LanguageModelTool { // TODO@API should it be LanguageModelToolResult | string? - invoke(parameters: any, token: CancellationToken): Thenable; + invoke(options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable; } export interface ChatLanguageModelToolReference {