diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts index 34e79d3282e..f7da7c62939 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts @@ -103,11 +103,11 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { this._lmProviderChange.fire({ vendor }); } - async $reportResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise { + async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers): Promise { const data = this._pendingProgress.get(requestId); this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk); if (data) { - data.stream.emitOne(chunk); + data.stream.emitOne(chunk.value); } } @@ -154,7 +154,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { try { for await (const part of response.stream) { this._logService.trace('[CHAT] request PART', extension.value, requestId, part); - await this._proxy.$acceptResponsePart(requestId, part); + await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part)); } this._logService.trace('[CHAT] request DONE', extension.value, requestId); } catch (err) { diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index e4db6637d56..6dde9438d86 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1262,7 +1262,7 @@ export interface MainThreadLanguageModelsShape extends IDisposable { $onLMProviderChange(vendor: string): void; $unregisterProvider(vendor: string): void; $tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers, options: {}, token: CancellationToken): Promise; - $reportResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise; + $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers): Promise; $reportResponseDone(requestId: number, error: SerializedError | undefined): Promise; $selectChatModels(selector: ILanguageModelChatSelector): Promise; $countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise; @@ -1275,7 +1275,7 @@ export interface ExtHostLanguageModelsShape { $prepareLanguageModelProvider(vendor: string, options: { silent: boolean }, token: CancellationToken): Promise; $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void; $startChatRequest(modelId: string, requestId: number, from: ExtensionIdentifier, messages: SerializableObjectWithBuffers, options: { [name: string]: any }, token: CancellationToken): Promise; - $acceptResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise; + $acceptResponsePart(requestId: number, chunk: SerializableObjectWithBuffers): Promise; $acceptResponseDone(requestId: number, error: SerializedError | undefined): Promise; $provideTokenLength(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise; $isFileIgnored(handle: number, uri: UriComponents, token: CancellationToken): Promise; diff --git a/src/vs/workbench/api/common/extHostLanguageModels.ts b/src/vs/workbench/api/common/extHostLanguageModels.ts index 4e886255ddd..4b8863a58d0 100644 --- a/src/vs/workbench/api/common/extHostLanguageModels.ts +++ b/src/vs/workbench/api/common/extHostLanguageModels.ts @@ -5,6 +5,7 @@ import type * as vscode from 'vscode'; import { AsyncIterableObject, AsyncIterableSource, RunOnceScheduler } from '../../../base/common/async.js'; +import { VSBuffer } from '../../../base/common/buffer.js'; import { CancellationToken } from '../../../base/common/cancellation.js'; import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js'; import { Emitter, Event } from '../../../base/common/event.js'; @@ -16,17 +17,16 @@ import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IE import { createDecorator } from '../../../platform/instantiation/common/instantiation.js'; import { ILogService } from '../../../platform/log/common/log.js'; import { Progress } from '../../../platform/progress/common/progress.js'; -import { ChatImageMimeType, IChatMessage, IChatResponseFragment, IChatResponsePart, ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../contrib/chat/common/languageModels.js'; +import { IChatMessage, IChatResponseFragment, IChatResponsePart, ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../contrib/chat/common/languageModels.js'; +import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../contrib/chat/common/modelPicker/modelPickerWidget.js'; import { INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js'; import { checkProposedApiEnabled } from '../../services/extensions/common/extensions.js'; +import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js'; import { ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from './extHost.protocol.js'; import { IExtHostAuthentication } from './extHostAuthentication.js'; import { IExtHostRpcService } from './extHostRpcService.js'; import * as typeConvert from './extHostTypeConverters.js'; import * as extHostTypes from './extHostTypes.js'; -import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js'; -import { VSBuffer } from '../../../base/common/buffer.js'; -import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../contrib/chat/common/modelPicker/modelPickerWidget.js'; export interface IExtHostLanguageModels extends ExtHostLanguageModels { } @@ -38,15 +38,17 @@ type LanguageModelProviderData = { readonly provider: vscode.LanguageModelChatProvider2; }; +type LMResponsePart = vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart | vscode.LanguageModelDataPart; + class LanguageModelResponseStream { - readonly stream = new AsyncIterableSource(); + readonly stream = new AsyncIterableSource(); constructor( readonly option: number, - stream?: AsyncIterableSource + stream?: AsyncIterableSource ) { - this.stream = stream ?? new AsyncIterableSource(); + this.stream = stream ?? new AsyncIterableSource(); } } @@ -55,7 +57,7 @@ class LanguageModelResponse { readonly apiObject: vscode.LanguageModelChatResponse; private readonly _responseStreams = new Map(); - private readonly _defaultStream = new AsyncIterableSource(); + private readonly _defaultStream = new AsyncIterableSource(); private _isDone: boolean = false; constructor() { @@ -93,15 +95,15 @@ class LanguageModelResponse { return; } - const partsByIndex = new Map(); + const partsByIndex = new Map(); for (const fragment of Iterable.wrap(fragments)) { - let out: vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart; + let out: LMResponsePart; if (fragment.part.type === 'text') { out = new extHostTypes.LanguageModelTextPart(fragment.part.value, fragment.part.audience); } else if (fragment.part.type === 'data') { - out = new extHostTypes.LanguageModelTextPart(''); + out = new extHostTypes.LanguageModelDataPart(fragment.part.data.buffer, fragment.part.mimeType, fragment.part.audience); } else { out = new extHostTypes.LanguageModelToolCallPart(fragment.part.toolCallId, fragment.part.name, fragment.part.parameters); } @@ -270,7 +272,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { const queue: IChatResponseFragment[] = []; const sendNow = () => { if (queue.length > 0) { - this._proxy.$reportResponsePart(requestId, queue); + this._proxy.$reportResponsePart(requestId, new SerializableObjectWithBuffers(queue)); queue.length = 0; } }; @@ -298,7 +300,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { } else if (fragment.part instanceof extHostTypes.LanguageModelTextPart) { part = { type: 'text', value: fragment.part.value, audience: fragment.part.audience }; } else if (fragment.part instanceof extHostTypes.LanguageModelDataPart) { - part = { type: 'data', value: { mimeType: fragment.part.mimeType as ChatImageMimeType, data: VSBuffer.wrap(fragment.part.data) }, audience: fragment.part.audience }; + part = { type: 'data', mimeType: fragment.part.mimeType, data: VSBuffer.wrap(fragment.part.data), audience: fragment.part.audience }; } if (!part) { @@ -482,10 +484,10 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { return internalMessages; } - async $acceptResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise { + async $acceptResponsePart(requestId: number, chunk: SerializableObjectWithBuffers): Promise { const data = this._pendingRequest.get(requestId); if (data) { - data.res.handleFragment(chunk); + data.res.handleFragment(chunk.value); } } diff --git a/src/vs/workbench/api/common/extHostTypeConverters.ts b/src/vs/workbench/api/common/extHostTypeConverters.ts index 2166c0da478..ea38c9a83ad 100644 --- a/src/vs/workbench/api/common/extHostTypeConverters.ts +++ b/src/vs/workbench/api/common/extHostTypeConverters.ts @@ -2318,13 +2318,15 @@ export namespace LanguageModelChatMessage { if (c.type === 'text') { return new LanguageModelTextPart(c.value, c.audience); } else if (c.type === 'tool_result') { - const content: (LanguageModelTextPart | LanguageModelPromptTsxPart)[] = c.value.map(part => { + const content: (LanguageModelTextPart | LanguageModelPromptTsxPart)[] = coalesce(c.value.map(part => { if (part.type === 'text') { return new types.LanguageModelTextPart(part.value, part.audience); - } else { + } else if (part.type === 'prompt_tsx') { return new types.LanguageModelPromptTsxPart(part.value); + } else { + return undefined; // Strip unknown parts } - }); + })); return new types.LanguageModelToolResultPart(c.toolCallId, content, c.isError); } else if (c.type === 'image_url') { // Non-stable types @@ -2418,7 +2420,7 @@ export namespace LanguageModelChatMessage2 { if (part.type === 'text') { return new types.LanguageModelTextPart(part.value, part.audience); } else if (part.type === 'data') { - return new types.LanguageModelDataPart(part.value.data.buffer, part.value.mimeType); + return new types.LanguageModelDataPart(part.data.buffer, part.mimeType); } else { return new types.LanguageModelPromptTsxPart(part.value); } @@ -2467,10 +2469,8 @@ export namespace LanguageModelChatMessage2 { } else if (part instanceof types.LanguageModelDataPart) { return { type: 'data', - value: { - mimeType: part.mimeType as chatProvider.ChatImageMimeType, - data: VSBuffer.wrap(part.data) - }, + mimeType: part.mimeType, + data: VSBuffer.wrap(part.data), audience: part.audience } satisfies IChatResponseDataPart; } else { diff --git a/src/vs/workbench/contrib/chat/common/languageModels.ts b/src/vs/workbench/contrib/chat/common/languageModels.ts index 2538c1efc41..92bf0570c41 100644 --- a/src/vs/workbench/contrib/chat/common/languageModels.ts +++ b/src/vs/workbench/contrib/chat/common/languageModels.ts @@ -111,7 +111,8 @@ export interface IChatResponsePromptTsxPart { export interface IChatResponseDataPart { type: 'data'; - value: IChatImageURLPart; + mimeType: string; + data: VSBuffer; audience?: LanguageModelPartAudience[]; } diff --git a/src/vscode-dts/vscode.proposed.chatProvider.d.ts b/src/vscode-dts/vscode.proposed.chatProvider.d.ts index 18e53e0a1cf..c6a58e900e9 100644 --- a/src/vscode-dts/vscode.proposed.chatProvider.d.ts +++ b/src/vscode-dts/vscode.proposed.chatProvider.d.ts @@ -139,6 +139,6 @@ declare module 'vscode' { export interface ChatResponseFragment2 { index: number; - part: LanguageModelTextPart | LanguageModelToolCallPart; + part: LanguageModelTextPart | LanguageModelToolCallPart | LanguageModelDataPart; } }