From 702a1ffd587b5fa1ed2fed03a77f0989eac4eb94 Mon Sep 17 00:00:00 2001 From: Tyler James Leonhardt Date: Tue, 13 Feb 2024 16:15:48 -0800 Subject: [PATCH] Access tweaks for requestLanguageModelAccess (#205156) 1. remove the requirement that it has to be done during agent invocation 2. don't ask for auth when the model provider and the model requester are the same extension 3. since we don't have "language model activation events" start with a simple 3*2s timeout poll to wait for the language model registration to happen. (scenario: an extension activates before the extension that registers the model activates) --- .../api/browser/mainThreadChatProvider.ts | 11 ++++- .../workbench/api/common/extHost.api.impl.ts | 2 +- .../workbench/api/common/extHost.protocol.ts | 1 - .../api/common/extHostChatAgents2.ts | 5 --- .../api/common/extHostChatProvider.ts | 42 ++++++------------- 5 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/vs/workbench/api/browser/mainThreadChatProvider.ts b/src/vs/workbench/api/browser/mainThreadChatProvider.ts index 801b0139627..d8a197db254 100644 --- a/src/vs/workbench/api/browser/mainThreadChatProvider.ts +++ b/src/vs/workbench/api/browser/mainThreadChatProvider.ts @@ -3,6 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import { timeout } from 'vs/base/common/async'; import { CancellationToken } from 'vs/base/common/cancellation'; import { Emitter, Event } from 'vs/base/common/event'; import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from 'vs/base/common/lifecycle'; @@ -81,7 +82,15 @@ export class MainThreadChatProvider implements MainThreadChatProviderShape { } async $prepareChatAccess(extension: ExtensionIdentifier, providerId: string, justification?: string): Promise { - return this._chatProviderService.lookupChatResponseProvider(providerId); + const metadata = this._chatProviderService.lookupChatResponseProvider(providerId); + // TODO: This should use a real activation event. Perhaps following what authentication does. + for (let i = 0; i < 3; i++) { + if (metadata) { + return metadata; + } + await timeout(2000); + } + return undefined; } async $fetchResponse(extension: ExtensionIdentifier, providerId: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise { diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index a588532185d..2757392e5f5 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -208,7 +208,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I rpcProtocol.set(ExtHostContext.ExtHostInteractive, new ExtHostInteractive(rpcProtocol, extHostNotebook, extHostDocumentsAndEditors, extHostCommands, extHostLogService)); const extHostInteractiveEditor = rpcProtocol.set(ExtHostContext.ExtHostInlineChat, new ExtHostInteractiveEditor(rpcProtocol, extHostCommands, extHostDocuments, extHostLogService)); const extHostChatProvider = rpcProtocol.set(ExtHostContext.ExtHostChatProvider, new ExtHostChatProvider(rpcProtocol, extHostLogService, extHostAuthentication)); - const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostChatProvider, extHostLogService, extHostCommands)); + const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostLogService, extHostCommands)); const extHostChatVariables = rpcProtocol.set(ExtHostContext.ExtHostChatVariables, new ExtHostChatVariables(rpcProtocol)); const extHostChat = rpcProtocol.set(ExtHostContext.ExtHostChat, new ExtHostChat(rpcProtocol)); const extHostAiRelatedInformation = rpcProtocol.set(ExtHostContext.ExtHostAiRelatedInformation, new ExtHostRelatedInformation(rpcProtocol)); diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index b0fbf742fbc..b657fa82840 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1183,7 +1183,6 @@ export interface MainThreadChatProviderShape extends IDisposable { export interface ExtHostChatProviderShape { $updateLanguageModels(data: { added?: string[]; removed?: string[] }): void; - $updateAccesslist(data: { extension: ExtensionIdentifier; enabled: boolean }[]): void; $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void; $provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise; $handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise; diff --git a/src/vs/workbench/api/common/extHostChatAgents2.ts b/src/vs/workbench/api/common/extHostChatAgents2.ts index bd233330fb2..efaff6f31c9 100644 --- a/src/vs/workbench/api/common/extHostChatAgents2.ts +++ b/src/vs/workbench/api/common/extHostChatAgents2.ts @@ -16,7 +16,6 @@ import { localize } from 'vs/nls'; import { IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { ILogService } from 'vs/platform/log/common/log'; import { ExtHostChatAgentsShape2, IChatAgentCompletionItem, IChatAgentHistoryEntryDto, IMainContext, MainContext, MainThreadChatAgentsShape2 } from 'vs/workbench/api/common/extHost.protocol'; -import { ExtHostChatProvider } from 'vs/workbench/api/common/extHostChatProvider'; import { CommandsConverter, ExtHostCommands } from 'vs/workbench/api/common/extHostCommands'; import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; import * as extHostTypes from 'vs/workbench/api/common/extHostTypes'; @@ -164,7 +163,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { constructor( mainContext: IMainContext, - private readonly _extHostChatProvider: ExtHostChatProvider, private readonly _logService: ILogService, private readonly commands: ExtHostCommands, ) { @@ -186,8 +184,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { throw new Error(`[CHAT](${handle}) CANNOT invoke agent because the agent is not registered`); } - this._extHostChatProvider.$updateAccesslist([{ extension: agent.extension.identifier, enabled: true }]); - // Init session disposables let sessionDisposables = this._sessionDisposables.get(request.sessionId); if (!sessionDisposables) { @@ -224,7 +220,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { } finally { stream.close(); - this._extHostChatProvider.$updateAccesslist([{ extension: agent.extension.identifier, enabled: false }]); } } diff --git a/src/vs/workbench/api/common/extHostChatProvider.ts b/src/vs/workbench/api/common/extHostChatProvider.ts index 923a2db9eda..51b9c7fee3b 100644 --- a/src/vs/workbench/api/common/extHostChatProvider.ts +++ b/src/vs/workbench/api/common/extHostChatProvider.ts @@ -10,7 +10,7 @@ import { ExtHostChatProviderShape, IMainContext, MainContext, MainThreadChatProv import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; import type * as vscode from 'vscode'; import { Progress } from 'vs/platform/progress/common/progress'; -import { IChatMessage, IChatResponseFragment } from 'vs/workbench/contrib/chat/common/chatProvider'; +import { IChatMessage, IChatResponseFragment, IChatResponseProviderMetadata } from 'vs/workbench/contrib/chat/common/chatProvider'; import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { AsyncIterableSource } from 'vs/base/common/async'; import { Emitter, Event } from 'vs/base/common/event'; @@ -97,7 +97,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { private readonly _languageModels = new Map(); private readonly _languageModelIds = new Set(); // these are ALL models, not just the one in this EH - private readonly _accesslist = new ExtensionIdentifierMap(); private readonly _modelAccessList = new ExtensionIdentifierMap(); private readonly _pendingRequest = new Map(); @@ -197,18 +196,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return Array.from(this._languageModelIds); } - $updateAccesslist(data: { extension: ExtensionIdentifier; enabled: boolean }[]): void { - const updated = new ExtensionIdentifierSet(); - for (const { extension, enabled } of data) { - const oldValue = this._accesslist.get(extension); - if (oldValue !== enabled) { - this._accesslist.set(extension, enabled); - updated.add(extension); - } - } - this._onDidChangeAccess.fire(updated); - } - $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void { const updated = new Array<{ from: ExtensionIdentifier; to: ExtensionIdentifier }>(); for (const { from, to, enabled } of data) { @@ -230,23 +217,15 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { async requestLanguageModelAccess(extension: IExtensionDescription, languageModelId: string, options?: vscode.LanguageModelAccessOptions): Promise { const from = extension.identifier; - // check if the extension is in the access list and allowed to make chat requests - if (this._accesslist.get(from) === false) { - throw new Error('Extension is NOT allowed to make chat requests'); - } - const justification = options?.justification; const metadata = await this._proxy.$prepareChatAccess(from, languageModelId, justification); if (!metadata) { - if (!this._accesslist.get(from)) { - throw new Error('Extension is NOT allowed to make chat requests'); - } throw new Error(`Language model '${languageModelId}' NOT found`); } - if (metadata.auth) { - await this._checkAuthAccess(extension, { identifier: metadata.extension, displayName: metadata.auth?.providerLabel }, justification); + if (this._isUsingAuth(from, metadata)) { + await this._getAuthAccess(extension, { identifier: metadata.extension, displayName: metadata.auth.providerLabel }, justification); } const that = this; @@ -256,9 +235,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return metadata.model; }, get isRevoked() { - return !that._accesslist.get(from) - || (metadata.auth && !that._modelAccessList.get(from)?.has(metadata.extension)) - || !that._languageModelIds.has(languageModelId); + return (that._isUsingAuth(from, metadata) && !that._modelAccessList.get(from)?.has(metadata.extension)) || !that._languageModelIds.has(languageModelId); }, get onDidChangeAccess() { const onDidChangeAccess = Event.filter(that._onDidChangeAccess.event, set => set.has(from)); @@ -267,7 +244,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return Event.signal(Event.any(onDidChangeAccess, onDidRemoveLM, onDidChangeModelAccess)); }, makeChatRequest(messages, options, token) { - if (!that._accesslist.get(from) || (metadata.auth && !that._modelAccessList.get(from)?.has(metadata.extension))) { + if (that._isUsingAuth(from, metadata) && !that._modelAccessList.get(from)?.has(metadata.extension)) { throw new Error('Access to chat has been revoked'); } if (!that._languageModelIds.has(languageModelId)) { @@ -297,7 +274,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { } // BIG HACK: Using AuthenticationProviders to check access to Language Models - private async _checkAuthAccess(from: IExtensionDescription, to: { identifier: ExtensionIdentifier; displayName: string }, detail?: string): Promise { + private async _getAuthAccess(from: IExtensionDescription, to: { identifier: ExtensionIdentifier; displayName: string }, detail?: string): Promise { // This needs to be done in both MainThread & ExtHost ChatProvider const providerId = INTERNAL_AUTH_PROVIDER_PREFIX + to.identifier.value; const session = await this._extHostAuthentication.getSession(from, providerId, [], { silent: true }); @@ -315,4 +292,11 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { this.$updateModelAccesslist([{ from: from.identifier, to: to.identifier, enabled: true }]); } + + private _isUsingAuth(from: ExtensionIdentifier, toMetadata: IChatResponseProviderMetadata): toMetadata is IChatResponseProviderMetadata & { auth: NonNullable } { + // If the 'to' extension uses an auth check + return !!toMetadata.auth + // And we're asking from a different extension + && !ExtensionIdentifier.equals(toMetadata.extension, from); + } }