Handle PniChangeNumber sync on linked devices.

This commit is contained in:
Alex Hart
2026-05-21 10:12:05 -03:00
committed by jeffrey-signal
parent 698fc38aed
commit 2ea59bef68
9 changed files with 852 additions and 88 deletions
@@ -0,0 +1,353 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.messages
import androidx.test.ext.junit.runners.AndroidJUnit4
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNotNull
import assertk.assertions.isTrue
import okio.ByteString
import okio.ByteString.Companion.toByteString
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.signal.core.models.ServiceId
import org.signal.core.util.UuidUtil
import org.signal.core.util.orNull
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.thoughtcrime.securesms.crypto.PreKeyUtil
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.testing.MessageContentFuzzer
import org.thoughtcrime.securesms.testing.SignalActivityRule
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.internal.push.Content
import org.whispersystems.signalservice.internal.push.SyncMessage
import java.util.UUID
@Suppress("ClassName")
@RunWith(AndroidJUnit4::class)
class SyncMessageProcessorTest_synchronizePniChangeNumber {
@get:Rule
val harness = SignalActivityRule(createGroup = true)
private lateinit var messageHelper: MessageHelper
private val newPniUuid: UUID = UUID.randomUUID()
private val newPni: ServiceId.PNI = ServiceId.PNI.from(newPniUuid)
// 16-byte raw UUID — matches the actual wire format the server sends (per proto comment and
// iOS/Desktop behavior). Do NOT use `newPni.toByteString()` here — that produces libsignal's
// 17-byte ServiceIdBinary form, which is a different format.
private val newPniBytes: ByteString = UuidUtil.toByteArray(newPniUuid).toByteString()
private val newE164 = "+15555550199"
private val newPniIdentity: IdentityKeyPair = IdentityKeyPair.generate()
private val newSignedPreKey: SignedPreKeyRecord = PreKeyUtil.generateSignedPreKey(1234, newPniIdentity.privateKey)
private val newLastResortKyber: KyberPreKeyRecord = PreKeyUtil.generateLastResortKyberPreKey(5678, newPniIdentity.privateKey)
private val newRegistrationId = 4242
@Before
fun setUp() {
messageHelper = MessageHelper(harness)
SignalStore.account.deviceId = 2
}
@After
fun tearDown() {
messageHelper.tearDown()
}
@Test
fun appliesAllStateOnHappyPath() {
sendPniChangeNumber()
assertThat(SignalStore.account.e164).isEqualTo(newE164)
assertThat(SignalStore.account.pni).isEqualTo(newPni)
assertThat(SignalStore.account.pniRegistrationId).isEqualTo(newRegistrationId)
assertThat(SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString())
.isEqualTo(newPniIdentity.publicKey.serialize().toByteString())
assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(newSignedPreKey.id)
assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isTrue()
assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(newLastResortKyber.id)
assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isTrue()
val self = Recipient.self().fresh()
assertThat(self.requireE164()).isEqualTo(newE164)
assertThat(self.pni.orNull()).isEqualTo(newPni)
val pniProtocolStore = AppDependencies.protocolStore.pni()
val storedSigned = pniProtocolStore.loadSignedPreKey(newSignedPreKey.id)
assertThat(storedSigned.serialize().toByteString()).isEqualTo(newSignedPreKey.serialize().toByteString())
val storedKyber = pniProtocolStore.loadLastResortKyberPreKeys().firstOrNull { it.id == newLastResortKyber.id }
assertThat(storedKyber).isNotNull()
assertThat(storedKyber!!.serialize().toByteString()).isEqualTo(newLastResortKyber.serialize().toByteString())
// The IdentityTable cache is keyed by ServiceId string, not RecipientId — for self, that's
// separate ACI and PNI rows. We want the PNI row, so look it up by the new PNI directly.
val selfPniIdentity = pniProtocolStore.getIdentity(SignalProtocolAddress(newPni.toString(), SignalServiceAddress.DEFAULT_DEVICE_ID))
assertThat(selfPniIdentity).isNotNull()
assertThat(selfPniIdentity!!.publicKey.serialize().toByteString())
.isEqualTo(newPniIdentity.publicKey.serialize().toByteString())
}
@Test
fun appliesStateWhenLastResortKyberAbsent() {
val original = captureOriginalState()
sendPniChangeNumber(lastResortKyberPreKey = null)
assertThat(SignalStore.account.e164).isEqualTo(newE164)
assertThat(SignalStore.account.pni).isEqualTo(newPni)
assertThat(SignalStore.account.pniRegistrationId).isEqualTo(newRegistrationId)
assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(newSignedPreKey.id)
assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isTrue()
// No kyber was supplied, so kyber metadata should be unchanged.
assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(original.lastResortKyberPreKeyId)
assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isTrue()
}
@Test
fun bailsWhenPrimaryDevice() {
SignalStore.account.deviceId = SignalServiceAddress.DEFAULT_DEVICE_ID
val original = captureOriginalState()
sendPniChangeNumber()
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenSourceIsNotPrimaryDevice() {
val original = captureOriginalState()
sendPniChangeNumber(sourceDeviceId = 3)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenEnvelopePniMissing() {
val original = captureOriginalState()
sendPniChangeNumber(envelopePniBinary = null)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenIdentityKeyPairMissing() {
val original = captureOriginalState()
sendPniChangeNumber(identityKeyPair = null)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenSignedPreKeyMissing() {
val original = captureOriginalState()
sendPniChangeNumber(signedPreKey = null)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenRegistrationIdMissing() {
val original = captureOriginalState()
sendPniChangeNumber(registrationId = null)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenRegistrationIdZero() {
val original = captureOriginalState()
sendPniChangeNumber(registrationId = 0)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenNewE164Missing() {
val original = captureOriginalState()
sendPniChangeNumber(e164 = null)
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenNewE164Empty() {
val original = captureOriginalState()
sendPniChangeNumber(e164 = "")
assertOriginalStatePreserved(original)
}
@Test
fun bailsWhenNewE164NotValid() {
val original = captureOriginalState()
sendPniChangeNumber(e164 = "not a phone number")
assertOriginalStatePreserved(original)
}
@Test
fun bailsOnMalformedIdentityKeyPair() {
val original = captureOriginalState()
sendPniChangeNumber(identityKeyPair = malformedBytes())
assertOriginalStatePreserved(original)
}
@Test
fun bailsOnMalformedSignedPreKey() {
val original = captureOriginalState()
sendPniChangeNumber(signedPreKey = malformedBytes())
assertOriginalStatePreserved(original)
}
@Test
fun bailsOnMalformedLastResortKyber() {
val original = captureOriginalState()
sendPniChangeNumber(lastResortKyberPreKey = malformedBytes())
assertOriginalStatePreserved(original)
}
@Test
fun skipsRedeliveryWhenPniAlreadyMatches() {
sendPniChangeNumber()
val afterFirstApply = captureOriginalState()
val otherIdentity = IdentityKeyPair.generate()
val otherSignedPreKey = PreKeyUtil.generateSignedPreKey(9999, otherIdentity.privateKey)
sendPniChangeNumber(
identityKeyPair = otherIdentity.serialize().toByteString(),
signedPreKey = otherSignedPreKey.serialize().toByteString(),
e164 = "+15555550100",
timestamp = messageHelper.nextStartTime() + 1000
)
assertOriginalStatePreserved(afterFirstApply)
}
@Test
fun bailsWhenServerTimestampStale() {
sendPniChangeNumber()
val afterFirstApply = captureOriginalState()
val otherPniUuid = UUID.randomUUID()
val otherPniBytes = UuidUtil.toByteArray(otherPniUuid).toByteString()
sendPniChangeNumber(
envelopePniBinary = otherPniBytes,
e164 = "+15555550100",
timestamp = messageHelper.nextStartTime() - 100_000L
)
assertOriginalStatePreserved(afterFirstApply)
}
private fun captureOriginalState(): OriginalState {
val self = Recipient.self().fresh()
return OriginalState(
e164 = SignalStore.account.e164,
pni = SignalStore.account.pni,
pniRegistrationId = SignalStore.account.pniRegistrationId,
isSignedPreKeyRegistered = SignalStore.account.pniPreKeys.isSignedPreKeyRegistered,
activeSignedPreKeyId = SignalStore.account.pniPreKeys.activeSignedPreKeyId,
lastResortKyberPreKeyId = SignalStore.account.pniPreKeys.lastResortKyberPreKeyId,
pniIdentityPublicKey = SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString(),
selfE164 = self.e164.orNull(),
selfPni = self.pni.orNull(),
forcePniSignedPreKeyRotation = SignalStore.misc.forcePniSignedPreKeyRotation
)
}
private fun assertOriginalStatePreserved(original: OriginalState) {
assertThat(SignalStore.account.e164).isEqualTo(original.e164)
assertThat(SignalStore.account.pni).isEqualTo(original.pni)
assertThat(SignalStore.account.pniRegistrationId).isEqualTo(original.pniRegistrationId)
assertThat(SignalStore.account.pniPreKeys.isSignedPreKeyRegistered).isEqualTo(original.isSignedPreKeyRegistered)
assertThat(SignalStore.account.pniPreKeys.activeSignedPreKeyId).isEqualTo(original.activeSignedPreKeyId)
assertThat(SignalStore.account.pniPreKeys.lastResortKyberPreKeyId).isEqualTo(original.lastResortKyberPreKeyId)
assertThat(SignalStore.account.pniIdentityKey.publicKey.serialize().toByteString())
.isEqualTo(original.pniIdentityPublicKey)
assertThat(SignalStore.misc.forcePniSignedPreKeyRotation).isEqualTo(original.forcePniSignedPreKeyRotation)
val self = Recipient.self().fresh()
assertThat(self.e164.orNull()).isEqualTo(original.selfE164)
assertThat(self.pni.orNull()).isEqualTo(original.selfPni)
}
private data class OriginalState(
val e164: String?,
val pni: ServiceId.PNI?,
val pniRegistrationId: Int,
val isSignedPreKeyRegistered: Boolean,
val activeSignedPreKeyId: Int,
val lastResortKyberPreKeyId: Int,
val pniIdentityPublicKey: ByteString,
val selfE164: String?,
val selfPni: ServiceId.PNI?,
val forcePniSignedPreKeyRotation: Boolean
)
private fun malformedBytes(): ByteString = byteArrayOf(0x00, 0x01, 0x02).toByteString()
private fun sendPniChangeNumber(
identityKeyPair: ByteString? = newPniIdentity.serialize().toByteString(),
signedPreKey: ByteString? = newSignedPreKey.serialize().toByteString(),
lastResortKyberPreKey: ByteString? = newLastResortKyber.serialize().toByteString(),
registrationId: Int? = newRegistrationId,
e164: String? = newE164,
envelopePniBinary: ByteString? = newPniBytes,
sourceDeviceId: Int = SignalServiceAddress.DEFAULT_DEVICE_ID,
timestamp: Long = messageHelper.nextStartTime()
) {
val content = Content(
syncMessage = SyncMessage(
pniChangeNumber = SyncMessage.PniChangeNumber(
identityKeyPair = identityKeyPair,
signedPreKey = signedPreKey,
lastResortKyberPreKey = lastResortKyberPreKey,
registrationId = registrationId,
newE164 = e164
)
)
)
val envelope = MessageContentFuzzer.envelope(
timestamp = timestamp,
updatedPniBinary = envelopePniBinary
)
messageHelper.processor.process(
envelope = envelope,
content = content,
metadata = MessageContentFuzzer.envelopeMetadata(harness.self.id, harness.self.id, sourceDeviceId = sourceDeviceId),
serverDeliveredTimestamp = timestamp + 10
)
}
}
@@ -41,11 +41,12 @@ object MessageContentFuzzer {
/**
* Create an [Envelope].
*/
fun envelope(timestamp: Long, serverGuid: UUID = UUID.randomUUID()): Envelope {
fun envelope(timestamp: Long, serverGuid: UUID = UUID.randomUUID(), updatedPniBinary: ByteString? = null): Envelope {
return Envelope.Builder()
.clientTimestamp(timestamp)
.serverTimestamp(timestamp + 5)
.serverGuidBinary(serverGuid.toByteArray().toByteString())
.also { if (updatedPniBinary != null) it.updatedPniBinary(updatedPniBinary) }
.build()
}
@@ -103,20 +103,6 @@ class ChangeNumberRepository(
@WorkerThread
fun changeLocalNumber(e164: String, pni: ServiceId.PNI) {
SignalDatabase.recipients.updateSelfE164(e164, pni)
AppDependencies.recipientCache.clear()
if (e164 != SignalStore.account.requireE164()) {
SignalDatabase.recipients.rotateStorageId(Recipient.self().fresh().id)
StorageSyncHelper.scheduleSyncForDataChange()
}
SignalStore.account.setE164(e164)
SignalStore.account.setPni(pni)
AppDependencies.resetProtocolStores()
AppDependencies.groupsV2Authorization.clear()
val metadata: PendingChangeNumberMetadata? = SignalStore.misc.pendingChangeNumberMetadata
if (metadata == null) {
Log.w(TAG, "No change number metadata, this shouldn't happen")
@@ -125,25 +111,32 @@ class ChangeNumberRepository(
val pniIdentityKeyPair = IdentityKeyPair(metadata.pniIdentityKeyPair.toByteArray())
val pniRegistrationId = metadata.pniRegistrationId
val pniSignedPreyKeyId = metadata.pniSignedPreKeyId
val pniSignedPreKeyId = metadata.pniSignedPreKeyId
val pniLastResortKyberPreKeyId = metadata.pniLastResortKyberPreKeyId
// Prekeys were generated and stored during createChangeNumberRequest; reload them so we can pass them through and reuse for the upload below.
val preResetPniStore = AppDependencies.protocolStore.pni()
val signedPreKey = preResetPniStore.loadSignedPreKey(pniSignedPreKeyId)
val lastResortKyberPreKey = preResetPniStore.loadLastResortKyberPreKeys().firstOrNull { it.id == pniLastResortKyberPreKeyId }
applyLocalNumberChange(
e164 = e164,
pni = pni,
pniIdentityKeyPair = pniIdentityKeyPair,
pniSignedPreKey = signedPreKey,
pniLastResortKyberPreKey = lastResortKyberPreKey,
pniRegistrationId = pniRegistrationId
)
AppDependencies.resetNetwork()
AppDependencies.startNetwork()
val pniProtocolStore = AppDependencies.protocolStore.pni()
val pniMetadataStore = SignalStore.account.pniPreKeys
SignalStore.account.pniRegistrationId = pniRegistrationId
SignalStore.account.setPniIdentityKeyAfterChangeNumber(pniIdentityKeyPair)
val signedPreKey = pniProtocolStore.loadSignedPreKey(pniSignedPreyKeyId)
val oneTimeEcPreKeys = PreKeyUtil.generateAndStoreOneTimeEcPreKeys(pniProtocolStore, pniMetadataStore)
val lastResortKyberPreKey = pniProtocolStore.loadLastResortKyberPreKeys().firstOrNull { it.id == pniLastResortKyberPreKeyId }
val oneTimeKyberPreKeys = PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(pniProtocolStore, pniMetadataStore)
if (lastResortKyberPreKey == null) {
Log.w(TAG, "Last-resort kyber prekey is missing!")
}
pniMetadataStore.activeSignedPreKeyId = signedPreKey.id
Log.i(TAG, "Submitting prekeys with PNI identity key: ${pniIdentityKeyPair.publicKey.fingerprint}")
retryChangeLocalNumberNetworkOperation {
@@ -161,6 +154,61 @@ class ChangeNumberRepository(
pniMetadataStore.isSignedPreKeyRegistered = true
pniMetadataStore.lastResortKyberPreKeyId = pniLastResortKyberPreKeyId
SignalStore.misc.hasPniInitializedDevices = true
AppDependencies.jobManager.add(RefreshAttributesJob())
rotateCertificates()
SignalStore.misc.unlockChangeNumber()
}
/**
* Applies the local state for a successful number change: self recipient row, account values,
* PNI protocol store, and identity entry.
*
* Does NOT reset the network — callers must do so before any subsequent traffic that needs to
* use the new PNI. Does NOT make any server requests and does NOT flag prekeys as registered
* server-side — the caller is responsible for that once it can attest to server state.
*/
@WorkerThread
fun applyLocalNumberChange(
e164: String,
pni: ServiceId.PNI,
pniIdentityKeyPair: IdentityKeyPair,
pniSignedPreKey: SignedPreKeyRecord,
pniLastResortKyberPreKey: KyberPreKeyRecord?,
pniRegistrationId: Int
) {
SignalDatabase.recipients.updateSelfE164(e164, pni)
AppDependencies.recipientCache.clear()
if (e164 != SignalStore.account.requireE164()) {
SignalDatabase.recipients.rotateStorageId(Recipient.self().fresh().id)
StorageSyncHelper.scheduleSyncForDataChange()
}
SignalStore.account.setE164(e164)
SignalStore.account.setPni(pni)
AppDependencies.resetProtocolStores()
AppDependencies.groupsV2Authorization.clear()
val pniProtocolStore = AppDependencies.protocolStore.pni()
val pniMetadataStore = SignalStore.account.pniPreKeys
SignalStore.account.pniRegistrationId = pniRegistrationId
SignalStore.account.setPniIdentityKeyAfterChangeNumber(pniIdentityKeyPair)
PreKeyUtil.storeSignedPreKey(pniProtocolStore, pniMetadataStore, pniSignedPreKey)
pniMetadataStore.activeSignedPreKeyId = pniSignedPreKey.id
if (pniLastResortKyberPreKey != null) {
PreKeyUtil.storeLastResortKyberPreKey(pniProtocolStore, pniMetadataStore, pniLastResortKyberPreKey)
} else {
Log.w(TAG, "Last-resort kyber prekey is missing!")
}
pniProtocolStore.identities().saveIdentityWithoutSideEffects(
Recipient.self().id,
pni,
@@ -171,20 +219,8 @@ class ChangeNumberRepository(
true
)
SignalStore.misc.hasPniInitializedDevices = true
AppDependencies.groupsV2Authorization.clear()
Recipient.self().fresh()
StorageSyncHelper.scheduleSyncForDataChange()
AppDependencies.resetNetwork()
AppDependencies.startNetwork()
AppDependencies.jobManager.add(RefreshAttributesJob())
rotateCertificates()
SignalStore.misc.unlockChangeNumber()
}
@WorkerThread
@@ -125,6 +125,11 @@ class PreKeysSyncJob private constructor(
return
}
val pniRotationOverride = SignalStore.misc.forcePniSignedPreKeyRotation
if (pniRotationOverride) {
warn(TAG, ServiceIdType.PNI, "Forced PNI prekey rotation pending after PniChangeNumber sync. Bypassing dedup/interval gating for PNI.")
}
val forceRotation = if (forceRotationRequested) {
warn(TAG, "Forced rotation was requested.")
warn(TAG, ServiceIdType.ACI, "Active Signed EC: ${SignalStore.account.aciPreKeys.activeSignedPreKeyId}, Last Resort Kyber: ${SignalStore.account.aciPreKeys.lastResortKyberPreKeyId}")
@@ -146,19 +151,26 @@ class PreKeysSyncJob private constructor(
false
}
if (forceRotation) {
warn(TAG, "Forcing prekey rotation.")
val forcePniRotation = forceRotation || pniRotationOverride
if (forcePniRotation) {
warn(TAG, "Forcing prekey rotation. ACI=$forceRotation PNI=$forcePniRotation")
} else if (forceRotationRequested) {
warn(TAG, "Forced prekey rotation was requested, but we already did a forced refresh ${System.currentTimeMillis() - SignalStore.misc.lastForcedPreKeyRefresh} ms ago. Ignoring.")
}
syncPreKeys(ServiceIdType.ACI, SignalStore.account.aci, AppDependencies.protocolStore.aci(), SignalStore.account.aciPreKeys, forceRotation)
syncPreKeys(ServiceIdType.PNI, SignalStore.account.pni, AppDependencies.protocolStore.pni(), SignalStore.account.pniPreKeys, forceRotation)
syncPreKeys(ServiceIdType.PNI, SignalStore.account.pni, AppDependencies.protocolStore.pni(), SignalStore.account.pniPreKeys, forcePniRotation)
SignalStore.misc.lastFullPrekeyRefreshTime = System.currentTimeMillis()
if (forceRotation) {
if (forcePniRotation) {
SignalStore.misc.lastForcedPreKeyRefresh = System.currentTimeMillis()
}
if (pniRotationOverride) {
// Cleared only after both syncPreKeys calls completed without throwing; a thrown upload leaves the flag set for the next attempt.
SignalStore.misc.forcePniSignedPreKeyRotation = false
}
}
private fun syncPreKeys(serviceIdType: ServiceIdType, serviceId: ServiceId?, protocolStore: SignalServiceAccountDataStore, metadataStore: PreKeyMetadataStore, forceRotation: Boolean) {
@@ -32,6 +32,7 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto
private const val LAST_SERVER_TIME_OFFSET_UPDATE = "misc.last_server_time_offset_update"
private const val NEEDS_USERNAME_RESTORE = "misc.needs_username_restore"
private const val LAST_FORCED_PREKEY_REFRESH = "misc.last_forced_prekey_refresh"
private const val FORCE_PNI_SIGNED_PREKEY_ROTATION = "misc.force_pni_signed_prekey_rotation"
private const val LAST_CDS_FOREGROUND_SYNC = "misc.last_cds_foreground_sync"
private const val LINKED_DEVICE_LAST_ACTIVE_CHECK_TIME = "misc.linked_device.last_active_check_time"
private const val LEAST_ACTIVE_LINKED_DEVICE = "misc.linked_device.least_active"
@@ -51,6 +52,7 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto
private const val CAPTCHA_LAST_VIEWED_AT = "misc.captcha_last_viewed_at"
private const val CALLING_ASSETS_VERSION = "misc.calling_assets_version"
private const val LAST_SYNC_MESSAGE_SEEN_TIME_MS = "misc.last_sync_message_seen_time"
private const val LAST_APPLIED_PNI_CHANGE_SERVER_TIMESTAMP = "misc.last_applied_pni_change_server_timestamp"
}
public override fun onFirstEverAppLaunch() {
@@ -75,6 +77,17 @@ class MiscellaneousValues internal constructor(store: KeyValueStore) : SignalSto
*/
var lastForcedPreKeyRefresh by longValue(LAST_FORCED_PREKEY_REFRESH, 0)
/**
* Bypasses the timeout in [org.thoughtcrime.securesms.jobs.PreKeysSyncJob] since otherwise we can hit a race.
*/
var forcePniSignedPreKeyRotation by booleanValue(FORCE_PNI_SIGNED_PREKEY_ROTATION, false)
/**
* Envelope serverTimestamp of the most recently applied PniChangeNumber sync. Used to reject
* stale replays — a sync with serverTimestamp <= this value is treated as a replay and ignored.
*/
var lastAppliedPniChangeServerTimestamp by longValue(LAST_APPLIED_PNI_CHANGE_SERVER_TIMESTAMP, 0L)
/**
* The last time we completed a routine profile refresh.
*/
@@ -255,7 +255,10 @@ class IncomingMessageObserver(
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Open or Keep-alives: $websocketAlreadyOpen, Registered: $registered, Unauthorized: $unauthorizedReceived, Proxy: $hasProxy, Force websocket: $forceWebsocket")
Log.d(
TAG,
"[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Open or Keep-alives: $websocketAlreadyOpen, Registered: $registered, Unauthorized: $unauthorizedReceived, Proxy: $hasProxy, Force websocket: $forceWebsocket"
)
return conclusion
}
@@ -287,7 +290,7 @@ class IncomingMessageObserver(
}
@VisibleForTesting
fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List<FollowUpOperation>? {
fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): ProcessingResult? {
return when (envelope.type) {
Envelope.Type.SERVER_DELIVERY_RECEIPT -> {
processReceipt(envelope)
@@ -299,9 +302,9 @@ class IncomingMessageObserver(
Envelope.Type.UNIDENTIFIED_SENDER,
Envelope.Type.PLAINTEXT_CONTENT -> {
SignalTrace.beginSection("IncomingMessageObserver#processMessage")
val followUps = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp, batchCache)
val result = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp, batchCache)
SignalTrace.endSection()
followUps
result
}
else -> {
@@ -311,56 +314,79 @@ class IncomingMessageObserver(
}
}
private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List<FollowUpOperation> {
private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): ProcessingResult {
val localReceiveMetric = SignalLocalMetrics.MessageReceive.start()
SignalTrace.beginSection("IncomingMessageObserver#decryptMessage")
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp)
SignalTrace.endSection()
localReceiveMetric.onEnvelopeDecrypted()
var isNetworkResetRequired = false
SignalLocalMetrics.MessageLatency.onMessageReceived(envelope.serverTimestamp!!, serverDeliveredTimestamp, envelope.urgent!!)
when (result) {
is MessageDecryptor.Result.Success -> {
val job = PushProcessMessageJob.processOrDefer(messageContentProcessor, result, localReceiveMetric, batchCache)
isNetworkResetRequired = isNetworkResetRequired(result, bufferedProtocolStore.pni)
if (job != null) {
return result.followUpOperations + FollowUpOperation { job.asChain() }
}
}
is MessageDecryptor.Result.Error -> {
return result.followUpOperations + FollowUpOperation {
val jobs = mutableListOf<Job>()
if (result.errorMetadata.groupMasterKey != null) {
val groupId = result.errorMetadata.groupId!!
if (!SignalDatabase.groups.getGroup(groupId).isPresent) {
Log.w(TAG, "Decryption error in group, but group not found. Creating placeholder for groupId: $groupId")
SignalDatabase.groups.create(
groupMasterKey = result.errorMetadata.groupMasterKey!!,
groupState = DecryptedGroup(revision = GroupsV2StateProcessor.RESTORE_PLACEHOLDER_REVISION),
groupSendEndorsements = null
)
jobs += RequestGroupV2InfoJob(groupId)
}
}
jobs += PushProcessMessageErrorJob(
result.toMessageState(),
result.errorMetadata.toExceptionMetadata(),
result.envelope.clientTimestamp!!
return ProcessingResult(
followUpOperations = result.followUpOperations + FollowUpOperation { job.asChain() },
isNetworkResetRequired = isNetworkResetRequired
)
AppDependencies.jobManager.startChain(jobs)
}
}
is MessageDecryptor.Result.Error -> {
return ProcessingResult(
result.followUpOperations + FollowUpOperation {
val jobs = mutableListOf<Job>()
if (result.errorMetadata.groupMasterKey != null) {
val groupId = result.errorMetadata.groupId!!
if (!SignalDatabase.groups.getGroup(groupId).isPresent) {
Log.w(TAG, "Decryption error in group, but group not found. Creating placeholder for groupId: $groupId")
SignalDatabase.groups.create(
groupMasterKey = result.errorMetadata.groupMasterKey!!,
groupState = DecryptedGroup(revision = GroupsV2StateProcessor.RESTORE_PLACEHOLDER_REVISION),
groupSendEndorsements = null
)
jobs += RequestGroupV2InfoJob(groupId)
}
}
jobs += PushProcessMessageErrorJob(
result.toMessageState(),
result.errorMetadata.toExceptionMetadata(),
result.envelope.clientTimestamp!!
)
AppDependencies.jobManager.startChain(jobs)
}
)
}
is MessageDecryptor.Result.Ignore -> {
// No action needed
}
else -> {
throw AssertionError("Unexpected result! ${result.javaClass.simpleName}")
}
}
return result.followUpOperations
return ProcessingResult(
followUpOperations = result.followUpOperations,
isNetworkResetRequired = isNetworkResetRequired
)
}
/**
* True iff this envelope's PniChangeNumber sync actually changed our PNI within this batch.
* Comparing the batch-start PNI against the current value makes the check idempotent — a
* redelivered envelope finds the PNI already applied and won't re-trigger a websocket reset.
*/
private fun isNetworkResetRequired(result: MessageDecryptor.Result.Success, pniAtBatchStart: ServiceId.PNI): Boolean {
return result.content.syncMessage?.pniChangeNumber != null && SignalStore.account.pni != pniAtBatchStart
}
private fun processReceipt(envelope: Envelope) {
@@ -527,16 +553,26 @@ class IncomingMessageObserver(
val allFollowUpOperations = mutableListOf<FollowUpOperation>()
val bufferedStore = BufferedProtocolStore.create()
val batchCache = ReusedBatchCache()
var processedCount = 0
var networkResetRequired = false
val committed = SignalDatabase.tryRunInTransaction {
batch.forEach { response ->
for (response in batch) {
SignalTrace.beginSection("IncomingMessageObserver#perMessageTransaction")
val followUps = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache)
val result = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache)
bufferedStore.flushToDisk()
SignalTrace.endSection()
if (followUps?.isNotEmpty() == true) {
allFollowUpOperations += followUps
if (result?.followUpOperations?.isNotEmpty() == true) {
allFollowUpOperations += result.followUpOperations
}
processedCount++
if (result?.isNetworkResetRequired == true) {
networkResetRequired = true
Log.w(TAG, "Self identity changed mid-batch after envelope $processedCount of ${batch.size}. Committing what we have; the remainder will be redelivered to the new connection.")
break
}
}
}
@@ -550,8 +586,13 @@ class IncomingMessageObserver(
AppDependencies.jobManager.addAllChains(jobs)
}
batch.forEach { response ->
authWebSocket.sendAck(response)
for (i in 0 until processedCount) {
sendAckSafely(batch[i], i, batch.size)
}
if (networkResetRequired) {
AppDependencies.resetNetwork()
AppDependencies.startNetwork()
}
}
@@ -565,26 +606,46 @@ class IncomingMessageObserver(
val bufferedStore = BufferedProtocolStore.create()
val batchCache = ReusedBatchCache()
batch.forEach { response ->
for ((index, response) in batch.withIndex()) {
SignalTrace.beginSection("IncomingMessageObserver#perMessageTransaction")
val followUpOperations = SignalDatabase.runInTransaction {
val followUps = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache)
val results = SignalDatabase.runInTransaction {
val result = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache)
bufferedStore.flushToDisk()
followUps
result
}
SignalTrace.endSection()
if (followUpOperations?.isNotEmpty() == true) {
val jobs = followUpOperations.mapNotNull { it.run() }
if (results?.followUpOperations?.isNotEmpty() == true) {
val jobs = results.followUpOperations.mapNotNull { it.run() }
AppDependencies.jobManager.addAllChains(jobs)
}
authWebSocket.sendAck(response)
sendAckSafely(response, index, batch.size)
if (results?.isNetworkResetRequired == true) {
Log.w(TAG, "Self identity changed mid-batch after envelope ${index + 1} of ${batch.size}. Stopping individual processing; the remainder will be redelivered to the new connection.")
AppDependencies.resetNetwork()
AppDependencies.startNetwork()
break
}
}
batchCache.flushAndClear()
}
/**
* Best-effort ack. Failures just mean the server will redeliver — and for a redelivered
* PniChangeNumber sync, [isNetworkResetRequired] sees the PNI is already applied and won't
* re-trigger a reset, so we don't loop.
*/
private fun sendAckSafely(response: EnvelopeResponse, index: Int, size: Int) {
try {
authWebSocket.sendAck(response)
} catch (e: Exception) {
Log.w(TAG, "Failed to send ack for envelope $index of $size. The server will redeliver.", e)
}
}
override fun uncaughtException(t: Thread, e: Throwable) {
Log.w(TAG, "Uncaught exception in message thread!", e)
}
@@ -649,4 +710,9 @@ class IncomingMessageObserver(
}
}
}
data class ProcessingResult(
val followUpOperations: List<FollowUpOperation>,
val isNetworkResetRequired: Boolean = false
)
}
@@ -15,8 +15,12 @@ import org.signal.core.util.UuidUtil
import org.signal.core.util.isNotEmpty
import org.signal.core.util.orNull
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.InvalidKeyException
import org.signal.libsignal.protocol.ServiceId.Pni
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.ringrtc.CallException
import org.signal.ringrtc.CallId
import org.signal.ringrtc.CallLinkRootKey
@@ -24,6 +28,7 @@ import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
import org.thoughtcrime.securesms.attachments.TombstoneAttachment
import org.thoughtcrime.securesms.components.emoji.EmojiUtil
import org.thoughtcrime.securesms.components.settings.app.changenumber.ChangeNumberRepository
import org.thoughtcrime.securesms.contactshare.Contact
import org.thoughtcrime.securesms.database.AttachmentTable
import org.thoughtcrime.securesms.database.CallLinkTable
@@ -66,6 +71,7 @@ import org.thoughtcrime.securesms.jobs.MultiDeviceContactSyncJob
import org.thoughtcrime.securesms.jobs.MultiDeviceContactUpdateJob
import org.thoughtcrime.securesms.jobs.MultiDeviceKeysUpdateJob
import org.thoughtcrime.securesms.jobs.MultiDeviceStickerPackSyncJob
import org.thoughtcrime.securesms.jobs.PreKeysSyncJob
import org.thoughtcrime.securesms.jobs.PushProcessEarlyMessagesJob
import org.thoughtcrime.securesms.jobs.RefreshCallLinkDetailsJob
import org.thoughtcrime.securesms.jobs.RefreshDonationSubscriptionStatusJob
@@ -175,6 +181,7 @@ object SyncMessageProcessor {
syncMessage.outgoingPayment != null -> handleSynchronizeOutgoingPayment(syncMessage.outgoingPayment!!, envelope.clientTimestamp!!)
syncMessage.contacts != null -> handleSynchronizeContacts(syncMessage.contacts!!, envelope.clientTimestamp!!)
syncMessage.keys != null -> handleSynchronizeKeys(syncMessage.keys!!, envelope.clientTimestamp!!)
syncMessage.pniChangeNumber != null -> handleSynchronizePniChangeNumber(envelope, metadata, syncMessage.pniChangeNumber!!)
syncMessage.callEvent != null -> handleSynchronizeCallEvent(syncMessage.callEvent!!, envelope.clientTimestamp!!)
syncMessage.callLinkUpdate != null -> handleSynchronizeCallLink(syncMessage.callLinkUpdate!!, envelope.clientTimestamp!!)
syncMessage.callLogEvent != null -> handleSynchronizeCallLogEvent(syncMessage.callLogEvent!!, envelope.clientTimestamp!!)
@@ -1750,6 +1757,99 @@ object SyncMessageProcessor {
MultiDeviceAttachmentBackfillUpdateJob.enqueue(request.targetMessage!!, request.targetConversation!!, messageId)
}
private fun handleSynchronizePniChangeNumber(envelope: Envelope, metadata: EnvelopeMetadata, pniChangeNumber: SyncMessage.PniChangeNumber) {
val timestamp = envelope.clientTimestamp!!
if (SignalStore.account.isPrimaryDevice) {
warn(timestamp, "Received a PniChangeNumber sync message on the primary device. Bailing.")
return
}
if (metadata.sourceDeviceId != SignalServiceAddress.DEFAULT_DEVICE_ID) {
warn(timestamp, "Received a PniChangeNumber sync message from a non-primary device (${metadata.sourceDeviceId}). Bailing.")
return
}
if (SignalStore.account.aci == null) {
warn(timestamp, "Received a PniChangeNumber sync message but no local ACI is set. Bailing.")
return
}
val envelopeServerTimestamp = envelope.serverTimestamp ?: 0L
val lastAppliedServerTimestamp = SignalStore.misc.lastAppliedPniChangeServerTimestamp
if (envelopeServerTimestamp <= lastAppliedServerTimestamp) {
warn(timestamp, "PniChangeNumber sync serverTimestamp ($envelopeServerTimestamp) is not newer than the last applied ($lastAppliedServerTimestamp). Treating as a replay and bailing.")
return
}
// updatedPniBinary is a raw 16-byte UUID per the proto contract instead of a 17-byte service-id array.
val pni = if (envelope.updatedPniBinary != null) {
val updatedPniUuid = UuidUtil.parseOrNull(envelope.updatedPniBinary!!.toByteArray())
if (updatedPniUuid == null) {
warn(timestamp, "Could not parse updatedPniBinary as a UUID. Bailing.")
return
}
Pni(updatedPniUuid)
} else if (envelope.updatedPni != null) {
Pni.parseFromString(envelope.updatedPni)
} else {
warn(timestamp, "Neither updatedPni or updatedPniBinary were present on the envelope. Bailing.")
return
}
if (SignalStore.account.pni == PNI(pni)) {
log(timestamp, "PniChangeNumber sync already applied locally. Skipping.")
return
}
val identityKeyPairBytes = pniChangeNumber.identityKeyPair
val signedPreKeyBytes = pniChangeNumber.signedPreKey
val registrationId = pniChangeNumber.registrationId
val newE164 = pniChangeNumber.newE164
if (identityKeyPairBytes == null || signedPreKeyBytes == null || registrationId == null || registrationId <= 0 || newE164.isNullOrEmpty() || !SignalE164Util.isPotentialE164(newE164)) {
warn(timestamp, "PniChangeNumber sync message is missing or has an invalid required field. Bailing.")
return
}
val pniIdentityKeyPair: IdentityKeyPair
val pniSignedPreKey: SignedPreKeyRecord
val pniLastResortKyberPreKey: KyberPreKeyRecord?
try {
pniIdentityKeyPair = IdentityKeyPair(identityKeyPairBytes.toByteArray())
pniSignedPreKey = SignedPreKeyRecord(signedPreKeyBytes.toByteArray())
pniLastResortKyberPreKey = pniChangeNumber.lastResortKyberPreKey?.let { KyberPreKeyRecord(it.toByteArray()) }
} catch (e: Exception) {
warn(timestamp, "Failed to deserialize PniChangeNumber sync message. Bailing.", e)
return
}
log(timestamp, "Applying PniChangeNumber sync message.")
ChangeNumberRepository().applyLocalNumberChange(
e164 = newE164,
pni = PNI(pni),
pniIdentityKeyPair = pniIdentityKeyPair,
pniSignedPreKey = pniSignedPreKey,
pniLastResortKyberPreKey = pniLastResortKyberPreKey,
pniRegistrationId = registrationId
)
SignalStore.misc.lastAppliedPniChangeServerTimestamp = envelopeServerTimestamp
// The primary already submitted these per-device prekeys to the server as part of the
// change-number request, so they are registered server-side from this device's perspective.
val pniMetadataStore = SignalStore.account.pniPreKeys
pniMetadataStore.isSignedPreKeyRegistered = true
if (pniLastResortKyberPreKey != null) {
pniMetadataStore.lastResortKyberPreKeyId = pniLastResortKyberPreKey.id
}
// Rotate the primary-generated keys as soon as possible so we don't rely on them long-term.
SignalStore.misc.forcePniSignedPreKeyRotation = true
AppDependencies.jobManager.add(PreKeysSyncJob.create(forceRotationRequested = true))
}
private fun handleSynchronizedPollCreate(
envelope: Envelope,
message: DataMessage,
@@ -1,6 +1,7 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.core.models.ServiceId
import org.signal.core.models.ServiceId.PNI
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
@@ -13,9 +14,12 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
*/
class BufferedProtocolStore private constructor(
private val aciStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>,
private val pniStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>
private val pniStore: Pair<PNI, BufferedSignalServiceAccountDataStore>
) {
/** The PNI captured when this batch's store was created. Does not refresh if [SignalStore.account.pni] later changes mid-batch. */
val pni: PNI get() = pniStore.first
fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore {
return when (serviceId) {
aciStore.first -> aciStore.second
@@ -0,0 +1,179 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.jobs
import android.app.Application
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkObject
import io.mockk.mockkStatic
import io.mockk.unmockkObject
import io.mockk.unmockkStatic
import io.mockk.verify
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.annotation.Config
import org.signal.core.models.ServiceId
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.network.NetworkResult
import org.thoughtcrime.securesms.crypto.PreKeyUtil
import org.thoughtcrime.securesms.crypto.storage.PreKeyMetadataStore
import org.thoughtcrime.securesms.crypto.storage.SignalServiceAccountDataStoreImpl
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.MiscellaneousValues
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.testutil.MockAppDependenciesRule
import org.thoughtcrime.securesms.testutil.MockSignalStoreRule
import org.thoughtcrime.securesms.util.RemoteConfig
import org.whispersystems.signalservice.api.keys.OneTimePreKeyCounts
import java.io.IOException
import java.util.UUID
import kotlin.time.Duration.Companion.hours
import kotlin.time.Duration.Companion.minutes
@RunWith(RobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class PreKeysSyncJobTest {
@get:Rule
val mockSignalStore = MockSignalStoreRule()
@get:Rule
val appDependencies = MockAppDependenciesRule()
private val misc: MiscellaneousValues = mockk(relaxUnitFun = true)
private val aciMetadataStore: PreKeyMetadataStore = mockk(relaxUnitFun = true)
private val pniMetadataStore: PreKeyMetadataStore = mockk(relaxUnitFun = true)
private val aciProtocolStore: SignalServiceAccountDataStoreImpl = mockk(relaxed = true)
private val pniProtocolStore: SignalServiceAccountDataStoreImpl = mockk(relaxed = true)
@Before
fun setUp() {
every { SignalStore.misc } returns misc
every { mockSignalStore.account.isRegistered } returns true
every { mockSignalStore.account.aci } returns ServiceId.ACI.from(UUID.randomUUID())
every { mockSignalStore.account.pni } returns ServiceId.PNI.from(UUID.randomUUID())
every { mockSignalStore.account.aciPreKeys } returns aciMetadataStore
every { mockSignalStore.account.pniPreKeys } returns pniMetadataStore
// Default metadata: everything is fresh and registered, so absent a force, no rotation triggers.
listOf(aciMetadataStore, pniMetadataStore).forEach {
every { it.isSignedPreKeyRegistered } returns true
every { it.activeSignedPreKeyId } returns 1
every { it.lastResortKyberPreKeyId } returns 1
every { it.lastSignedPreKeyRotationTime } returns System.currentTimeMillis()
every { it.lastResortKyberPreKeyRotationTime } returns System.currentTimeMillis()
}
every { misc.lastForcedPreKeyRefresh } returns 0L
every { misc.forcePniSignedPreKeyRotation } returns false
// `AppDependencies.protocolStore` / `keysApi` are already relaxed mockks set up by
// MockAppDependenciesRule; configure the chained calls we care about.
every { AppDependencies.protocolStore.aci() } returns aciProtocolStore
every { AppDependencies.protocolStore.pni() } returns pniProtocolStore
val identityKeyPair = IdentityKeyPair.generate()
every { aciProtocolStore.identityKeyPair } returns identityKeyPair
every { pniProtocolStore.identityKeyPair } returns identityKeyPair
// Counts well above ONE_TIME_PREKEY_MINIMUM (10) so we don't generate one-time keys unless forced.
every { AppDependencies.keysApi.getAvailablePreKeyCountsSync(any()) } returns NetworkResult.Success(OneTimePreKeyCounts(100, 100))
every { AppDependencies.keysApi.setPreKeysSync(any()) } returns NetworkResult.Success(Unit)
// Consistency check (only reached when forceRotationRequested=true) returns "everything matches".
every { AppDependencies.keysApi.checkRepeatedUseKeysSync(any(), any(), any(), any(), any(), any()) } returns NetworkResult.Success(Unit)
mockkObject(RemoteConfig)
every { RemoteConfig.preKeyForceRefreshInterval } returns 1.hours.inWholeMilliseconds
// Used by BaseJob's retry-backoff path when a syncPreKeys call throws a retryable IOException.
every { RemoteConfig.defaultMaxBackoff } returns 1.hours.inWholeMilliseconds
mockkStatic(PreKeyUtil::class)
every { PreKeyUtil.generateAndStoreSignedPreKey(any(), any()) } answers { fakeSignedPreKey() }
every { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(any(), any()) } returns emptyList()
every { PreKeyUtil.generateAndStoreLastResortKyberPreKey(any(), any()) } answers { fakeKyberPreKey() }
every { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(any(), any()) } returns emptyList()
every { PreKeyUtil.cleanSignedPreKeys(any(), any()) } returns Unit
every { PreKeyUtil.cleanLastResortKyberPreKeys(any(), any()) } returns Unit
every { PreKeyUtil.cleanOneTimePreKeys(any()) } returns Unit
}
@After
fun tearDown() {
unmockkObject(RemoteConfig)
unmockkStatic(PreKeyUtil::class)
}
@Test
fun `when forcePniSignedPreKeyRotation flag set, PNI sync runs forced and ACI does not`() {
every { misc.forcePniSignedPreKeyRotation } returns true
PreKeysSyncJob.create(forceRotationRequested = false).run()
verify(exactly = 1) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) }
verify(exactly = 1) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(pniProtocolStore, pniMetadataStore) }
verify(exactly = 1) { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(pniProtocolStore, pniMetadataStore) }
verify(exactly = 1) { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(pniProtocolStore, pniMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(aciProtocolStore, aciMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreOneTimeEcPreKeys(aciProtocolStore, aciMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreOneTimeKyberPreKeys(aciProtocolStore, aciMetadataStore) }
verify(exactly = 1) { misc.forcePniSignedPreKeyRotation = false }
}
@Test
fun `when forcePniSignedPreKeyRotation flag set but uploads fail, flag is preserved`() {
every { misc.forcePniSignedPreKeyRotation } returns true
// Fail PNI's upload so syncPreKeys throws before the flag-clear runs.
every { AppDependencies.keysApi.setPreKeysSync(any()) } returns NetworkResult.NetworkError(IOException("simulated"))
PreKeysSyncJob.create(forceRotationRequested = false).run()
verify(exactly = 0) { misc.forcePniSignedPreKeyRotation = false }
}
@Test
fun `when flag not set, no PNI force and flag write is skipped`() {
every { misc.forcePniSignedPreKeyRotation } returns false
PreKeysSyncJob.create(forceRotationRequested = false).run()
verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) }
verify(exactly = 0) { misc.forcePniSignedPreKeyRotation = false }
}
@Test
fun `flag set forces PNI rotation even when consistency check passes and time gate would skip`() {
every { misc.forcePniSignedPreKeyRotation } returns true
// forceRotationRequested=true + consistency checks pass + recent forced refresh (well within
// preKeyForceRefreshInterval=1h) → without the flag, the existing logic would skip rotation.
every { misc.lastForcedPreKeyRefresh } returns System.currentTimeMillis() - 1.minutes.inWholeMilliseconds
PreKeysSyncJob.create(forceRotationRequested = true).run()
verify(exactly = 1) { PreKeyUtil.generateAndStoreSignedPreKey(pniProtocolStore, pniMetadataStore) }
verify(exactly = 1) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(pniProtocolStore, pniMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreSignedPreKey(aciProtocolStore, aciMetadataStore) }
verify(exactly = 0) { PreKeyUtil.generateAndStoreLastResortKyberPreKey(aciProtocolStore, aciMetadataStore) }
verify(exactly = 1) { misc.forcePniSignedPreKeyRotation = false }
}
private fun fakeSignedPreKey(): SignedPreKeyRecord = mockk(relaxed = true) {
every { id } returns 42
}
private fun fakeKyberPreKey(): KyberPreKeyRecord = mockk(relaxed = true) {
every { id } returns 42
}
}