diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts index 88ab6e0e598..0a7db37edb5 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModelTools.ts @@ -6,7 +6,7 @@ import { CancellationToken } from '../../../base/common/cancellation.js'; import { Disposable, DisposableMap } from '../../../base/common/lifecycle.js'; import { revive } from '../../../base/common/marshalling.js'; -import { IProgressStep } from '../../../platform/progress/common/progress.js'; +import { IProgress, IProgressStep } from '../../../platform/progress/common/progress.js'; import { CountTokensCallback, ILanguageModelToolsService, IToolData, IToolInvocation, IToolResult } from '../../contrib/chat/common/languageModelToolsService.js'; import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js'; import { Dto } from '../../services/extensions/common/proxyIdentifier.js'; @@ -17,7 +17,10 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre private readonly _proxy: ExtHostLanguageModelToolsShape; private readonly _tools = this._register(new DisposableMap()); - private readonly _countTokenCallbacks = new Map(); + private readonly _runningToolCalls = new Map; + }>(); constructor( extHostContext: IExtHostContext, @@ -46,30 +49,30 @@ export class MainThreadLanguageModelTools extends Disposable implements MainThre }; } - $acceptToolProgress(requestId: string | undefined, callId: string, progress: IProgressStep): void { - this._languageModelToolsService.acceptProgress(requestId, callId, progress); + $acceptToolProgress(callId: string, progress: IProgressStep): void { + this._runningToolCalls.get(callId)?.progress.report(progress); } $countTokensForInvocation(callId: string, input: string, token: CancellationToken): Promise { - const fn = this._countTokenCallbacks.get(callId); + const fn = this._runningToolCalls.get(callId); if (!fn) { throw new Error(`Tool invocation call ${callId} not found`); } - return fn(input, token); + return fn.countTokens(input, token); } $registerTool(id: string): void { const disposable = this._languageModelToolsService.registerToolImplementation( id, { - invoke: async (dto, countTokens, _progress, token) => { + invoke: async (dto, countTokens, progress, token) => { try { - this._countTokenCallbacks.set(dto.callId, countTokens); + this._runningToolCalls.set(dto.callId, { countTokens, progress }); const resultDto = await this._proxy.$invokeTool(dto, token); return revive(resultDto) as IToolResult; } finally { - this._countTokenCallbacks.delete(dto.callId); + this._runningToolCalls.delete(dto.callId); } }, prepareToolInvocation: (parameters, token) => this._proxy.$prepareToolInvocation(id, parameters, token), diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index 3e50b7d456e..9eeabd64abd 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1380,7 +1380,7 @@ export type IToolDataDto = Omit; export interface MainThreadLanguageModelToolsShape extends IDisposable { $getTools(): Promise[]>; - $acceptToolProgress(requestId: string | undefined, callId: string, progress: IProgressStep): void; + $acceptToolProgress(callId: string, progress: IProgressStep): void; $invokeTool(dto: IToolInvocation, token?: CancellationToken): Promise>; $countTokensForInvocation(callId: string, input: string, token: CancellationToken): Promise; $registerTool(id: string): void; diff --git a/src/vs/workbench/api/common/extHostLanguageModelTools.ts b/src/vs/workbench/api/common/extHostLanguageModelTools.ts index 82de8a39f02..faaa208029e 100644 --- a/src/vs/workbench/api/common/extHostLanguageModelTools.ts +++ b/src/vs/workbench/api/common/extHostLanguageModelTools.ts @@ -138,18 +138,20 @@ export class ExtHostLanguageModelTools implements ExtHostLanguageModelToolsShape }; } - const progress: vscode.Progress<{ message?: string; increment?: number }> = { - report: value => { - checkProposedApiEnabled(item.extension, 'toolProgress'); - this._proxy.$acceptToolProgress(dto.chatRequestId, dto.callId, { - message: value.message, - increment: value.increment, - total: 100, - }); - } - }; + let progress: vscode.Progress<{ message?: string; increment?: number }> | undefined; + if (isProposedApiEnabled(item.extension, 'toolProgress')) { + progress = { + report: value => { + this._proxy.$acceptToolProgress(dto.callId, { + message: value.message, + increment: value.increment, + total: 100, + }); + } + }; + } - const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token, progress)), token); + const extensionResult = await raceCancellation(Promise.resolve(item.tool.invoke(options, token, progress!)), token); if (!extensionResult) { throw new CancellationError(); } diff --git a/src/vs/workbench/contrib/chat/browser/chatContentParts/chatToolInvocationPart.ts b/src/vs/workbench/contrib/chat/browser/chatContentParts/chatToolInvocationPart.ts index 20ca1f46cdb..24a8d787794 100644 --- a/src/vs/workbench/contrib/chat/browser/chatContentParts/chatToolInvocationPart.ts +++ b/src/vs/workbench/contrib/chat/browser/chatContentParts/chatToolInvocationPart.ts @@ -23,7 +23,6 @@ import { IContextKeyService } from '../../../../../platform/contextkey/common/co import { IInstantiationService } from '../../../../../platform/instantiation/common/instantiation.js'; import { IKeybindingService } from '../../../../../platform/keybinding/common/keybinding.js'; import { IMarkerData, IMarkerService, MarkerSeverity } from '../../../../../platform/markers/common/markers.js'; -import { IProgressService } from '../../../../../platform/progress/common/progress.js'; import { ChatContextKeys } from '../../common/chatContextKeys.js'; import { IChatMarkdownContent, IChatProgressMessage, IChatTerminalToolInvocationData, IChatToolInvocation, IChatToolInvocationSerialized } from '../../common/chatService.js'; import { IChatRendererContent } from '../../common/chatViewModel.js'; @@ -66,7 +65,6 @@ export class ChatToolInvocationPart extends Disposable implements IChatContentPa codeBlockModelCollection: CodeBlockModelCollection, codeBlockStartIndex: number, @IInstantiationService instantiationService: IInstantiationService, - @IProgressService progressService: IProgressService, ) { super(); diff --git a/src/vs/workbench/contrib/chat/browser/chatSetup.ts b/src/vs/workbench/contrib/chat/browser/chatSetup.ts index 7b15fa89f3f..2fb469b94f2 100644 --- a/src/vs/workbench/contrib/chat/browser/chatSetup.ts +++ b/src/vs/workbench/contrib/chat/browser/chatSetup.ts @@ -37,7 +37,7 @@ import { ILogService } from '../../../../platform/log/common/log.js'; import { IOpenerService } from '../../../../platform/opener/common/opener.js'; import product from '../../../../platform/product/common/product.js'; import { IProductService } from '../../../../platform/product/common/productService.js'; -import { IProgressService, ProgressLocation } from '../../../../platform/progress/common/progress.js'; +import { IProgress, IProgressService, IProgressStep, ProgressLocation } from '../../../../platform/progress/common/progress.js'; import { IQuickInputService } from '../../../../platform/quickinput/common/quickInput.js'; import { Registry } from '../../../../platform/registry/common/platform.js'; import { ITelemetryService, TelemetryLevel } from '../../../../platform/telemetry/common/telemetry.js'; @@ -518,7 +518,7 @@ class SetupTool extends Disposable implements IToolImpl { super(); } - invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise { + invoke(invocation: IToolInvocation, countTokens: CountTokensCallback, progress: IProgress, token: CancellationToken): Promise { const result: IToolResult = { content: [ { diff --git a/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts index 4a6b02c24b8..b21654e4d9e 100644 --- a/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/browser/languageModelToolsService.ts @@ -22,7 +22,6 @@ import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js'; import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js'; import * as JSONContributionRegistry from '../../../../platform/jsonschemas/common/jsonContributionRegistry.js'; import { ILogService } from '../../../../platform/log/common/log.js'; -import { IProgressStep } from '../../../../platform/progress/common/progress.js'; import { Registry } from '../../../../platform/registry/common/platform.js'; import { IStorageService, StorageScope, StorageTarget } from '../../../../platform/storage/common/storage.js'; import { ITelemetryService } from '../../../../platform/telemetry/common/telemetry.js'; @@ -41,7 +40,7 @@ interface IToolEntry { impl?: IToolImpl; } -interface TrackedCall { +interface ITrackedCall { invocation?: ChatToolInvocation; store: IDisposable; } @@ -59,7 +58,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo private _toolContextKeys = new Set(); private readonly _ctxToolsCount: IContextKey; - private _callsByRequestId = new Map(); + private _callsByRequestId = new Map(); private _workspaceToolConfirmStore: Lazy; private _profileToolConfirmStore: Lazy; @@ -97,17 +96,6 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo this._ctxToolsCount = ChatContextKeys.Tools.toolsCount.bindTo(_contextKeyService); } - acceptProgress(sessionId: string | undefined, callId: string, progress: IProgressStep): void { - if (!sessionId) { - return; // not supported, yet - } - - this._callsByRequestId.get(sessionId) - ?.find(call => call.invocation?.toolCallId === callId) - ?.invocation - ?.acceptProgress(progress); - } - registerToolData(toolData: IToolData): IDisposable { if (this._tools.has(toolData.id)) { throw new Error(`Tool "${toolData.id}" is already registered.`); @@ -252,7 +240,7 @@ export class LanguageModelToolsService extends Disposable implements ILanguageMo if (!this._callsByRequestId.has(requestId)) { this._callsByRequestId.set(requestId, []); } - const trackedCall: TrackedCall = { store }; + const trackedCall: ITrackedCall = { store }; this._callsByRequestId.get(requestId)!.push(trackedCall); const source = new CancellationTokenSource(); diff --git a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts index f78950c320f..75332b21169 100644 --- a/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts +++ b/src/vs/workbench/contrib/chat/common/languageModelToolsService.ts @@ -152,7 +152,6 @@ export interface ILanguageModelToolsService { getTool(id: string): IToolData | undefined; getToolByName(name: string): IToolData | undefined; invokeTool(invocation: IToolInvocation, countTokens: CountTokensCallback, token: CancellationToken): Promise; - acceptProgress(sessionId: string | undefined, callId: string, progress: IProgressStep): void; setToolAutoConfirmation(toolId: string, scope: 'workspace' | 'profile' | 'memory', autoConfirm?: boolean): void; resetToolAutoConfirmation(): void; cancelToolCallsForRequest(requestId: string): void; diff --git a/src/vscode-dts/vscode.proposed.toolProgress.d.ts b/src/vscode-dts/vscode.proposed.toolProgress.d.ts index 8deea607747..51562cd6ee3 100644 --- a/src/vscode-dts/vscode.proposed.toolProgress.d.ts +++ b/src/vscode-dts/vscode.proposed.toolProgress.d.ts @@ -6,9 +6,9 @@ declare module 'vscode' { /** - * todo@connor4312: `vscode.window.withProgres` can take this interface as well. + * A progress update during an {@link LanguageModelTool.invoke} call. */ - export interface ProgressStep { + export interface ToolProgressStep { /** * A progress message that represents a chunk of work */ @@ -20,6 +20,6 @@ declare module 'vscode' { } export interface LanguageModelTool { - invoke(options: LanguageModelToolInvocationOptions, token: CancellationToken, progress: Progress): ProviderResult; + invoke(options: LanguageModelToolInvocationOptions, token: CancellationToken, progress: Progress): ProviderResult; } }