diff --git a/extensions/microsoft-authentication/media/auth.html b/extensions/microsoft-authentication/media/index.html similarity index 93% rename from extensions/microsoft-authentication/media/auth.html rename to extensions/microsoft-authentication/media/index.html index 0fcba4e3c62..9c0a9eec080 100644 --- a/extensions/microsoft-authentication/media/auth.html +++ b/extensions/microsoft-authentication/media/index.html @@ -1,10 +1,9 @@ - + - Azure Account - Sign In diff --git a/extensions/microsoft-authentication/src/AADHelper.ts b/extensions/microsoft-authentication/src/AADHelper.ts index 72408f855a1..6b7d1a85131 100644 --- a/extensions/microsoft-authentication/src/AADHelper.ts +++ b/extensions/microsoft-authentication/src/AADHelper.ts @@ -10,12 +10,13 @@ import * as vscode from 'vscode'; import * as nls from 'vscode-nls'; import { v4 as uuid } from 'uuid'; import fetch, { Response } from 'node-fetch'; -import { createServer, startServer } from './authServer'; import { Keychain } from './keychain'; import Logger from './logger'; import { toBase64UrlEncoding } from './utils'; import { sha256 } from './env/node/sha256'; import { BetterTokenStorage, IDidChangeInOtherWindowEvent } from './betterSecretStorage'; +import { LoopbackAuthServer } from './authServer'; +import path = require('path'); const localize = nls.loadMessageBundle(); @@ -238,63 +239,42 @@ export class AzureActiveDirectoryService { } private async createSessionWithLocalServer(scopeData: IScopeData) { - const nonce = randomBytes(16).toString('base64'); - const { server, redirectPromise, codePromise } = createServer(nonce); + const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64')); + const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier)); + const qs = querystring.stringify({ + response_type: 'code', + response_mode: 'query', + client_id: scopeData.clientId, + redirect_uri: redirectUrl, + scope: scopeData.scopesToSend, + prompt: 'select_account', + code_challenge_method: 'S256', + code_challenge: codeChallenge, + }); + const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?${qs}`; + const server = new LoopbackAuthServer(path.join(__dirname, '../media'), loginUrl); + await server.start(); + server.state = `${server.port},${encodeURIComponent(server.nonce)}`; - let token: IToken | undefined; + let codeToExchange; try { - const port = await startServer(server); - vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${port}/signin?nonce=${encodeURIComponent(nonce)}`)); - - const redirectReq = await redirectPromise; - if ('err' in redirectReq) { - const { err, res } = redirectReq; - res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` }); - res.end(); - throw err; - } - - const host = redirectReq.req.headers.host || ''; - const updatedPortStr = (/^[^:]+:(\d+)$/.exec(Array.isArray(host) ? host[0] : host) || [])[1]; - const updatedPort = updatedPortStr ? parseInt(updatedPortStr, 10) : port; - - const state = `${updatedPort},${encodeURIComponent(nonce)}`; - - const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64')); - const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier)); - - const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=${encodeURIComponent(scopeData.clientId)}&redirect_uri=${encodeURIComponent(redirectUrl)}&state=${state}&scope=${encodeURIComponent(scopeData.scopesToSend)}&prompt=select_account&code_challenge_method=S256&code_challenge=${codeChallenge}`; - - redirectReq.res.writeHead(302, { Location: loginUrl }); - redirectReq.res.end(); - - const codeRes = await codePromise; - const res = codeRes.res; - - try { - if ('err' in codeRes) { - throw codeRes.err; - } - token = await this.exchangeCodeForToken(codeRes.code, codeVerifier, scopeData); - if (token.expiresIn) { - this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER); - } - await this.setToken(token, scopeData); - Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`); - res.writeHead(302, { Location: '/' }); - const session = await this.convertToSession(token); - return session; - } catch (err) { - res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` }); - throw err; - } finally { - res.end(); - } + vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${server.port}/signin?nonce=${encodeURIComponent(server.nonce)}`)); + const { code } = await server.waitForOAuthResponse(); + codeToExchange = code; } finally { setTimeout(() => { - server.close(); + void server.stop(); }, 5000); } + + const token = await this.exchangeCodeForToken(codeToExchange, codeVerifier, scopeData); + if (token.expiresIn) { + this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER); + } + await this.setToken(token, scopeData); + Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`); + const session = await this.convertToSession(token); + return session; } private async createSessionWithoutLocalServer(scopeData: IScopeData): Promise { diff --git a/extensions/microsoft-authentication/src/authServer.ts b/extensions/microsoft-authentication/src/authServer.ts index 09dec9d4619..158d0db257f 100644 --- a/extensions/microsoft-authentication/src/authServer.ts +++ b/extensions/microsoft-authentication/src/authServer.ts @@ -2,65 +2,13 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ - import * as http from 'http'; -import * as url from 'url'; +import { URL } from 'url'; import * as fs from 'fs'; import * as path from 'path'; +import { randomBytes } from 'crypto'; -interface Deferred { - resolve: (result: T | Promise) => void; - reject: (reason: any) => void; -} - -/** - * Asserts that the argument passed in is neither undefined nor null. - */ -function assertIsDefined(arg: T | null | undefined): T { - if (typeof (arg) === 'undefined' || arg === null) { - throw new Error('Assertion Failed: argument is undefined or null'); - } - - return arg; -} - -export async function startServer(server: http.Server): Promise { - let portTimer: NodeJS.Timer; - - function cancelPortTimer() { - clearTimeout(portTimer); - } - - const port = new Promise((resolve, reject) => { - portTimer = setTimeout(() => { - reject(new Error('Timeout waiting for port')); - }, 5000); - - server.on('listening', () => { - const address = server.address(); - if (typeof address === 'string') { - resolve(address); - } else { - resolve(assertIsDefined(address).port.toString()); - } - }); - - server.on('error', _ => { - reject(new Error('Error listening to server')); - }); - - server.on('close', () => { - reject(new Error('Closed')); - }); - - server.listen(0, '127.0.0.1'); - }); - - port.then(cancelPortTimer, cancelPortTimer); - return port; -} - -function sendFile(res: http.ServerResponse, filepath: string, contentType: string) { +function sendFile(res: http.ServerResponse, filepath: string) { fs.readFile(filepath, (err, body) => { if (err) { console.error(err); @@ -68,89 +16,173 @@ function sendFile(res: http.ServerResponse, filepath: string, contentType: strin res.end(); } else { res.writeHead(200, { - 'Content-Length': body.length, - 'Content-Type': contentType + 'content-length': body.length, }); res.end(body); } }); } -async function callback(nonce: string, reqUrl: url.Url): Promise { - const query = reqUrl.query; - if (!query || typeof query === 'string') { - throw new Error('No query received.'); - } - - let error = query.error_description || query.error; - - if (!error) { - const state = (query.state as string) || ''; - const receivedNonce = (state.split(',')[1] || '').replace(/ /g, '+'); - if (receivedNonce !== nonce) { - error = 'Nonce does not match.'; - } - } - - const code = query.code as string; - if (!error && code) { - return code; - } - - throw new Error((error as string) || 'No code received.'); +interface IOAuthResult { + code: string; + state: string; } -export function createServer(nonce: string) { - type RedirectResult = { req: http.IncomingMessage; res: http.ServerResponse } | { err: any; res: http.ServerResponse }; - let deferredRedirect: Deferred; - const redirectPromise = new Promise((resolve, reject) => deferredRedirect = { resolve, reject }); +interface ILoopbackServer { + /** + * If undefined, the server is not started yet. + */ + port: number | undefined; - type CodeResult = { code: string; res: http.ServerResponse } | { err: any; res: http.ServerResponse }; - let deferredCode: Deferred; - const codePromise = new Promise((resolve, reject) => deferredCode = { resolve, reject }); + /** + * The nonce used + */ + nonce: string; - const codeTimer = setTimeout(() => { - deferredCode.reject(new Error('Timeout waiting for code')); - }, 5 * 60 * 1000); + /** + * The state parameter used in the OAuth flow. + */ + state: string | undefined; - function cancelCodeTimer() { - clearTimeout(codeTimer); + /** + * Starts the server. + * @returns The port to listen on. + * @throws If the server fails to start. + * @throws If the server is already started. + */ + start(): Promise; + /** + * Stops the server. + * @throws If the server is not started. + * @throws If the server fails to stop. + */ + stop(): Promise; + /** + * Returns a promise that resolves to the result of the OAuth flow. + */ + waitForOAuthResponse(): Promise; +} + +export class LoopbackAuthServer implements ILoopbackServer { + private readonly _server: http.Server; + private readonly _resultPromise: Promise; + private _startingRedirect: URL; + + public nonce = randomBytes(16).toString('base64'); + public port: number | undefined; + + public set state(state: string | undefined) { + if (state) { + this._startingRedirect.searchParams.set('state', state); + } else { + this._startingRedirect.searchParams.delete('state'); + } + } + public get state(): string | undefined { + return this._startingRedirect.searchParams.get('state') ?? undefined; } - const server = http.createServer(function (req, res) { - const reqUrl = url.parse(req.url!, /* parseQueryString */ true); - switch (reqUrl.pathname) { - case '/signin': { - const receivedNonce = ((reqUrl.query.nonce as string) || '').replace(/ /g, '+'); - if (receivedNonce === nonce) { - deferredRedirect.resolve({ req, res }); - } else { - const err = new Error('Nonce does not match.'); - deferredRedirect.resolve({ err, res }); + constructor(serveRoot: string, startingRedirect: string) { + if (!serveRoot) { + throw new Error('serveRoot must be defined'); + } + if (!startingRedirect) { + throw new Error('startingRedirect must be defined'); + } + this._startingRedirect = new URL(startingRedirect); + let deferred: { resolve: (result: IOAuthResult) => void; reject: (reason: any) => void }; + this._resultPromise = new Promise((resolve, reject) => deferred = { resolve, reject }); + + this._server = http.createServer((req, res) => { + const reqUrl = new URL(req.url!, `http://${req.headers.host}`); + switch (reqUrl.pathname) { + case '/signin': { + const receivedNonce = (reqUrl.searchParams.get('nonce') ?? '').replace(/ /g, '+'); + if (receivedNonce !== this.nonce) { + res.writeHead(302, { location: `/?error=${encodeURIComponent('Nonce does not match.')}` }); + res.end(); + } + res.writeHead(302, { location: this._startingRedirect.toString() }); + res.end(); + break; } - break; + case '/callback': { + const code = reqUrl.searchParams.get('code') ?? undefined; + const state = reqUrl.searchParams.get('state') ?? undefined; + if (!code || !state) { + res.writeHead(400); + res.end(); + return; + } + if (this.state !== state) { + res.writeHead(302, { location: `/?error=${encodeURIComponent('State does not match.')}` }); + res.end(); + throw new Error('State does not match.'); + } + deferred.resolve({ code, state }); + res.writeHead(302, { location: '/' }); + res.end(); + break; + } + // Serve the static files + case '/': + sendFile(res, path.join(serveRoot, 'index.html')); + break; + default: + // substring to get rid of leading '/' + sendFile(res, path.join(serveRoot, reqUrl.pathname.substring(1))); + break; } - case '/': - sendFile(res, path.join(__dirname, '../media/auth.html'), 'text/html; charset=utf-8'); - break; - case '/auth.css': - sendFile(res, path.join(__dirname, '../media/auth.css'), 'text/css; charset=utf-8'); - break; - case '/callback': - deferredCode.resolve(callback(nonce, reqUrl) - .then(code => ({ code, res }), err => ({ err, res }))); - break; - default: - res.writeHead(404); - res.end(); - break; - } - }); + }); + } - codePromise.then(cancelCodeTimer, cancelCodeTimer); - return { - server, - redirectPromise, - codePromise - }; + public start(): Promise { + return new Promise((resolve, reject) => { + if (this._server.listening) { + throw new Error('Server is already started'); + } + const portTimeout = setTimeout(() => { + reject(new Error('Timeout waiting for port')); + }, 5000); + this._server.on('listening', () => { + const address = this._server.address(); + if (typeof address === 'string') { + this.port = parseInt(address); + } else if (address instanceof Object) { + this.port = address.port; + } else { + throw new Error('Unable to determine port'); + } + + clearTimeout(portTimeout); + resolve(this.port); + }); + this._server.on('error', err => { + reject(new Error(`Error listening to server: ${err}`)); + }); + this._server.on('close', () => { + reject(new Error('Closed')); + }); + this._server.listen(0, '127.0.0.1'); + }); + } + + public stop(): Promise { + return new Promise((resolve, reject) => { + if (!this._server.listening) { + throw new Error('Server is not started'); + } + this._server.close((err) => { + if (err) { + reject(err); + } else { + resolve(); + } + }); + }); + } + + public waitForOAuthResponse(): Promise { + return this._resultPromise; + } }