Refactor model selection logic into separate file. Add tons of tests (#299210)

* Refactor model selection logic into separate file. Add tons of tests

* Use one get all models function

* Copilot comments
This commit is contained in:
Logan Ramos
2026-03-04 11:48:18 -05:00
committed by GitHub
parent 359c7722c8
commit 98dc3fd3a4
3 changed files with 1885 additions and 86 deletions

View File

@@ -89,6 +89,7 @@ import { ChatAgentLocation, ChatConfiguration, ChatModeKind, validateChatMode }
import { IChatEditingSession, IModifiedFileEntry, ModifiedFileEntryState } from '../../../common/editing/chatEditingService.js';
import { ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier, ILanguageModelsService } from '../../../common/languageModels.js';
import { IChatModelInputState, IChatRequestModeInfo, IInputModel } from '../../../common/model/chatModel.js';
import { filterModelsForSession, findDefaultModel, hasModelsTargetingSession, isModelValidForSession, mergeModelsWithCache, resolveModelFromSyncState, shouldResetModelToDefault, shouldResetOnModelListChange, shouldRestoreLateArrivingModel, shouldRestorePersistedModel } from './chatModelSelectionLogic.js';
import { getChatSessionType } from '../../../common/model/chatUri.js';
import { IChatResponseViewModel, isResponseVM } from '../../../common/model/chatViewModel.js';
import { IChatAgentService } from '../../../common/participants/chatAgents.js';
@@ -625,8 +626,7 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
this.initSelectedModel();
this._register(this.languageModelsService.onDidChangeLanguageModels(() => {
const selectedModel = this._currentLanguageModel ? this.getModels().find(m => m.identifier === this._currentLanguageModel.get()?.identifier) : undefined;
if (!this.currentLanguageModel || !selectedModel) {
if (shouldResetOnModelListChange(this._currentLanguageModel.get()?.identifier, this.getModels())) {
this.setCurrentLanguageModelToDefault();
}
}));
@@ -719,25 +719,20 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
const persistedAsDefault = this.storageService.getBoolean(this.getSelectedModelIsDefaultStorageKey(), StorageScope.APPLICATION, true);
if (persistedSelection) {
const model = this.getModels().find(m => m.identifier === persistedSelection);
if (model) {
// Only restore the model if it wasn't the default at the time of storing or it is now the default
if (!persistedAsDefault || model.metadata.isDefaultForLocation[this.location]) {
this.setCurrentLanguageModel(model);
this.checkModelSupported();
}
} else {
const result = shouldRestorePersistedModel(persistedSelection, persistedAsDefault, this.getModels(), this.location);
if (result.shouldRestore && result.model) {
this.setCurrentLanguageModel(result.model);
this.checkModelSupported();
} else if (!result.model) {
this._waitForPersistedLanguageModel.value = this.languageModelsService.onDidChangeLanguageModels(e => {
const persistedModel = this.languageModelsService.lookupLanguageModel(persistedSelection);
if (persistedModel) {
this._waitForPersistedLanguageModel.clear();
// Only restore the model if it wasn't the default at the time of storing or it is now the default
if (!persistedAsDefault || persistedModel.isDefaultForLocation[this.location]) {
if (persistedModel.isUserSelectable) {
this.setCurrentLanguageModel({ metadata: persistedModel, identifier: persistedSelection });
this.checkModelSupported();
}
const lateModel = { metadata: persistedModel, identifier: persistedSelection };
if (shouldRestoreLateArrivingModel(persistedSelection, persistedAsDefault, lateModel, this.location)) {
this.setCurrentLanguageModel(lateModel);
this.checkModelSupported();
}
} else {
this.setCurrentLanguageModelToDefault();
@@ -946,14 +941,18 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
// Sync selected model - validate it belongs to the current session's model pool
if (state?.selectedModel) {
const lm = this._currentLanguageModel.get();
if (!lm || lm.identifier !== state.selectedModel.identifier) {
if (this.isModelValidForCurrentSession(state.selectedModel)) {
this.setCurrentLanguageModel(state.selectedModel);
} else {
// Model from state doesn't belong to this session's pool - use default
this.setCurrentLanguageModelToDefault();
}
const allModels = this.getAllMergedModels();
const sessionType = this.getCurrentSessionType();
const syncResult = resolveModelFromSyncState(state.selectedModel, this._currentLanguageModel.get(), allModels, sessionType, {
location: this.location,
currentModeKind: this.currentModeKind,
isInlineChatV2Enabled: !!this.configurationService.getValue(InlineChatConfigKeys.EnableV2),
sessionType,
});
if (syncResult.action === 'apply') {
this.setCurrentLanguageModel(state.selectedModel);
} else if (syncResult.action === 'default') {
this.setCurrentLanguageModelToDefault();
}
}
@@ -1019,7 +1018,13 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
private checkModelSupported(): void {
const lm = this._currentLanguageModel.get();
if (lm && (!this.modelSupportedForDefaultAgent(lm) || !this.modelSupportedForInlineChat(lm) || !this.isModelValidForCurrentSession(lm))) {
const allModels = this.getAllMergedModels();
if (shouldResetModelToDefault(lm, this.getModels(), {
location: this.location,
currentModeKind: this.currentModeKind,
isInlineChatV2Enabled: !!this.configurationService.getValue(InlineChatConfigKeys.EnableV2),
sessionType: this.getCurrentSessionType(),
}, allModels)) {
this.setCurrentLanguageModelToDefault();
}
}
@@ -1051,56 +1056,29 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
this._syncInputStateToModel();
}
private modelSupportedForDefaultAgent(model: ILanguageModelChatMetadataAndIdentifier): boolean {
// Probably this logic could live in configuration on the agent, or somewhere else, if it gets more complex
if (this.currentModeKind === ChatModeKind.Agent) {
return ILanguageModelChatMetadata.suitableForAgentMode(model.metadata);
}
return true;
}
private modelSupportedForInlineChat(model: ILanguageModelChatMetadataAndIdentifier): boolean {
if (this.location !== ChatAgentLocation.EditorInline || !this.configurationService.getValue(InlineChatConfigKeys.EnableV2)) {
return true;
}
return !!model.metadata.capabilities?.toolCalling;
}
private getModels(): ILanguageModelChatMetadataAndIdentifier[] {
/**
* Get all models merged from live and cache, without session/mode filtering.
* This is the canonical source for the full model pool, including cached models
* that bridge startup races when live models haven't loaded yet.
*/
private getAllMergedModels(): ILanguageModelChatMetadataAndIdentifier[] {
const cachedModels = this.storageService.getObject<ILanguageModelChatMetadataAndIdentifier[]>(CachedLanguageModelsKey, StorageScope.APPLICATION, []);
const liveModels = this.languageModelsService.getLanguageModelIds()
.map(modelId => ({ identifier: modelId, metadata: this.languageModelsService.lookupLanguageModel(modelId)! }));
// Merge live models with cached models per-vendor. For vendors whose
// models have resolved, use the live data. For vendors that are still
// contributed but haven't resolved yet (startup race), keep their
// cached models. Vendors that are no longer contributed at all (e.g.
// extension uninstalled) are evicted from the cache.
let models: ILanguageModelChatMetadataAndIdentifier[];
const contributedVendors = new Set(this.languageModelsService.getVendors().map(v => v.vendor));
const models = mergeModelsWithCache(liveModels, cachedModels, contributedVendors);
if (liveModels.length > 0) {
const liveVendors = new Set(liveModels.map(m => m.metadata.vendor));
const contributedVendors = new Set(this.languageModelsService.getVendors().map(v => v.vendor));
models = [
...liveModels,
...cachedModels.filter(m => !liveVendors.has(m.metadata.vendor) && contributedVendors.has(m.metadata.vendor)),
];
this.storageService.store(CachedLanguageModelsKey, models, StorageScope.APPLICATION, StorageTarget.MACHINE);
} else {
models = cachedModels;
}
return models;
}
private getModels(): ILanguageModelChatMetadataAndIdentifier[] {
const models = this.getAllMergedModels();
models.sort((a, b) => a.metadata.name.localeCompare(b.metadata.name));
const sessionType = this.getCurrentSessionType();
if (sessionType && sessionType !== AgentSessionProviders.Local) {
// Session has a specific chat session type - show only models that target
// this session type, if any such models exist.
return models.filter(entry => entry.metadata?.targetChatSessionType === sessionType && entry.metadata?.isUserSelectable);
}
// No session type or no targeted models - show general models (those without
// a targetChatSessionType) filtered by the standard criteria.
return models.filter(entry => !entry.metadata?.targetChatSessionType && entry.metadata?.isUserSelectable && this.modelSupportedForDefaultAgent(entry) && this.modelSupportedForInlineChat(entry));
return filterModelsForSession(models, this.getCurrentSessionType(), this.currentModeKind, this.location, !!this.configurationService.getValue(InlineChatConfigKeys.EnableV2));
}
/**
@@ -1122,28 +1100,11 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
* This is used to set the context key that controls model picker visibility.
*/
private hasModelsTargetingSessionType(): boolean {
const sessionType = this.getCurrentSessionType();
if (!sessionType) {
return false;
}
return this.languageModelsService.getLanguageModelIds().some(modelId => {
const metadata = this.languageModelsService.lookupLanguageModel(modelId);
return metadata?.targetChatSessionType === sessionType;
});
return hasModelsTargetingSession(this.getAllMergedModels(), this.getCurrentSessionType());
}
/**
* Check if a model is valid for the current session's model pool.
* If the session has targeted models, the model must target this session type.
* If no models target this session, the model must not have a targetChatSessionType.
*/
private isModelValidForCurrentSession(model: ILanguageModelChatMetadataAndIdentifier): boolean {
if (this.hasModelsTargetingSessionType()) {
// Session has targeted models - model must match
return model.metadata.targetChatSessionType === this.getCurrentSessionType();
}
// No targeted models - model must not be session-specific
return !model.metadata.targetChatSessionType;
return isModelValidForSession(model, this.getAllMergedModels(), this.getCurrentSessionType());
}
/**
@@ -1218,7 +1179,7 @@ export class ChatInputPart extends Disposable implements IHistoryNavigationWidge
private setCurrentLanguageModelToDefault() {
const allModels = this.getModels();
const defaultModel = allModels.find(m => m.metadata.isDefaultForLocation[this.location]) || allModels[0];
const defaultModel = findDefaultModel(allModels, this.location);
if (defaultModel) {
this.setCurrentLanguageModel(defaultModel);
}

View File

@@ -0,0 +1,290 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ChatAgentLocation, ChatModeKind } from '../../../common/constants.js';
import { ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../../common/languageModels.js';
/**
* Describes the context needed for model selection decisions.
*/
export interface IModelSelectionContext {
readonly location: ChatAgentLocation;
readonly currentModeKind: ChatModeKind;
readonly isInlineChatV2Enabled: boolean;
readonly sessionType: string | undefined;
}
/**
* Filter models based on session type.
* When a session has a specific type (and it's not 'local'), only models targeting that
* session type are returned. Otherwise, general-purpose models are returned.
*/
export function filterModelsForSession(
models: ILanguageModelChatMetadataAndIdentifier[],
sessionType: string | undefined,
currentModeKind: ChatModeKind,
location: ChatAgentLocation,
isInlineChatV2Enabled: boolean,
): ILanguageModelChatMetadataAndIdentifier[] {
if (sessionType && sessionType !== 'local' && hasModelsTargetingSession(models, sessionType)) {
return models.filter(entry =>
entry.metadata?.targetChatSessionType === sessionType &&
entry.metadata?.isUserSelectable &&
isModelSupportedForMode(entry, currentModeKind) &&
isModelSupportedForInlineChat(entry, location, isInlineChatV2Enabled)
);
}
return models.filter(entry =>
!entry.metadata?.targetChatSessionType &&
entry.metadata?.isUserSelectable &&
isModelSupportedForMode(entry, currentModeKind) &&
isModelSupportedForInlineChat(entry, location, isInlineChatV2Enabled)
);
}
/**
* Check if a model is suitable for the current chat mode (e.g., agent mode requires tool calling).
*/
export function isModelSupportedForMode(
model: ILanguageModelChatMetadataAndIdentifier,
currentModeKind: ChatModeKind,
): boolean {
if (currentModeKind === ChatModeKind.Agent) {
return ILanguageModelChatMetadata.suitableForAgentMode(model.metadata);
}
return true;
}
/**
* Check if a model is suitable for inline chat (editor inline) usage.
*/
export function isModelSupportedForInlineChat(
model: ILanguageModelChatMetadataAndIdentifier,
location: ChatAgentLocation,
isInlineChatV2Enabled: boolean,
): boolean {
if (location !== ChatAgentLocation.EditorInline || !isInlineChatV2Enabled) {
return true;
}
return !!model.metadata.capabilities?.toolCalling;
}
/**
* Check if any models in the pool target a specific session type.
*/
export function hasModelsTargetingSession(
allModels: ILanguageModelChatMetadataAndIdentifier[],
sessionType: string | undefined,
): boolean {
if (!sessionType) {
return false;
}
return allModels.some(m => m.metadata.targetChatSessionType === sessionType);
}
/**
* Check if a model is valid for the current session's model pool.
* If the session has targeted models, the model must target that session type.
* If no models target this session, the model must not be session-specific.
*/
export function isModelValidForSession(
model: ILanguageModelChatMetadataAndIdentifier,
allModels: ILanguageModelChatMetadataAndIdentifier[],
sessionType: string | undefined,
): boolean {
if (hasModelsTargetingSession(allModels, sessionType)) {
return model.metadata.targetChatSessionType === sessionType;
}
return !model.metadata.targetChatSessionType;
}
/**
* Find the default model for a given location from a list of models.
* Prefers the model marked as default for the location, falls back to the first model.
*/
export function findDefaultModel(
models: ILanguageModelChatMetadataAndIdentifier[],
location: ChatAgentLocation,
): ILanguageModelChatMetadataAndIdentifier | undefined {
return models.find(m => m.metadata.isDefaultForLocation[location]) || models[0];
}
/**
* Determine whether a persisted model selection should be restored.
*
* A persisted model should be restored if:
* 1. The model still exists in the available models list
* 2. Either the model wasn't the default at the time it was persisted,
* OR it is currently the default for the location
*
* This prevents scenarios where a user's explicit model choice gets overridden
* when the default model changes, while still tracking default model changes
* for users who never explicitly chose a model.
*/
export function shouldRestorePersistedModel(
persistedModelId: string,
persistedAsDefault: boolean,
availableModels: ILanguageModelChatMetadataAndIdentifier[],
location: ChatAgentLocation,
): { shouldRestore: boolean; model: ILanguageModelChatMetadataAndIdentifier | undefined } {
const model = availableModels.find(m => m.identifier === persistedModelId);
if (!model) {
return { shouldRestore: false, model: undefined };
}
if (!persistedAsDefault || model.metadata.isDefaultForLocation[location]) {
return { shouldRestore: true, model };
}
return { shouldRestore: false, model };
}
/**
* Determines whether the current model should be reset because it is no longer
* compatible with the current mode, session, or availability.
*
* Returns true if the model should be reset to default.
*/
export function shouldResetModelToDefault(
currentModel: ILanguageModelChatMetadataAndIdentifier | undefined,
availableModels: ILanguageModelChatMetadataAndIdentifier[],
context: IModelSelectionContext,
allModels: ILanguageModelChatMetadataAndIdentifier[],
): boolean {
if (!currentModel) {
return true;
}
// Model is no longer in the available list
if (!availableModels.some(m => m.identifier === currentModel.identifier)) {
return true;
}
// Model not supported for current mode
if (!isModelSupportedForMode(currentModel, context.currentModeKind)) {
return true;
}
// Model not supported for inline chat
if (!isModelSupportedForInlineChat(currentModel, context.location, context.isInlineChatV2Enabled)) {
return true;
}
// Model not valid for current session
if (!isModelValidForSession(currentModel, allModels, context.sessionType)) {
return true;
}
return false;
}
/**
* Determines whether a model from a sync state should be applied to the current view.
*
* Returns an action:
* - `'keep'` - the view already has the same model; no change needed.
* - `'apply'` - the state model is valid; the caller should switch to it.
* - `'default'` - the state model is incompatible (wrong session pool, unsupported
* mode, or missing inline-chat capability); the caller should fall
* back to the default model for the current location.
*
* @param context Optional because some callers (e.g. unit tests, or code paths
* that only care about session-pool validation) don't have a full UI context
* available. When omitted, mode and inline-chat checks are skipped and only
* session-pool membership is validated.
*/
export function resolveModelFromSyncState(
stateModel: ILanguageModelChatMetadataAndIdentifier,
currentModel: ILanguageModelChatMetadataAndIdentifier | undefined,
allModels: ILanguageModelChatMetadataAndIdentifier[],
sessionType: string | undefined,
context?: IModelSelectionContext,
): { action: 'keep' | 'apply' | 'default' } {
// Already the same model — nothing to do
if (currentModel && currentModel.identifier === stateModel.identifier) {
return { action: 'keep' };
}
// Validate the state model belongs to this session's model pool
if (!isModelValidForSession(stateModel, allModels, sessionType)) {
return { action: 'default' };
}
// When a UI context is available, also validate mode and inline-chat compatibility
if (context) {
if (!isModelSupportedForMode(stateModel, context.currentModeKind)) {
return { action: 'default' };
}
if (!isModelSupportedForInlineChat(stateModel, context.location, context.isInlineChatV2Enabled)) {
return { action: 'default' };
}
}
return { action: 'apply' };
}
/**
* Merges live models with cached models per-vendor.
* For vendors whose models have resolved, uses live data.
* For vendors that are contributed but haven't resolved yet (startup race), keeps cached models.
* Vendors no longer contributed are evicted from cache.
*/
export function mergeModelsWithCache(
liveModels: ILanguageModelChatMetadataAndIdentifier[],
cachedModels: ILanguageModelChatMetadataAndIdentifier[],
contributedVendors: Set<string>,
): ILanguageModelChatMetadataAndIdentifier[] {
if (liveModels.length > 0) {
const liveVendors = new Set(liveModels.map(m => m.metadata.vendor));
return [
...liveModels,
...cachedModels.filter(m => !liveVendors.has(m.metadata.vendor) && contributedVendors.has(m.metadata.vendor)),
];
}
return cachedModels;
}
/**
* Determines whether the currently selected model should be reset to default
* when the language model list changes.
*
* Returns true if the model should be reset to default (i.e., the selected model
* is no longer in the available models list).
*/
export function shouldResetOnModelListChange(
currentModelId: string | undefined,
availableModels: ILanguageModelChatMetadataAndIdentifier[],
): boolean {
if (!currentModelId) {
return true;
}
return !availableModels.some(m => m.identifier === currentModelId);
}
/**
* Determines whether a late-arriving persisted model should be restored.
* This handles the startup race where the model wasn't available during
* `initSelectedModel` but arrives later via `onDidChangeLanguageModels`.
*
* The model must pass both the persisted-default check and the `isUserSelectable` check.
*/
export function shouldRestoreLateArrivingModel(
persistedModelId: string,
persistedAsDefault: boolean,
model: ILanguageModelChatMetadataAndIdentifier,
location: ChatAgentLocation,
): boolean {
if (!model.metadata.isUserSelectable) {
return false;
}
const result = shouldRestorePersistedModel(
persistedModelId,
persistedAsDefault,
[model],
location,
);
return result.shouldRestore;
}