lm: a second rendition of returning data from LM tools (#225634)

* lm: a second rendition of returning data from LM tools

This is an alternative to #225454. It allows the tool caller to pass
through token budget and counting information to the tool, and the
tool can then 'do its thing.'

Most of the actual implementation is in prompt-tsx with a new method to
render elements into a JSON-serializable form, and then splice them back
into the tree by the consumer. The implementation can be found here:

https://github.com/microsoft/vscode-prompt-tsx/tree/connor4312/tools-api-v2

On the tool side, this looks like:

```ts
vscode.lm.registerTool('myTestTool', {
	async invoke(context, token): Promise<vscode.LanguageModelToolResult> {
		return {
			// context includes the token info:
			'mytype': await renderElementJSON(MyCustomPrompt, {}, context, token),
			toString() {
				return 'hello world!'
			}
		};
	},
});
```

I didn't make any nice wrappers yet, but the MVP consumer side looks like:

```
export class TestPrompt extends PromptElement {
    async render(_state: void, sizing: PromptSizing) {
        const result = await vscode.lm.invokeTool('myTestTool', {
			parameters: {},
			tokenBudget: sizing.tokenBudget,
            countTokens: (v, token) => tokenizer.countTokens(v, token),
		}, new vscode.CancellationTokenSource().token);

        return (
            <>
                <elementJSON data={result.mytype} />
            </>
        );
    }
}
```

I like this approach better. It avoids bleeding knowledge of TSX into
the extension host and comparatively simple.

* address comments

* address comments
This commit is contained in:
Connor Peet
2024-08-22 09:41:31 -07:00
committed by GitHub
parent bf52a5cfb2
commit 2bf25ee2fd
8 changed files with 130 additions and 33 deletions
@@ -246,7 +246,7 @@ const _allApiProposals = {
},
lmTools: {
proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.lmTools.d.ts',
version: 4
version: 5
},
mappedEditsProvider: {
proposal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.mappedEditsProvider.d.ts',
@@ -6,7 +6,8 @@
import { CancellationToken } from 'vs/base/common/cancellation';
import { Disposable, DisposableMap } from 'vs/base/common/lifecycle';
import { ExtHostLanguageModelToolsShape, ExtHostContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol';
import { IToolData, ILanguageModelToolsService, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels';
import { IToolData, ILanguageModelToolsService, IToolResult, IToolInvokation, CountTokensCallback } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { IExtHostContext, extHostNamedCustomer } from 'vs/workbench/services/extensions/common/extHostCustomers';
@extHostNamedCustomer(MainContext.MainThreadLanguageModelTools)
@@ -14,6 +15,7 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre
private readonly _proxy: ExtHostLanguageModelToolsShape;
private readonly _tools = this._register(new DisposableMap<string>());
private readonly _countTokenCallbacks = new Map</* call ID */string, CountTokensCallback>();
constructor(
extHostContext: IExtHostContext,
@@ -29,16 +31,34 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre
return Array.from(this._languageModelToolsService.getTools());
}
$invokeTool(id: string, parameters: any, token: CancellationToken): Promise<IToolResult> {
return this._languageModelToolsService.invokeTool(id, parameters, token);
$invokeTool(dto: IToolInvokation, token: CancellationToken): Promise<IToolResult> {
return this._languageModelToolsService.invokeTool(
dto,
(input, token) => this._proxy.$countTokensForInvokation(dto.callId, input, token),
token,
);
}
$countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise<number> {
const fn = this._countTokenCallbacks.get(callId);
if (!fn) {
throw new Error(`Tool invokation call ${callId} not found`);
}
return fn(input, token);
}
$registerTool(name: string): void {
const disposable = this._languageModelToolsService.registerToolImplementation(
name,
{
invoke: async (parameters, token) => {
return await this._proxy.$invokeTool(name, parameters, token);
invoke: async (dto, countTokens, token) => {
try {
this._countTokenCallbacks.set(dto.callId, countTokens);
return await this._proxy.$invokeTool(dto, token);
} finally {
this._countTokenCallbacks.delete(dto.callId);
}
},
});
this._tools.set(name, disposable);
@@ -1489,7 +1489,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
checkProposedApiEnabled(extension, 'lmTools');
return extHostLanguageModelTools.registerTool(extension, toolId, tool);
},
invokeTool(toolId: string, parameters: Object, token: vscode.CancellationToken) {
invokeTool(toolId: string, parameters: vscode.LanguageModelToolInvokationOptions, token: vscode.CancellationToken) {
checkProposedApiEnabled(extension, 'lmTools');
return extHostLanguageModelTools.invokeTool(toolId, parameters, token);
},
@@ -56,7 +56,7 @@ import { IChatProgressResponseContent } from 'vs/workbench/contrib/chat/common/c
import { ChatAgentVoteDirection, IChatFollowup, IChatProgress, IChatResponseErrorDetails, IChatTask, IChatTaskDto, IChatUserActionEvent } from 'vs/workbench/contrib/chat/common/chatService';
import { IChatRequestVariableValue, IChatVariableData, IChatVariableResolverProgress } from 'vs/workbench/contrib/chat/common/chatVariables';
import { IChatMessage, IChatResponseFragment, ILanguageModelChatMetadata, ILanguageModelChatSelector, ILanguageModelsChangeEvent } from 'vs/workbench/contrib/chat/common/languageModels';
import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { DebugConfigurationProviderTriggerKind, IAdapterDescriptor, IConfig, IDebugSessionReplMode, IDebugTestRunReference, IDebugVisualization, IDebugVisualizationContext, IDebugVisualizationTreeItem, MainThreadDebugVisualization } from 'vs/workbench/contrib/debug/common/debug';
import * as notebookCommon from 'vs/workbench/contrib/notebook/common/notebookCommon';
import { CellExecutionUpdateType } from 'vs/workbench/contrib/notebook/common/notebookExecutionService';
@@ -1313,7 +1313,8 @@ export interface MainThreadChatVariablesShape extends IDisposable {
export interface MainThreadLanguageModelToolsShape extends IDisposable {
$getTools(): Promise<Dto<IToolData>[]>;
$invokeTool(name: string, parameters: any, token: CancellationToken): Promise<IToolResult>;
$invokeTool(dto: IToolInvokation, token: CancellationToken): Promise<IToolResult>;
$countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise<number>;
$registerTool(id: string): void;
$unregisterTool(name: string): void;
}
@@ -1326,7 +1327,8 @@ export interface ExtHostChatVariablesShape {
export interface ExtHostLanguageModelToolsShape {
$acceptToolDelta(delta: IToolDelta): Promise<void>;
$invokeTool(id: string, parameters: any, token: CancellationToken): Promise<IToolResult>;
$invokeTool(dto: IToolInvokation, token: CancellationToken): Promise<IToolResult>;
$countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise<number>;
}
export interface MainThreadUrlsShape extends IDisposable {
@@ -3,19 +3,24 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { raceCancellation } from 'vs/base/common/async';
import { CancellationToken } from 'vs/base/common/cancellation';
import { CancellationError } from 'vs/base/common/errors';
import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
import { revive } from 'vs/base/common/marshalling';
import { generateUuid } from 'vs/base/common/uuid';
import { IExtensionDescription } from 'vs/platform/extensions/common/extensions';
import { ExtHostLanguageModelToolsShape, IMainContext, MainContext, MainThreadLanguageModelToolsShape } from 'vs/workbench/api/common/extHost.protocol';
import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters';
import { IToolData, IToolDelta, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels';
import { IToolData, IToolDelta, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import type * as vscode from 'vscode';
export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape {
/** A map of tools that were registered in this EH */
private readonly _registeredTools = new Map<string, { extension: IExtensionDescription; tool: vscode.LanguageModelTool }>();
private readonly _proxy: MainThreadLanguageModelToolsShape;
private readonly _tokenCountFuncs = new Map</* call ID */string, (text: string | vscode.LanguageModelChatMessage, token?: vscode.CancellationToken) => Thenable<number>>();
/** A map of all known tools, from other EHs or registered in vscode core */
private readonly _allTools = new Map<string, IToolData>();
@@ -30,10 +35,32 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
});
}
async invokeTool(id: string, parameters: any, token: CancellationToken): Promise<vscode.LanguageModelToolResult> {
// Making the round trip here because not all tools were necessarily registered in this EH
const result = await this._proxy.$invokeTool(id, parameters, token);
return typeConvert.LanguageModelToolResult.to(result);
async $countTokensForInvokation(callId: string, input: string | IChatMessage, token: CancellationToken): Promise<number> {
const fn = this._tokenCountFuncs.get(callId);
if (!fn) {
throw new Error(`Tool invokation call ${callId} not found`);
}
return await fn(typeof input === 'string' ? input : typeConvert.LanguageModelChatMessage.to(input), token);
}
async invokeTool(toolId: string, options: vscode.LanguageModelToolInvokationOptions, token: CancellationToken): Promise<vscode.LanguageModelToolResult> {
const callId = generateUuid();
if (options.tokenOptions) {
this._tokenCountFuncs.set(callId, options.tokenOptions.countTokens);
}
try {
// Making the round trip here because not all tools were necessarily registered in this EH
const result = await this._proxy.$invokeTool({
toolId,
callId,
parameters: options.parameters,
tokenBudget: options.tokenOptions?.tokenBudget,
}, token);
return typeConvert.LanguageModelToolResult.to(result);
} finally {
this._tokenCountFuncs.delete(callId);
}
}
async $acceptToolDelta(delta: IToolDelta): Promise<void> {
@@ -51,13 +78,26 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape
.map(tool => typeConvert.LanguageModelToolDescription.to(tool));
}
async $invokeTool(name: string, parameters: any, token: CancellationToken): Promise<IToolResult> {
const item = this._registeredTools.get(name);
async $invokeTool(dto: IToolInvokation, token: CancellationToken): Promise<IToolResult> {
const item = this._registeredTools.get(dto.toolId);
if (!item) {
throw new Error(`Unknown tool ${name}`);
throw new Error(`Unknown tool ${dto.toolId}`);
}
const options: vscode.LanguageModelToolInvokationOptions = { parameters: dto.parameters };
if (dto.tokenBudget !== undefined) {
options.tokenOptions = {
tokenBudget: dto.tokenBudget,
countTokens: this._tokenCountFuncs.get(dto.callId) || ((value, token = CancellationToken.None) =>
this._proxy.$countTokensForInvokation(dto.callId, typeof value === 'string' ? value : typeConvert.LanguageModelChatMessage.from(value), token))
};
}
const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token)), token);
if (!extensionResult) {
throw new CancellationError();
}
const extensionResult = await item.tool.invoke(parameters, token);
return typeConvert.LanguageModelToolResult.from(extensionResult);
}
@@ -11,6 +11,7 @@ import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
import { ThemeIcon } from 'vs/base/common/themables';
import { URI } from 'vs/base/common/uri';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
import { IChatMessage } from 'vs/workbench/contrib/chat/common/languageModels';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export interface IToolData {
@@ -29,13 +30,20 @@ interface IToolEntry {
impl?: IToolImpl;
}
export interface IToolInvokation {
callId: string;
toolId: string;
parameters: any;
tokenBudget?: number;
}
export interface IToolResult {
[contentType: string]: any;
string: string;
}
export interface IToolImpl {
invoke(parameters: any, token: CancellationToken): Promise<IToolResult>;
invoke(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult>;
}
export const ILanguageModelToolsService = createDecorator<ILanguageModelToolsService>('ILanguageModelToolsService');
@@ -45,6 +53,8 @@ export interface IToolDelta {
removed?: string;
}
export type CountTokensCallback = (input: string | IChatMessage, token: CancellationToken) => Promise<number>;
export interface ILanguageModelToolsService {
_serviceBrand: undefined;
onDidChangeTools: Event<IToolDelta>;
@@ -53,7 +63,7 @@ export interface ILanguageModelToolsService {
getTools(): Iterable<Readonly<IToolData>>;
getTool(id: string): IToolData | undefined;
getToolByName(name: string): IToolData | undefined;
invokeTool(name: string, parameters: any, token: CancellationToken): Promise<IToolResult>;
invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult>;
}
export class LanguageModelToolsService implements ILanguageModelToolsService {
@@ -117,22 +127,22 @@ export class LanguageModelToolsService implements ILanguageModelToolsService {
return undefined;
}
async invokeTool(id: string, parameters: any, token: CancellationToken): Promise<IToolResult> {
let tool = this._tools.get(id);
async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult> {
let tool = this._tools.get(dto.toolId);
if (!tool) {
throw new Error(`Tool ${id} was not contributed`);
throw new Error(`Tool ${dto.toolId} was not contributed`);
}
if (!tool.impl) {
await this._extensionService.activateByEvent(`onLanguageModelTool:${id}`);
await this._extensionService.activateByEvent(`onLanguageModelTool:${dto.toolId}`);
// Extension should activate and register the tool implementation
tool = this._tools.get(id);
tool = this._tools.get(dto.toolId);
if (!tool?.impl) {
throw new Error(`Tool ${id} does not have an implementation registered.`);
throw new Error(`Tool ${dto.toolId} does not have an implementation registered.`);
}
}
return tool.impl.invoke(parameters, token);
return tool.impl.invoke(dto, countTokens, token);
}
}
@@ -6,7 +6,7 @@
import { CancellationToken } from 'vs/base/common/cancellation';
import { Event } from 'vs/base/common/event';
import { Disposable, IDisposable } from 'vs/base/common/lifecycle';
import { ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
import { CountTokensCallback, ILanguageModelToolsService, IToolData, IToolDelta, IToolImpl, IToolInvokation, IToolResult } from 'vs/workbench/contrib/chat/common/languageModelToolsService';
export class MockLanguageModelToolsService implements ILanguageModelToolsService {
_serviceBrand: undefined;
@@ -35,7 +35,7 @@ export class MockLanguageModelToolsService implements ILanguageModelToolsService
return undefined;
}
async invokeTool(name: string, parameters: any, token: CancellationToken): Promise<IToolResult> {
async invokeTool(dto: IToolInvokation, countTokens: CountTokensCallback, token: CancellationToken): Promise<IToolResult> {
return {
string: ''
};
+28 -3
View File
@@ -3,7 +3,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
// version: 4
// version: 5
// https://github.com/microsoft/vscode/issues/213274
declare module 'vscode' {
@@ -91,7 +91,32 @@ declare module 'vscode' {
* Invoke a tool with the given parameters.
* TODO@API Could request a set of contentTypes to be returned so they don't all need to be computed?
*/
export function invokeTool(id: string, parameters: Object, token: CancellationToken): Thenable<LanguageModelToolResult>;
export function invokeTool(id: string, options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable<LanguageModelToolResult>;
}
export interface LanguageModelToolInvokationOptions {
/**
* Parameters with which to invoke the tool.
*/
parameters: Object;
/**
* Options to hint at how many tokens the tool should return in its response.
*/
tokenOptions?: {
/**
* If known, the maximum number of tokens the tool should emit in its result.
*/
tokenBudget: number;
/**
* Count the number of tokens in a message using the model specific tokenizer-logic.
* @param text A string or a message instance.
* @param token Optional cancellation token. See {@link CancellationTokenSource} for how to create one.
* @returns A thenable that resolves to the number of tokens.
*/
countTokens(text: string | LanguageModelChatMessage, token?: CancellationToken): Thenable<number>;
};
}
export type JSONSchema = object;
@@ -120,7 +145,7 @@ declare module 'vscode' {
export interface LanguageModelTool {
// TODO@API should it be LanguageModelToolResult | string?
invoke(parameters: any, token: CancellationToken): Thenable<LanguageModelToolResult>;
invoke(options: LanguageModelToolInvokationOptions, token: CancellationToken): Thenable<LanguageModelToolResult>;
}
export interface ChatLanguageModelToolReference {