diff --git a/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsViewModel.ts b/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsViewModel.ts index fc73c476433..482c495d720 100644 --- a/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsViewModel.ts +++ b/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsViewModel.ts @@ -16,6 +16,7 @@ export const VENDOR_ENTRY_TEMPLATE_ID = 'vendor.entry.template'; const wordFilter = or(matchesBaseContiguousSubString, matchesWords); const CAPABILITY_REGEX = /@capability:\s*([^\s]+)/gi; const VISIBLE_REGEX = /@visible:\s*(true|false)/i; +const PROVIDER_REGEX = /@provider:\s*((".+?")|([^\s]+))/gi; export const SEARCH_SUGGESTIONS = { FILTER_TYPES: [ @@ -54,6 +55,7 @@ export interface IModelItemEntry { templateId: string; providerMatches?: IMatch[]; modelNameMatches?: IMatch[]; + modelIdMatches?: IMatch[]; capabilityMatches?: string[]; } @@ -111,92 +113,154 @@ export class ChatModelsViewModel extends EditorModel { filter(searchValue: string): readonly IViewModelEntry[] { this.searchValue = searchValue; - - let modelEntries = this.modelEntries; - const capabilityMatchesMap = new Map(); - - const visibleMatches = VISIBLE_REGEX.exec(searchValue); - if (visibleMatches && visibleMatches[1]) { - const visible = visibleMatches[1].toLowerCase() === 'true'; - modelEntries = this.filterByVisible(modelEntries, visible); - searchValue = searchValue.replace(VISIBLE_REGEX, ''); - } - - const providerNames: string[] = []; - let match: RegExpExecArray | null; - - const providerRegexGlobal = /@provider:\s*((".+?")|([^\s]+))/gi; - while ((match = providerRegexGlobal.exec(searchValue)) !== null) { - const providerName = match[2] ? match[2].substring(1, match[2].length - 1) : match[3]; - providerNames.push(providerName); - } - - // Apply provider filter with OR logic if multiple providers - if (providerNames.length > 0) { - modelEntries = this.filterByProviders(modelEntries, providerNames); - searchValue = searchValue.replace(/@provider:\s*((".+?")|([^\s]+))/gi, '').replace(/@vendor:\s*((".+?")|([^\s]+))/gi, ''); - } - - // Apply capability filters with AND logic if multiple capabilities - const capabilityNames: string[] = []; - let capabilityMatch: RegExpExecArray | null; - - while ((capabilityMatch = CAPABILITY_REGEX.exec(searchValue)) !== null) { - capabilityNames.push(capabilityMatch[1].toLowerCase()); - } - - if (capabilityNames.length > 0) { - const filteredEntries = this.filterByCapabilities(modelEntries, capabilityNames); - modelEntries = []; - for (const { entry, matchedCapabilities } of filteredEntries) { - modelEntries.push(entry); - capabilityMatchesMap.set(ChatModelsViewModel.getId(entry), matchedCapabilities); - } - searchValue = searchValue.replace(/@capability:\s*([^\s]+)/gi, ''); - } - - searchValue = searchValue.trim(); - const filtered = searchValue ? this.filterByText(modelEntries, searchValue, capabilityMatchesMap) : this.toEntries(modelEntries, capabilityMatchesMap); - + const filtered = this.filterModels(this.modelEntries, searchValue); this.splice(0, this._viewModelEntries.length, filtered); return this.viewModelEntries; } - private filterByProviders(modelEntries: IModelEntry[], providers: string[]): IModelEntry[] { - const lowerProviders = providers.map(p => p.toLowerCase().trim()); - return modelEntries.filter(m => - lowerProviders.some(provider => - m.vendor.toLowerCase() === provider || - m.vendorDisplayName.toLowerCase() === provider - ) - ); - } + private filterModels(modelEntries: IModelEntry[], searchValue: string): (IVendorItemEntry | IModelItemEntry)[] { + let visible: boolean | undefined; - private filterByVisible(modelEntries: IModelEntry[], visible: boolean): IModelEntry[] { - return modelEntries.filter(m => (m.metadata.isUserSelectable ?? false) === visible); - } + const visibleMatches = VISIBLE_REGEX.exec(searchValue); + if (visibleMatches && visibleMatches[1]) { + visible = visibleMatches[1].toLowerCase() === 'true'; + searchValue = searchValue.replace(VISIBLE_REGEX, ''); + } - private filterByCapabilities(modelEntries: IModelEntry[], capabilities: string[]): { entry: IModelEntry; matchedCapabilities: string[] }[] { - const result: { entry: IModelEntry; matchedCapabilities: string[] }[] = []; - for (const m of modelEntries) { - if (!m.metadata.capabilities) { + const providerNames: string[] = []; + let providerMatch: RegExpExecArray | null; + PROVIDER_REGEX.lastIndex = 0; + while ((providerMatch = PROVIDER_REGEX.exec(searchValue)) !== null) { + const providerName = providerMatch[2] ? providerMatch[2].substring(1, providerMatch[2].length - 1) : providerMatch[3]; + providerNames.push(providerName); + } + if (providerNames.length > 0) { + searchValue = searchValue.replace(PROVIDER_REGEX, ''); + } + + const capabilities: string[] = []; + let capabilityMatch: RegExpExecArray | null; + CAPABILITY_REGEX.lastIndex = 0; + while ((capabilityMatch = CAPABILITY_REGEX.exec(searchValue)) !== null) { + capabilities.push(capabilityMatch[1].toLowerCase()); + } + if (capabilities.length > 0) { + searchValue = searchValue.replace(CAPABILITY_REGEX, ''); + } + + const quoteAtFirstChar = searchValue.charAt(0) === '"'; + const quoteAtLastChar = searchValue.charAt(searchValue.length - 1) === '"'; + const completeMatch = quoteAtFirstChar && quoteAtLastChar; + if (quoteAtFirstChar) { + searchValue = searchValue.substring(1); + } + if (quoteAtLastChar) { + searchValue = searchValue.substring(0, searchValue.length - 1); + } + searchValue = searchValue.trim(); + + const isFiltering = searchValue !== '' || capabilities.length > 0 || providerNames.length > 0 || visible !== undefined; + + const result: (IVendorItemEntry | IModelItemEntry)[] = []; + const words = searchValue.split(' '); + const allVendors = new Set(this.modelEntries.map(m => m.vendor)); + const showHeaders = allVendors.size > 1; + const addedVendors = new Set(); + const lowerProviders = providerNames.map(p => p.toLowerCase().trim()); + + for (const modelEntry of modelEntries) { + if (!isFiltering && showHeaders && this.collapsedVendors.has(modelEntry.vendor)) { + if (!addedVendors.has(modelEntry.vendor)) { + const vendorInfo = this.languageModelsService.getVendors().find(v => v.vendor === modelEntry.vendor); + result.push({ + type: 'vendor', + id: `vendor-${modelEntry.vendor}`, + vendorEntry: { + vendor: modelEntry.vendor, + vendorDisplayName: modelEntry.vendorDisplayName, + managementCommand: vendorInfo?.managementCommand + }, + templateId: VENDOR_ENTRY_TEMPLATE_ID, + collapsed: true + }); + addedVendors.add(modelEntry.vendor); + } continue; } - const allMatchedCapabilities: string[] = []; - let matchesAll = true; - for (const capability of capabilities) { - const matchedForThisCapability = this.getMatchingCapabilities(m, capability); - if (matchedForThisCapability.length === 0) { - matchesAll = false; - break; + if (visible !== undefined) { + if ((modelEntry.metadata.isUserSelectable ?? false) !== visible) { + continue; } - allMatchedCapabilities.push(...matchedForThisCapability); } - if (matchesAll) { - result.push({ entry: m, matchedCapabilities: distinct(allMatchedCapabilities) }); + if (lowerProviders.length > 0) { + const matchesProvider = lowerProviders.some(provider => + modelEntry.vendor.toLowerCase() === provider || + modelEntry.vendorDisplayName.toLowerCase() === provider + ); + if (!matchesProvider) { + continue; + } } + + // Filter by capabilities + let matchedCapabilities: string[] = []; + if (capabilities.length > 0) { + if (!modelEntry.metadata.capabilities) { + continue; + } + let matchesAll = true; + for (const capability of capabilities) { + const matchedForThisCapability = this.getMatchingCapabilities(modelEntry, capability); + if (matchedForThisCapability.length === 0) { + matchesAll = false; + break; + } + matchedCapabilities.push(...matchedForThisCapability); + } + if (!matchesAll) { + continue; + } + matchedCapabilities = distinct(matchedCapabilities); + } + + // Filter by text + let modelMatches: ModelItemMatches | undefined; + if (searchValue) { + modelMatches = new ModelItemMatches(modelEntry, searchValue, words, completeMatch); + if (!modelMatches.modelNameMatches && !modelMatches.modelIdMatches && !modelMatches.providerMatches && !modelMatches.capabilityMatches) { + continue; + } + } + + if (showHeaders && !addedVendors.has(modelEntry.vendor)) { + const vendorInfo = this.languageModelsService.getVendors().find(v => v.vendor === modelEntry.vendor); + result.push({ + type: 'vendor', + id: `vendor-${modelEntry.vendor}`, + vendorEntry: { + vendor: modelEntry.vendor, + vendorDisplayName: modelEntry.vendorDisplayName, + managementCommand: vendorInfo?.managementCommand + }, + templateId: VENDOR_ENTRY_TEMPLATE_ID, + collapsed: false + }); + addedVendors.add(modelEntry.vendor); + } + + const modelId = ChatModelsViewModel.getId(modelEntry); + result.push({ + type: 'model', + id: modelId, + templateId: MODEL_ENTRY_TEMPLATE_ID, + modelEntry, + modelNameMatches: modelMatches?.modelNameMatches || undefined, + modelIdMatches: modelMatches?.modelIdMatches || undefined, + providerMatches: modelMatches?.providerMatches || undefined, + capabilityMatches: matchedCapabilities.length ? matchedCapabilities : undefined, + }); } return result; } @@ -239,42 +303,6 @@ export class ChatModelsViewModel extends EditorModel { return matchedCapabilities; } - private filterByText(modelEntries: IModelEntry[], searchValue: string, capabilityMatchesMap: Map): IModelItemEntry[] { - const quoteAtFirstChar = searchValue.charAt(0) === '"'; - const quoteAtLastChar = searchValue.charAt(searchValue.length - 1) === '"'; - const completeMatch = quoteAtFirstChar && quoteAtLastChar; - if (quoteAtFirstChar) { - searchValue = searchValue.substring(1); - } - if (quoteAtLastChar) { - searchValue = searchValue.substring(0, searchValue.length - 1); - } - searchValue = searchValue.trim(); - - const result: IModelItemEntry[] = []; - const words = searchValue.split(' '); - - for (const modelEntry of modelEntries) { - const modelMatches = new ModelItemMatches(modelEntry, searchValue, words, completeMatch); - if (modelMatches.modelNameMatches - || modelMatches.providerMatches - || modelMatches.capabilityMatches - ) { - const modelId = ChatModelsViewModel.getId(modelEntry); - result.push({ - type: 'model', - id: modelId, - templateId: MODEL_ENTRY_TEMPLATE_ID, - modelEntry, - modelNameMatches: modelMatches.modelNameMatches || undefined, - providerMatches: modelMatches.providerMatches || undefined, - capabilityMatches: capabilityMatchesMap.get(modelId), - }); - } - } - return result; - } - getVendors(): IUserFriendlyLanguageModel[] { return [...this.languageModelsService.getVendors()].sort((a, b) => { if (a.vendor === 'copilot') { return -1; } @@ -342,55 +370,20 @@ export class ChatModelsViewModel extends EditorModel { this.filter(this.searchValue); } - getConfiguredVendors(): IVendorItemEntry[] { - return this.toEntries(this.modelEntries, new Map(), true) as IVendorItemEntry[]; - } - - private toEntries(modelEntries: IModelEntry[], capabilityMatchesMap: Map, excludeModels?: boolean): (IVendorItemEntry | IModelItemEntry)[] { - const result: (IVendorItemEntry | IModelItemEntry)[] = []; - const vendorMap = new Map(); - - for (const modelEntry of modelEntries) { - const models = vendorMap.get(modelEntry.vendor) || []; - models.push(modelEntry); - vendorMap.set(modelEntry.vendor, models); - } - - const showVendorHeaders = vendorMap.size > 1; - - for (const [vendor, models] of vendorMap) { - const firstModel = models[0]; - const isCollapsed = this.collapsedVendors.has(vendor); - const vendorInfo = this.languageModelsService.getVendors().find(v => v.vendor === vendor); - - if (showVendorHeaders) { + getConfiguredVendors(): IVendorEntry[] { + const result: IVendorEntry[] = []; + const seenVendors = new Set(); + for (const modelEntry of this.modelEntries) { + if (!seenVendors.has(modelEntry.vendor)) { + seenVendors.add(modelEntry.vendor); + const vendorInfo = this.languageModelsService.getVendors().find(v => v.vendor === modelEntry.vendor); result.push({ - type: 'vendor', - id: `vendor-${vendor}`, - vendorEntry: { - vendor: firstModel.vendor, - vendorDisplayName: firstModel.vendorDisplayName, - managementCommand: vendorInfo?.managementCommand - }, - templateId: VENDOR_ENTRY_TEMPLATE_ID, - collapsed: isCollapsed + vendor: modelEntry.vendor, + vendorDisplayName: modelEntry.vendorDisplayName, + managementCommand: vendorInfo?.managementCommand }); } - - if (!excludeModels && (!isCollapsed || !showVendorHeaders)) { - for (const modelEntry of models) { - const modelId = ChatModelsViewModel.getId(modelEntry); - result.push({ - type: 'model', - id: modelId, - modelEntry, - templateId: MODEL_ENTRY_TEMPLATE_ID, - capabilityMatches: capabilityMatchesMap.get(modelId), - }); - } - } } - return result; } } @@ -398,6 +391,7 @@ export class ChatModelsViewModel extends EditorModel { class ModelItemMatches { readonly modelNameMatches: IMatch[] | null = null; + readonly modelIdMatches: IMatch[] | null = null; readonly providerMatches: IMatch[] | null = null; readonly capabilityMatches: IMatch[] | null = null; @@ -408,10 +402,7 @@ class ModelItemMatches { this.matches(searchValue, modelEntry.metadata.name, (word, wordToMatchAgainst) => matchesWords(word, wordToMatchAgainst, true), words) : null; - // Match against model identifier - if (!this.modelNameMatches) { - this.modelNameMatches = this.matches(searchValue, modelEntry.identifier, or(matchesWords, matchesCamelCase), words); - } + this.modelIdMatches = this.matches(searchValue, modelEntry.identifier, or(matchesWords, matchesCamelCase), words); // Match against vendor display name this.providerMatches = this.matches(searchValue, modelEntry.vendorDisplayName, (word, wordToMatchAgainst) => matchesWords(word, wordToMatchAgainst, true), words); diff --git a/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsWidget.ts b/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsWidget.ts index d23b1fa993e..a7608c3200c 100644 --- a/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsWidget.ts +++ b/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsWidget.ts @@ -226,7 +226,7 @@ class ModelsSearchFilterDropdownMenuActionViewItem extends DropdownMenuActionVie const configuredVendors = this.viewModel.getConfiguredVendors(); if (configuredVendors.length > 1) { actions.push(new Separator()); - actions.push(...configuredVendors.map(vendor => this.createProviderAction(vendor.vendorEntry.vendor, vendor.vendorEntry.vendorDisplayName))); + actions.push(...configuredVendors.map(vendor => this.createProviderAction(vendor.vendor, vendor.vendorDisplayName))); } return actions; @@ -717,17 +717,25 @@ export class ChatModelsWidget extends Disposable { { triggerCharacters: ['@', ':'], provideResults: (query: string) => { + const providerSuggestions = this.viewModel.getVendors().map(v => `@provider:"${v.displayName}"`); + const allSuggestions = [ + ...providerSuggestions, + ...SEARCH_SUGGESTIONS.CAPABILITIES, + ...SEARCH_SUGGESTIONS.VISIBILITY, + ]; + if (!query.trim()) { + return allSuggestions; + } const queryParts = query.split(/\s/g); const lastPart = queryParts[queryParts.length - 1]; if (lastPart.startsWith('@provider:')) { - const vendors = this.viewModel.getVendors(); - return vendors.map(v => `@provider:"${v.displayName}"`); + return providerSuggestions; } else if (lastPart.startsWith('@capability:')) { return SEARCH_SUGGESTIONS.CAPABILITIES; } else if (lastPart.startsWith('@visible:')) { return SEARCH_SUGGESTIONS.VISIBILITY; } else if (lastPart.startsWith('@')) { - return SEARCH_SUGGESTIONS.FILTER_TYPES; + return allSuggestions; } return []; } @@ -930,7 +938,7 @@ export class ChatModelsWidget extends Disposable { } const vendors = this.viewModel.getVendors(); - const configuredVendors = new Set(this.viewModel.getConfiguredVendors().map(cv => cv.vendorEntry.vendor)); + const configuredVendors = new Set(this.viewModel.getConfiguredVendors().map(cv => cv.vendor)); const vendorsWithoutModels = vendors.filter(v => !configuredVendors.has(v.vendor)); const hasPlan = this.chatEntitlementService.entitlement !== ChatEntitlement.Unknown && this.chatEntitlementService.entitlement !== ChatEntitlement.Available; diff --git a/src/vs/workbench/contrib/chat/test/browser/chatModelsViewModel.test.ts b/src/vs/workbench/contrib/chat/test/browser/chatModelsViewModel.test.ts index 47af77255b2..6afc0021d30 100644 --- a/src/vs/workbench/contrib/chat/test/browser/chatModelsViewModel.test.ts +++ b/src/vs/workbench/contrib/chat/test/browser/chatModelsViewModel.test.ts @@ -383,6 +383,15 @@ suite('ChatModelsViewModel', () => { assert.ok(models[0].modelNameMatches); }); + test('should filter by text matching model id', () => { + const results = viewModel.filter('copilot-gpt-4o'); + + const models = results.filter(r => !isVendorEntry(r)) as IModelItemEntry[]; + assert.strictEqual(models.length, 1); + assert.strictEqual(models[0].modelEntry.identifier, 'copilot-gpt-4o'); + assert.ok(models[0].modelIdMatches); + }); + test('should filter by text matching vendor name', () => { const results = viewModel.filter('GitHub'); @@ -731,4 +740,19 @@ suite('ChatModelsViewModel', () => { assert.strictEqual(vendors[0].vendorEntry.vendor, 'copilot'); } }); + + test('should show vendor headers when filtered', () => { + const results = viewModel.filter('GPT'); + const vendors = results.filter(isVendorEntry); + assert.ok(vendors.length > 0); + }); + + test('should not show vendor headers when filtered if only one vendor exists', async () => { + const { viewModel: singleVendorViewModel } = createSingleVendorViewModel(store, chatEntitlementService); + await singleVendorViewModel.resolve(); + + const results = singleVendorViewModel.filter('GPT'); + const vendors = results.filter(isVendorEntry); + assert.strictEqual(vendors.length, 0); + }); });