From 2ea59bef686c68a37598d8f835a647ec75691864 Mon Sep 17 00:00:00 2001 From: Alex Hart Date: Thu, 21 May 2026 10:12:05 -0300 Subject: [PATCH] Handle PniChangeNumber sync on linked devices. --- ...rocessorTest_synchronizePniChangeNumber.kt | 353 ++++++++++++++++++ .../securesms/testing/MessageContentFuzzer.kt | 3 +- .../changenumber/ChangeNumberRepository.kt | 110 ++++-- .../securesms/jobs/PreKeysSyncJob.kt | 20 +- .../securesms/keyvalue/MiscellaneousValues.kt | 13 + .../messages/IncomingMessageObserver.kt | 156 +++++--- .../messages/SyncMessageProcessor.kt | 100 +++++ .../protocol/BufferedProtocolStore.kt | 6 +- .../securesms/jobs/PreKeysSyncJobTest.kt | 179 +++++++++ 9 files changed, 852 insertions(+), 88 deletions(-) create mode 100644 app/src/androidTest/java/org/thoughtcrime/securesms/messages/SyncMessageProcessorTest_synchronizePniChangeNumber.kt create mode 100644 app/src/test/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJobTest.kt diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/messages/SyncMessageProcessorTest_synchronizePniChangeNumber.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/SyncMessageProcessorTest_synchronizePniChangeNumber.kt new file mode 100644 index 0000000000..ae0a139c5f --- /dev/null +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/messages/SyncMessageProcessorTest_synchronizePniChangeNumber.kt @@ -0,0 +1,353 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.messages + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isNotNull +import assertk.assertions.isTrue +import okio.ByteString +import okio.ByteString.Companion.toByteString +import org.junit.After +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.signal.core.models.ServiceId +import org.signal.core.util.UuidUtil +import org.signal.core.util.orNull +import org.signal.libsignal.protocol.IdentityKeyPair +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.state.KyberPreKeyRecord +import org.signal.libsignal.protocol.state.SignedPreKeyRecord +import org.thoughtcrime.securesms.crypto.PreKeyUtil +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.recipients.Recipient +import org.thoughtcrime.securesms.testing.MessageContentFuzzer +import org.thoughtcrime.securesms.testing.SignalActivityRule +import org.whispersystems.signalservice.api.push.SignalServiceAddress +import org.whispersystems.signalservice.internal.push.Content +import org.whispersystems.signalservice.internal.push.SyncMessage +import java.util.UUID + +@Suppress("ClassName") +@RunWith(AndroidJUnit4::class) +class SyncMessageProcessorTest_synchronizePniChangeNumber { + + @get:Rule + val harness = SignalActivityRule(createGroup = true) + + private lateinit var messageHelper: MessageHelper + + private val newPniUuid: UUID = UUID.randomUUID() + private val newPni: ServiceId.PNI = ServiceId.PNI.from(newPniUuid) + + // 16-byte raw UUID — matches the actual wire format the server sends (per proto comment and + // iOS/Desktop behavior). Do NOT use `newPni.toByteString()` here — that produces libsignal's + // 17-byte ServiceIdBinary form, which is a different format. + private val newPniBytes: ByteString = UuidUtil.toByteArray(newPniUuid).toByteString() + private val newE164 = "+15555550199" + private val newPniIdentity: IdentityKeyPair = IdentityKeyPair.generate() + private val newSignedPreKey: SignedPreKeyRecord = PreKeyUtil.generateSignedPreKey(1234, newPniIdentity.privateKey) + private val newLastResortKyber: KyberPreKeyRecord = PreKeyUtil.generateLastResortKyberPreKey(5678, newPniIdentity.privateKey) + private val newRegistrationId = 4242 + + @Before + fun setUp() { + messageHelper = MessageHelper(harness) + SignalStore.account.deviceId = 2 + } + + @After + fun tearDown() { + messageHelper.tearDown() + } + + @Test + fun appliesAllStateOnHappyPath() { + sendPniChangeNumber() + + assertThat(SignalStore.account.e164).isEqualTo(newE164) + assertThat(SignalStore.account.pni).isEqualTo(newPni) + assertThat(SignalStore.account.pniRegistrationId).isEqualTo(newRegistrationId) + assertThat(SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString()) + .isEqualTo(newPniIdentity.publicKey.serialize().toByteString()) + assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(newSignedPreKey.id) + assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isTrue() + assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(newLastResortKyber.id) + assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isTrue() + + val self = Recipient.self().fresh() + assertThat(self.requireE164()).isEqualTo(newE164) + assertThat(self.pni.orNull()).isEqualTo(newPni) + + val pniProtocolStore = AppDependencies.protocolStore.pni() + val storedSigned = pniProtocolStore.loadSignedPreKey(newSignedPreKey.id) + assertThat(storedSigned.serialize().toByteString()).isEqualTo(newSignedPreKey.serialize().toByteString()) + val storedKyber = pniProtocolStore.loadLastResortKyberPreKeys().firstOrNull { it.id == newLastResortKyber.id } + assertThat(storedKyber).isNotNull() + assertThat(storedKyber!!.serialize().toByteString()).isEqualTo(newLastResortKyber.serialize().toByteString()) + + // The IdentityTable cache is keyed by ServiceId string, not RecipientId — for self, that's + // separate ACI and PNI rows. We want the PNI row, so look it up by the new PNI directly. + val selfPniIdentity = pniProtocolStore.getIdentity(SignalProtocolAddress(newPni.toString(), SignalServiceAddress.DEFAULT_DEVICE_ID)) + assertThat(selfPniIdentity).isNotNull() + assertThat(selfPniIdentity!!.publicKey.serialize().toByteString()) + .isEqualTo(newPniIdentity.publicKey.serialize().toByteString()) + } + + @Test + fun appliesStateWhenLastResortKyberAbsent() { + val original = captureOriginalState() + + sendPniChangeNumber(lastResortKyberPreKey = null) + + assertThat(SignalStore.account.e164).isEqualTo(newE164) + assertThat(SignalStore.account.pni).isEqualTo(newPni) + assertThat(SignalStore.account.pniRegistrationId).isEqualTo(newRegistrationId) + assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(newSignedPreKey.id) + assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isTrue() + // No kyber was supplied, so kyber metadata should be unchanged. + assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(original.lastResortKyberPreKeyId) + assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isTrue() + } + + @Test + fun bailsWhenPrimaryDevice() { + SignalStore.account.deviceId = SignalServiceAddress.DEFAULT_DEVICE_ID + val original = captureOriginalState() + + sendPniChangeNumber() + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenSourceIsNotPrimaryDevice() { + val original = captureOriginalState() + + sendPniChangeNumber(sourceDeviceId = 3) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenEnvelopePniMissing() { + val original = captureOriginalState() + + sendPniChangeNumber(envelopePniBinary = null) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenIdentityKeyPairMissing() { + val original = captureOriginalState() + + sendPniChangeNumber(identityKeyPair = null) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenSignedPreKeyMissing() { + val original = captureOriginalState() + + sendPniChangeNumber(signedPreKey = null) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenRegistrationIdMissing() { + val original = captureOriginalState() + + sendPniChangeNumber(registrationId = null) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenRegistrationIdZero() { + val original = captureOriginalState() + + sendPniChangeNumber(registrationId = 0) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenNewE164Missing() { + val original = captureOriginalState() + + sendPniChangeNumber(e164 = null) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenNewE164Empty() { + val original = captureOriginalState() + + sendPniChangeNumber(e164 = "") + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsWhenNewE164NotValid() { + val original = captureOriginalState() + + sendPniChangeNumber(e164 = "not a phone number") + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsOnMalformedIdentityKeyPair() { + val original = captureOriginalState() + + sendPniChangeNumber(identityKeyPair = malformedBytes()) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsOnMalformedSignedPreKey() { + val original = captureOriginalState() + + sendPniChangeNumber(signedPreKey = malformedBytes()) + + assertOriginalStatePreserved(original) + } + + @Test + fun bailsOnMalformedLastResortKyber() { + val original = captureOriginalState() + + sendPniChangeNumber(lastResortKyberPreKey = malformedBytes()) + + assertOriginalStatePreserved(original) + } + + @Test + fun skipsRedeliveryWhenPniAlreadyMatches() { + sendPniChangeNumber() + val afterFirstApply = captureOriginalState() + + val otherIdentity = IdentityKeyPair.generate() + val otherSignedPreKey = PreKeyUtil.generateSignedPreKey(9999, otherIdentity.privateKey) + + sendPniChangeNumber( + identityKeyPair = otherIdentity.serialize().toByteString(), + signedPreKey = otherSignedPreKey.serialize().toByteString(), + e164 = "+15555550100", + timestamp = messageHelper.nextStartTime() + 1000 + ) + + assertOriginalStatePreserved(afterFirstApply) + } + + @Test + fun bailsWhenServerTimestampStale() { + sendPniChangeNumber() + val afterFirstApply = captureOriginalState() + + val otherPniUuid = UUID.randomUUID() + val otherPniBytes = UuidUtil.toByteArray(otherPniUuid).toByteString() + + sendPniChangeNumber( + envelopePniBinary = otherPniBytes, + e164 = "+15555550100", + timestamp = messageHelper.nextStartTime() - 100_000L + ) + + assertOriginalStatePreserved(afterFirstApply) + } + + private fun captureOriginalState(): OriginalState { + val self = Recipient.self().fresh() + return OriginalState( + e164 = SignalStore.account.e164, + pni = SignalStore.account.pni, + pniRegistrationId = SignalStore.account.pniRegistrationId, + isSignedPreKeyRegistered = SignalStore.account.pniPreKeys.isSignedPreKeyRegistered, + activeSignedPreKeyId = SignalStore.account.pniPreKeys.activeSignedPreKeyId, + lastResortKyberPreKeyId = SignalStore.account.pniPreKeys.lastResortKyberPreKeyId, + pniIdentityPublicKey = SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString(), + selfE164 = self.e164.orNull(), + selfPni = self.pni.orNull(), + forcePniSignedPreKeyRotation = SignalStore.misc.forcePniSignedPreKeyRotation + ) + } + + private fun assertOriginalStatePreserved(original: OriginalState) { + assertThat(SignalStore.account.e164).isEqualTo(original.e164) + assertThat(SignalStore.account.pni).isEqualTo(original.pni) + assertThat(SignalStore.account.pniRegistrationId).isEqualTo(original.pniRegistrationId) + assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isEqualTo(original.isSignedPreKeyRegistered) + assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(original.activeSignedPreKeyId) + assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(original.lastResortKyberPreKeyId) + assertThat(SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString()) + .isEqualTo(original.pniIdentityPublicKey) + assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isEqualTo(original.forcePniSignedPreKeyRotation) + val self = Recipient.self().fresh() + assertThat(self.e164.orNull()).isEqualTo(original.selfE164) + assertThat(self.pni.orNull()).isEqualTo(original.selfPni) + } + + private data class OriginalState( + val e164: String?, + val pni: ServiceId.PNI?, + val pniRegistrationId: Int, + val isSignedPreKeyRegistered: Boolean, + val activeSignedPreKeyId: Int, + val lastResortKyberPreKeyId: Int, + val pniIdentityPublicKey: ByteString, + val selfE164: String?, + val selfPni: ServiceId.PNI?, + val forcePniSignedPreKeyRotation: Boolean + ) + + private fun malformedBytes(): ByteString = byteArrayOf(0x00, 0x01, 0x02).toByteString() + + private fun sendPniChangeNumber( + identityKeyPair: ByteString? = newPniIdentity.serialize().toByteString(), + signedPreKey: ByteString? = newSignedPreKey.serialize().toByteString(), + lastResortKyberPreKey: ByteString? = newLastResortKyber.serialize().toByteString(), + registrationId: Int? = newRegistrationId, + e164: String? = newE164, + envelopePniBinary: ByteString? = newPniBytes, + sourceDeviceId: Int = SignalServiceAddress.DEFAULT_DEVICE_ID, + timestamp: Long = messageHelper.nextStartTime() + ) { + val content = Content( + syncMessage = SyncMessage( + pniChangeNumber = SyncMessage.PniChangeNumber( + identityKeyPair = identityKeyPair, + signedPreKey = signedPreKey, + lastResortKyberPreKey = lastResortKyberPreKey, + registrationId = registrationId, + newE164 = e164 + ) + ) + ) + + val envelope = MessageContentFuzzer.envelope( + timestamp = timestamp, + updatedPniBinary = envelopePniBinary + ) + + messageHelper.processor.process( + envelope = envelope, + content = content, + metadata = MessageContentFuzzer.envelopeMetadata(harness.self.id, harness.self.id, sourceDeviceId = sourceDeviceId), + serverDeliveredTimestamp = timestamp + 10 + ) + } +} diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/MessageContentFuzzer.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/MessageContentFuzzer.kt index 9b1d281bb0..3a9d490e47 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/testing/MessageContentFuzzer.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/testing/MessageContentFuzzer.kt @@ -41,11 +41,12 @@ object MessageContentFuzzer { /** * Create an [Envelope]. */ - fun envelope(timestamp: Long, serverGuid: UUID = UUID.randomUUID()): Envelope { + fun envelope(timestamp: Long, serverGuid: UUID = UUID.randomUUID(), updatedPniBinary: ByteString? = null): Envelope { return Envelope.Builder() .clientTimestamp(timestamp) .serverTimestamp(timestamp + 5) .serverGuidBinary(serverGuid.toByteArray().toByteString()) + .also { if (updatedPniBinary != null) it.updatedPniBinary(updatedPniBinary) } .build() } diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt index 0ac0f97cee..3ef13bc8ad 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt @@ -103,20 +103,6 @@ class ChangeNumberRepository( @WorkerThread fun changeLocalNumber(e164: String, pni: ServiceId.PNI) { - SignalDatabase.recipients.updateSelfE164(e164, pni) - AppDependencies.recipientCache.clear() - - if (e164 != SignalStore.account.requireE164()) { - SignalDatabase.recipients.rotateStorageId(Recipient.self().fresh().id) - StorageSyncHelper.scheduleSyncForDataChange() - } - - SignalStore.account.setE164(e164) - SignalStore.account.setPni(pni) - AppDependencies.resetProtocolStores() - - AppDependencies.groupsV2Authorization.clear() - val metadata: PendingChangeNumberMetadata? = SignalStore.misc.pendingChangeNumberMetadata if (metadata == null) { Log.w(TAG, "No change number metadata, this shouldn't happen") @@ -125,25 +111,32 @@ class ChangeNumberRepository( val pniIdentityKeyPair = IdentityKeyPair(metadata.pniIdentityKeyPair.toByteArray()) val pniRegistrationId = metadata.pniRegistrationId - val pniSignedPreyKeyId = metadata.pniSignedPreKeyId + val pniSignedPreKeyId = metadata.pniSignedPreKeyId val pniLastResortKyberPreKeyId = metadata.pniLastResortKyberPreKeyId + // Prekeys were generated and stored during createChangeNumberRequest; reload them so we can pass them through and reuse for the upload below. + val preResetPniStore = AppDependencies.protocolStore.pni() + val signedPreKey = preResetPniStore.loadSignedPreKey(pniSignedPreKeyId) + val lastResortKyberPreKey = preResetPniStore.loadLastResortKyberPreKeys().firstOrNull { it.id == pniLastResortKyberPreKeyId } + + applyLocalNumberChange( + e164 = e164, + pni = pni, + pniIdentityKeyPair = pniIdentityKeyPair, + pniSignedPreKey = signedPreKey, + pniLastResortKyberPreKey = lastResortKyberPreKey, + pniRegistrationId = pniRegistrationId + ) + + AppDependencies.resetNetwork() + AppDependencies.startNetwork() + val pniProtocolStore = AppDependencies.protocolStore.pni() val pniMetadataStore = SignalStore.account.pniPreKeys - SignalStore.account.pniRegistrationId = pniRegistrationId - SignalStore.account.setPniIdentityKeyAfterChangeNumber(pniIdentityKeyPair) - - val signedPreKey = pniProtocolStore.loadSignedPreKey(pniSignedPreyKeyId) val oneTimeEcPreKeys = PreKeyUtil.generateAndStoreOneTimeEcPreKeys(pniProtocolStore, pniMetadataStore) - val lastResortKyberPreKey = pniProtocolStore.loadLastResortKyberPreKeys().firstOrNull { it.id == pniLastResortKyberPreKeyId } val oneTimeKyberPreKeys = PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(pniProtocolStore, pniMetadataStore) - if (lastResortKyberPreKey == null) { - Log.w(TAG, "Last-resort kyber prekey is missing!") - } - - pniMetadataStore.activeSignedPreKeyId = signedPreKey.id Log.i(TAG, "Submitting prekeys with PNI identity key: ${pniIdentityKeyPair.publicKey.fingerprint}") retryChangeLocalNumberNetworkOperation { @@ -161,6 +154,61 @@ class ChangeNumberRepository( pniMetadataStore.isSignedPreKeyRegistered = true pniMetadataStore.lastResortKyberPreKeyId = pniLastResortKyberPreKeyId + SignalStore.misc.hasPniInitializedDevices = true + + AppDependencies.jobManager.add(RefreshAttributesJob()) + + rotateCertificates() + + SignalStore.misc.unlockChangeNumber() + } + + /** + * Applies the local state for a successful number change: self recipient row, account values, + * PNI protocol store, and identity entry. + * + * Does NOT reset the network — callers must do so before any subsequent traffic that needs to + * use the new PNI. Does NOT make any server requests and does NOT flag prekeys as registered + * server-side — the caller is responsible for that once it can attest to server state. + */ + @WorkerThread + fun applyLocalNumberChange( + e164: String, + pni: ServiceId.PNI, + pniIdentityKeyPair: IdentityKeyPair, + pniSignedPreKey: SignedPreKeyRecord, + pniLastResortKyberPreKey: KyberPreKeyRecord?, + pniRegistrationId: Int + ) { + SignalDatabase.recipients.updateSelfE164(e164, pni) + AppDependencies.recipientCache.clear() + + if (e164 != SignalStore.account.requireE164()) { + SignalDatabase.recipients.rotateStorageId(Recipient.self().fresh().id) + StorageSyncHelper.scheduleSyncForDataChange() + } + + SignalStore.account.setE164(e164) + SignalStore.account.setPni(pni) + AppDependencies.resetProtocolStores() + + AppDependencies.groupsV2Authorization.clear() + + val pniProtocolStore = AppDependencies.protocolStore.pni() + val pniMetadataStore = SignalStore.account.pniPreKeys + + SignalStore.account.pniRegistrationId = pniRegistrationId + SignalStore.account.setPniIdentityKeyAfterChangeNumber(pniIdentityKeyPair) + + PreKeyUtil.storeSignedPreKey(pniProtocolStore, pniMetadataStore, pniSignedPreKey) + pniMetadataStore.activeSignedPreKeyId = pniSignedPreKey.id + + if (pniLastResortKyberPreKey != null) { + PreKeyUtil.storeLastResortKyberPreKey(pniProtocolStore, pniMetadataStore, pniLastResortKyberPreKey) + } else { + Log.w(TAG, "Last-resort kyber prekey is missing!") + } + pniProtocolStore.identities().saveIdentityWithoutSideEffects( Recipient.self().id, pni, @@ -171,20 +219,8 @@ class ChangeNumberRepository( true ) - SignalStore.misc.hasPniInitializedDevices = true - AppDependencies.groupsV2Authorization.clear() - Recipient.self().fresh() StorageSyncHelper.scheduleSyncForDataChange() - - AppDependencies.resetNetwork() - AppDependencies.startNetwork() - - AppDependencies.jobManager.add(RefreshAttributesJob()) - - rotateCertificates() - - SignalStore.misc.unlockChangeNumber() } @WorkerThread diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt index 80a9bf4a00..89a0a0d482 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt @@ -125,6 +125,11 @@ class PreKeysSyncJob private constructor( return } + val pniRotationOverride = SignalStore.misc.forcePniSignedPreKeyRotation + if (pniRotationOverride) { + warn(TAG, ServiceIdType.PNI, "Forced PNI prekey rotation pending after PniChangeNumber sync. Bypassing dedup/interval gating for PNI.") + } + val forceRotation = if (forceRotationRequested) { warn(TAG, "Forced rotation was requested.") warn(TAG, ServiceIdType.ACI, "Active Signed EC: ${SignalStore.account.aciPreKeys.activeSignedPreKeyId}, Last Resort Kyber: ${SignalStore.account.aciPreKeys.lastResortKyberPreKeyId}") @@ -146,19 +151,26 @@ class PreKeysSyncJob private constructor( false } - if (forceRotation) { - warn(TAG, "Forcing prekey rotation.") + val forcePniRotation = forceRotation || pniRotationOverride + + if (forcePniRotation) { + warn(TAG, "Forcing prekey rotation. ACI=$forceRotation PNI=$forcePniRotation") } else if (forceRotationRequested) { warn(TAG, "Forced prekey rotation was requested, but we already did a forced refresh ${System.currentTimeMillis() - SignalStore.misc.lastForcedPreKeyRefresh} ms ago. Ignoring.") } syncPreKeys(ServiceIdType.ACI, SignalStore.account.aci, AppDependencies.protocolStore.aci(), SignalStore.account.aciPreKeys, forceRotation) - syncPreKeys(ServiceIdType.PNI, SignalStore.account.pni, AppDependencies.protocolStore.pni(), SignalStore.account.pniPreKeys, forceRotation) + syncPreKeys(ServiceIdType.PNI, SignalStore.account.pni, AppDependencies.protocolStore.pni(), SignalStore.account.pniPreKeys, forcePniRotation) SignalStore.misc.lastFullPrekeyRefreshTime = System.currentTimeMillis() - if (forceRotation) { + if (forcePniRotation) { SignalStore.misc.lastForcedPreKeyRefresh = System.currentTimeMillis() } + + if (pniRotationOverride) { + // Cleared only after both syncPreKeys calls completed without throwing; a thrown upload leaves the flag set for the next attempt. + SignalStore.misc.forcePniSignedPreKeyRotation = false + } } private fun syncPreKeys(serviceIdType: ServiceIdType, serviceId: ServiceId?, protocolStore: SignalServiceAccountDataStore, metadataStore: PreKeyMetadataStore, forceRotation: Boolean) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/keyvalue/MiscellaneousValues.kt b/app/src/main/java/org/thoughtcrime/securesms/keyvalue/MiscellaneousValues.kt index 50b3495b66..2011ba90e6 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/keyvalue/MiscellaneousValues.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/keyvalue/MiscellaneousValues.kt @@ -32,6 +32,7 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto private const val LAST_SERVER_TIME_OFFSET_UPDATE = "misc.last_server_time_offset_update" private const val NEEDS_USERNAME_RESTORE = "misc.needs_username_restore" private const val LAST_FORCED_PREKEY_REFRESH = "misc.last_forced_prekey_refresh" + private const val FORCE_PNI_SIGNED_PREKEY_ROTATION = "misc.force_pni_signed_prekey_rotation" private const val LAST_CDS_FOREGROUND_SYNC = "misc.last_cds_foreground_sync" private const val LINKED_DEVICE_LAST_ACTIVE_CHECK_TIME = "misc.linked_device.last_active_check_time" private const val LEAST_ACTIVE_LINKED_DEVICE = "misc.linked_device.least_active" @@ -51,6 +52,7 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto private const val CAPTCHA_LAST_VIEWED_AT = "misc.captcha_last_viewed_at" private const val CALLING_ASSETS_VERSION = "misc.calling_assets_version" private const val LAST_SYNC_MESSAGE_SEEN_TIME_MS = "misc.last_sync_message_seen_time" + private const val LAST_APPLIED_PNI_CHANGE_SERVER_TIMESTAMP = "misc.last_applied_pni_change_server_timestamp" } public override fun onFirstEverAppLaunch() { @@ -75,6 +77,17 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto */ var lastForcedPreKeyRefresh by longValue(LAST_FORCED_PREKEY_REFRESH, 0) + /** + * Bypasses the timeout in [org.thoughtcrime.securesms.jobs.PreKeysSyncJob] since otherwise we can hit a race. + */ + var forcePniSignedPreKeyRotation by booleanValue(FORCE_PNI_SIGNED_PREKEY_ROTATION, false) + + /** + * Envelope serverTimestamp of the most recently applied PniChangeNumber sync. Used to reject + * stale replays — a sync with serverTimestamp <= this value is treated as a replay and ignored. + */ + var lastAppliedPniChangeServerTimestamp by longValue(LAST_APPLIED_PNI_CHANGE_SERVER_TIMESTAMP, 0L) + /** * The last time we completed a routine profile refresh. */ diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index 8338a455d4..2f4c1b9134 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -255,7 +255,10 @@ class IncomingMessageObserver( val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection" - Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Open or Keep-alives: $websocketAlreadyOpen, Registered: $registered, Unauthorized: $unauthorizedReceived, Proxy: $hasProxy, Force websocket: $forceWebsocket") + Log.d( + TAG, + "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Open or Keep-alives: $websocketAlreadyOpen, Registered: $registered, Unauthorized: $unauthorizedReceived, Proxy: $hasProxy, Force websocket: $forceWebsocket" + ) return conclusion } @@ -287,7 +290,7 @@ class IncomingMessageObserver( } @VisibleForTesting - fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List? { + fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): ProcessingResult? { return when (envelope.type) { Envelope.Type.SERVER_DELIVERY_RECEIPT -> { processReceipt(envelope) @@ -299,9 +302,9 @@ class IncomingMessageObserver( Envelope.Type.UNIDENTIFIED_SENDER, Envelope.Type.PLAINTEXT_CONTENT -> { SignalTrace.beginSection("IncomingMessageObserver#processMessage") - val followUps = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp, batchCache) + val result = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp, batchCache) SignalTrace.endSection() - followUps + result } else -> { @@ -311,56 +314,79 @@ class IncomingMessageObserver( } } - private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List { + private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): ProcessingResult { val localReceiveMetric = SignalLocalMetrics.MessageReceive.start() SignalTrace.beginSection("IncomingMessageObserver#decryptMessage") val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp) SignalTrace.endSection() localReceiveMetric.onEnvelopeDecrypted() + var isNetworkResetRequired = false + SignalLocalMetrics.MessageLatency.onMessageReceived(envelope.serverTimestamp!!, serverDeliveredTimestamp, envelope.urgent!!) when (result) { is MessageDecryptor.Result.Success -> { val job = PushProcessMessageJob.processOrDefer(messageContentProcessor, result, localReceiveMetric, batchCache) + isNetworkResetRequired = isNetworkResetRequired(result, bufferedProtocolStore.pni) if (job != null) { - return result.followUpOperations + FollowUpOperation { job.asChain() } - } - } - is MessageDecryptor.Result.Error -> { - return result.followUpOperations + FollowUpOperation { - val jobs = mutableListOf() - - if (result.errorMetadata.groupMasterKey != null) { - val groupId = result.errorMetadata.groupId!! - if (!SignalDatabase.groups.getGroup(groupId).isPresent) { - Log.w(TAG, "Decryption error in group, but group not found. Creating placeholder for groupId: $groupId") - SignalDatabase.groups.create( - groupMasterKey = result.errorMetadata.groupMasterKey!!, - groupState = DecryptedGroup(revision = GroupsV2StateProcessor.RESTORE_PLACEHOLDER_REVISION), - groupSendEndorsements = null - ) - jobs += RequestGroupV2InfoJob(groupId) - } - } - - jobs += PushProcessMessageErrorJob( - result.toMessageState(), - result.errorMetadata.toExceptionMetadata(), - result.envelope.clientTimestamp!! + return ProcessingResult( + followUpOperations = result.followUpOperations + FollowUpOperation { job.asChain() }, + isNetworkResetRequired = isNetworkResetRequired ) - - AppDependencies.jobManager.startChain(jobs) } } + + is MessageDecryptor.Result.Error -> { + return ProcessingResult( + result.followUpOperations + FollowUpOperation { + val jobs = mutableListOf() + + if (result.errorMetadata.groupMasterKey != null) { + val groupId = result.errorMetadata.groupId!! + if (!SignalDatabase.groups.getGroup(groupId).isPresent) { + Log.w(TAG, "Decryption error in group, but group not found. Creating placeholder for groupId: $groupId") + SignalDatabase.groups.create( + groupMasterKey = result.errorMetadata.groupMasterKey!!, + groupState = DecryptedGroup(revision = GroupsV2StateProcessor.RESTORE_PLACEHOLDER_REVISION), + groupSendEndorsements = null + ) + jobs += RequestGroupV2InfoJob(groupId) + } + } + + jobs += PushProcessMessageErrorJob( + result.toMessageState(), + result.errorMetadata.toExceptionMetadata(), + result.envelope.clientTimestamp!! + ) + + AppDependencies.jobManager.startChain(jobs) + } + ) + } + is MessageDecryptor.Result.Ignore -> { // No action needed } + else -> { throw AssertionError("Unexpected result! ${result.javaClass.simpleName}") } } - return result.followUpOperations + return ProcessingResult( + followUpOperations = result.followUpOperations, + isNetworkResetRequired = isNetworkResetRequired + ) + } + + /** + * True iff this envelope's PniChangeNumber sync actually changed our PNI within this batch. + * Comparing the batch-start PNI against the current value makes the check idempotent — a + * redelivered envelope finds the PNI already applied and won't re-trigger a websocket reset. + */ + private fun isNetworkResetRequired(result: MessageDecryptor.Result.Success, pniAtBatchStart: ServiceId.PNI): Boolean { + return result.content.syncMessage?.pniChangeNumber != null && SignalStore.account.pni != pniAtBatchStart } private fun processReceipt(envelope: Envelope) { @@ -527,16 +553,26 @@ class IncomingMessageObserver( val allFollowUpOperations = mutableListOf() val bufferedStore = BufferedProtocolStore.create() val batchCache = ReusedBatchCache() + var processedCount = 0 + var networkResetRequired = false val committed = SignalDatabase.tryRunInTransaction { - batch.forEach { response -> + for (response in batch) { SignalTrace.beginSection("IncomingMessageObserver#perMessageTransaction") - val followUps = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache) + val result = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache) bufferedStore.flushToDisk() SignalTrace.endSection() - if (followUps?.isNotEmpty() == true) { - allFollowUpOperations += followUps + if (result?.followUpOperations?.isNotEmpty() == true) { + allFollowUpOperations += result.followUpOperations + } + + processedCount++ + + if (result?.isNetworkResetRequired == true) { + networkResetRequired = true + Log.w(TAG, "Self identity changed mid-batch after envelope $processedCount of ${batch.size}. Committing what we have; the remainder will be redelivered to the new connection.") + break } } } @@ -550,8 +586,13 @@ class IncomingMessageObserver( AppDependencies.jobManager.addAllChains(jobs) } - batch.forEach { response -> - authWebSocket.sendAck(response) + for (i in 0 until processedCount) { + sendAckSafely(batch[i], i, batch.size) + } + + if (networkResetRequired) { + AppDependencies.resetNetwork() + AppDependencies.startNetwork() } } @@ -565,26 +606,46 @@ class IncomingMessageObserver( val bufferedStore = BufferedProtocolStore.create() val batchCache = ReusedBatchCache() - batch.forEach { response -> + for ((index, response) in batch.withIndex()) { SignalTrace.beginSection("IncomingMessageObserver#perMessageTransaction") - val followUpOperations = SignalDatabase.runInTransaction { - val followUps = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache) + val results = SignalDatabase.runInTransaction { + val result = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache) bufferedStore.flushToDisk() - followUps + result } SignalTrace.endSection() - if (followUpOperations?.isNotEmpty() == true) { - val jobs = followUpOperations.mapNotNull { it.run() } + if (results?.followUpOperations?.isNotEmpty() == true) { + val jobs = results.followUpOperations.mapNotNull { it.run() } AppDependencies.jobManager.addAllChains(jobs) } - authWebSocket.sendAck(response) + sendAckSafely(response, index, batch.size) + + if (results?.isNetworkResetRequired == true) { + Log.w(TAG, "Self identity changed mid-batch after envelope ${index + 1} of ${batch.size}. Stopping individual processing; the remainder will be redelivered to the new connection.") + AppDependencies.resetNetwork() + AppDependencies.startNetwork() + break + } } batchCache.flushAndClear() } + /** + * Best-effort ack. Failures just mean the server will redeliver — and for a redelivered + * PniChangeNumber sync, [isNetworkResetRequired] sees the PNI is already applied and won't + * re-trigger a reset, so we don't loop. + */ + private fun sendAckSafely(response: EnvelopeResponse, index: Int, size: Int) { + try { + authWebSocket.sendAck(response) + } catch (e: Exception) { + Log.w(TAG, "Failed to send ack for envelope $index of $size. The server will redeliver.", e) + } + } + override fun uncaughtException(t: Thread, e: Throwable) { Log.w(TAG, "Uncaught exception in message thread!", e) } @@ -649,4 +710,9 @@ class IncomingMessageObserver( } } } + + data class ProcessingResult( + val followUpOperations: List, + val isNetworkResetRequired: Boolean = false + ) } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt index d541319d81..f17745a735 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt @@ -15,8 +15,12 @@ import org.signal.core.util.UuidUtil import org.signal.core.util.isNotEmpty import org.signal.core.util.orNull import org.signal.libsignal.protocol.IdentityKey +import org.signal.libsignal.protocol.IdentityKeyPair import org.signal.libsignal.protocol.InvalidKeyException +import org.signal.libsignal.protocol.ServiceId.Pni import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.state.KyberPreKeyRecord +import org.signal.libsignal.protocol.state.SignedPreKeyRecord import org.signal.ringrtc.CallException import org.signal.ringrtc.CallId import org.signal.ringrtc.CallLinkRootKey @@ -24,6 +28,7 @@ import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.DatabaseAttachment import org.thoughtcrime.securesms.attachments.TombstoneAttachment import org.thoughtcrime.securesms.components.emoji.EmojiUtil +import org.thoughtcrime.securesms.components.settings.app.changenumber.ChangeNumberRepository import org.thoughtcrime.securesms.contactshare.Contact import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.CallLinkTable @@ -66,6 +71,7 @@ import org.thoughtcrime.securesms.jobs.MultiDeviceContactSyncJob import org.thoughtcrime.securesms.jobs.MultiDeviceContactUpdateJob import org.thoughtcrime.securesms.jobs.MultiDeviceKeysUpdateJob import org.thoughtcrime.securesms.jobs.MultiDeviceStickerPackSyncJob +import org.thoughtcrime.securesms.jobs.PreKeysSyncJob import org.thoughtcrime.securesms.jobs.PushProcessEarlyMessagesJob import org.thoughtcrime.securesms.jobs.RefreshCallLinkDetailsJob import org.thoughtcrime.securesms.jobs.RefreshDonationSubscriptionStatusJob @@ -175,6 +181,7 @@ object SyncMessageProcessor { syncMessage.outgoingPayment != null -> handleSynchronizeOutgoingPayment(syncMessage.outgoingPayment!!, envelope.clientTimestamp!!) syncMessage.contacts != null -> handleSynchronizeContacts(syncMessage.contacts!!, envelope.clientTimestamp!!) syncMessage.keys != null -> handleSynchronizeKeys(syncMessage.keys!!, envelope.clientTimestamp!!) + syncMessage.pniChangeNumber != null -> handleSynchronizePniChangeNumber(envelope, metadata, syncMessage.pniChangeNumber!!) syncMessage.callEvent != null -> handleSynchronizeCallEvent(syncMessage.callEvent!!, envelope.clientTimestamp!!) syncMessage.callLinkUpdate != null -> handleSynchronizeCallLink(syncMessage.callLinkUpdate!!, envelope.clientTimestamp!!) syncMessage.callLogEvent != null -> handleSynchronizeCallLogEvent(syncMessage.callLogEvent!!, envelope.clientTimestamp!!) @@ -1750,6 +1757,99 @@ object SyncMessageProcessor { MultiDeviceAttachmentBackfillUpdateJob.enqueue(request.targetMessage!!, request.targetConversation!!, messageId) } + private fun handleSynchronizePniChangeNumber(envelope: Envelope, metadata: EnvelopeMetadata, pniChangeNumber: SyncMessage.PniChangeNumber) { + val timestamp = envelope.clientTimestamp!! + + if (SignalStore.account.isPrimaryDevice) { + warn(timestamp, "Received a PniChangeNumber sync message on the primary device. Bailing.") + return + } + + if (metadata.sourceDeviceId != SignalServiceAddress.DEFAULT_DEVICE_ID) { + warn(timestamp, "Received a PniChangeNumber sync message from a non-primary device (${metadata.sourceDeviceId}). Bailing.") + return + } + + if (SignalStore.account.aci == null) { + warn(timestamp, "Received a PniChangeNumber sync message but no local ACI is set. Bailing.") + return + } + + val envelopeServerTimestamp = envelope.serverTimestamp ?: 0L + val lastAppliedServerTimestamp = SignalStore.misc.lastAppliedPniChangeServerTimestamp + if (envelopeServerTimestamp <= lastAppliedServerTimestamp) { + warn(timestamp, "PniChangeNumber sync serverTimestamp ($envelopeServerTimestamp) is not newer than the last applied ($lastAppliedServerTimestamp). Treating as a replay and bailing.") + return + } + + // updatedPniBinary is a raw 16-byte UUID per the proto contract instead of a 17-byte service-id array. + val pni = if (envelope.updatedPniBinary != null) { + val updatedPniUuid = UuidUtil.parseOrNull(envelope.updatedPniBinary!!.toByteArray()) + if (updatedPniUuid == null) { + warn(timestamp, "Could not parse updatedPniBinary as a UUID. Bailing.") + return + } + Pni(updatedPniUuid) + } else if (envelope.updatedPni != null) { + Pni.parseFromString(envelope.updatedPni) + } else { + warn(timestamp, "Neither updatedPni or updatedPniBinary were present on the envelope. Bailing.") + return + } + + if (SignalStore.account.pni == PNI(pni)) { + log(timestamp, "PniChangeNumber sync already applied locally. Skipping.") + return + } + + val identityKeyPairBytes = pniChangeNumber.identityKeyPair + val signedPreKeyBytes = pniChangeNumber.signedPreKey + val registrationId = pniChangeNumber.registrationId + val newE164 = pniChangeNumber.newE164 + + if (identityKeyPairBytes == null || signedPreKeyBytes == null || registrationId == null || registrationId <= 0 || newE164.isNullOrEmpty() || !SignalE164Util.isPotentialE164(newE164)) { + warn(timestamp, "PniChangeNumber sync message is missing or has an invalid required field. Bailing.") + return + } + + val pniIdentityKeyPair: IdentityKeyPair + val pniSignedPreKey: SignedPreKeyRecord + val pniLastResortKyberPreKey: KyberPreKeyRecord? + try { + pniIdentityKeyPair = IdentityKeyPair(identityKeyPairBytes.toByteArray()) + pniSignedPreKey = SignedPreKeyRecord(signedPreKeyBytes.toByteArray()) + pniLastResortKyberPreKey = pniChangeNumber.lastResortKyberPreKey?.let { KyberPreKeyRecord(it.toByteArray()) } + } catch (e: Exception) { + warn(timestamp, "Failed to deserialize PniChangeNumber sync message. Bailing.", e) + return + } + + log(timestamp, "Applying PniChangeNumber sync message.") + + ChangeNumberRepository().applyLocalNumberChange( + e164 = newE164, + pni = PNI(pni), + pniIdentityKeyPair = pniIdentityKeyPair, + pniSignedPreKey = pniSignedPreKey, + pniLastResortKyberPreKey = pniLastResortKyberPreKey, + pniRegistrationId = registrationId + ) + + SignalStore.misc.lastAppliedPniChangeServerTimestamp = envelopeServerTimestamp + + // The primary already submitted these per-device prekeys to the server as part of the + // change-number request, so they are registered server-side from this device's perspective. + val pniMetadataStore = SignalStore.account.pniPreKeys + pniMetadataStore.isSignedPreKeyRegistered = true + if (pniLastResortKyberPreKey != null) { + pniMetadataStore.lastResortKyberPreKeyId = pniLastResortKyberPreKey.id + } + + // Rotate the primary-generated keys as soon as possible so we don't rely on them long-term. + SignalStore.misc.forcePniSignedPreKeyRotation = true + AppDependencies.jobManager.add(PreKeysSyncJob.create(forceRotationRequested = true)) + } + private fun handleSynchronizedPollCreate( envelope: Envelope, message: DataMessage, diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt index d901c14bfa..a4e60d8fc3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/protocol/BufferedProtocolStore.kt @@ -1,6 +1,7 @@ package org.thoughtcrime.securesms.messages.protocol import org.signal.core.models.ServiceId +import org.signal.core.models.ServiceId.PNI import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.keyvalue.SignalStore @@ -13,9 +14,12 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore */ class BufferedProtocolStore private constructor( private val aciStore: Pair, - private val pniStore: Pair + private val pniStore: Pair ) { + /** The PNI captured when this batch's store was created. Does not refresh if [SignalStore.account.pni] later changes mid-batch. */ + val pni: PNI get() = pniStore.first + fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore { return when (serviceId) { aciStore.first -> aciStore.second diff --git a/app/src/test/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJobTest.kt b/app/src/test/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJobTest.kt new file mode 100644 index 0000000000..1768972341 --- /dev/null +++ b/app/src/test/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJobTest.kt @@ -0,0 +1,179 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.jobs + +import android.app.Application +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkObject +import io.mockk.mockkStatic +import io.mockk.unmockkObject +import io.mockk.unmockkStatic +import io.mockk.verify +import org.junit.After +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config +import org.signal.core.models.ServiceId +import org.signal.libsignal.protocol.IdentityKeyPair +import org.signal.libsignal.protocol.state.KyberPreKeyRecord +import org.signal.libsignal.protocol.state.SignedPreKeyRecord +import org.signal.network.NetworkResult +import org.thoughtcrime.securesms.crypto.PreKeyUtil +import org.thoughtcrime.securesms.crypto.storage.PreKeyMetadataStore +import org.thoughtcrime.securesms.crypto.storage.SignalServiceAccountDataStoreImpl +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.keyvalue.MiscellaneousValues +import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.testutil.MockAppDependenciesRule +import org.thoughtcrime.securesms.testutil.MockSignalStoreRule +import org.thoughtcrime.securesms.util.RemoteConfig +import org.whispersystems.signalservice.api.keys.OneTimePreKeyCounts +import java.io.IOException +import java.util.UUID +import kotlin.time.Duration.Companion.hours +import kotlin.time.Duration.Companion.minutes + +@RunWith(RobolectricTestRunner::class) +@Config(manifest = Config.NONE, application = Application::class) +class PreKeysSyncJobTest { + + @get:Rule + val mockSignalStore = MockSignalStoreRule() + + @get:Rule + val appDependencies = MockAppDependenciesRule() + + private val misc: MiscellaneousValues = mockk(relaxUnitFun = true) + private val aciMetadataStore: PreKeyMetadataStore = mockk(relaxUnitFun = true) + private val pniMetadataStore: PreKeyMetadataStore = mockk(relaxUnitFun = true) + private val aciProtocolStore: SignalServiceAccountDataStoreImpl = mockk(relaxed = true) + private val pniProtocolStore: SignalServiceAccountDataStoreImpl = mockk(relaxed = true) + + @Before + fun setUp() { + every { SignalStore.misc } returns misc + + every { mockSignalStore.account.isRegistered } returns true + every { mockSignalStore.account.aci } returns ServiceId.ACI.from(UUID.randomUUID()) + every { mockSignalStore.account.pni } returns ServiceId.PNI.from(UUID.randomUUID()) + every { mockSignalStore.account.aciPreKeys } returns aciMetadataStore + every { mockSignalStore.account.pniPreKeys } returns pniMetadataStore + + // Default metadata: everything is fresh and registered, so absent a force, no rotation triggers. + listOf(aciMetadataStore, pniMetadataStore).forEach { + every { it.isSignedPreKeyRegistered } returns true + every { it.activeSignedPreKeyId } returns 1 + every { it.lastResortKyberPreKeyId } returns 1 + every { it.lastSignedPreKeyRotationTime } returns System.currentTimeMillis() + every { it.lastResortKyberPreKeyRotationTime } returns System.currentTimeMillis() + } + + every { misc.lastForcedPreKeyRefresh } returns 0L + every { misc.forcePniSignedPreKeyRotation } returns false + + // `AppDependencies.protocolStore` / `keysApi` are already relaxed mockks set up by + // MockAppDependenciesRule; configure the chained calls we care about. + every { AppDependencies.protocolStore.aci() } returns aciProtocolStore + every { AppDependencies.protocolStore.pni() } returns pniProtocolStore + + val identityKeyPair = IdentityKeyPair.generate() + every { aciProtocolStore.identityKeyPair } returns identityKeyPair + every { pniProtocolStore.identityKeyPair } returns identityKeyPair + + // Counts well above ONE_TIME_PREKEY_MINIMUM (10) so we don't generate one-time keys unless forced. + every { AppDependencies.keysApi.getAvailablePreKeyCountsSync(any()) } returns NetworkResult.Success(OneTimePreKeyCounts(100, 100)) + every { AppDependencies.keysApi.setPreKeysSync(any()) } returns NetworkResult.Success(Unit) + // Consistency check (only reached when forceRotationRequested=true) returns "everything matches". + every { AppDependencies.keysApi.checkRepeatedUseKeysSync(any(), any(), any(), any(), any(), any()) } returns NetworkResult.Success(Unit) + + mockkObject(RemoteConfig) + every { RemoteConfig.preKeyForceRefreshInterval } returns 1.hours.inWholeMilliseconds + // Used by BaseJob's retry-backoff path when a syncPreKeys call throws a retryable IOException. + every { RemoteConfig.defaultMaxBackoff } returns 1.hours.inWholeMilliseconds + + mockkStatic(PreKeyUtil::class) + every { PreKeyUtil.generateAndStoreSignedPreKey(any(), any()) } answers { fakeSignedPreKey() } + every { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(any(), any()) } returns emptyList() + every { PreKeyUtil.generateAndStoreLastResortKyberPreKey(any(), any()) } answers { fakeKyberPreKey() } + every { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(any(), any()) } returns emptyList() + every { PreKeyUtil.cleanSignedPreKeys(any(), any()) } returns Unit + every { PreKeyUtil.cleanLastResortKyberPreKeys(any(), any()) } returns Unit + every { PreKeyUtil.cleanOneTimePreKeys(any()) } returns Unit + } + + @After + fun tearDown() { + unmockkObject(RemoteConfig) + unmockkStatic(PreKeyUtil::class) + } + + @Test + fun `when forcePniSignedPreKeyRotation flag set, PNI sync runs forced and ACI does not`() { + every { misc.forcePniSignedPreKeyRotation } returns true + + PreKeysSyncJob.create(forceRotationRequested = false).run() + + verify(exactly = 1) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) } + verify(exactly = 1) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(pniProtocolStore, pniMetadataStore) } + verify(exactly = 1) { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(pniProtocolStore, pniMetadataStore) } + verify(exactly = 1) { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(pniProtocolStore, pniMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(aciProtocolStore, aciMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(aciProtocolStore, aciMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(aciProtocolStore, aciMetadataStore) } + verify(exactly = 1) { misc.forcePniSignedPreKeyRotation = false } + } + + @Test + fun `when forcePniSignedPreKeyRotation flag set but uploads fail, flag is preserved`() { + every { misc.forcePniSignedPreKeyRotation } returns true + // Fail PNI's upload so syncPreKeys throws before the flag-clear runs. + every { AppDependencies.keysApi.setPreKeysSync(any()) } returns NetworkResult.NetworkError(IOException("simulated")) + + PreKeysSyncJob.create(forceRotationRequested = false).run() + + verify(exactly = 0) { misc.forcePniSignedPreKeyRotation = false } + } + + @Test + fun `when flag not set, no PNI force and flag write is skipped`() { + every { misc.forcePniSignedPreKeyRotation } returns false + + PreKeysSyncJob.create(forceRotationRequested = false).run() + + verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) } + verify(exactly = 0) { misc.forcePniSignedPreKeyRotation = false } + } + + @Test + fun `flag set forces PNI rotation even when consistency check passes and time gate would skip`() { + every { misc.forcePniSignedPreKeyRotation } returns true + // forceRotationRequested=true + consistency checks pass + recent forced refresh (well within + // preKeyForceRefreshInterval=1h) → without the flag, the existing logic would skip rotation. + every { misc.lastForcedPreKeyRefresh } returns System.currentTimeMillis() - 1.minutes.inWholeMilliseconds + + PreKeysSyncJob.create(forceRotationRequested = true).run() + + verify(exactly = 1) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) } + verify(exactly = 1) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(pniProtocolStore, pniMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) } + verify(exactly = 0) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(aciProtocolStore, aciMetadataStore) } + verify(exactly = 1) { misc.forcePniSignedPreKeyRotation = false } + } + + private fun fakeSignedPreKey(): SignedPreKeyRecord = mockk(relaxed = true) { + every { id } returns 42 + } + + private fun fakeKyberPreKey(): KyberPreKeyRecord = mockk(relaxed = true) { + every { id } returns 42 + } +}