diff --git a/ts/textsecure/WebAPI.ts b/ts/textsecure/WebAPI.ts index b68d3cb968..ce673ca9d3 100644 --- a/ts/textsecure/WebAPI.ts +++ b/ts/textsecure/WebAPI.ts @@ -28,6 +28,7 @@ import { formatAcceptLanguageHeader } from '../util/userLanguages'; import { toWebSafeBase64, fromWebSafeBase64 } from '../util/webSafeBase64'; import { getBasicAuth } from '../util/getBasicAuth'; import { isPnpEnabled } from '../util/isPnpEnabled'; +import { lookupWithFallback } from '../util/dns'; import type { SocketStatus } from '../types/SocketStatus'; import { toLogFormat } from '../types/errors'; import { isPackIdValid, redactPackId } from '../types/Stickers'; @@ -243,6 +244,7 @@ async function _promiseAjax( agent: proxyUrl ? new ProxyAgent(proxyUrl) : new Agent({ + lookup: lookupWithFallback, keepAlive: !options.disableSessionResumption, maxCachedSessions: options.disableSessionResumption ? 0 : undefined, }), diff --git a/ts/textsecure/WebSocket.ts b/ts/textsecure/WebSocket.ts index 3a0dc19caf..f90ede37fe 100644 --- a/ts/textsecure/WebSocket.ts +++ b/ts/textsecure/WebSocket.ts @@ -10,6 +10,7 @@ import { strictAssert } from '../util/assert'; import { explodePromise } from '../util/explodePromise'; import { getUserAgent } from '../util/getUserAgent'; import * as durations from '../util/durations'; +import { lookupWithFallback } from '../util/dns'; import * as log from '../logging/log'; import * as Timers from '../Timers'; import { ConnectTimeoutError, HTTPError } from './Errors'; @@ -55,6 +56,7 @@ export function connect({ tlsOptions: { ca: certificateAuthority, agent: proxyAgent, + lookup: lookupWithFallback, }, maxReceivedFrameSize: 0x210000, }); diff --git a/ts/updater/got.ts b/ts/updater/got.ts index 1572eb0283..faea141818 100644 --- a/ts/updater/got.ts +++ b/ts/updater/got.ts @@ -8,6 +8,7 @@ import ProxyAgent from 'proxy-agent'; import * as packageJson from '../../package.json'; import { getUserAgent } from '../util/getUserAgent'; import * as durations from '../util/durations'; +import { lookupWithFallback } from '../util/dns'; export const GOT_CONNECT_TIMEOUT = durations.MINUTE; export const GOT_LOOKUP_TIMEOUT = durations.MINUTE; @@ -37,6 +38,7 @@ export function getGotOptions(): GotOptions { https: { certificateAuthority, }, + lookup: lookupWithFallback as GotOptions['lookup'], headers: { 'Cache-Control': 'no-cache', 'User-Agent': getUserAgent(packageJson.version), diff --git a/ts/util/dns.ts b/ts/util/dns.ts new file mode 100644 index 0000000000..fbb1a6b4fc --- /dev/null +++ b/ts/util/dns.ts @@ -0,0 +1,62 @@ +// Copyright 2023 Signal Messenger, LLC +// SPDX-License-Identifier: AGPL-3.0-only + +import { lookup as nativeLookup, resolve4, resolve6 } from 'dns'; +import type { LookupOneOptions } from 'dns'; + +import * as log from '../logging/log'; +import * as Errors from '../types/errors'; +import { strictAssert } from './assert'; + +export function lookupWithFallback( + hostname: string, + opts: LookupOneOptions, + callback: ( + err: NodeJS.ErrnoException | null, + address: string, + family: number + ) => void +): void { + // Node.js support various signatures, but we only support one. + strictAssert(typeof opts === 'object', 'missing options'); + strictAssert(Boolean(opts.all) !== true, 'options.all is not supported'); + strictAssert(typeof callback === 'function', 'missing callback'); + + nativeLookup(hostname, opts, (err, ...nativeArgs) => { + if (!err) { + return callback(err, ...nativeArgs); + } + + const family = opts.family === 6 ? 6 : 4; + + log.error( + `lookup: failed for ${hostname}, error: ${Errors.toLogFormat(err)}. ` + + `Retrying with c-ares (IPv${family})` + ); + const onRecords = ( + fallbackErr: NodeJS.ErrnoException | null, + records: Array + ): void => { + if (fallbackErr) { + return callback(fallbackErr, '', 0); + } + + if (!Array.isArray(records) || records.length === 0) { + return callback( + new Error(`No DNS records returned for: ${hostname}`), + '', + 0 + ); + } + + const index = Math.floor(Math.random() * records.length); + callback(null, records[index], family); + }; + + if (family === 4) { + resolve4(hostname, onRecords); + } else { + resolve6(hostname, onRecords); + } + }); +}