diff --git a/src/vs/platform/extensions/common/extensionsApiProposals.ts b/src/vs/platform/extensions/common/extensionsApiProposals.ts index 3fddef41a72..4eb66e7ea6e 100644 --- a/src/vs/platform/extensions/common/extensionsApiProposals.ts +++ b/src/vs/platform/extensions/common/extensionsApiProposals.ts @@ -246,7 +246,7 @@ const _allApiProposals = { }, lmTools: { proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.lmTools.d.ts', - version: 4 + version: 5 }, mappedEditsProvider: { proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.mappedEditsProvider.d.ts', diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts index 0d72268dd51..4925ab39f7e 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts @@ -6,7 +6,8 @@ import { CancellationToken } from 'vs/base/common/cancellation'; import { Disposable, DisposableMap } from 'vs/base/common/lifecycle'; import { ExtHostLanguageModelToolsShape, ExtHostContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol'; -import { IToolData, ILanguageModelToolsService, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; +import { IToolData, ILanguageModelToolsService, IToolResult, IToolInvokation, CountTokensCallback } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import { IExtHostContext, extHostNamedCustomer } from 'vs/workbench/services/extensions/common/extHostCustomers'; @extHostNamedCustomer(MainContext.MainThreadLanguageModelTools) @@ -14,6 +15,7 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre private readonly _proxy: ExtHostLanguageModelToolsShape; private readonly _tools = this._register(new DisposableMap()); + private readonly _countTokenCallbacks = new Map(); constructor( extHostContext: IExtHostContext, @@ -29,16 +31,34 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre return Array.from(this._languageModelToolsService.getTools()); } - $invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - return this._languageModelToolsService.invokeTool(id, parameters, token); + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise { + return this._languageModelToolsService.invokeTool( + dto, + (input, token) => this._proxy.$countTokensForInvokation(dto.callId, input, token), + token, + ); + } + + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise { + const fn = this._countTokenCallbacks.get(callId); + if (!fn) { + throw new Error(`Tool invokation call ${callId} not found`); + } + + return fn(input, token); } $registerTool(name: string): void { const disposable = this._languageModelToolsService.registerToolImplementation( name, { - invoke: async (parameters, token) => { - return await this._proxy.$invokeTool(name, parameters, token); + invoke: async (dto, countTokens, token) => { + try { + this._countTokenCallbacks.set(dto.callId, countTokens); + return await this._proxy.$invokeTool(dto, token); + } finally { + this._countTokenCallbacks.delete(dto.callId); + } }, }); this._tools.set(name, disposable); diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index cfafa785cc5..f8a167ebdf0 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -1489,7 +1489,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I checkProposedApiEnabled(extension, 'lmTools'); return extHostLanguageModelTools.registerTool(extension, toolId, tool); }, - invokeTool(toolId: string, parameters: Object, token: vscode.CancellationToken) { + invokeTool(toolId: string, parameters: vscode.LanguageModelToolInvokationOptions, token: vscode.CancellationToken) { checkProposedApiEnabled(extension, 'lmTools'); return extHostLanguageModelTools.invokeTool(toolId, parameters, token); }, diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index 9187129951c..d7e7e3c5930 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -56,7 +56,7 @@ import { IChatProgressResponseContent } from 'vs/workbench/contrib/chat/common/c import { ChatAgentVoteDirection, IChatFollowup, IChatProgress, IChatResponseErrorDetails, IChatTask, IChatTaskDto, IChatUserActionEvent } from 'vs/workbench/contrib/chat/common/chatService'; import { IChatRequestVariableValue, IChatVariableData, IChatVariableResolverProgress } from 'vs/workbench/contrib/chat/common/chatVariables'; import { IChatMessage, IChatResponseFragment, ILanguageModelChatMetadata, ILanguageModelChatSelector, ILanguageModelsChangeEvent } from 'vs/workbench/contrib/chat/common/languageModels'; -import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import { DebugConfigurationProviderTriggerKind, IAdapterDescriptor, IConfig, IDebugSessionReplMode, IDebugTestRunReference, IDebugVisualization, IDebugVisualizationContext, IDebugVisualizationTreeItem, MainThreadDebugVisualization } from 'vs/workbench/contrib/debug/common/debug'; import * as notebookCommon from 'vs/workbench/contrib/notebook/common/notebookCommon'; import { CellExecutionUpdateType } from 'vs/workbench/contrib/notebook/common/notebookExecutionService'; @@ -1313,7 +1313,8 @@ export interface MainThreadChatVariablesShape extends IDisposable { export interface MainThreadLanguageModelToolsShape extends IDisposable { $getTools(): Promise[]>; - $invokeTool(name: string, parameters: any, token: CancellationToken): Promise; + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise; + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise; $registerTool(id: string): void; $unregisterTool(name: string): void; } @@ -1326,7 +1327,8 @@ export interface ExtHostChatVariablesShape { export interface ExtHostLanguageModelToolsShape { $acceptToolDelta(delta: IToolDelta): Promise; - $invokeTool(id: string, parameters: any, token: CancellationToken): Promise; + $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise; + $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise; } export interface MainThreadUrlsShape extends IDisposable { diff --git a/src/vs/workbench/api/common/extHostLanguageModelTools.ts b/src/vs/workbench/api/common/extHostLanguageModelTools.ts index d65eefe4ec9..f46b950a74a 100644 --- a/src/vs/workbench/api/common/extHostLanguageModelTools.ts +++ b/src/vs/workbench/api/common/extHostLanguageModelTools.ts @@ -3,19 +3,24 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import { raceCancellation } from 'vs/base/common/async'; import { CancellationToken } from 'vs/base/common/cancellation'; +import { CancellationError } from 'vs/base/common/errors'; import { IDisposable, toDisposable } from 'vs/base/common/lifecycle'; import { revive } from 'vs/base/common/marshalling'; +import { generateUuid } from 'vs/base/common/uuid'; import { IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { ExtHostLanguageModelToolsShape, IMainContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol'; import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; -import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; +import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; import type * as vscode from 'vscode'; export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape { /** A map of tools that were registered in this EH */ private readonly _registeredTools = new Map(); private readonly _proxy: MainThreadLanguageModelToolsShape; + private readonly _tokenCountFuncs = new Map Thenable>(); /** A map of all known tools, from other EHs or registered in vscode core */ private readonly _allTools = new Map(); @@ -30,10 +35,32 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape }); } - async invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - // Making the round trip here because not all tools were necessarily registered in this EH - const result = await this._proxy.$invokeTool(id, parameters, token); - return typeConvert.LanguageModelToolResult.to(result); + async $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise { + const fn = this._tokenCountFuncs.get(callId); + if (!fn) { + throw new Error(`Tool invokation call ${callId} not found`); + } + + return await fn(typeof input === 'string' ? input : typeConvert.LanguageModelChatMessage.to(input), token); + } + + async invokeTool(toolId: string, options: vscode.LanguageModelToolInvokationOptions, token: CancellationToken): Promise { + const callId = generateUuid(); + if (options.tokenOptions) { + this._tokenCountFuncs.set(callId, options.tokenOptions.countTokens); + } + try { + // Making the round trip here because not all tools were necessarily registered in this EH + const result = await this._proxy.$invokeTool({ + toolId, + callId, + parameters: options.parameters, + tokenBudget: options.tokenOptions?.tokenBudget, + }, token); + return typeConvert.LanguageModelToolResult.to(result); + } finally { + this._tokenCountFuncs.delete(callId); + } } async $acceptToolDelta(delta: IToolDelta): Promise { @@ -51,13 +78,26 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape .map(tool => typeConvert.LanguageModelToolDescription.to(tool)); } - async $invokeTool(name: string, parameters: any, token: CancellationToken): Promise { - const item = this._registeredTools.get(name); + async $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise { + const item = this._registeredTools.get(dto.toolId); if (!item) { - throw new Error(`Unknown tool ${name}`); + throw new Error(`Unknown tool ${dto.toolId}`); + } + + const options: vscode.LanguageModelToolInvokationOptions = { parameters: dto.parameters }; + if (dto.tokenBudget !== undefined) { + options.tokenOptions = { + tokenBudget: dto.tokenBudget, + countTokens: this._tokenCountFuncs.get(dto.callId) || ((value, token = CancellationToken.None) => + this._proxy.$countTokensForInvokation(dto.callId, typeof value === 'string' ? value : typeConvert.LanguageModelChatMessage.from(value), token)) + }; + } + + const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token)), token); + if (!extensionResult) { + throw new CancellationError(); } - const extensionResult = await item.tool.invoke(parameters, token); return typeConvert.LanguageModelToolResult.from(extensionResult); } diff --git a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts index 70a0d44d930..9e97ebb5eb6 100644 --- a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts @@ -11,6 +11,7 @@ import { IDisposable, toDisposable } from 'vs/base/common/lifecycle'; import { ThemeIcon } from 'vs/base/common/themables'; import { URI } from 'vs/base/common/uri'; import { createDecorator } from 'vs/platform/instantiation/common/instantiation'; +import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels'; import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; export interface IToolData { @@ -29,13 +30,20 @@ interface IToolEntry { impl?: IToolImpl; } +export interface IToolInvokation { + callId: string; + toolId: string; + parameters: any; + tokenBudget?: number; +} + export interface IToolResult { [contentType: string]: any; string: string; } export interface IToolImpl { - invoke(parameters: any, token: CancellationToken): Promise; + invoke(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise; } export const ILanguageModelToolsService = createDecorator('ILanguageModelToolsService'); @@ -45,6 +53,8 @@ export interface IToolDelta { removed?: string; } +export type CountTokensCallback = (input: string | IChatMessage, token: CancellationToken) => Promise; + export interface ILanguageModelToolsService { _serviceBrand: undefined; onDidChangeTools: Event; @@ -53,7 +63,7 @@ export interface ILanguageModelToolsService { getTools(): Iterable>; getTool(id: string): IToolData | undefined; getToolByName(name: string): IToolData | undefined; - invokeTool(name: string, parameters: any, token: CancellationToken): Promise; + invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise; } export class LanguageModelToolsService implements ILanguageModelToolsService { @@ -117,22 +127,22 @@ export class LanguageModelToolsService implements ILanguageModelToolsService { return undefined; } - async invokeTool(id: string, parameters: any, token: CancellationToken): Promise { - let tool = this._tools.get(id); + async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise { + let tool = this._tools.get(dto.toolId); if (!tool) { - throw new Error(`Tool ${id} was not contributed`); + throw new Error(`Tool ${dto.toolId} was not contributed`); } if (!tool.impl) { - await this._extensionService.activateByEvent(`onLanguageModelTool:${id}`); + await this._extensionService.activateByEvent(`onLanguageModelTool:${dto.toolId}`); // Extension should activate and register the tool implementation - tool = this._tools.get(id); + tool = this._tools.get(dto.toolId); if (!tool?.impl) { - throw new Error(`Tool ${id} does not have an implementation registered.`); + throw new Error(`Tool ${dto.toolId} does not have an implementation registered.`); } } - return tool.impl.invoke(parameters, token); + return tool.impl.invoke(dto, countTokens, token); } } diff --git a/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts b/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts index a764d8e88b9..5f2847ef579 100644 --- a/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/test/common/mockLanguageModelToolsService.ts @@ -6,7 +6,7 @@ import { CancellationToken } from 'vs/base/common/cancellation'; import { Event } from 'vs/base/common/event'; import { Disposable, IDisposable } from 'vs/base/common/lifecycle'; -import { ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; +import { CountTokensCallback, ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService'; export class MockLanguageModelToolsService implements ILanguageModelToolsService { _serviceBrand: undefined; @@ -35,7 +35,7 @@ export class MockLanguageModelToolsService implements ILanguageModelToolsService return undefined; } - async invokeTool(name: string, parameters: any, token: CancellationToken): Promise { + async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise { return { string: '' }; diff --git a/src/vscode-dts/vscode.proposed.lmTools.d.ts b/src/vscode-dts/vscode.proposed.lmTools.d.ts index 663aba13218..852e9328919 100644 --- a/src/vscode-dts/vscode.proposed.lmTools.d.ts +++ b/src/vscode-dts/vscode.proposed.lmTools.d.ts @@ -3,7 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -// version: 4 +// version: 5 // https://github.com/microsoft/vscode/issues/213274 declare module 'vscode' { @@ -91,7 +91,32 @@ declare module 'vscode' { * Invoke a tool with the given parameters. * TODO@API Could request a set of contentTypes to be returned so they don't all need to be computed? */ - export function invokeTool(id: string, parameters: Object, token: CancellationToken): Thenable; + export function invokeTool(id: string, options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable; + } + + export interface LanguageModelToolInvokationOptions { + /** + * Parameters with which to invoke the tool. + */ + parameters: Object; + + /** + * Options to hint at how many tokens the tool should return in its response. + */ + tokenOptions?: { + /** + * If known, the maximum number of tokens the tool should emit in its result. + */ + tokenBudget: number; + + /** + * Count the number of tokens in a message using the model specific tokenizer-logic. + * @param text A string or a message instance. + * @param token Optional cancellation token. See {@link CancellationTokenSource} for how to create one. + * @returns A thenable that resolves to the number of tokens. + */ + countTokens(text: string | LanguageModelChatMessage, token?: CancellationToken): Thenable; + }; } export type JSONSchema = object; @@ -120,7 +145,7 @@ declare module 'vscode' { export interface LanguageModelTool { // TODO@API should it be LanguageModelToolResult | string? - invoke(parameters: any, token: CancellationToken): Thenable; + invoke(options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable; } export interface ChatLanguageModelToolReference {