diff --git a/src/vs/workbench/contrib/chat/browser/widget/input/chatInputPart.ts b/src/vs/workbench/contrib/chat/browser/widget/input/chatInputPart.ts index 27a57bc5aa1..2a8af368253 100644 --- a/src/vs/workbench/contrib/chat/browser/widget/input/chatInputPart.ts +++ b/src/vs/workbench/contrib/chat/browser/widget/input/chatInputPart.ts @@ -624,9 +624,9 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge this._register(autorun(r => { const mode = this._currentModeObservable.read(r); this.chatModeKindKey.set(mode.kind); - const model = mode.model?.read(r); - if (model) { - this.switchModelByQualifiedName(model); + const models = mode.model?.read(r); + if (models) { + this.switchModelByQualifiedName(models); } })); @@ -722,17 +722,20 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge } } - public switchModelByQualifiedName(qualifiedModelName: string): boolean { + public switchModelByQualifiedName(qualifiedModelNames: readonly string[]): boolean { const models = this.getModels(); - const model = models.find(m => ILanguageModelChatMetadata.matchesQualifiedName(qualifiedModelName, m.metadata)); - if (model) { - this.setCurrentLanguageModel(model); - return true; + for (const qualifiedModelName of qualifiedModelNames) { + const model = models.find(m => ILanguageModelChatMetadata.matchesQualifiedName(qualifiedModelName, m.metadata)); + if (model) { + this.setCurrentLanguageModel(model); + return true; + } } - this.logService.warn(`[chat] Model "${qualifiedModelName}" not found. Use format " ()", e.g. "GPT-4o (copilot)".`); + this.logService.warn(`[chat] Node of the models "${qualifiedModelNames.join(', ')}" not found. Use format " ()", e.g. "GPT-4o (copilot)".`); return false; } + public switchToNextModel(): void { const models = this.getModels(); if (models.length > 0) { diff --git a/src/vs/workbench/contrib/chat/common/chatModes.ts b/src/vs/workbench/contrib/chat/common/chatModes.ts index e41fa96ae97..3199f25b9af 100644 --- a/src/vs/workbench/contrib/chat/common/chatModes.ts +++ b/src/vs/workbench/contrib/chat/common/chatModes.ts @@ -23,6 +23,7 @@ import { IHandOff } from './promptSyntax/promptFileParser.js'; import { ExtensionAgentSourceType, IAgentSource, ICustomAgent, IPromptsService, PromptsStorage } from './promptSyntax/service/promptsService.js'; import { ThemeIcon } from '../../../../base/common/themables.js'; import { Codicon } from '../../../../base/common/codicons.js'; +import { isString } from '../../../../base/common/types.js'; export const IChatModeService = createDecorator('chatModeService'); export interface IChatModeService { @@ -116,7 +117,7 @@ export class ChatModeService extends Disposable implements IChatModeService { name: cachedMode.name, description: cachedMode.description, tools: cachedMode.customTools, - model: cachedMode.model, + model: isString(cachedMode.model) ? [cachedMode.model] : cachedMode.model, argumentHint: cachedMode.argumentHint, agentInstructions: cachedMode.modeInstructions ?? { content: cachedMode.body ?? '', toolReferences: [] }, handOffs: cachedMode.handOffs, @@ -236,7 +237,7 @@ export interface IChatModeData { readonly description?: string; readonly kind: ChatModeKind; readonly customTools?: readonly string[]; - readonly model?: string; + readonly model?: readonly string[] | string; readonly argumentHint?: string; readonly modeInstructions?: IChatModeInstructions; readonly body?: string; /* deprecated */ @@ -258,7 +259,7 @@ export interface IChatMode { readonly kind: ChatModeKind; readonly customTools?: IObservable; readonly handOffs?: IObservable; - readonly model?: IObservable; + readonly model?: IObservable; readonly argumentHint?: IObservable; readonly modeInstructions?: IObservable; readonly uri?: IObservable; @@ -291,7 +292,7 @@ function isCachedChatModeData(data: unknown): data is IChatModeData { (mode.description === undefined || typeof mode.description === 'string') && (mode.customTools === undefined || Array.isArray(mode.customTools)) && (mode.modeInstructions === undefined || (typeof mode.modeInstructions === 'object' && mode.modeInstructions !== null)) && - (mode.model === undefined || typeof mode.model === 'string') && + (mode.model === undefined || typeof mode.model === 'string' || Array.isArray(mode.model)) && (mode.argumentHint === undefined || typeof mode.argumentHint === 'string') && (mode.handOffs === undefined || Array.isArray(mode.handOffs)) && (mode.uri === undefined || (typeof mode.uri === 'object' && mode.uri !== null)) && @@ -307,7 +308,7 @@ export class CustomChatMode implements IChatMode { private readonly _customToolsObservable: ISettableObservable; private readonly _modeInstructions: ISettableObservable; private readonly _uriObservable: ISettableObservable; - private readonly _modelObservable: ISettableObservable; + private readonly _modelObservable: ISettableObservable; private readonly _argumentHintObservable: ISettableObservable; private readonly _handoffsObservable: ISettableObservable; private readonly _targetObservable: ISettableObservable; @@ -337,7 +338,7 @@ export class CustomChatMode implements IChatMode { return this._customToolsObservable; } - get model(): IObservable { + get model(): IObservable { return this._modelObservable; } diff --git a/src/vs/workbench/contrib/chat/common/languageModels.ts b/src/vs/workbench/contrib/chat/common/languageModels.ts index f1f0ad87618..e6a883096fe 100644 --- a/src/vs/workbench/contrib/chat/common/languageModels.ts +++ b/src/vs/workbench/contrib/chat/common/languageModels.ts @@ -297,6 +297,11 @@ export interface ILanguageModelsService { lookupLanguageModel(modelId: string): ILanguageModelChatMetadata | undefined; + /** + * Find a model by its qualified name. The qualified name is what is used in prompt and agent files and is in the format "Model Name (Vendor)". + */ + lookupLanguageModelByQualifiedName(qualifiedName: string): ILanguageModelChatMetadata | undefined; + getLanguageModelGroups(vendor: string): ILanguageModelsGroup[]; /** @@ -637,6 +642,15 @@ export class LanguageModelsService implements ILanguageModelsService { return model; } + lookupLanguageModelByQualifiedName(referenceName: string): ILanguageModelChatMetadata | undefined { + for (const model of this._modelCache.values()) { + if (ILanguageModelChatMetadata.matchesQualifiedName(referenceName, model)) { + return model; + } + } + return undefined; + } + private async _resolveAllLanguageModels(vendorId: string, silent: boolean): Promise { const vendor = this._vendors.get(vendorId); diff --git a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderAutocompletion.ts b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderAutocompletion.ts index 5cabc3ec5ed..798f943d5b9 100644 --- a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderAutocompletion.ts +++ b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHeaderAutocompletion.ts @@ -158,6 +158,13 @@ export class PromptHeaderAutocompletion implements CompletionItemProvider { } if (promptType === PromptsType.prompt || promptType === PromptsType.agent) { + if (attribute.key === PromptHeaderAttributes.model) { + if (attribute.value.type === 'array') { + // if the position is inside the tools metadata, we provide tool name completions + const getValues = async () => this.getModelNames(promptType === PromptsType.agent); + return this.provideArrayCompletions(model, position, attribute, getValues); + } + } if (attribute.key === PromptHeaderAttributes.tools) { if (attribute.value.type === 'array') { // if the position is inside the tools metadata, we provide tool name completions diff --git a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHovers.ts b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHovers.ts index 2a2826a0407..a4a4d7f0d1b 100644 --- a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHovers.ts +++ b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptHovers.ts @@ -10,7 +10,7 @@ import { Range } from '../../../../../../editor/common/core/range.js'; import { Hover, HoverContext, HoverProvider } from '../../../../../../editor/common/languages.js'; import { ITextModel } from '../../../../../../editor/common/model.js'; import { localize } from '../../../../../../nls.js'; -import { ILanguageModelChatMetadata, ILanguageModelsService } from '../../languageModels.js'; +import { ILanguageModelsService } from '../../languageModels.js'; import { ILanguageModelToolsService, isToolSet, IToolSet } from '../../tools/languageModelToolsService.js'; import { IChatModeService, isBuiltinChatMode } from '../../chatModes.js'; import { getPromptsTypeForLanguageId, PromptsType } from '../promptTypes.js'; @@ -106,7 +106,7 @@ export class PromptHoverProvider implements HoverProvider { case PromptHeaderAttributes.argumentHint: return this.createHover(localize('promptHeader.agent.argumentHint', 'The argument-hint describes what inputs the custom agent expects or supports.'), attribute.range); case PromptHeaderAttributes.model: - return this.getModelHover(attribute, attribute.range, localize('promptHeader.agent.model', 'Specify the model that runs this custom agent.'), isGithubTarget(promptType, header.target)); + return this.getModelHover(attribute, position, localize('promptHeader.agent.model', 'Specify the model that runs this custom agent. Can also be a list of models. The first available model will be used.'), isGithubTarget(promptType, header.target)); case PromptHeaderAttributes.tools: return this.getToolHover(attribute, position, localize('promptHeader.agent.tools', 'The set of tools that the custom agent has access to.')); case PromptHeaderAttributes.handOffs: @@ -132,7 +132,7 @@ export class PromptHoverProvider implements HoverProvider { case PromptHeaderAttributes.argumentHint: return this.createHover(localize('promptHeader.prompt.argumentHint', 'The argument-hint describes what inputs the prompt expects or supports.'), attribute.range); case PromptHeaderAttributes.model: - return this.getModelHover(attribute, attribute.range, localize('promptHeader.prompt.model', 'The model to use in this prompt.'), false); + return this.getModelHover(attribute, position, localize('promptHeader.prompt.model', 'The model to use in this prompt. Can also be a list of models. The first available model will be used.'), false); case PromptHeaderAttributes.tools: return this.getToolHover(attribute, position, localize('promptHeader.prompt.tools', 'The tools to use in this prompt.')); case PromptHeaderAttributes.agent: @@ -184,27 +184,41 @@ export class PromptHoverProvider implements HoverProvider { return this.createHover(lines.join('\n'), range); } - private getModelHover(node: IHeaderAttribute, range: Range, baseMessage: string, isGitHubTarget: boolean): Hover | undefined { + private getModelHover(node: IHeaderAttribute, position: Position, baseMessage: string, isGitHubTarget: boolean): Hover | undefined { if (isGitHubTarget) { - return this.createHover(baseMessage + '\n\n' + localize('promptHeader.agent.model.githubCopilot', 'Note: This attribute is not used when target is github-copilot.'), range); + return this.createHover(baseMessage + '\n\n' + localize('promptHeader.agent.model.githubCopilot', 'Note: This attribute is not used when target is github-copilot.'), node.range); } + const modelHoverContent = (modelName: string): Hover | undefined => { + const meta = this.languageModelsService.lookupLanguageModelByQualifiedName(modelName); + if (meta) { + const lines: string[] = []; + lines.push(baseMessage + '\n'); + lines.push(localize('modelName', '- Name: {0}', meta.name)); + lines.push(localize('modelFamily', '- Family: {0}', meta.family)); + lines.push(localize('modelVendor', '- Vendor: {0}', meta.vendor)); + if (meta.tooltip) { + lines.push('', '', meta.tooltip); + } + return this.createHover(lines.join('\n'), node.range); + } + return undefined; + }; if (node.value.type === 'string') { - for (const id of this.languageModelsService.getLanguageModelIds()) { - const meta = this.languageModelsService.lookupLanguageModel(id); - if (meta && ILanguageModelChatMetadata.matchesQualifiedName(node.value.value, meta)) { - const lines: string[] = []; - lines.push(baseMessage + '\n'); - lines.push(localize('modelName', '- Name: {0}', meta.name)); - lines.push(localize('modelFamily', '- Family: {0}', meta.family)); - lines.push(localize('modelVendor', '- Vendor: {0}', meta.vendor)); - if (meta.tooltip) { - lines.push('', '', meta.tooltip); + const hover = modelHoverContent(node.value.value); + if (hover) { + return hover; + } + } else if (node.value.type === 'array') { + for (const item of node.value.items) { + if (item.type === 'string' && item.range.containsPosition(position)) { + const hover = modelHoverContent(item.value); + if (hover) { + return hover; } - return this.createHover(lines.join('\n'), range); } } } - return this.createHover(baseMessage, range); + return this.createHover(baseMessage, node.range); } private getAgentHover(agentAttribute: IHeaderAttribute, position: Position): Hover | undefined { diff --git a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptValidator.ts b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptValidator.ts index 19acdca2789..c0fa72ef88c 100644 --- a/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptValidator.ts +++ b/src/vs/workbench/contrib/chat/common/promptSyntax/languageProviders/promptValidator.ts @@ -282,36 +282,58 @@ export class PromptValidator { if (!attribute) { return; } - if (attribute.value.type !== 'string') { - report(toMarker(localize('promptValidator.modelMustBeString', "The 'model' attribute must be a string."), attribute.value.range, MarkerSeverity.Error)); - return; - } - const modelName = attribute.value.value.trim(); - if (modelName.length === 0) { - report(toMarker(localize('promptValidator.modelMustBeNonEmpty', "The 'model' attribute must be a non-empty string."), attribute.value.range, MarkerSeverity.Error)); + if (attribute.value.type !== 'string' && attribute.value.type !== 'array') { + report(toMarker(localize('promptValidator.modelMustBeStringOrArray', "The 'model' attribute must be a string or an array of strings."), attribute.value.range, MarkerSeverity.Error)); return; } - const languageModes = this.languageModelsService.getLanguageModelIds(); - if (languageModes.length === 0) { + const modelNames: [string, Range][] = []; + if (attribute.value.type === 'string') { + const modelName = attribute.value.value.trim(); + if (modelName.length === 0) { + report(toMarker(localize('promptValidator.modelMustBeNonEmpty', "The 'model' attribute must be a non-empty string."), attribute.value.range, MarkerSeverity.Error)); + return; + } + modelNames.push([modelName, attribute.value.range]); + } else if (attribute.value.type === 'array') { + if (attribute.value.items.length === 0) { + report(toMarker(localize('promptValidator.modelArrayMustNotBeEmpty', "The 'model' array must not be empty."), attribute.value.range, MarkerSeverity.Error)); + return; + } + for (const item of attribute.value.items) { + if (item.type !== 'string') { + report(toMarker(localize('promptValidator.modelArrayMustContainStrings', "The 'model' array must contain only strings."), item.range, MarkerSeverity.Error)); + return; + } + const modelName = item.value.trim(); + if (modelName.length === 0) { + report(toMarker(localize('promptValidator.modelArrayItemMustBeNonEmpty', "Model names in the array must be non-empty strings."), item.range, MarkerSeverity.Error)); + return; + } + modelNames.push([modelName, item.range]); + } + } + + const languageModels = this.languageModelsService.getLanguageModelIds(); + if (languageModels.length === 0) { // likely the service is not initialized yet return; } - const modelMetadata = this.findModelByName(languageModes, modelName); - if (!modelMetadata) { - report(toMarker(localize('promptValidator.modelNotFound', "Unknown model '{0}'.", modelName), attribute.value.range, MarkerSeverity.Warning)); - } else if (agentKind === ChatModeKind.Agent && !ILanguageModelChatMetadata.suitableForAgentMode(modelMetadata)) { - report(toMarker(localize('promptValidator.modelNotSuited', "Model '{0}' is not suited for agent mode.", modelName), attribute.value.range, MarkerSeverity.Warning)); + for (const [modelName, range] of modelNames) { + const modelMetadata = this.findModelByName(modelName); + if (!modelMetadata) { + report(toMarker(localize('promptValidator.modelNotFound', "Unknown model '{0}'.", modelName), range, MarkerSeverity.Warning)); + } else if (agentKind === ChatModeKind.Agent && !ILanguageModelChatMetadata.suitableForAgentMode(modelMetadata)) { + report(toMarker(localize('promptValidator.modelNotSuited', "Model '{0}' is not suited for agent mode.", modelName), range, MarkerSeverity.Warning)); + } } } - private findModelByName(languageModes: string[], modelName: string): ILanguageModelChatMetadata | undefined { - for (const model of languageModes) { - const metadata = this.languageModelsService.lookupLanguageModel(model); - if (metadata && metadata.isUserSelectable !== false && ILanguageModelChatMetadata.matchesQualifiedName(modelName, metadata)) { - return metadata; - } + private findModelByName(modelName: string): ILanguageModelChatMetadata | undefined { + const metadata = this.languageModelsService.lookupLanguageModelByQualifiedName(modelName); + if (metadata && metadata.isUserSelectable !== false) { + return metadata; } return undefined; } diff --git a/src/vs/workbench/contrib/chat/common/promptSyntax/promptFileParser.ts b/src/vs/workbench/contrib/chat/common/promptSyntax/promptFileParser.ts index a68e3ddbcef..eec997aa3ec 100644 --- a/src/vs/workbench/contrib/chat/common/promptSyntax/promptFileParser.ts +++ b/src/vs/workbench/contrib/chat/common/promptSyntax/promptFileParser.ts @@ -184,8 +184,8 @@ export class PromptHeader { return this.getStringAttribute(PromptHeaderAttributes.agent) ?? this.getStringAttribute(PromptHeaderAttributes.mode); } - public get model(): string | undefined { - return this.getStringAttribute(PromptHeaderAttributes.model); + public get model(): readonly string[] | undefined { + return this.getStringOrStringArrayAttribute(PromptHeaderAttributes.model); } public get applyTo(): string | undefined { @@ -294,6 +294,26 @@ export class PromptHeader { return undefined; } + private getStringOrStringArrayAttribute(key: string): readonly string[] | undefined { + const attribute = this._parsedHeader.attributes.find(attr => attr.key === key); + if (!attribute) { + return undefined; + } + if (attribute.value.type === 'string') { + return [attribute.value.value]; + } + if (attribute.value.type === 'array') { + const result: string[] = []; + for (const item of attribute.value.items) { + if (item.type === 'string') { + result.push(item.value); + } + } + return result; + } + return undefined; + } + public get agents(): string[] | undefined { return this.getStringArrayAttribute(PromptHeaderAttributes.agents); } diff --git a/src/vs/workbench/contrib/chat/common/promptSyntax/service/promptsService.ts b/src/vs/workbench/contrib/chat/common/promptSyntax/service/promptsService.ts index 604ea54f149..dbd2ce355a9 100644 --- a/src/vs/workbench/contrib/chat/common/promptSyntax/service/promptsService.ts +++ b/src/vs/workbench/contrib/chat/common/promptSyntax/service/promptsService.ts @@ -139,7 +139,7 @@ export interface ICustomAgent { /** * Model metadata in the prompt header. */ - readonly model?: string; + readonly model?: readonly string[]; /** * Argument hint metadata in the prompt header that describes what inputs the agent expects or supports. diff --git a/src/vs/workbench/contrib/chat/common/tools/builtinTools/runSubagentTool.ts b/src/vs/workbench/contrib/chat/common/tools/builtinTools/runSubagentTool.ts index 646f3d93516..e9245caf859 100644 --- a/src/vs/workbench/contrib/chat/common/tools/builtinTools/runSubagentTool.ts +++ b/src/vs/workbench/contrib/chat/common/tools/builtinTools/runSubagentTool.ts @@ -21,7 +21,7 @@ import { ChatMode, IChatMode, IChatModeService } from '../../chatModes.js'; import { IChatProgress, IChatService } from '../../chatService/chatService.js'; import { ChatRequestVariableSet } from '../../attachments/chatVariableEntries.js'; import { ChatAgentLocation, ChatConfiguration, ChatModeKind } from '../../constants.js'; -import { ILanguageModelChatMetadata, ILanguageModelsService } from '../../languageModels.js'; +import { ILanguageModelsService } from '../../languageModels.js'; import { CountTokensCallback, ILanguageModelToolsService, @@ -146,14 +146,13 @@ export class RunSubagentTool extends Disposable implements IToolImpl { mode = this.chatModeService.findModeByName(args.agentName); if (mode) { // Use mode-specific model if available - const modeModelQualifiedName = mode.model?.get(); - if (modeModelQualifiedName) { - // Find the actual model identifier from the qualified name - const modelIds = this.languageModelsService.getLanguageModelIds(); - for (const modelId of modelIds) { - const metadata = this.languageModelsService.lookupLanguageModel(modelId); - if (metadata && ILanguageModelChatMetadata.matchesQualifiedName(modeModelQualifiedName, metadata)) { - modeModelId = modelId; + const modeModelQualifiedNames = mode.model?.get(); + if (modeModelQualifiedNames) { + // Find the actual model identifier from the qualified name(s) + for (const qualifiedName of modeModelQualifiedNames) { + const metadata = this.languageModelsService.lookupLanguageModelByQualifiedName(qualifiedName); + if (metadata) { + modeModelId = metadata.id; break; } } diff --git a/src/vs/workbench/contrib/chat/test/browser/chatManagement/chatModelsViewModel.test.ts b/src/vs/workbench/contrib/chat/test/browser/chatManagement/chatModelsViewModel.test.ts index 01c1fa05735..322f8893f18 100644 --- a/src/vs/workbench/contrib/chat/test/browser/chatManagement/chatModelsViewModel.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/chatManagement/chatModelsViewModel.test.ts @@ -82,6 +82,15 @@ class MockLanguageModelsService implements ILanguageModelsService { return this.models.get(identifier); } + lookupLanguageModelByQualifiedName(referenceName: string): ILanguageModelChatMetadata | undefined { + for (const metadata of this.models.values()) { + if (ILanguageModelChatMetadata.matchesQualifiedName(referenceName, metadata)) { + return metadata; + } + } + return undefined; + } + getLanguageModels(): ILanguageModelChatMetadataAndIdentifier[] { const result: ILanguageModelChatMetadataAndIdentifier[] = []; for (const [identifier, metadata] of this.models.entries()) { diff --git a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHeaderAutocompletion.test.ts b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHeaderAutocompletion.test.ts index 1db2b8d62f3..6fa16703d14 100644 --- a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHeaderAutocompletion.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHeaderAutocompletion.test.ts @@ -178,6 +178,38 @@ suite('PromptHeaderAutocompletion', () => { ]); }); + test('complete model names inside model array', async () => { + const content = [ + '---', + 'description: "Test"', + 'model: [|]', + '---', + ].join('\n'); + + const actual = await getCompletions(content, PromptsType.agent); + // GPT 4 is excluded because it has agentMode: false + assert.deepStrictEqual(actual.sort(sortByLabel), [ + { label: 'MAE 4 (olama)', result: `model: ['MAE 4 (olama)']` }, + { label: 'MAE 4.1 (copilot)', result: `model: ['MAE 4.1 (copilot)']` }, + ].sort(sortByLabel)); + }); + + test('complete model names inside model array with existing entries', async () => { + const content = [ + '---', + 'description: "Test"', + `model: ['MAE 4 (olama)', |]`, + '---', + ].join('\n'); + + const actual = await getCompletions(content, PromptsType.agent); + // GPT 4 is excluded because it has agentMode: false + assert.deepStrictEqual(actual.sort(sortByLabel), [ + { label: 'MAE 4 (olama)', result: `model: ['MAE 4 (olama)', 'MAE 4 (olama)']` }, + { label: 'MAE 4.1 (copilot)', result: `model: ['MAE 4 (olama)', 'MAE 4.1 (copilot)']` }, + ].sort(sortByLabel)); + }); + test('complete tool names inside tools array', async () => { const content = [ '---', diff --git a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHovers.test.ts b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHovers.test.ts index e1a5a2d29f8..bc7cba4d758 100644 --- a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHovers.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptHovers.test.ts @@ -59,8 +59,13 @@ suite('PromptHoverProvider', () => { instaService.stub(ILanguageModelsService, { getLanguageModelIds() { return testModels.map(m => m.id); }, - lookupLanguageModel(name: string) { - return testModels.find(m => m.id === name); + lookupLanguageModelByQualifiedName(qualifiedName: string) { + for (const metadata of testModels) { + if (ILanguageModelChatMetadata.matchesQualifiedName(qualifiedName, metadata)) { + return metadata; + } + } + return undefined; } }); @@ -121,7 +126,7 @@ suite('PromptHoverProvider', () => { ].join('\n'); const hover = await getHover(content, 4, 1, PromptsType.agent); const expected = [ - 'Specify the model that runs this custom agent.', + 'Specify the model that runs this custom agent. Can also be a list of models. The first available model will be used.', '', 'Note: This attribute is not used when target is github-copilot.' ].join('\n'); @@ -138,7 +143,7 @@ suite('PromptHoverProvider', () => { ].join('\n'); const hover = await getHover(content, 4, 1, PromptsType.agent); const expected = [ - 'Specify the model that runs this custom agent.', + 'Specify the model that runs this custom agent. Can also be a list of models. The first available model will be used.', '', '- Name: MAE 4', '- Family: mae', @@ -229,6 +234,44 @@ suite('PromptHoverProvider', () => { assert.strictEqual(hover, 'Test Tool 1'); }); + test('hover on model attribute with vscode target and model array', async () => { + const content = [ + '---', + 'description: "Test"', + 'target: vscode', + `model: ['MAE 4 (olama)', 'MAE 4.1 (copilot)']`, + '---', + ].join('\n'); + const hover = await getHover(content, 4, 10, PromptsType.agent); + const expected = [ + 'Specify the model that runs this custom agent. Can also be a list of models. The first available model will be used.', + '', + '- Name: MAE 4', + '- Family: mae', + '- Vendor: olama' + ].join('\n'); + assert.strictEqual(hover, expected); + }); + + test('hover on second model in model array', async () => { + const content = [ + '---', + 'description: "Test"', + 'target: vscode', + `model: ['MAE 4 (olama)', 'MAE 4.1 (copilot)']`, + '---', + ].join('\n'); + const hover = await getHover(content, 4, 30, PromptsType.agent); + const expected = [ + 'Specify the model that runs this custom agent. Can also be a list of models. The first available model will be used.', + '', + '- Name: MAE 4.1', + '- Family: mae', + '- Vendor: copilot' + ].join('\n'); + assert.strictEqual(hover, expected); + }); + test('hover on description attribute', async () => { const content = [ '---', @@ -298,7 +341,7 @@ suite('PromptHoverProvider', () => { ].join('\n'); const hover = await getHover(content, 3, 1, PromptsType.prompt); const expected = [ - 'The model to use in this prompt.', + 'The model to use in this prompt. Can also be a list of models. The first available model will be used.', '', '- Name: MAE 4', '- Family: mae', diff --git a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptValidator.test.ts b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptValidator.test.ts index c2af631b9a8..2589c128598 100644 --- a/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptValidator.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/promptSyntax/languageProviders/promptValidator.test.ts @@ -119,8 +119,13 @@ suite('PromptValidator', () => { instaService.stub(ILanguageModelsService, { getLanguageModelIds() { return testModels.map(m => m.id); }, - lookupLanguageModel(name: string) { - return testModels.find(m => m.id === name); + lookupLanguageModelByQualifiedName(qualifiedName: string) { + for (const metadata of testModels) { + if (ILanguageModelChatMetadata.matchesQualifiedName(qualifiedName, metadata)) { + return metadata; + } + } + return undefined; } }); @@ -199,6 +204,95 @@ suite('PromptValidator', () => { assert.deepStrictEqual(markers.map(m => m.message), [`The 'tools' attribute must be an array.`]); }); + test('model as string array - valid', async () => { + const content = [ + '---', + 'description: "Test with model array"', + `model: ['MAE 4 (olama)', 'MAE 4.1']`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.deepStrictEqual(markers, []); + }); + + test('model as string array - unknown model', async () => { + const content = [ + '---', + 'description: "Test with model array"', + `model: ['MAE 4 (olama)', 'Unknown Model']`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Warning); + assert.strictEqual(markers[0].message, `Unknown model 'Unknown Model'.`); + }); + + test('model as string array - unsuitable model', async () => { + const content = [ + '---', + 'description: "Test with model array"', + `model: ['MAE 4 (olama)', 'MAE 3.5 Turbo']`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Warning); + assert.strictEqual(markers[0].message, `Model 'MAE 3.5 Turbo' is not suited for agent mode.`); + }); + + test('model as string array - empty array', async () => { + const content = [ + '---', + 'description: "Test with empty model array"', + `model: []`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Error); + assert.strictEqual(markers[0].message, `The 'model' array must not be empty.`); + }); + + test('model as string array - non-string item', async () => { + const content = [ + '---', + 'description: "Test with invalid model array"', + `model: ['MAE 4 (olama)', 123]`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Error); + assert.strictEqual(markers[0].message, `The 'model' array must contain only strings.`); + }); + + test('model as string array - empty string item', async () => { + const content = [ + '---', + 'description: "Test with empty string in model array"', + `model: ['MAE 4 (olama)', '']`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Error); + assert.strictEqual(markers[0].message, `Model names in the array must be non-empty strings.`); + }); + + test('model as invalid type', async () => { + const content = [ + '---', + 'description: "Test with invalid model type"', + `model: 123`, + '---', + ].join('\n'); + const markers = await validate(content, PromptsType.agent); + assert.strictEqual(markers.length, 1); + assert.strictEqual(markers[0].severity, MarkerSeverity.Error); + assert.strictEqual(markers[0].message, `The 'model' attribute must be a string or an array of strings.`); + }); + test('each tool must be string', async () => { const content = [ '---', diff --git a/src/vs/workbench/contrib/chat/test/common/chatModeService.test.ts b/src/vs/workbench/contrib/chat/test/common/chatModeService.test.ts index fe448e3fd95..819bbabb1ab 100644 --- a/src/vs/workbench/contrib/chat/test/common/chatModeService.test.ts +++ b/src/vs/workbench/contrib/chat/test/common/chatModeService.test.ts @@ -195,7 +195,7 @@ suite('ChatModeService', () => { description: 'Initial description', tools: ['tool1'], agentInstructions: { content: 'Initial body', toolReferences: [] }, - model: 'gpt-4', + model: ['gpt-4'], source: workspaceSource, }; @@ -212,7 +212,7 @@ suite('ChatModeService', () => { description: 'Updated description', tools: ['tool1', 'tool2'], agentInstructions: { content: 'Updated body', toolReferences: [] }, - model: 'Updated model' + model: ['Updated model'] }; promptsService.setCustomModes([updatedMode]); @@ -228,7 +228,7 @@ suite('ChatModeService', () => { assert.strictEqual(updatedCustomMode.description.get(), 'Updated description'); assert.deepStrictEqual(updatedCustomMode.customTools?.get(), ['tool1', 'tool2']); assert.deepStrictEqual(updatedCustomMode.modeInstructions?.get(), { content: 'Updated body', toolReferences: [] }); - assert.strictEqual(updatedCustomMode.model?.get(), 'Updated model'); + assert.deepStrictEqual(updatedCustomMode.model?.get(), ['Updated model']); assert.deepStrictEqual(updatedCustomMode.source, workspaceSource); }); diff --git a/src/vs/workbench/contrib/chat/test/common/languageModels.ts b/src/vs/workbench/contrib/chat/test/common/languageModels.ts index 35a614a06b8..35b2595109d 100644 --- a/src/vs/workbench/contrib/chat/test/common/languageModels.ts +++ b/src/vs/workbench/contrib/chat/test/common/languageModels.ts @@ -40,6 +40,10 @@ export class NullLanguageModelsService implements ILanguageModelsService { return undefined; } + lookupLanguageModelByQualifiedName(qualifiedName: string) { + return undefined; + } + getLanguageModels(): ILanguageModelChatMetadataAndIdentifier[] { return []; } diff --git a/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/newPromptsParser.test.ts b/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/newPromptsParser.test.ts index b4326f32c2d..0c70640aadb 100644 --- a/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/newPromptsParser.test.ts +++ b/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/newPromptsParser.test.ts @@ -53,7 +53,7 @@ suite('NewPromptsParser', () => { { range: new Range(7, 79, 7, 85), name: 'tool-2', offset: 170 } ]); assert.deepEqual(result.header.description, 'Agent test'); - assert.deepEqual(result.header.model, 'GPT 4.1'); + assert.deepEqual(result.header.model, ['GPT 4.1']); assert.ok(result.header.tools); assert.deepEqual(result.header.tools, ['tool1', 'tool2']); }); @@ -110,7 +110,7 @@ suite('NewPromptsParser', () => { }, ]); assert.deepEqual(result.header.description, 'Agent test'); - assert.deepEqual(result.header.model, 'GPT 4.1'); + assert.deepEqual(result.header.model, ['GPT 4.1']); assert.ok(result.header.handOffs); assert.deepEqual(result.header.handOffs, [ { label: 'Implement', agent: 'Default', prompt: 'Implement the plan', send: false }, @@ -234,7 +234,7 @@ suite('NewPromptsParser', () => { ]); assert.deepEqual(result.header.description, 'General purpose coding assistant'); assert.deepEqual(result.header.agent, 'agent'); - assert.deepEqual(result.header.model, 'GPT 4.1'); + assert.deepEqual(result.header.model, ['GPT 4.1']); assert.ok(result.header.tools); assert.deepEqual(result.header.tools, ['search', 'terminal']); }); diff --git a/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/promptsService.test.ts b/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/promptsService.test.ts index cc9465e83eb..c31071db411 100644 --- a/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/promptsService.test.ts +++ b/src/vs/workbench/contrib/chat/test/common/promptSyntax/service/promptsService.test.ts @@ -994,7 +994,7 @@ suite('PromptsService', () => { name: 'vscode-agent', description: 'VS Code specialized agent.', target: 'vscode', - model: 'gpt-4', + model: ['gpt-4'], agentInstructions: { content: 'I am specialized for VS Code editor tasks.', toolReferences: [], diff --git a/src/vs/workbench/contrib/inlineChat/browser/inlineChatController.ts b/src/vs/workbench/contrib/inlineChat/browser/inlineChatController.ts index 8ffc6fbb64e..8e1b2cb7373 100644 --- a/src/vs/workbench/contrib/inlineChat/browser/inlineChatController.ts +++ b/src/vs/workbench/contrib/inlineChat/browser/inlineChatController.ts @@ -474,7 +474,7 @@ export class InlineChatController implements IEditorContribution { // Check for default model setting const defaultModelSetting = this._configurationService.getValue(InlineChatConfigKeys.DefaultModel); - if (defaultModelSetting && !this._zone.value.widget.chatWidget.input.switchModelByQualifiedName(defaultModelSetting)) { + if (defaultModelSetting && !this._zone.value.widget.chatWidget.input.switchModelByQualifiedName([defaultModelSetting])) { this._logService.warn(`inlineChat.defaultModel setting value '${defaultModelSetting}' did not match any available model. Falling back to vendor default.`); }