diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt index ac60fbc89b..64e805aeb7 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt @@ -147,7 +147,7 @@ class ChangeNumberRepository( Log.i(TAG, "Submitting prekeys with PNI identity key: ${pniIdentityKeyPair.publicKey.fingerprint}") retryChangeLocalNumberNetworkOperation { - SignalNetwork.keys.setPreKeys( + SignalNetwork.keys.setPreKeysSync( PreKeyUpload( serviceIdType = ServiceIdType.PNI, signedPreKey = signedPreKey, diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt index a4f4666f30..3e83f80e89 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt @@ -28,6 +28,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.thoughtcrime.securesms.MainActivity import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.UriAttachment +import org.thoughtcrime.securesms.components.SignalProgressDialog import org.thoughtcrime.securesms.database.AttachmentTable import org.thoughtcrime.securesms.database.MessageType import org.thoughtcrime.securesms.database.SignalDatabase @@ -297,15 +298,38 @@ class InternalConversationSettingsFragment : ComposeFragment(), InternalConversa } override fun clearSenderKeyAndArchiveSessions(recipientId: RecipientId) { - clearSenderKey(recipientId) + lifecycleScope.launch { + val dialog = withContext(Dispatchers.Main) { + SignalProgressDialog.show(requireContext(), "Clearing...", cancelable = false, indeterminate = true) + } - val group = SignalDatabase.groups.getGroup(recipientId).orNull() - if (group == null) { - Log.w(TAG, "Couldn't find group for recipientId: $recipientId") - return + withContext(Dispatchers.Default) { + clearSenderKey(recipientId) + + val group = SignalDatabase.groups.getGroup(recipientId).orNull() + if (group == null) { + Log.w(TAG, "Couldn't find group for recipientId: $recipientId") + return@withContext + } + + group.members.forEach { memberId -> + archiveSessions(memberId) + + val member = Recipient.resolved(memberId) + if (member.hasAci) { + AppDependencies.protocolStore.aci().identities().delete(member.requireAci().toString()) + } + + if (member.hasPni) { + AppDependencies.protocolStore.aci().identities().delete(member.requirePni().toString()) + } + } + } + + withContext(Dispatchers.Main) { + dialog.dismiss() + } } - - group.members.forEach { archiveSessions(it) } } class InternalViewModel( diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsScreen.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsScreen.kt index de96e727c6..70701d333f 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsScreen.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsScreen.kt @@ -212,7 +212,7 @@ fun InternalConversationSettingsScreen( item { Rows.TextRow( text = "Clear sender key and archive sessions", - label = "Resets any sender key state and archives all sessions for group members, will force creating new sessions and re-distributing sender key material.", + label = "Resets any sender key state, archives all sessions, and removes identity keys for group members, will force creating new sessions and re-distributing sender key material.", onClick = { dialog = Dialog.CLEAR_SENDER_KEY_AND_ARCHIVE_SESSIONS } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/IdentityTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/IdentityTable.kt index fd14864010..7217f0b8e4 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/IdentityTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/IdentityTable.kt @@ -49,6 +49,16 @@ class IdentityTable internal constructor(context: Context?, databaseHelper: Sign companion object { private val TAG = Log.tag(IdentityTable::class.java) + + /** + * When set, [saveIdentity] will skip its per-recipient `markNeedsSync` + `scheduleSyncForDataChange` + * side effects and instead deposit the affected [RecipientId] into the set. The caller is then + * responsible for performing a single bulk follow-up (storage-id rotation, cache invalidate, + * storage-sync schedule). + */ + @JvmField + val SUPPRESS_RECIPIENT_REFRESH: ThreadLocal> = ThreadLocal() + const val TABLE_NAME = "identities" private const val ID = "_id" const val ADDRESS = "address" @@ -125,8 +135,14 @@ class IdentityTable internal constructor(context: Context?, databaseHelper: Sign nonBlockingApproval: Boolean ) { saveIdentityInternal(addressName, recipientId, identityKey, verifiedStatus, firstUse, timestamp, nonBlockingApproval) - recipients.markNeedsSync(recipientId) - StorageSyncHelper.scheduleSyncForDataChange() + + val deferred = SUPPRESS_RECIPIENT_REFRESH.get() + if (deferred != null) { + deferred += recipientId + } else { + recipients.markNeedsSync(recipientId) + StorageSyncHelper.scheduleSyncForDataChange() + } } fun setApproval(addressName: String, recipientId: RecipientId, nonBlockingApproval: Boolean) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 430084bad6..4347d85690 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -10,13 +10,27 @@ import androidx.annotation.VisibleForTesting; import org.jetbrains.annotations.NotNull; import org.signal.billing.BillingFactory; +import org.signal.core.models.ServiceId.ACI; +import org.signal.core.models.ServiceId.PNI; import org.signal.core.util.ThreadUtil; import org.signal.core.util.billing.BillingApi; import org.signal.core.util.concurrent.DeadlockDetector; import org.signal.core.util.concurrent.SignalExecutors; import org.signal.libsignal.net.Network; +import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; +import org.signal.network.api.ArchiveApi; +import org.signal.network.api.CallingApi; +import org.signal.network.api.CdsApi; +import org.signal.network.api.CertificateApi; +import org.signal.network.api.LinkDeviceApi; +import org.signal.network.api.PaymentsApi; +import org.signal.network.api.ProvisioningApi; +import org.signal.network.api.RateLimitChallengeApi; +import org.signal.network.api.RemoteConfigApi; +import org.signal.network.api.SvrBApi; +import org.signal.network.api.UsernameApi; import org.thoughtcrime.securesms.BuildConfig; import org.thoughtcrime.securesms.components.TypingStatusRepository; import org.thoughtcrime.securesms.components.TypingStatusSender; @@ -79,6 +93,7 @@ import org.thoughtcrime.securesms.util.ByteUnit; import org.thoughtcrime.securesms.util.EarlyMessageCache; import org.thoughtcrime.securesms.util.Environment; import org.thoughtcrime.securesms.util.FrameRateTracker; +import org.thoughtcrime.securesms.util.PreKeyBatcher; import org.thoughtcrime.securesms.util.RemoteConfig; import org.thoughtcrime.securesms.util.TextSecurePreferences; import org.thoughtcrime.securesms.video.exo.GiphyMp4Cache; @@ -89,30 +104,18 @@ import org.whispersystems.signalservice.api.SignalServiceDataStore; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; import org.whispersystems.signalservice.api.SignalServiceMessageSender; import org.whispersystems.signalservice.api.account.AccountApi; -import org.signal.network.api.ArchiveApi; import org.whispersystems.signalservice.api.attachment.AttachmentApi; -import org.signal.network.api.CallingApi; -import org.signal.network.api.CdsApi; -import org.signal.network.api.CertificateApi; import org.whispersystems.signalservice.api.donations.DonationsApi; import org.whispersystems.signalservice.api.groupsv2.ClientZkOperations; import org.whispersystems.signalservice.api.groupsv2.GroupsV2Operations; import org.whispersystems.signalservice.api.keys.KeysApi; -import org.signal.network.api.LinkDeviceApi; +import org.whispersystems.signalservice.api.keys.PreKeyRepository; import org.whispersystems.signalservice.api.message.MessageApi; -import org.signal.network.api.PaymentsApi; import org.whispersystems.signalservice.api.profiles.ProfileApi; -import org.signal.network.api.ProvisioningApi; -import org.signal.core.models.ServiceId.ACI; -import org.signal.core.models.ServiceId.PNI; -import org.signal.network.api.RateLimitChallengeApi; import org.whispersystems.signalservice.api.registration.RegistrationApi; -import org.signal.network.api.RemoteConfigApi; import org.whispersystems.signalservice.api.services.DonationsService; import org.whispersystems.signalservice.api.services.ProfileService; import org.whispersystems.signalservice.api.storage.StorageServiceApi; -import org.signal.network.api.SvrBApi; -import org.signal.network.api.UsernameApi; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.UptimeSleepTimer; @@ -179,7 +182,15 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { RemoteConfig.maxIncrementalMacsPerEnvelope(), RemoteConfig::useMessageSendRestFallback, RemoteConfig.useBinaryId(), - BuildConfig.USE_STRING_ID); + BuildConfig.USE_STRING_ID, + new PreKeyRepository( + keysApi, + protocolStore.aci(), + new SignalProtocolAddress(pushServiceSocket.getCredentialsProvider().getAci().getLibSignalServiceId(), + pushServiceSocket.getCredentialsProvider().getDeviceId()), + PreKeyBatcher.INSTANCE + ) + ); } @Override diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt index 763a8dca72..96b811216a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PreKeysSyncJob.kt @@ -167,7 +167,7 @@ class PreKeysSyncJob private constructor( return } - val availablePreKeyCounts = SignalNetwork.keys.getAvailablePreKeyCounts(serviceIdType).successOrThrow() + val availablePreKeyCounts = SignalNetwork.keys.getAvailablePreKeyCountsSync(serviceIdType).successOrThrow() val signedPreKeyToUpload: SignedPreKeyRecord? = signedPreKeyUploadIfNeeded(serviceIdType, protocolStore, metadataStore, forceRotation) @@ -191,7 +191,7 @@ class PreKeysSyncJob private constructor( if (signedPreKeyToUpload != null || oneTimeEcPreKeysToUpload != null || lastResortKyberPreKeyToUpload != null || oneTimeKyberPreKeysToUpload != null) { log(serviceIdType, "Something to upload. SignedPreKey: ${signedPreKeyToUpload != null}, OneTimeEcPreKeys: ${oneTimeEcPreKeysToUpload != null}, LastResortKyberPreKey: ${lastResortKyberPreKeyToUpload != null}, OneTimeKyberPreKeys: ${oneTimeKyberPreKeysToUpload != null}") - SignalNetwork.keys.setPreKeys( + SignalNetwork.keys.setPreKeysSync( PreKeyUpload( serviceIdType = serviceIdType, signedPreKey = signedPreKeyToUpload, @@ -260,7 +260,7 @@ class PreKeysSyncJob private constructor( @Throws(IOException::class) private fun checkPreKeyConsistency(serviceIdType: ServiceIdType, protocolStore: SignalServiceAccountDataStore, metadataStore: PreKeyMetadataStore): Boolean { val result: NetworkResult = try { - SignalNetwork.keys.checkRepeatedUseKeys( + SignalNetwork.keys.checkRepeatedUseKeysSync( serviceIdType = serviceIdType, identityKey = protocolStore.identityKeyPair.publicKey, signedPreKeyId = metadataStore.activeSignedPreKeyId, diff --git a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java index 7cba7626dc..04f255e2e6 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/migrations/PniAccountInitializationMigrationJob.java @@ -84,7 +84,7 @@ public class PniAccountInitializationMigrationJob extends MigrationJob { SignedPreKeyRecord signedPreKey = PreKeyUtil.generateAndStoreSignedPreKey(protocolStore, metadataStore); List oneTimePreKeys = PreKeyUtil.generateAndStoreOneTimeEcPreKeys(protocolStore, metadataStore); - NetworkResultUtil.toPreKeysLegacy(SignalNetwork.keys().setPreKeys(new PreKeyUpload(ServiceIdType.PNI, signedPreKey, oneTimePreKeys, null, null))); + NetworkResultUtil.toPreKeysLegacy(SignalNetwork.keys().setPreKeysSync(new PreKeyUpload(ServiceIdType.PNI, signedPreKey, oneTimePreKeys, null, null))); metadataStore.setActiveSignedPreKeyId(signedPreKey.getId()); metadataStore.setSignedPreKeyRegistered(true); } else { diff --git a/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipient.java b/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipient.java index 822ab5ced8..e06d0485c3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipient.java +++ b/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipient.java @@ -15,6 +15,7 @@ import org.thoughtcrime.securesms.database.SignalDatabase; import org.thoughtcrime.securesms.database.model.RecipientRecord; import org.thoughtcrime.securesms.util.livedata.LiveDataUtil; +import java.util.List; import java.util.Objects; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; diff --git a/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipientCache.java b/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipientCache.java index e328f35fcf..ea2f5c5e4b 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipientCache.java +++ b/app/src/main/java/org/thoughtcrime/securesms/recipients/LiveRecipientCache.java @@ -7,7 +7,9 @@ import androidx.annotation.AnyThread; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; +import androidx.annotation.WorkerThread; +import org.jetbrains.annotations.NotNull; import org.signal.core.util.ThreadUtil; import org.signal.core.util.concurrent.SignalExecutors; import org.signal.core.util.logging.Log; @@ -15,6 +17,7 @@ import org.thoughtcrime.securesms.database.RecipientTable; import org.thoughtcrime.securesms.database.RecipientTable.MissingRecipientException; import org.thoughtcrime.securesms.database.SignalDatabase; import org.thoughtcrime.securesms.database.ThreadTable; +import org.thoughtcrime.securesms.database.model.RecipientRecord; import org.thoughtcrime.securesms.database.model.ThreadRecord; import org.thoughtcrime.securesms.keyvalue.SignalStore; import org.signal.core.util.CursorUtil; @@ -27,9 +30,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; public final class LiveRecipientCache { @@ -102,6 +107,29 @@ public final class LiveRecipientCache { } } + /** + * Resolves and updates entries for each recipient already in the cache. + */ + @WorkerThread + public void refresh(@NonNull Collection recipientIds) { + Set cachedIds; + synchronized (recipients) { + cachedIds = recipientIds.stream().filter(recipients::containsKey).collect(Collectors.toSet()); + } + + if (!cachedIds.isEmpty()) { + Set recipients = SignalDatabase + .recipients() + .getExistingRecords(cachedIds) + .values() + .stream() + .map(record -> RecipientCreator.forRecord(context, record)) + .collect(Collectors.toSet()); + + addToCache(recipients); + } + } + /** * Adds a recipient to the cache if we don't have an entry. This will also update a cache entry * if the provided recipient is resolved, or if the existing cache entry is unresolved. diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/PreKeyBatcher.kt b/app/src/main/java/org/thoughtcrime/securesms/util/PreKeyBatcher.kt new file mode 100644 index 0000000000..9f7380714a --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/util/PreKeyBatcher.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.util + +import org.thoughtcrime.securesms.database.IdentityTable +import org.thoughtcrime.securesms.database.SignalDatabase +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.recipients.RecipientId +import org.thoughtcrime.securesms.storage.StorageSyncHelper.scheduleSyncForDataChange +import org.whispersystems.signalservice.api.keys.PreKeyRepository + +/** + * Helper to batch recipient updates and storage sync when doing a large prekey fetch. + * + * See [PreKeyRepository.BatchHelper] for additional details. + */ +object PreKeyBatcher : PreKeyRepository.BatchHelper { + + override fun batch(block: Runnable) { + val affected: MutableSet = HashSet() + + try { + IdentityTable.SUPPRESS_RECIPIENT_REFRESH.set(affected) + block.run() + if (!affected.isEmpty()) { + SignalDatabase.recipients.markNeedsSyncWithoutRefresh(affected) + } + } finally { + IdentityTable.SUPPRESS_RECIPIENT_REFRESH.remove() + } + + if (!affected.isEmpty()) { + AppDependencies.recipientCache.refresh(affected) + scheduleSyncForDataChange() + } + } +} diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 0e347cb101..6abea2c9cc 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -46,6 +46,7 @@ import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; import org.whispersystems.signalservice.api.crypto.UntrustedIdentityException; import org.whispersystems.signalservice.api.groupsv2.GroupSendEndorsements; import org.whispersystems.signalservice.api.keys.KeysApi; +import org.whispersystems.signalservice.api.keys.PreKeyRepository; import org.whispersystems.signalservice.api.message.MessageApi; import org.whispersystems.signalservice.api.message.MessageApiKt; import org.whispersystems.signalservice.api.messages.SendMessageResult; @@ -183,9 +184,10 @@ public class SignalServiceMessageSender { private final Optional eventListener; private final IdentityKeyPair localPniIdentity; - private final AttachmentApi attachmentApi; - private final MessageApi messageApi; - private final KeysApi keysApi; + private final AttachmentApi attachmentApi; + private final MessageApi messageApi; + private final KeysApi keysApi; + private final PreKeyRepository preKeyRepository; private final Scheduler scheduler; private final long maxEnvelopeSize; @@ -206,7 +208,8 @@ public class SignalServiceMessageSender { int maxIncrementalMacsPerEnvelope, BooleanSupplier useRestFallback, boolean useBinaryId, - boolean useStringId) + boolean useStringId, + PreKeyRepository preKeyRepository) { CredentialsProvider credentialsProvider = pushServiceSocket.getCredentialsProvider(); @@ -225,6 +228,7 @@ public class SignalServiceMessageSender { this.localPniIdentity = store.pni().getIdentityKeyPair(); this.scheduler = Schedulers.from(executor, false, false); this.keysApi = keysApi; + this.preKeyRepository = preKeyRepository; this.useRestFallback = useRestFallback; this.useBinaryId = useBinaryId; this.useStringId = useStringId; @@ -2143,7 +2147,16 @@ public class SignalServiceMessageSender { long startTime = System.currentTimeMillis(); - eagerlyFetchMissingPreKeys(recipients, sealedSenderAccesses, story); + List eagerRequests = new ArrayList<>(recipients.size()); + for (int i = 0; i < recipients.size(); i++) { + eagerRequests.add(new PreKeyRepository.EagerPreKeyRequest(recipients.get(i), sealedSenderAccesses.get(i), story)); + } + preKeyRepository.eagerlyFetchMissingPreKeys(eagerRequests, recipient -> { + if (eventListener.isPresent()) { + eventListener.get().onSecurityEvent(recipient); + } + return kotlin.Unit.INSTANCE; + }); List> singleResults = new LinkedList<>(); Iterator recipientIterator = recipients.iterator(); @@ -2912,81 +2925,6 @@ public class SignalServiceMessageSender { } } - private void eagerlyFetchMissingPreKeys(List recipients, List sealedSenderAccesses, boolean story) { - long start = System.currentTimeMillis(); - - Iterator recipientIterator = recipients.iterator(); - Iterator sealedSenderAccessIterator = sealedSenderAccesses.iterator(); - List> eagerFetches = new LinkedList<>(); - - while (recipientIterator.hasNext()) { - SignalServiceAddress recipient = recipientIterator.next(); - SealedSenderAccess sealedSenderAccess = sealedSenderAccessIterator.next(); - SignalProtocolAddress signalProtocolAddress = new SignalProtocolAddress(recipient.getIdentifier(), SignalServiceAddress.DEFAULT_DEVICE_ID); - - if (!aciStore.containsSession(signalProtocolAddress)) { - Observable thing = Single.fromCallable(() -> { - eagerlyFetchMissingPreKeys(recipient, sealedSenderAccess, story); - return true; - }) - .subscribeOn(scheduler) - .toObservable(); - - eagerFetches.add(thing); - } - } - - if (eagerFetches.isEmpty()) { - return; - } - - Log.i(TAG, "[eagerPrefetch] Attempting to fetch prekeys for " + eagerFetches.size() + " recipients"); - - try { - //noinspection ResultOfMethodCallIgnored - Observable.mergeDelayError(eagerFetches, Integer.MAX_VALUE, 1) - .observeOn(scheduler) - .lastOrError() - .blockingGet(); - } catch (RuntimeException e) { - Log.w(TAG, "[eagerPrefetch] Unexpectedly failed eager fetching prekeys", e); - return; - } - - Log.i(TAG, "[eagerPrefetch] Completed in " + (System.currentTimeMillis() - start) + "ms"); - } - - private void eagerlyFetchMissingPreKeys(SignalServiceAddress recipient, SealedSenderAccess sealedSenderAccess, boolean story) { - SignalProtocolAddress signalProtocolAddress = new SignalProtocolAddress(recipient.getIdentifier(), SignalServiceAddress.DEFAULT_DEVICE_ID); - - try { - List preKeys = getPreKeys(recipient, sealedSenderAccess, SignalServiceAddress.DEFAULT_DEVICE_ID, story); - - for (PreKeyBundle preKey : preKeys) { - Log.d(TAG, "[eagerFetch] Initializing prekey session for " + signalProtocolAddress); - - try { - SignalProtocolAddress preKeyAddress = new SignalProtocolAddress(recipient.getIdentifier(), preKey.getDeviceId()); - SignalSessionBuilder sessionBuilder = new SignalSessionBuilder(sessionLock, new SessionBuilder(aciStore, preKeyAddress, localProtocolAddress)); - sessionBuilder.process(preKey); - } catch (org.signal.libsignal.protocol.UntrustedIdentityException e) { - Log.i(TAG, "[eagerPrefetch] Untrusted identity for recipient"); - return; - - } - } - - if (eventListener.isPresent()) { - eventListener.get().onSecurityEvent(recipient); - } - } catch (IOException e) { - Log.i(TAG, "[eagerPrefetch] Network issue encountered"); - } catch (InvalidKeyException e) { - Log.i(TAG, "[eagerPrefetch] Invalid pre-key"); - return; - } - } - private List getPreKeys(SignalServiceAddress recipient, @Nullable SealedSenderAccess sealedSenderAccess, int deviceId, boolean story) throws IOException { try { // If it's only unrestricted because it's a story send, then we know it'll fail @@ -2994,11 +2932,11 @@ public class SignalServiceMessageSender { sealedSenderAccess = null; } - return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeys(recipient, sealedSenderAccess, deviceId)); + return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeysSync(recipient, sealedSenderAccess, deviceId)); } catch (NonSuccessfulResponseCodeException e) { if (e.code == 401 && story) { Log.d(TAG, "Got 401 when fetching prekey for story. Trying without UD."); - return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeys(recipient, null, deviceId)); + return NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeysSync(recipient, null, deviceId)); } else { throw e; } @@ -3019,7 +2957,7 @@ public class SignalServiceMessageSender { clearSenderKeySharedWith(recipient, mismatchedDeviceIds); for (int missingDeviceId : mismatchedDevices.getMissingDevices()) { - PreKeyBundle preKey = NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKey(recipient, missingDeviceId)); + PreKeyBundle preKey = NetworkResultUtil.toPreKeysLegacy(keysApi.getPreKeySync(recipient, missingDeviceId)); try { SignalSessionBuilder sessionBuilder = new SignalSessionBuilder(sessionLock, new SessionBuilder(aciStore, new SignalProtocolAddress(recipient.getIdentifier(), missingDeviceId), localProtocolAddress)); diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/KeysApi.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/KeysApi.kt index 0fd7488e8b..662de493e4 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/KeysApi.kt +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/KeysApi.kt @@ -55,7 +55,7 @@ class KeysApi( * - 200: Everything matches * - 409: Something doesn't match */ - fun checkRepeatedUseKeys( + fun checkRepeatedUseKeysSync( serviceIdType: ServiceIdType, identityKey: IdentityKey, signedPreKeyId: Int, @@ -83,7 +83,7 @@ class KeysApi( * GET /v2/keys?identity=[serviceIdType] * - 200: Success */ - fun getAvailablePreKeyCounts(serviceIdType: ServiceIdType): NetworkResult { + fun getAvailablePreKeyCountsSync(serviceIdType: ServiceIdType): NetworkResult { val request = WebSocketRequestMessage.get("/v2/keys?identity=${serviceIdType.queryParam()}") return NetworkResult.fromWebSocketRequest(authWebSocket, request, OneTimePreKeyCounts::class) } @@ -93,7 +93,7 @@ class KeysApi( * * PUT /v2/keys?identity=[preKeyUpload]`.serviceIdType` */ - fun setPreKeys(preKeyUpload: PreKeyUpload): NetworkResult { + fun setPreKeysSync(preKeyUpload: PreKeyUpload): NetworkResult { val signedPreKey: SignedPreKeyEntity? = if (preKeyUpload.signedPreKey != null) { SignedPreKeyEntity( preKeyUpload.signedPreKey.id.toLong(), @@ -147,7 +147,18 @@ class KeysApi( * - 404: No keys found for address/device * - 429: Rate limited */ - fun getPreKeys( + fun getPreKeysSync( + destination: SignalServiceAddress, + sealedSenderAccess: SealedSenderAccess?, + deviceId: Int + ): NetworkResult> { + return getPreKeysBySpecifierSync(destination, sealedSenderAccess, if (deviceId == 1) "*" else deviceId.toString()) + } + + /** + * Coroutine-friendly variant of [getPreKeysSync] that suspends instead of blocking the calling thread. + */ + suspend fun getPreKeys( destination: SignalServiceAddress, sealedSenderAccess: SealedSenderAccess?, deviceId: Int @@ -165,8 +176,8 @@ class KeysApi( * - 404: No keys found for address/device * - 429: Rate limited */ - fun getPreKey(destination: SignalServiceAddress, deviceId: Int): NetworkResult { - return getPreKeysBySpecifier(destination, null, deviceId.toString()) + fun getPreKeySync(destination: SignalServiceAddress, deviceId: Int): NetworkResult { + return getPreKeysBySpecifierSync(destination, null, deviceId.toString()) .then { bundles -> if (bundles.isNotEmpty()) { NetworkResult.Success(bundles[0]) @@ -188,9 +199,8 @@ class KeysApi( * - 404: No keys found for address/device * - 429: Rate limited */ - private fun getPreKeysBySpecifier(destination: SignalServiceAddress, sealedSenderAccess: SealedSenderAccess?, deviceSpecifier: String): NetworkResult> { - val request = WebSocketRequestMessage.get("/v2/keys/${destination.identifier}/$deviceSpecifier") - Log.d(TAG, "Fetching prekeys for ${destination.identifier}.$deviceSpecifier, i.e. GET ${request.path}") + private fun getPreKeysBySpecifierSync(destination: SignalServiceAddress, sealedSenderAccess: SealedSenderAccess?, deviceSpecifier: String): NetworkResult> { + val request = preKeysRequest(destination, deviceSpecifier) val result: NetworkResult = NetworkResult.fromWebSocket { if (sealedSenderAccess != null) { @@ -200,22 +210,54 @@ class KeysApi( } } - if (result is NetworkResult.StatusCodeError && result.code == 404) { - return NetworkResult.NetworkError(UnregisteredUserException(destination.identifier, result.exception)) + return result.toPreKeyBundles(destination) + } + + /** + * Coroutine-friendly counterpart to [getPreKeysBySpecifierSync] that suspends instead of blocking the calling thread. + */ + private suspend fun getPreKeysBySpecifier( + destination: SignalServiceAddress, + sealedSenderAccess: SealedSenderAccess?, + deviceSpecifier: String + ): NetworkResult> { + val request = preKeysRequest(destination, deviceSpecifier) + val converter = NetworkResult.DefaultWebSocketConverter(PreKeyResponse::class) + + val result: NetworkResult = NetworkResult.fromWebSocketSuspend(converter) { + if (sealedSenderAccess != null) { + unauthWebSocket.requestSuspend(request, sealedSenderAccess) + } else { + authWebSocket.requestSuspend(request) + } } - return result.map { response -> + return result.toPreKeyBundles(destination) + } + + private fun preKeysRequest(destination: SignalServiceAddress, deviceSpecifier: String): WebSocketRequestMessage { + val request = WebSocketRequestMessage.get("/v2/keys/${destination.identifier}/$deviceSpecifier") + Log.d(TAG, "Fetching prekeys for ${destination.identifier}.$deviceSpecifier, i.e. GET ${request.path}") + return request + } + + private fun NetworkResult.toPreKeyBundles(destination: SignalServiceAddress): NetworkResult> { + if (this is NetworkResult.StatusCodeError && this.code == 404) { + return NetworkResult.NetworkError(UnregisteredUserException(destination.identifier, this.exception)) + } + + return this.map { response -> val bundles: MutableList = LinkedList() for (device in response.getDevices()) { var preKey: ECPublicKey? = null - var signedPreKey: ECPublicKey? = null - var signedPreKeySignature: ByteArray? = null + var signedPreKey: ECPublicKey? + var signedPreKeySignature: ByteArray? var preKeyId = PreKeyBundle.NULL_PRE_KEY_ID - var signedPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID - var kyberPreKeyId = PreKeyBundle.NULL_PRE_KEY_ID - var kyberPreKey: KEMPublicKey? = null - var kyberPreKeySignature: ByteArray? = null + var signedPreKeyId: Int + var kyberPreKeyId: Int + var kyberPreKey: KEMPublicKey? + var kyberPreKeySignature: ByteArray? if (device.getSignedPreKey() != null) { val rawSignedPreKeyId = device.getSignedPreKey().keyId diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/PreKeyRepository.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/PreKeyRepository.kt new file mode 100644 index 0000000000..166db473d0 --- /dev/null +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/keys/PreKeyRepository.kt @@ -0,0 +1,185 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.keys + +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import org.signal.core.util.Stopwatch +import org.signal.core.util.logging.Log +import org.signal.libsignal.protocol.InvalidKeyException +import org.signal.libsignal.protocol.SessionBuilder +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.UntrustedIdentityException +import org.signal.libsignal.protocol.state.PreKeyBundle +import org.whispersystems.signalservice.api.NetworkResult +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.SignalSessionLock +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.push.SignalServiceAddress + +/** + * Perform pre-key operations to establish sessions. + */ +class PreKeyRepository( + private val keysApi: KeysApi, + private val aciStore: SignalServiceAccountDataStore, + private val localProtocolAddress: SignalProtocolAddress, + private val batchHelper: BatchHelper +) { + + companion object { + private val TAG = Log.tag(PreKeyRepository::class) + + private const val MAX_PARALLEL_FETCHES = 32 + + @OptIn(ExperimentalCoroutinesApi::class) + private val fetchDispatcher = Dispatchers.IO.limitedParallelism(MAX_PARALLEL_FETCHES, "PreKeyRepository") + } + + /** + * Wraps prekey fetching that initializes sessions in advance of a send. + * + * Network fetches run in parallel, once they all complete, sessions are built sequentially + * under a single acquisition of the [SignalSessionLock] so individual fetch threads are not + * competing for the lock. + */ + fun eagerlyFetchMissingPreKeys(requests: List, onSecurityEvent: ((SignalServiceAddress) -> Unit)? = null) { + val stopwatch = Stopwatch("eagerPrefetch") + + val needsFetch = requests.filter { request -> + val defaultAddress = SignalProtocolAddress(request.recipient.identifier, SignalServiceAddress.DEFAULT_DEVICE_ID) + !aciStore.containsSession(defaultAddress) + } + stopwatch.split("filter") + + if (needsFetch.isEmpty()) { + return + } + + Log.i(TAG, "[eagerPrefetch] Attempting to fetch prekeys for ${needsFetch.size} recipients") + + val fetched: List = try { + runBlocking { fetchAll(needsFetch) } + } catch (e: InterruptedException) { + Log.w(TAG, "[eagerPrefetch] Interrupted while fetching prekeys", e) + return + } + stopwatch.split("fetch") + + val securityEventRecipients = applySessions(fetched) + stopwatch.split("apply") + + if (onSecurityEvent != null) { + for (recipient in securityEventRecipients) { + onSecurityEvent(recipient) + } + } + + stopwatch.stop(TAG) + } + + private suspend fun fetchAll(requests: List): List = coroutineScope { + val tasks: List> = requests.map { request -> + async(fetchDispatcher) { + when (val result = getPreKeys(request)) { + is NetworkResult.Success -> PreKeyFetchResult.Success(request.recipient, result.result) + else -> { + Log.d(TAG, "[eagerPrefetch] Network issue encountered for ${request.recipient.identifier}") + PreKeyFetchResult.Failure + } + } + } + } + + try { + tasks.awaitAll() + } catch (e: Exception) { + Log.w(TAG, "Hit an exception that caused us to end early.", e) + emptyList() + } + } + + private suspend fun getPreKeys(request: EagerPreKeyRequest): NetworkResult> { + val sealedSenderAccess = if (request.story && SealedSenderAccess.isUnrestrictedForStory(request.sealedSenderAccess)) null else request.sealedSenderAccess + + val response = keysApi.getPreKeys(request.recipient, sealedSenderAccess, SignalServiceAddress.DEFAULT_DEVICE_ID) + + if (response is NetworkResult.StatusCodeError && response.code == 401 && request.story) { + Log.d(TAG, "Got 401 when fetching prekey for story. Trying without UD.") + return keysApi.getPreKeys(request.recipient, null, SignalServiceAddress.DEFAULT_DEVICE_ID) + } + + return response + } + + private fun applySessions(results: List): List { + if (results.isEmpty()) { + return emptyList() + } + + val securityEventRecipients = mutableListOf() + + batchHelper.batch { + for (result in results) { + if (result !is PreKeyFetchResult.Success) continue + + val recipient = result.recipient + var aborted = false + + for (preKey in result.bundles) { + val preKeyAddress = SignalProtocolAddress(recipient.identifier, preKey.deviceId) + try { + SessionBuilder(aciStore, preKeyAddress, localProtocolAddress).process(preKey) + } catch (_: UntrustedIdentityException) { + Log.i(TAG, "[eagerPrefetch] Untrusted identity for recipient") + aborted = true + break + } catch (_: InvalidKeyException) { + Log.i(TAG, "[eagerPrefetch] Invalid pre-key") + aborted = true + break + } + } + + if (!aborted) { + securityEventRecipients += recipient + } + } + } + + return securityEventRecipients + } + + data class EagerPreKeyRequest( + val recipient: SignalServiceAddress, + val sealedSenderAccess: SealedSenderAccess?, + val story: Boolean + ) + + private sealed interface PreKeyFetchResult { + data class Success(val recipient: SignalServiceAddress, val bundles: List) : PreKeyFetchResult + data object Failure : PreKeyFetchResult + } + + fun interface BatchHelper { + + /** + * Establishes the thread local used to optimize batch updating many Recipients when their + * identity keys change. + * + * When saving an identity from libsignal session creation, the save will happen, but defer + * rotating storage id, schedule storage sync job, and updating the live recipients. + * + * After the [block] is finished it will then perform the deferred operations. + */ + fun batch(block: Runnable) + } +}