diff --git a/extensions/microsoft-authentication/media/auth.html b/extensions/microsoft-authentication/media/index.html
similarity index 93%
rename from extensions/microsoft-authentication/media/auth.html
rename to extensions/microsoft-authentication/media/index.html
index 0fcba4e3c62..9c0a9eec080 100644
--- a/extensions/microsoft-authentication/media/auth.html
+++ b/extensions/microsoft-authentication/media/index.html
@@ -1,10 +1,9 @@
-
+
-
Azure Account - Sign In
diff --git a/extensions/microsoft-authentication/src/AADHelper.ts b/extensions/microsoft-authentication/src/AADHelper.ts
index 72408f855a1..6b7d1a85131 100644
--- a/extensions/microsoft-authentication/src/AADHelper.ts
+++ b/extensions/microsoft-authentication/src/AADHelper.ts
@@ -10,12 +10,13 @@ import * as vscode from 'vscode';
import * as nls from 'vscode-nls';
import { v4 as uuid } from 'uuid';
import fetch, { Response } from 'node-fetch';
-import { createServer, startServer } from './authServer';
import { Keychain } from './keychain';
import Logger from './logger';
import { toBase64UrlEncoding } from './utils';
import { sha256 } from './env/node/sha256';
import { BetterTokenStorage, IDidChangeInOtherWindowEvent } from './betterSecretStorage';
+import { LoopbackAuthServer } from './authServer';
+import path = require('path');
const localize = nls.loadMessageBundle();
@@ -238,63 +239,42 @@ export class AzureActiveDirectoryService {
}
private async createSessionWithLocalServer(scopeData: IScopeData) {
- const nonce = randomBytes(16).toString('base64');
- const { server, redirectPromise, codePromise } = createServer(nonce);
+ const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64'));
+ const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier));
+ const qs = querystring.stringify({
+ response_type: 'code',
+ response_mode: 'query',
+ client_id: scopeData.clientId,
+ redirect_uri: redirectUrl,
+ scope: scopeData.scopesToSend,
+ prompt: 'select_account',
+ code_challenge_method: 'S256',
+ code_challenge: codeChallenge,
+ });
+ const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?${qs}`;
+ const server = new LoopbackAuthServer(path.join(__dirname, '../media'), loginUrl);
+ await server.start();
+ server.state = `${server.port},${encodeURIComponent(server.nonce)}`;
- let token: IToken | undefined;
+ let codeToExchange;
try {
- const port = await startServer(server);
- vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${port}/signin?nonce=${encodeURIComponent(nonce)}`));
-
- const redirectReq = await redirectPromise;
- if ('err' in redirectReq) {
- const { err, res } = redirectReq;
- res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` });
- res.end();
- throw err;
- }
-
- const host = redirectReq.req.headers.host || '';
- const updatedPortStr = (/^[^:]+:(\d+)$/.exec(Array.isArray(host) ? host[0] : host) || [])[1];
- const updatedPort = updatedPortStr ? parseInt(updatedPortStr, 10) : port;
-
- const state = `${updatedPort},${encodeURIComponent(nonce)}`;
-
- const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64'));
- const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier));
-
- const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=${encodeURIComponent(scopeData.clientId)}&redirect_uri=${encodeURIComponent(redirectUrl)}&state=${state}&scope=${encodeURIComponent(scopeData.scopesToSend)}&prompt=select_account&code_challenge_method=S256&code_challenge=${codeChallenge}`;
-
- redirectReq.res.writeHead(302, { Location: loginUrl });
- redirectReq.res.end();
-
- const codeRes = await codePromise;
- const res = codeRes.res;
-
- try {
- if ('err' in codeRes) {
- throw codeRes.err;
- }
- token = await this.exchangeCodeForToken(codeRes.code, codeVerifier, scopeData);
- if (token.expiresIn) {
- this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER);
- }
- await this.setToken(token, scopeData);
- Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`);
- res.writeHead(302, { Location: '/' });
- const session = await this.convertToSession(token);
- return session;
- } catch (err) {
- res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` });
- throw err;
- } finally {
- res.end();
- }
+ vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${server.port}/signin?nonce=${encodeURIComponent(server.nonce)}`));
+ const { code } = await server.waitForOAuthResponse();
+ codeToExchange = code;
} finally {
setTimeout(() => {
- server.close();
+ void server.stop();
}, 5000);
}
+
+ const token = await this.exchangeCodeForToken(codeToExchange, codeVerifier, scopeData);
+ if (token.expiresIn) {
+ this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER);
+ }
+ await this.setToken(token, scopeData);
+ Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`);
+ const session = await this.convertToSession(token);
+ return session;
}
private async createSessionWithoutLocalServer(scopeData: IScopeData): Promise {
diff --git a/extensions/microsoft-authentication/src/authServer.ts b/extensions/microsoft-authentication/src/authServer.ts
index 09dec9d4619..158d0db257f 100644
--- a/extensions/microsoft-authentication/src/authServer.ts
+++ b/extensions/microsoft-authentication/src/authServer.ts
@@ -2,65 +2,13 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
-
import * as http from 'http';
-import * as url from 'url';
+import { URL } from 'url';
import * as fs from 'fs';
import * as path from 'path';
+import { randomBytes } from 'crypto';
-interface Deferred {
- resolve: (result: T | Promise) => void;
- reject: (reason: any) => void;
-}
-
-/**
- * Asserts that the argument passed in is neither undefined nor null.
- */
-function assertIsDefined(arg: T | null | undefined): T {
- if (typeof (arg) === 'undefined' || arg === null) {
- throw new Error('Assertion Failed: argument is undefined or null');
- }
-
- return arg;
-}
-
-export async function startServer(server: http.Server): Promise {
- let portTimer: NodeJS.Timer;
-
- function cancelPortTimer() {
- clearTimeout(portTimer);
- }
-
- const port = new Promise((resolve, reject) => {
- portTimer = setTimeout(() => {
- reject(new Error('Timeout waiting for port'));
- }, 5000);
-
- server.on('listening', () => {
- const address = server.address();
- if (typeof address === 'string') {
- resolve(address);
- } else {
- resolve(assertIsDefined(address).port.toString());
- }
- });
-
- server.on('error', _ => {
- reject(new Error('Error listening to server'));
- });
-
- server.on('close', () => {
- reject(new Error('Closed'));
- });
-
- server.listen(0, '127.0.0.1');
- });
-
- port.then(cancelPortTimer, cancelPortTimer);
- return port;
-}
-
-function sendFile(res: http.ServerResponse, filepath: string, contentType: string) {
+function sendFile(res: http.ServerResponse, filepath: string) {
fs.readFile(filepath, (err, body) => {
if (err) {
console.error(err);
@@ -68,89 +16,173 @@ function sendFile(res: http.ServerResponse, filepath: string, contentType: strin
res.end();
} else {
res.writeHead(200, {
- 'Content-Length': body.length,
- 'Content-Type': contentType
+ 'content-length': body.length,
});
res.end(body);
}
});
}
-async function callback(nonce: string, reqUrl: url.Url): Promise {
- const query = reqUrl.query;
- if (!query || typeof query === 'string') {
- throw new Error('No query received.');
- }
-
- let error = query.error_description || query.error;
-
- if (!error) {
- const state = (query.state as string) || '';
- const receivedNonce = (state.split(',')[1] || '').replace(/ /g, '+');
- if (receivedNonce !== nonce) {
- error = 'Nonce does not match.';
- }
- }
-
- const code = query.code as string;
- if (!error && code) {
- return code;
- }
-
- throw new Error((error as string) || 'No code received.');
+interface IOAuthResult {
+ code: string;
+ state: string;
}
-export function createServer(nonce: string) {
- type RedirectResult = { req: http.IncomingMessage; res: http.ServerResponse } | { err: any; res: http.ServerResponse };
- let deferredRedirect: Deferred;
- const redirectPromise = new Promise((resolve, reject) => deferredRedirect = { resolve, reject });
+interface ILoopbackServer {
+ /**
+ * If undefined, the server is not started yet.
+ */
+ port: number | undefined;
- type CodeResult = { code: string; res: http.ServerResponse } | { err: any; res: http.ServerResponse };
- let deferredCode: Deferred;
- const codePromise = new Promise((resolve, reject) => deferredCode = { resolve, reject });
+ /**
+ * The nonce used
+ */
+ nonce: string;
- const codeTimer = setTimeout(() => {
- deferredCode.reject(new Error('Timeout waiting for code'));
- }, 5 * 60 * 1000);
+ /**
+ * The state parameter used in the OAuth flow.
+ */
+ state: string | undefined;
- function cancelCodeTimer() {
- clearTimeout(codeTimer);
+ /**
+ * Starts the server.
+ * @returns The port to listen on.
+ * @throws If the server fails to start.
+ * @throws If the server is already started.
+ */
+ start(): Promise;
+ /**
+ * Stops the server.
+ * @throws If the server is not started.
+ * @throws If the server fails to stop.
+ */
+ stop(): Promise;
+ /**
+ * Returns a promise that resolves to the result of the OAuth flow.
+ */
+ waitForOAuthResponse(): Promise;
+}
+
+export class LoopbackAuthServer implements ILoopbackServer {
+ private readonly _server: http.Server;
+ private readonly _resultPromise: Promise;
+ private _startingRedirect: URL;
+
+ public nonce = randomBytes(16).toString('base64');
+ public port: number | undefined;
+
+ public set state(state: string | undefined) {
+ if (state) {
+ this._startingRedirect.searchParams.set('state', state);
+ } else {
+ this._startingRedirect.searchParams.delete('state');
+ }
+ }
+ public get state(): string | undefined {
+ return this._startingRedirect.searchParams.get('state') ?? undefined;
}
- const server = http.createServer(function (req, res) {
- const reqUrl = url.parse(req.url!, /* parseQueryString */ true);
- switch (reqUrl.pathname) {
- case '/signin': {
- const receivedNonce = ((reqUrl.query.nonce as string) || '').replace(/ /g, '+');
- if (receivedNonce === nonce) {
- deferredRedirect.resolve({ req, res });
- } else {
- const err = new Error('Nonce does not match.');
- deferredRedirect.resolve({ err, res });
+ constructor(serveRoot: string, startingRedirect: string) {
+ if (!serveRoot) {
+ throw new Error('serveRoot must be defined');
+ }
+ if (!startingRedirect) {
+ throw new Error('startingRedirect must be defined');
+ }
+ this._startingRedirect = new URL(startingRedirect);
+ let deferred: { resolve: (result: IOAuthResult) => void; reject: (reason: any) => void };
+ this._resultPromise = new Promise((resolve, reject) => deferred = { resolve, reject });
+
+ this._server = http.createServer((req, res) => {
+ const reqUrl = new URL(req.url!, `http://${req.headers.host}`);
+ switch (reqUrl.pathname) {
+ case '/signin': {
+ const receivedNonce = (reqUrl.searchParams.get('nonce') ?? '').replace(/ /g, '+');
+ if (receivedNonce !== this.nonce) {
+ res.writeHead(302, { location: `/?error=${encodeURIComponent('Nonce does not match.')}` });
+ res.end();
+ }
+ res.writeHead(302, { location: this._startingRedirect.toString() });
+ res.end();
+ break;
}
- break;
+ case '/callback': {
+ const code = reqUrl.searchParams.get('code') ?? undefined;
+ const state = reqUrl.searchParams.get('state') ?? undefined;
+ if (!code || !state) {
+ res.writeHead(400);
+ res.end();
+ return;
+ }
+ if (this.state !== state) {
+ res.writeHead(302, { location: `/?error=${encodeURIComponent('State does not match.')}` });
+ res.end();
+ throw new Error('State does not match.');
+ }
+ deferred.resolve({ code, state });
+ res.writeHead(302, { location: '/' });
+ res.end();
+ break;
+ }
+ // Serve the static files
+ case '/':
+ sendFile(res, path.join(serveRoot, 'index.html'));
+ break;
+ default:
+ // substring to get rid of leading '/'
+ sendFile(res, path.join(serveRoot, reqUrl.pathname.substring(1)));
+ break;
}
- case '/':
- sendFile(res, path.join(__dirname, '../media/auth.html'), 'text/html; charset=utf-8');
- break;
- case '/auth.css':
- sendFile(res, path.join(__dirname, '../media/auth.css'), 'text/css; charset=utf-8');
- break;
- case '/callback':
- deferredCode.resolve(callback(nonce, reqUrl)
- .then(code => ({ code, res }), err => ({ err, res })));
- break;
- default:
- res.writeHead(404);
- res.end();
- break;
- }
- });
+ });
+ }
- codePromise.then(cancelCodeTimer, cancelCodeTimer);
- return {
- server,
- redirectPromise,
- codePromise
- };
+ public start(): Promise {
+ return new Promise((resolve, reject) => {
+ if (this._server.listening) {
+ throw new Error('Server is already started');
+ }
+ const portTimeout = setTimeout(() => {
+ reject(new Error('Timeout waiting for port'));
+ }, 5000);
+ this._server.on('listening', () => {
+ const address = this._server.address();
+ if (typeof address === 'string') {
+ this.port = parseInt(address);
+ } else if (address instanceof Object) {
+ this.port = address.port;
+ } else {
+ throw new Error('Unable to determine port');
+ }
+
+ clearTimeout(portTimeout);
+ resolve(this.port);
+ });
+ this._server.on('error', err => {
+ reject(new Error(`Error listening to server: ${err}`));
+ });
+ this._server.on('close', () => {
+ reject(new Error('Closed'));
+ });
+ this._server.listen(0, '127.0.0.1');
+ });
+ }
+
+ public stop(): Promise {
+ return new Promise((resolve, reject) => {
+ if (!this._server.listening) {
+ throw new Error('Server is not started');
+ }
+ this._server.close((err) => {
+ if (err) {
+ reject(err);
+ } else {
+ resolve();
+ }
+ });
+ });
+ }
+
+ public waitForOAuthResponse(): Promise {
+ return this._resultPromise;
+ }
}