diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index e6d5b024f66..befec49c93b 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -213,7 +213,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I const extHostUriOpeners = rpcProtocol.set(ExtHostContext.ExtHostUriOpeners, new ExtHostUriOpeners(rpcProtocol)); const extHostProfileContentHandlers = rpcProtocol.set(ExtHostContext.ExtHostProfileContentHandlers, new ExtHostProfileContentHandlers(rpcProtocol)); rpcProtocol.set(ExtHostContext.ExtHostInteractive, new ExtHostInteractive(rpcProtocol, extHostNotebook, extHostDocumentsAndEditors, extHostCommands, extHostLogService)); - const extHostLanguageModelTools = rpcProtocol.set(ExtHostContext.ExtHostLanguageModelTools, new ExtHostLanguageModelTools(rpcProtocol)); + const extHostLanguageModelTools = rpcProtocol.set(ExtHostContext.ExtHostLanguageModelTools, new ExtHostLanguageModelTools(rpcProtocol, extHostLanguageModels)); const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostLogService, extHostCommands, extHostDocuments, extHostLanguageModels, extHostDiagnostics, extHostLanguageModelTools)); const extHostAiRelatedInformation = rpcProtocol.set(ExtHostContext.ExtHostAiRelatedInformation, new ExtHostRelatedInformation(rpcProtocol)); const extHostAiEmbeddingVector = rpcProtocol.set(ExtHostContext.ExtHostAiEmbeddingVector, new ExtHostAiEmbeddingVector(rpcProtocol)); diff --git a/src/vs/workbench/api/common/extHostLanguageModelTools.ts b/src/vs/workbench/api/common/extHostLanguageModelTools.ts index 78453ce7d16..0e7b81d42cf 100644 --- a/src/vs/workbench/api/common/extHostLanguageModelTools.ts +++ b/src/vs/workbench/api/common/extHostLanguageModelTools.ts @@ -18,6 +18,7 @@ import * as typeConvert from './extHostTypeConverters.js'; import { InternalFetchWebPageToolId, IToolInputProcessor } from '../../contrib/chat/common/tools/tools.js'; import { EditToolData, InternalEditToolId, EditToolInputProcessor, ExtensionEditToolId } from '../../contrib/chat/common/tools/editFileTool.js'; import { Dto } from '../../services/extensions/common/proxyIdentifier.js'; +import { ExtHostLanguageModels } from './extHostLanguageModels.js'; export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape { /** A map of tools that were registered in this EH */ @@ -30,7 +31,10 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape private readonly _toolInputProcessors = new Map(); - constructor(mainContext: IMainContext) { + constructor( + mainContext: IMainContext, + private readonly _languageModels: ExtHostLanguageModels, + ) { this._proxy = mainContext.getProxy(MainContext.MainThreadLanguageModelTools); this._proxy.$getTools().then(tools => { @@ -125,6 +129,10 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape } } + if (isProposedApiEnabled(item.extension, 'chatParticipantAdditions') && dto.modelId) { + options.model = await this.getModel(dto.modelId, item.extension); + } + if (dto.tokenBudget !== undefined) { options.tokenizationOptions = { tokenBudget: dto.tokenBudget, @@ -141,6 +149,21 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape return typeConvert.LanguageModelToolResult.from(extensionResult, item.extension); } + private async getModel(modelId: string, extension: IExtensionDescription): Promise { + let model: vscode.LanguageModelChat | undefined; + if (modelId) { + model = await this._languageModels.getLanguageModelByIdentifier(extension, modelId); + } + if (!model) { + model = await this._languageModels.getDefaultLanguageModel(extension); + if (!model) { + throw new Error('Language model unavailable'); + } + } + + return model; + } + async $prepareToolInvocation(toolId: string, input: any, token: CancellationToken): Promise { const item = this._registeredTools.get(toolId); if (!item) { diff --git a/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts index ac439089218..2761b05d679 100644 --- a/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts @@ -212,6 +212,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo const request = model.getRequests().at(-1)!; requestId = request.id; + dto.modelId = request.modelId; // Replace the token with a new token that we can cancel when cancelToolCallsForRequest is called if (!this._callsByRequestId.has(requestId)) { diff --git a/src/vs/workbench/contrib/chat/common/chatModel.ts b/src/vs/workbench/contrib/chat/common/chatModel.ts index 8a18efef803..870ff0d48a2 100644 --- a/src/vs/workbench/contrib/chat/common/chatModel.ts +++ b/src/vs/workbench/contrib/chat/common/chatModel.ts @@ -383,6 +383,7 @@ export class ChatRequestModel implements IChatRequestModel { private _locationData?: IChatLocationData, private _attachedContext?: IChatRequestVariableEntry[], public readonly isCompleteAddedRequest = false, + public readonly modelId?: string, restoredId?: string, ) { this.id = restoredId ?? 'request_' + generateUuid(); @@ -1338,7 +1339,7 @@ export class ChatModel extends Disposable implements IChatModel { // Old messages don't have variableData, or have it in the wrong (non-array) shape const variableData: IChatRequestVariableData = this.reviveVariableData(raw.variableData); - const request = new ChatRequestModel(this, parsedRequest, variableData, raw.timestamp ?? -1, undefined, undefined, undefined, undefined, undefined, raw.requestId); + const request = new ChatRequestModel(this, parsedRequest, variableData, raw.timestamp ?? -1, undefined, undefined, undefined, undefined, undefined, undefined, raw.requestId); request.shouldBeRemovedOnSend = raw.isHidden ? { requestId: raw.requestId } : raw.shouldBeRemovedOnSend; if (raw.response || raw.result || (raw as any).responseErrorDetails) { const agent = (raw.agent && 'metadata' in raw.agent) ? // Check for the new format, ignore entries in the old format @@ -1470,8 +1471,8 @@ export class ChatModel extends Disposable implements IChatModel { }); } - addRequest(message: IParsedChatRequest, variableData: IChatRequestVariableData, attempt: number, chatAgent?: IChatAgentData, slashCommand?: IChatAgentCommand, confirmation?: string, locationData?: IChatLocationData, attachments?: IChatRequestVariableEntry[], workingSet?: URI[], isCompleteAddedRequest?: boolean): ChatRequestModel { - const request = new ChatRequestModel(this, message, variableData, Date.now(), attempt, confirmation, locationData, attachments, isCompleteAddedRequest); + addRequest(message: IParsedChatRequest, variableData: IChatRequestVariableData, attempt: number, chatAgent?: IChatAgentData, slashCommand?: IChatAgentCommand, confirmation?: string, locationData?: IChatLocationData, attachments?: IChatRequestVariableEntry[], isCompleteAddedRequest?: boolean, modelId?: string): ChatRequestModel { + const request = new ChatRequestModel(this, message, variableData, Date.now(), attempt, confirmation, locationData, attachments, isCompleteAddedRequest, modelId); request.response = new ChatResponseModel([], this, chatAgent, slashCommand, request.id, undefined, undefined, undefined, undefined, undefined, undefined, isCompleteAddedRequest); this._requests.push(request); diff --git a/src/vs/workbench/contrib/chat/common/chatServiceImpl.ts b/src/vs/workbench/contrib/chat/common/chatServiceImpl.ts index deccb104d3b..f908cff7c01 100644 --- a/src/vs/workbench/contrib/chat/common/chatServiceImpl.ts +++ b/src/vs/workbench/contrib/chat/common/chatServiceImpl.ts @@ -739,7 +739,7 @@ export class ChatService extends Disposable implements IChatService { if (agentPart || (defaultAgent && !commandPart)) { const prepareChatAgentRequest = async (agent: IChatAgentData, command?: IChatAgentCommand, enableCommandDetection?: boolean, chatRequest?: ChatRequestModel, isParticipantDetected?: boolean): Promise => { const initVariableData: IChatRequestVariableData = { variables: [] }; - request = chatRequest ?? model.addRequest(parsedRequest, initVariableData, attempt, agent, command, options?.confirmation, options?.locationData, options?.attachedContext); + request = chatRequest ?? model.addRequest(parsedRequest, initVariableData, attempt, agent, command, options?.confirmation, options?.locationData, options?.attachedContext, undefined, options?.userSelectedModelId); let variableData: IChatRequestVariableData; let message: string; @@ -1044,7 +1044,7 @@ export class ChatService extends Disposable implements IChatService { const parsedRequest = typeof message === 'string' ? this.instantiationService.createInstance(ChatRequestParser).parseChatRequest(sessionId, message) : message; - const request = model.addRequest(parsedRequest, variableData || { variables: [] }, attempt ?? 0, undefined, undefined, undefined, undefined, undefined, undefined, true); + const request = model.addRequest(parsedRequest, variableData || { variables: [] }, attempt ?? 0, undefined, undefined, undefined, undefined, undefined, true); if (typeof response.message === 'string') { // TODO is this possible? model.acceptResponseProgress(request, { content: new MarkdownString(response.message), kind: 'markdownContent' }); diff --git a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts index 5c724835213..0e1909c202b 100644 --- a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts @@ -45,6 +45,7 @@ export interface IToolInvocation { chatRequestId?: string; chatInteractionId?: string; toolSpecificData?: IChatTerminalToolInvocationData | IChatToolInputInvocationData; + modelId?: string; } export interface IToolInvocationContext { diff --git a/src/vs/workbench/contrib/chat/test/common/chatModel.test.ts b/src/vs/workbench/contrib/chat/test/common/chatModel.test.ts index 9f5444ca5cd..cce4f5c0b98 100644 --- a/src/vs/workbench/contrib/chat/test/common/chatModel.test.ts +++ b/src/vs/workbench/contrib/chat/test/common/chatModel.test.ts @@ -160,7 +160,7 @@ suite('ChatModel', () => { model1.initialize(undefined); const text = 'hello'; - const request1 = model1.addRequest({ text, parts: [new ChatRequestTextPart(new OffsetRange(0, text.length), new Range(1, text.length, 1, text.length), text)] }, { variables: [] }, 0, undefined, undefined, undefined, undefined, undefined, undefined, true); + const request1 = model1.addRequest({ text, parts: [new ChatRequestTextPart(new OffsetRange(0, text.length), new Range(1, text.length, 1, text.length), text)] }, { variables: [] }, 0, undefined, undefined, undefined, undefined, undefined, true); assert.strictEqual(request1.isCompleteAddedRequest, true); assert.strictEqual(request1.response!.isCompleteAddedRequest, true); diff --git a/src/vscode-dts/vscode.proposed.chatParticipantAdditions.d.ts b/src/vscode-dts/vscode.proposed.chatParticipantAdditions.d.ts index e03a027d59b..2a33fc222f9 100644 --- a/src/vscode-dts/vscode.proposed.chatParticipantAdditions.d.ts +++ b/src/vscode-dts/vscode.proposed.chatParticipantAdditions.d.ts @@ -433,4 +433,8 @@ declare module 'vscode' { Medium = 2, Full = 3 } + + export interface LanguageModelToolInvocationOptions { + model?: LanguageModelChat; + } }