diff --git a/extensions/microsoft-authentication/src/node/cachedPublicClientApplication.ts b/extensions/microsoft-authentication/src/node/cachedPublicClientApplication.ts index 8807fc3098e..c679610466a 100644 --- a/extensions/microsoft-authentication/src/node/cachedPublicClientApplication.ts +++ b/extensions/microsoft-authentication/src/node/cachedPublicClientApplication.ts @@ -5,7 +5,7 @@ import { PublicClientApplication, AccountInfo, Configuration, SilentFlowRequest, AuthenticationResult, InteractiveRequest, LogLevel } from '@azure/msal-node'; import { Disposable, Memento, SecretStorage, LogOutputChannel, window, ProgressLocation, l10n, EventEmitter } from 'vscode'; -import { raceCancellationAndTimeoutError } from '../common/async'; +import { Delayer, raceCancellationAndTimeoutError } from '../common/async'; import { SecretStorageCachePlugin } from '../common/cachePlugin'; import { MsalLoggerOptions } from '../common/loggerOptions'; import { ICachedPublicClientApplication } from '../common/publicClientCache'; @@ -13,6 +13,7 @@ import { ICachedPublicClientApplication } from '../common/publicClientCache'; export class CachedPublicClientApplication implements ICachedPublicClientApplication { private _pca: PublicClientApplication; private _sequencer = new Sequencer(); + private readonly _refreshDelayer = new DelayerByKey(); private _accounts: AccountInfo[] = []; private readonly _disposable: Disposable; @@ -89,6 +90,7 @@ export class CachedPublicClientApplication implements ICachedPublicClientApplica this._logger.debug(`[acquireTokenSilent] [${this._clientId}] [${this._authority}] [${request.scopes.join(' ')}] [${request.account.username}] got result`); if (result.account && !result.fromCache) { this._logger.debug(`[acquireTokenSilent] [${this._clientId}] [${this._authority}] [${request.scopes.join(' ')}] [${request.account.username}] firing event due to change`); + this._setupRefresh(result); this._onDidAccountsChangeEmitter.fire({ added: [], changed: [result.account], deleted: [] }); } return result; @@ -96,7 +98,7 @@ export class CachedPublicClientApplication implements ICachedPublicClientApplica async acquireTokenInteractive(request: InteractiveRequest): Promise { this._logger.debug(`[acquireTokenInteractive] [${this._clientId}] [${this._authority}] [${request.scopes?.join(' ')}] loopbackClientOverride: ${request.loopbackClient ? 'true' : 'false'}`); - return await window.withProgress( + const result = await window.withProgress( { location: ProgressLocation.Notification, cancellable: true, @@ -108,6 +110,8 @@ export class CachedPublicClientApplication implements ICachedPublicClientApplica 1000 * 60 * 5 ) ); + this._setupRefresh(result); + return result; } removeAccount(account: AccountInfo): Promise { @@ -149,6 +153,25 @@ export class CachedPublicClientApplication implements ICachedPublicClientApplica } this._logger.debug(`[update] [${this._clientId}] [${this._authority}] CachedPublicClientApplication update complete`); } + + private _setupRefresh(result: AuthenticationResult) { + const on = result.refreshOn || result.expiresOn; + if (!result.account || !on) { + return; + } + + const account = result.account; + const scopes = result.scopes; + const timeToRefresh = on.getTime() - Date.now() - 5 * 60 * 1000; // 5 minutes before expiry + const key = JSON.stringify({ accountId: account.homeAccountId, scopes }); + this._logger.debug(`[_setupRefresh] [${this._clientId}] [${this._authority}] [${scopes.join(' ')}] [${account.username}] timeToRefresh: ${timeToRefresh}`); + this._refreshDelayer.trigger( + key, + // This may need the redirectUri when we switch to the broker + () => this.acquireTokenSilent({ account, scopes, redirectUri: undefined, forceRefresh: true }), + timeToRefresh > 0 ? timeToRefresh : 0 + ); + } } export class Sequencer { @@ -159,3 +182,17 @@ export class Sequencer { return this.current = this.current.then(() => promiseTask(), () => promiseTask()); } } + +class DelayerByKey { + private _delayers = new Map>(); + + trigger(key: string, fn: () => Promise, delay: number): Promise { + let delayer = this._delayers.get(key); + if (!delayer) { + delayer = new Delayer(delay); + this._delayers.set(key, delayer); + } + + return delayer.trigger(fn, delay); + } +}