diff --git a/ts/LibSignalStores.ts b/ts/LibSignalStores.ts index c6c022de8a..94fce3d670 100644 --- a/ts/LibSignalStores.ts +++ b/ts/LibSignalStores.ts @@ -181,14 +181,17 @@ export class IdentityKeys extends IdentityKeyStore { export type PreKeysOptions = Readonly<{ ourServiceId: ServiceIdString; + zone?: Zone; }>; export class PreKeys extends PreKeyStore { readonly #ourServiceId: ServiceIdString; + readonly #zone: Zone | undefined; - constructor({ ourServiceId }: PreKeysOptions) { + constructor({ ourServiceId, zone }: PreKeysOptions) { super(); this.#ourServiceId = ourServiceId; + this.#zone = zone; } async savePreKey(): Promise { @@ -209,18 +212,22 @@ export class PreKeys extends PreKeyStore { } async removePreKey(id: number): Promise { - await window.textsecure.storage.protocol.removePreKeys(this.#ourServiceId, [ - id, - ]); + await window.textsecure.storage.protocol.removePreKeys( + this.#ourServiceId, + [id], + { zone: this.#zone } + ); } } export class KyberPreKeys extends KyberPreKeyStore { readonly #ourServiceId: ServiceIdString; + readonly #zone: Zone | undefined; - constructor({ ourServiceId }: PreKeysOptions) { + constructor({ ourServiceId, zone }: PreKeysOptions) { super(); this.#ourServiceId = ourServiceId; + this.#zone = zone; } async saveKyberPreKey(): Promise { @@ -244,7 +251,8 @@ export class KyberPreKeys extends KyberPreKeyStore { async markKyberPreKeyUsed(id: number): Promise { await window.textsecure.storage.protocol.maybeRemoveKyberPreKey( this.#ourServiceId, - id + id, + { zone: this.#zone } ); } } @@ -256,7 +264,6 @@ export type SenderKeysOptions = Readonly<{ export class SenderKeys extends SenderKeyStore { readonly #ourServiceId: ServiceIdString; - readonly zone: Zone | undefined; constructor({ ourServiceId, zone }: SenderKeysOptions) { @@ -300,6 +307,7 @@ export type SignedPreKeysOptions = Readonly<{ ourServiceId: ServiceIdString; }>; +// No need for zone awareness, since no mutation happens in this store export class SignedPreKeys extends SignedPreKeyStore { readonly #ourServiceId: ServiceIdString; diff --git a/ts/SignalProtocolStore.ts b/ts/SignalProtocolStore.ts index 00ba1b491f..11048b4378 100644 --- a/ts/SignalProtocolStore.ts +++ b/ts/SignalProtocolStore.ts @@ -256,6 +256,8 @@ export class SignalProtocolStore extends EventEmitter { #currentZone?: Zone; #currentZoneDepth = 0; readonly #zoneQueue: Array = []; + #pendingKyberPreKeysToRemove = new Set(); + #pendingPreKeysToRemove = new Set(); #pendingSessions = new Map(); #pendingSenderKeys = new Map(); #pendingUnprocessed = new Map(); @@ -393,6 +395,12 @@ export class SignalProtocolStore extends EventEmitter { keyId: number ): Promise { const id: PreKeyIdType = this.#_getKeyId(ourServiceId, keyId); + + if (this.#pendingKyberPreKeysToRemove.has(id)) { + log.error('Not returning kyberPreKey pending removal', id); + return undefined; + } + const entry = this.#_getKyberPreKeyEntry(id, 'loadKyberPreKey'); return entry?.item; @@ -489,7 +497,8 @@ export class SignalProtocolStore extends EventEmitter { async maybeRemoveKyberPreKey( ourServiceId: ServiceIdString, - keyId: number + keyId: number, + { zone = GLOBAL_ZONE }: SessionTransactionOptions = {} ): Promise { const id: PreKeyIdType = this.#_getKeyId(ourServiceId, keyId); const entry = this.#_getKyberPreKeyEntry(id, 'maybeRemoveKyberPreKey'); @@ -504,33 +513,37 @@ export class SignalProtocolStore extends EventEmitter { return; } - await this.removeKyberPreKeys(ourServiceId, [keyId]); + await this.removeKyberPreKeys(ourServiceId, [keyId], { zone }); } async removeKyberPreKeys( ourServiceId: ServiceIdString, - keyIds: Array + keyIds: Array, + { zone = GLOBAL_ZONE }: SessionTransactionOptions = {} ): Promise { - const kyberPreKeyCache = this.kyberPreKeys; - if (!kyberPreKeyCache) { - throw new Error('removeKyberPreKeys: this.kyberPreKeys not yet cached!'); - } + await this.withZone(zone, 'removeKyberPreKeys', async () => { + const kyberPreKeyCache = this.kyberPreKeys; + if (!kyberPreKeyCache) { + throw new Error( + 'removeKyberPreKeys: this.kyberPreKeys not yet cached!' + ); + } - const ids = keyIds.map(keyId => this.#_getKeyId(ourServiceId, keyId)); + const ids = keyIds.map(keyId => this.#_getKeyId(ourServiceId, keyId)); - log.info('removeKyberPreKeys: Removing kyber prekeys:', formatKeys(keyIds)); - const changes = await DataWriter.removeKyberPreKeyById(ids); - log.info(`removeKyberPreKeys: Removed ${changes} kyber prekeys`); - ids.forEach(id => { - kyberPreKeyCache.delete(id); - }); - - if (kyberPreKeyCache.size < LOW_KEYS_THRESHOLD) { - this.#emitLowKeys( - ourServiceId, - `removeKyberPreKeys@${kyberPreKeyCache.size}` + log.info( + `removeKyberPreKeys(${zone.name}): Will remove kyberPreKeys:`, + formatKeys(keyIds) ); - } + + ids.forEach(id => { + this.#pendingKyberPreKeysToRemove.add(id); + }); + + if (!zone.supportsPendingKyberPreKeysToRemove()) { + await this.#commitZoneChanges('removeKyberPreKeys'); + } + }); } async clearKyberPreKeyStore(): Promise { @@ -552,6 +565,11 @@ export class SignalProtocolStore extends EventEmitter { } const id: PreKeyIdType = this.#_getKeyId(ourServiceId, keyId); + if (this.#pendingPreKeysToRemove.has(id)) { + log.error('Not returning prekey pending removal', id); + return undefined; + } + const entry = this.preKeys.get(id); if (!entry) { log.error('Failed to fetch prekey:', id); @@ -630,26 +648,30 @@ export class SignalProtocolStore extends EventEmitter { async removePreKeys( ourServiceId: ServiceIdString, - keyIds: Array + keyIds: Array, + { zone = GLOBAL_ZONE }: SessionTransactionOptions = {} ): Promise { - const preKeyCache = this.preKeys; - if (!preKeyCache) { - throw new Error('removePreKeys: this.preKeys not yet cached!'); - } + await this.withZone(zone, 'removePreKeys', async () => { + const preKeyCache = this.preKeys; + if (!preKeyCache) { + throw new Error('removePreKeys: this.preKeys not yet cached!'); + } - const ids = keyIds.map(keyId => this.#_getKeyId(ourServiceId, keyId)); + const ids = keyIds.map(keyId => this.#_getKeyId(ourServiceId, keyId)); - log.info('removePreKeys: Removing prekeys:', formatKeys(keyIds)); + log.info( + `removePreKeys(${zone.name}): Will remove preKeys:`, + formatKeys(keyIds) + ); - const changes = await DataWriter.removePreKeyById(ids); - log.info(`removePreKeys: Removed ${changes} prekeys`); - ids.forEach(id => { - preKeyCache.delete(id); + ids.forEach(id => { + this.#pendingPreKeysToRemove.add(id); + }); + + if (!zone.supportsPendingPreKeysToRemove()) { + await this.#commitZoneChanges('removePreKeys'); + } }); - - if (preKeyCache.size < LOW_KEYS_THRESHOLD) { - this.#emitLowKeys(ourServiceId, `removePreKeys@${preKeyCache.size}`); - } } async clearPreKeyStore(): Promise { @@ -1123,11 +1145,15 @@ export class SignalProtocolStore extends EventEmitter { } async #commitZoneChanges(name: string): Promise { - const pendingUnprocessed = this.#pendingUnprocessed; + const pendingKyberPreKeysToRemove = this.#pendingKyberPreKeysToRemove; + const pendingPreKeysToRemove = this.#pendingPreKeysToRemove; const pendingSenderKeys = this.#pendingSenderKeys; const pendingSessions = this.#pendingSessions; + const pendingUnprocessed = this.#pendingUnprocessed; if ( + pendingKyberPreKeysToRemove.size === 0 && + pendingPreKeysToRemove.size === 0 && pendingSenderKeys.size === 0 && pendingSessions.size === 0 && pendingUnprocessed.size === 0 @@ -1137,11 +1163,15 @@ export class SignalProtocolStore extends EventEmitter { log.info( `commitZoneChanges(${name}): ` + - `pending sender keys ${pendingSenderKeys.size}, ` + + `pending kyberPreKeysToRemove ${pendingPreKeysToRemove.size}, ` + + `pending preKeysToRemove ${pendingKyberPreKeysToRemove.size}, ` + + `pending senderKeys ${pendingSenderKeys.size}, ` + `pending sessions ${pendingSessions.size}, ` + `pending unprocessed ${pendingUnprocessed.size}` ); + this.#pendingKyberPreKeysToRemove = new Set(); + this.#pendingPreKeysToRemove = new Set(); this.#pendingSenderKeys = new Map(); this.#pendingSessions = new Map(); this.#pendingUnprocessed = new Map(); @@ -1149,6 +1179,8 @@ export class SignalProtocolStore extends EventEmitter { // Commit both sender keys, sessions and unprocessed in the same database transaction // to unroll both on error. await DataWriter.commitDecryptResult({ + kyberPreKeysToRemove: Array.from(pendingKyberPreKeysToRemove.values()), + preKeysToRemove: Array.from(pendingPreKeysToRemove.values()), senderKeys: Array.from(pendingSenderKeys.values()).map( ({ fromDB }) => fromDB ), @@ -1160,14 +1192,26 @@ export class SignalProtocolStore extends EventEmitter { // Apply changes to in-memory storage after successful DB write. - const { sessions } = this; + const { kyberPreKeys } = this; assertDev( - sessions !== undefined, - "Can't commit unhydrated session storage" + kyberPreKeys !== undefined, + "Can't commit unhydrated kyberPreKeys storage" ); - pendingSessions.forEach((value, key) => { - sessions.set(key, value); + pendingKyberPreKeysToRemove.forEach(value => { + kyberPreKeys.delete(value); }); + if (kyberPreKeys.size < LOW_KEYS_THRESHOLD) { + this.#emitLowKeys(`removeKyberPreKeys@${kyberPreKeys.size}`); + } + + const { preKeys } = this; + assertDev(preKeys !== undefined, "Can't commit unhydrated preKeys storage"); + pendingPreKeysToRemove.forEach(value => { + preKeys.delete(value); + }); + if (preKeys.size < LOW_KEYS_THRESHOLD) { + this.#emitLowKeys(`removePreKeys@${preKeys.size}`); + } const { senderKeys } = this; assertDev( @@ -1177,16 +1221,29 @@ export class SignalProtocolStore extends EventEmitter { pendingSenderKeys.forEach((value, key) => { senderKeys.set(key, value); }); + + const { sessions } = this; + assertDev( + sessions !== undefined, + "Can't commit unhydrated session storage" + ); + pendingSessions.forEach((value, key) => { + sessions.set(key, value); + }); } async #revertZoneChanges(name: string, error: Error): Promise { log.info( `revertZoneChanges(${name}): ` + - `pending sender keys size ${this.#pendingSenderKeys.size}, ` + + `pending kyberPreKeysToRemove size ${this.#pendingKyberPreKeysToRemove.size}, ` + + `pending preKeysToRemove size ${this.#pendingPreKeysToRemove.size}, ` + + `pending senderKeys size ${this.#pendingSenderKeys.size}, ` + `pending sessions size ${this.#pendingSessions.size}, ` + `pending unprocessed size ${this.#pendingUnprocessed.size}`, Errors.toLogFormat(error) ); + this.#pendingKyberPreKeysToRemove.clear(); + this.#pendingPreKeysToRemove.clear(); this.#pendingSenderKeys.clear(); this.#pendingSessions.clear(); this.#pendingUnprocessed.clear(); @@ -2649,11 +2706,11 @@ export class SignalProtocolStore extends EventEmitter { return Array.from(union.values()); } - #emitLowKeys(ourServiceId: ServiceIdString, source: string) { + #emitLowKeys(source: string) { const logId = `SignalProtocolStore.emitLowKeys/${source}:`; try { log.info(`${logId}: Emitting event`); - this.emit('lowKeys', ourServiceId); + this.emit('lowKeys'); } catch (error) { log.error(`${logId}: Error thrown from emit`, Errors.toLogFormat(error)); } @@ -2663,10 +2720,7 @@ export class SignalProtocolStore extends EventEmitter { // EventEmitter types // - public override on( - name: 'lowKeys', - handler: (ourServiceId: ServiceIdString) => unknown - ): this; + public override on(name: 'lowKeys', handler: () => unknown): this; public override on( name: 'keychange', @@ -2683,7 +2737,7 @@ export class SignalProtocolStore extends EventEmitter { return super.on(eventName, listener); } - public override emit(name: 'lowKeys', ourServiceid: ServiceIdString): boolean; + public override emit(name: 'lowKeys'): boolean; public override emit( name: 'keychange', diff --git a/ts/background.ts b/ts/background.ts index 6d089b77de..a5f98e8a51 100644 --- a/ts/background.ts +++ b/ts/background.ts @@ -395,10 +395,9 @@ export async function startApp(): Promise { window.textsecure.storage.protocol.on( 'lowKeys', throttle( - (ourServiceId: ServiceIdString) => { - const serviceIdKind = - window.textsecure.storage.user.getOurServiceIdKind(ourServiceId); - drop(window.getAccountManager().maybeUpdateKeys(serviceIdKind)); + async () => { + await window.getAccountManager().maybeUpdateKeys(ServiceIdKind.ACI); + await window.getAccountManager().maybeUpdateKeys(ServiceIdKind.PNI); }, durations.MINUTE, { trailing: true, leading: false } diff --git a/ts/sql/Interface.ts b/ts/sql/Interface.ts index 4ab41792ac..3ee0966331 100644 --- a/ts/sql/Interface.ts +++ b/ts/sql/Interface.ts @@ -958,6 +958,8 @@ type WritableInterface = { createOrUpdateSession: (data: SessionType) => void; createOrUpdateSessions: (array: Array) => void; commitDecryptResult(options: { + kyberPreKeysToRemove: Array; + preKeysToRemove: Array; senderKeys: Array; sessions: Array; unprocessed: Array; diff --git a/ts/sql/Server.ts b/ts/sql/Server.ts index 42c8040a2f..d5c280608f 100644 --- a/ts/sql/Server.ts +++ b/ts/sql/Server.ts @@ -1601,16 +1601,46 @@ function createOrUpdateSessions( function commitDecryptResult( db: WritableDB, { + kyberPreKeysToRemove, + preKeysToRemove, senderKeys, sessions, unprocessed, }: { + kyberPreKeysToRemove: Array; + preKeysToRemove: Array; senderKeys: Array; sessions: Array; unprocessed: Array; } ): void { db.transaction(() => { + if (kyberPreKeysToRemove.length > 0) { + const kyberPreKeyChanges = removeKyberPreKeyById( + db, + kyberPreKeysToRemove + ); + if (kyberPreKeyChanges === kyberPreKeysToRemove.length) { + logger.info( + `commitDecryptResult: Removed ${kyberPreKeyChanges} kyberPreKeys` + ); + } else { + logger.error( + `commitDecryptResult: Changed ${kyberPreKeyChanges} keys, but had ${kyberPreKeysToRemove.length} kyberPreKeys to remove` + ); + } + } + if (preKeysToRemove.length > 0) { + const preKeyChanges = removePreKeyById(db, preKeysToRemove); + if (preKeyChanges === preKeysToRemove.length) { + logger.info(`commitDecryptResult: Removed ${preKeyChanges} preKeys`); + } else { + logger.error( + `commitDecryptResult: Changed ${preKeyChanges} keys, but had ${preKeysToRemove.length} preKeys to remove` + ); + } + } + for (const item of senderKeys) { createOrUpdateSenderKey(db, item); } diff --git a/ts/test-electron/SignalProtocolStore_test.ts b/ts/test-electron/SignalProtocolStore_test.ts index 344a75c3f3..7d4ef4ef51 100644 --- a/ts/test-electron/SignalProtocolStore_test.ts +++ b/ts/test-electron/SignalProtocolStore_test.ts @@ -293,6 +293,8 @@ describe('SignalProtocolStore', () => { it('should not deadlock', async () => { const newIdentity = getPublicKey(); const zone = new Zone('zone', { + pendingKyberPreKeysToRemove: true, + pendingPreKeysToRemove: true, pendingSenderKeys: true, pendingSessions: true, pendingUnprocessed: true, @@ -1230,6 +1232,8 @@ describe('SignalProtocolStore', () => { describe('zones', () => { const distributionId = generateUuid(); const zone = new Zone('zone', { + pendingKyberPreKeysToRemove: true, + pendingPreKeysToRemove: true, pendingSenderKeys: true, pendingSessions: true, pendingUnprocessed: true, diff --git a/ts/textsecure/MessageReceiver.ts b/ts/textsecure/MessageReceiver.ts index 7cd91cd0dc..eafa3ba1f8 100644 --- a/ts/textsecure/MessageReceiver.ts +++ b/ts/textsecure/MessageReceiver.ts @@ -216,9 +216,13 @@ type CacheAddItemType = { }; type LockedStores = { + readonly identityKeyStore: IdentityKeys; + readonly kyberPreKeyStore: KyberPreKeys; + readonly preKeyStore: PreKeys; readonly senderKeyStore: SenderKeys; readonly sessionStore: Sessions; - readonly identityKeyStore: IdentityKeys; + readonly signedPreKeyStore: SignedPreKeys; + readonly zone?: Zone; }; @@ -994,6 +998,8 @@ export default class MessageReceiver try { const zone = new Zone('decryptAndCacheBatch', { + pendingKyberPreKeysToRemove: true, + pendingPreKeysToRemove: true, pendingSenderKeys: true, pendingSessions: true, pendingUnprocessed: true, @@ -1019,19 +1025,17 @@ export default class MessageReceiver let stores = storesMap.get(destinationServiceId); if (!stores) { + const sharedParams = { + ourServiceId: destinationServiceId, + zone, + }; stores = { - senderKeyStore: new SenderKeys({ - ourServiceId: destinationServiceId, - zone, - }), - sessionStore: new Sessions({ - zone, - ourServiceId: destinationServiceId, - }), - identityKeyStore: new IdentityKeys({ - zone, - ourServiceId: destinationServiceId, - }), + identityKeyStore: new IdentityKeys(sharedParams), + kyberPreKeyStore: new KyberPreKeys(sharedParams), + preKeyStore: new PreKeys(sharedParams), + senderKeyStore: new SenderKeys(sharedParams), + sessionStore: new Sessions(sharedParams), + signedPreKeyStore: new SignedPreKeys(sharedParams), zone, }; storesMap.set(destinationServiceId, stores); @@ -1671,7 +1675,15 @@ export default class MessageReceiver } async #decryptSealedSender( - { senderKeyStore, sessionStore, identityKeyStore, zone }: LockedStores, + { + identityKeyStore, + kyberPreKeyStore, + preKeyStore, + senderKeyStore, + sessionStore, + signedPreKeyStore, + zone, + }: LockedStores, envelope: UnsealedEnvelope ): Promise { const { destinationServiceId } = envelope; @@ -1747,14 +1759,6 @@ export default class MessageReceiver 'unidentified message/passing to sealedSenderDecryptMessage' ); - const preKeyStore = new PreKeys({ ourServiceId: destinationServiceId }); - const signedPreKeyStore = new SignedPreKeys({ - ourServiceId: destinationServiceId, - }); - const kyberPreKeyStore = new KyberPreKeys({ - ourServiceId: destinationServiceId, - }); - const sealedSenderIdentifier = envelope.sourceServiceId; strictAssert( sealedSenderIdentifier !== undefined, @@ -1810,7 +1814,14 @@ export default class MessageReceiver ciphertext: Uint8Array, serviceIdKind: ServiceIdKind ): Promise { - const { sessionStore, identityKeyStore, zone } = stores; + const { + identityKeyStore, + kyberPreKeyStore, + preKeyStore, + sessionStore, + signedPreKeyStore, + zone, + } = stores; const logId = getEnvelopeId(envelope); const envelopeTypeEnum = Proto.Envelope.Type; @@ -1819,13 +1830,6 @@ export default class MessageReceiver const { sourceDevice } = envelope; const { destinationServiceId } = envelope; - const preKeyStore = new PreKeys({ ourServiceId: destinationServiceId }); - const signedPreKeyStore = new SignedPreKeys({ - ourServiceId: destinationServiceId, - }); - const kyberPreKeyStore = new KyberPreKeys({ - ourServiceId: destinationServiceId, - }); strictAssert(identifier !== undefined, 'Empty identifier'); strictAssert(sourceDevice !== undefined, 'Empty source device'); diff --git a/ts/util/Zone.ts b/ts/util/Zone.ts index 7d0379856f..f65541f268 100644 --- a/ts/util/Zone.ts +++ b/ts/util/Zone.ts @@ -2,6 +2,8 @@ // SPDX-License-Identifier: AGPL-3.0-only export type ZoneOptions = { + readonly pendingKyberPreKeysToRemove?: boolean; + readonly pendingPreKeysToRemove?: boolean; readonly pendingSenderKeys?: boolean; readonly pendingSessions?: boolean; readonly pendingUnprocessed?: boolean; @@ -13,6 +15,14 @@ export class Zone { private readonly options: ZoneOptions = {} ) {} + public supportsPendingKyberPreKeysToRemove(): boolean { + return this.options.pendingKyberPreKeysToRemove === true; + } + + public supportsPendingPreKeysToRemove(): boolean { + return this.options.pendingPreKeysToRemove === true; + } + public supportsPendingSenderKeys(): boolean { return this.options.pendingSenderKeys === true; }