diff --git a/src/vs/workbench/api/browser/mainThreadChatProvider.ts b/src/vs/workbench/api/browser/mainThreadChatProvider.ts index 801b0139627..d8a197db254 100644 --- a/src/vs/workbench/api/browser/mainThreadChatProvider.ts +++ b/src/vs/workbench/api/browser/mainThreadChatProvider.ts @@ -3,6 +3,7 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import { timeout } from 'vs/base/common/async'; import { CancellationToken } from 'vs/base/common/cancellation'; import { Emitter, Event } from 'vs/base/common/event'; import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from 'vs/base/common/lifecycle'; @@ -81,7 +82,15 @@ export class MainThreadChatProvider implements MainThreadChatProviderShape { } async $prepareChatAccess(extension: ExtensionIdentifier, providerId: string, justification?: string): Promise { - return this._chatProviderService.lookupChatResponseProvider(providerId); + const metadata = this._chatProviderService.lookupChatResponseProvider(providerId); + // TODO: This should use a real activation event. Perhaps following what authentication does. + for (let i = 0; i < 3; i++) { + if (metadata) { + return metadata; + } + await timeout(2000); + } + return undefined; } async $fetchResponse(extension: ExtensionIdentifier, providerId: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise { diff --git a/src/vs/workbench/api/common/extHost.api.impl.ts b/src/vs/workbench/api/common/extHost.api.impl.ts index a588532185d..2757392e5f5 100644 --- a/src/vs/workbench/api/common/extHost.api.impl.ts +++ b/src/vs/workbench/api/common/extHost.api.impl.ts @@ -208,7 +208,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I rpcProtocol.set(ExtHostContext.ExtHostInteractive, new ExtHostInteractive(rpcProtocol, extHostNotebook, extHostDocumentsAndEditors, extHostCommands, extHostLogService)); const extHostInteractiveEditor = rpcProtocol.set(ExtHostContext.ExtHostInlineChat, new ExtHostInteractiveEditor(rpcProtocol, extHostCommands, extHostDocuments, extHostLogService)); const extHostChatProvider = rpcProtocol.set(ExtHostContext.ExtHostChatProvider, new ExtHostChatProvider(rpcProtocol, extHostLogService, extHostAuthentication)); - const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostChatProvider, extHostLogService, extHostCommands)); + const extHostChatAgents2 = rpcProtocol.set(ExtHostContext.ExtHostChatAgents2, new ExtHostChatAgents2(rpcProtocol, extHostLogService, extHostCommands)); const extHostChatVariables = rpcProtocol.set(ExtHostContext.ExtHostChatVariables, new ExtHostChatVariables(rpcProtocol)); const extHostChat = rpcProtocol.set(ExtHostContext.ExtHostChat, new ExtHostChat(rpcProtocol)); const extHostAiRelatedInformation = rpcProtocol.set(ExtHostContext.ExtHostAiRelatedInformation, new ExtHostRelatedInformation(rpcProtocol)); diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index b0fbf742fbc..b657fa82840 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1183,7 +1183,6 @@ export interface MainThreadChatProviderShape extends IDisposable { export interface ExtHostChatProviderShape { $updateLanguageModels(data: { added?: string[]; removed?: string[] }): void; - $updateAccesslist(data: { extension: ExtensionIdentifier; enabled: boolean }[]): void; $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void; $provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise; $handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise; diff --git a/src/vs/workbench/api/common/extHostChatAgents2.ts b/src/vs/workbench/api/common/extHostChatAgents2.ts index bd233330fb2..efaff6f31c9 100644 --- a/src/vs/workbench/api/common/extHostChatAgents2.ts +++ b/src/vs/workbench/api/common/extHostChatAgents2.ts @@ -16,7 +16,6 @@ import { localize } from 'vs/nls'; import { IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { ILogService } from 'vs/platform/log/common/log'; import { ExtHostChatAgentsShape2, IChatAgentCompletionItem, IChatAgentHistoryEntryDto, IMainContext, MainContext, MainThreadChatAgentsShape2 } from 'vs/workbench/api/common/extHost.protocol'; -import { ExtHostChatProvider } from 'vs/workbench/api/common/extHostChatProvider'; import { CommandsConverter, ExtHostCommands } from 'vs/workbench/api/common/extHostCommands'; import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; import * as extHostTypes from 'vs/workbench/api/common/extHostTypes'; @@ -164,7 +163,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { constructor( mainContext: IMainContext, - private readonly _extHostChatProvider: ExtHostChatProvider, private readonly _logService: ILogService, private readonly commands: ExtHostCommands, ) { @@ -186,8 +184,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { throw new Error(`[CHAT](${handle}) CANNOT invoke agent because the agent is not registered`); } - this._extHostChatProvider.$updateAccesslist([{ extension: agent.extension.identifier, enabled: true }]); - // Init session disposables let sessionDisposables = this._sessionDisposables.get(request.sessionId); if (!sessionDisposables) { @@ -224,7 +220,6 @@ export class ExtHostChatAgents2 implements ExtHostChatAgentsShape2 { } finally { stream.close(); - this._extHostChatProvider.$updateAccesslist([{ extension: agent.extension.identifier, enabled: false }]); } } diff --git a/src/vs/workbench/api/common/extHostChatProvider.ts b/src/vs/workbench/api/common/extHostChatProvider.ts index 923a2db9eda..51b9c7fee3b 100644 --- a/src/vs/workbench/api/common/extHostChatProvider.ts +++ b/src/vs/workbench/api/common/extHostChatProvider.ts @@ -10,7 +10,7 @@ import { ExtHostChatProviderShape, IMainContext, MainContext, MainThreadChatProv import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters'; import type * as vscode from 'vscode'; import { Progress } from 'vs/platform/progress/common/progress'; -import { IChatMessage, IChatResponseFragment } from 'vs/workbench/contrib/chat/common/chatProvider'; +import { IChatMessage, IChatResponseFragment, IChatResponseProviderMetadata } from 'vs/workbench/contrib/chat/common/chatProvider'; import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IExtensionDescription } from 'vs/platform/extensions/common/extensions'; import { AsyncIterableSource } from 'vs/base/common/async'; import { Emitter, Event } from 'vs/base/common/event'; @@ -97,7 +97,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { private readonly _languageModels = new Map(); private readonly _languageModelIds = new Set(); // these are ALL models, not just the one in this EH - private readonly _accesslist = new ExtensionIdentifierMap(); private readonly _modelAccessList = new ExtensionIdentifierMap(); private readonly _pendingRequest = new Map(); @@ -197,18 +196,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return Array.from(this._languageModelIds); } - $updateAccesslist(data: { extension: ExtensionIdentifier; enabled: boolean }[]): void { - const updated = new ExtensionIdentifierSet(); - for (const { extension, enabled } of data) { - const oldValue = this._accesslist.get(extension); - if (oldValue !== enabled) { - this._accesslist.set(extension, enabled); - updated.add(extension); - } - } - this._onDidChangeAccess.fire(updated); - } - $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void { const updated = new Array<{ from: ExtensionIdentifier; to: ExtensionIdentifier }>(); for (const { from, to, enabled } of data) { @@ -230,23 +217,15 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { async requestLanguageModelAccess(extension: IExtensionDescription, languageModelId: string, options?: vscode.LanguageModelAccessOptions): Promise { const from = extension.identifier; - // check if the extension is in the access list and allowed to make chat requests - if (this._accesslist.get(from) === false) { - throw new Error('Extension is NOT allowed to make chat requests'); - } - const justification = options?.justification; const metadata = await this._proxy.$prepareChatAccess(from, languageModelId, justification); if (!metadata) { - if (!this._accesslist.get(from)) { - throw new Error('Extension is NOT allowed to make chat requests'); - } throw new Error(`Language model '${languageModelId}' NOT found`); } - if (metadata.auth) { - await this._checkAuthAccess(extension, { identifier: metadata.extension, displayName: metadata.auth?.providerLabel }, justification); + if (this._isUsingAuth(from, metadata)) { + await this._getAuthAccess(extension, { identifier: metadata.extension, displayName: metadata.auth.providerLabel }, justification); } const that = this; @@ -256,9 +235,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return metadata.model; }, get isRevoked() { - return !that._accesslist.get(from) - || (metadata.auth && !that._modelAccessList.get(from)?.has(metadata.extension)) - || !that._languageModelIds.has(languageModelId); + return (that._isUsingAuth(from, metadata) && !that._modelAccessList.get(from)?.has(metadata.extension)) || !that._languageModelIds.has(languageModelId); }, get onDidChangeAccess() { const onDidChangeAccess = Event.filter(that._onDidChangeAccess.event, set => set.has(from)); @@ -267,7 +244,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { return Event.signal(Event.any(onDidChangeAccess, onDidRemoveLM, onDidChangeModelAccess)); }, makeChatRequest(messages, options, token) { - if (!that._accesslist.get(from) || (metadata.auth && !that._modelAccessList.get(from)?.has(metadata.extension))) { + if (that._isUsingAuth(from, metadata) && !that._modelAccessList.get(from)?.has(metadata.extension)) { throw new Error('Access to chat has been revoked'); } if (!that._languageModelIds.has(languageModelId)) { @@ -297,7 +274,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { } // BIG HACK: Using AuthenticationProviders to check access to Language Models - private async _checkAuthAccess(from: IExtensionDescription, to: { identifier: ExtensionIdentifier; displayName: string }, detail?: string): Promise { + private async _getAuthAccess(from: IExtensionDescription, to: { identifier: ExtensionIdentifier; displayName: string }, detail?: string): Promise { // This needs to be done in both MainThread & ExtHost ChatProvider const providerId = INTERNAL_AUTH_PROVIDER_PREFIX + to.identifier.value; const session = await this._extHostAuthentication.getSession(from, providerId, [], { silent: true }); @@ -315,4 +292,11 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape { this.$updateModelAccesslist([{ from: from.identifier, to: to.identifier, enabled: true }]); } + + private _isUsingAuth(from: ExtensionIdentifier, toMetadata: IChatResponseProviderMetadata): toMetadata is IChatResponseProviderMetadata & { auth: NonNullable } { + // If the 'to' extension uses an auth check + return !!toMetadata.auth + // And we're asking from a different extension + && !ExtensionIdentifier.equals(toMetadata.extension, from); + } }