diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 2e22ad827a..afb6f206e4 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -727,6 +727,7 @@ dependencies { } implementation(libs.dnsjava) implementation(libs.kotlinx.collections.immutable) + implementation(libs.arrow.core) implementation(libs.accompanist.permissions) implementation(libs.accompanist.drawablepainter) implementation(libs.kotlin.stdlib.jdk8) diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt index 3e83f80e89..fcb402ef46 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/conversation/InternalConversationSettingsFragment.kt @@ -188,7 +188,7 @@ class InternalConversationSettingsFragment : ComposeFragment(), InternalConversa message = OutgoingMessage(threadRecipient = recipient, sentTimeMillis = time, body = "Outgoing: $i"), threadId = targetThread ).messageId - SignalDatabase.messages.markAsSent(id, true) + SignalDatabase.messages.markAsSent(id) } else { SignalDatabase.messages.insertMessageInbox( retrieved = IncomingMessage(type = MessageType.NORMAL, from = recipient.id, sentTimeMillis = time, serverTimeMillis = time, receivedTimeMillis = System.currentTimeMillis(), body = "Incoming: $i"), @@ -218,7 +218,7 @@ class InternalConversationSettingsFragment : ComposeFragment(), InternalConversa message = OutgoingMessage(threadRecipient = recipient, sentTimeMillis = time, body = "Outgoing: $i", attachments = listOf(attachment)), threadId = targetThread ).messageId - SignalDatabase.messages.markAsSent(id, true) + SignalDatabase.messages.markAsSent(id) SignalDatabase.attachments.getAttachmentsForMessage(id).forEach { SignalDatabase.attachments.debugMakeValidForArchive(it.attachmentId) SignalDatabase.attachments.createRemoteKeyIfNecessary(it.attachmentId) @@ -252,7 +252,7 @@ class InternalConversationSettingsFragment : ComposeFragment(), InternalConversa false, null ).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) SignalDatabase.threads.update(splitThreadId, true) diff --git a/app/src/main/java/org/thoughtcrime/securesms/conversation/v2/ConversationRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/conversation/v2/ConversationRepository.kt index 70ab7fadf5..79e78857ca 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/conversation/v2/ConversationRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/conversation/v2/ConversationRepository.kt @@ -273,7 +273,7 @@ class ConversationRepository( Log.i(TAG, "Some recipients skipped when sending end poll. Resending to $filterRecipientIds") MessageSender.resendGroupMessage(applicationContext, messageRecord, filterRecipientIds) } else { - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) } emitter.onComplete() } else { @@ -381,7 +381,7 @@ class ConversationRepository( Log.i(TAG, "Some recipients skipped when sending pin message. Resending to $filterRecipientIds") MessageSender.resendGroupMessage(applicationContext, messageRecord, filterRecipientIds) } else { - SignalDatabase.messages.markAsSent(insertResult.messageId, true) + SignalDatabase.messages.markAsSent(insertResult.messageId) } emitter.onComplete() } else { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt index c85cf09416..4cc13ac782 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt @@ -2312,9 +2312,27 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat AppDependencies.databaseObserver.notifyConversationListListeners() } - fun markAsSent(messageId: Long, secure: Boolean) { + fun markAsSent(messageId: Long) { val threadId = getThreadIdForMessage(messageId) - updateMailboxBitmask(messageId, MessageTypes.BASE_TYPE_MASK, MessageTypes.BASE_SENT_TYPE or if (secure) MessageTypes.PUSH_MESSAGE_BIT or MessageTypes.SECURE_MESSAGE_BIT else 0, Optional.of(threadId)) + updateMailboxBitmask(messageId, MessageTypes.BASE_TYPE_MASK, MessageTypes.BASE_SENT_TYPE or MessageTypes.PUSH_MESSAGE_BIT or MessageTypes.SECURE_MESSAGE_BIT, Optional.of(threadId)) + AppDependencies.databaseObserver.notifyMessageUpdateObservers(MessageId(messageId)) + AppDependencies.databaseObserver.notifyConversationListListeners() + } + + fun markAsSent(messageId: Long, sealedSender: Boolean) { + val maskOff = MessageTypes.BASE_TYPE_MASK + val maskOn = MessageTypes.BASE_SENT_TYPE or MessageTypes.PUSH_MESSAGE_BIT or MessageTypes.SECURE_MESSAGE_BIT + + writableDatabase.execSQL( + """ + UPDATE $TABLE_NAME + SET + $TYPE = ($TYPE & ${MessageTypes.TOTAL_MASK - maskOff} | $maskOn ), + $UNIDENTIFIED = ${sealedSender.toInt()} + WHERE $ID = $messageId + """ + ) + AppDependencies.databaseObserver.notifyMessageUpdateObservers(MessageId(messageId)) AppDependencies.databaseObserver.notifyConversationListListeners() } @@ -2693,6 +2711,18 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat } } + fun getOutgoingMessageOrNull(messageId: Long): OutgoingMessage? { + return try { + getOutgoingMessage(messageId) + } catch (e: MmsException) { + Log.w(TAG, "Hit MmsException, returning null", e) + null + } catch (e: NoSuchMessageException) { + Log.w(TAG, "Hit NoSuchMessageException, returning null", e) + null + } + } + @Throws(MmsException::class, NoSuchMessageException::class) fun getOutgoingMessage(messageId: Long): OutgoingMessage { return queryMessages(RAW_ID_WHERE, arrayOf(messageId.toString())).readToSingleObject { cursor -> diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt index 658de2fc07..32346ec704 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt @@ -21,7 +21,9 @@ import org.signal.network.api.AttachmentApi import org.signal.network.api.CallingApi import org.signal.network.api.CdsApi import org.signal.network.api.CertificateApi +import org.signal.network.api.KeysApiV2 import org.signal.network.api.LinkDeviceApi +import org.signal.network.api.MessageApiV2 import org.signal.network.api.PaymentsApi import org.signal.network.api.ProvisioningApi import org.signal.network.api.RateLimitChallengeApi @@ -29,6 +31,7 @@ import org.signal.network.api.RemoteConfigApi import org.signal.network.api.SvrBApi import org.signal.network.api.UsernameApi import org.signal.network.rest.SignalRestClient +import org.signal.network.service.MessageService import org.thoughtcrime.securesms.BuildConfig import org.thoughtcrime.securesms.components.TypingStatusRepository import org.thoughtcrime.securesms.components.TypingStatusSender @@ -281,6 +284,10 @@ object AppDependencies { val signalServiceMessageSender: SignalServiceMessageSender get() = networkModule.signalServiceMessageSender + @JvmStatic + val messageService: MessageService + get() = networkModule.messageService + @JvmStatic val signalServiceAccountManager: SignalServiceAccountManager get() = networkModule.signalServiceAccountManager @@ -442,6 +449,7 @@ object AppDependencies { fun provideGroupsV2Operations(signalServiceConfiguration: SignalServiceConfiguration): GroupsV2Operations fun provideSignalServiceAccountManager(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, accountApi: AccountApi, pushServiceSocket: PushServiceSocket, groupsV2Operations: GroupsV2Operations): SignalServiceAccountManager fun provideSignalServiceMessageSender(protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket, messageApi: MessageApi, keysApi: KeysApi): SignalServiceMessageSender + fun provideMessageService(protocolStore: SignalServiceDataStore, messageApiV2: MessageApiV2, keysApiV2: KeysApiV2): MessageService fun provideSignalServiceMessageReceiver(pushServiceSocket: PushServiceSocket): SignalServiceMessageReceiver fun provideSignalServiceNetworkAccess(): SignalServiceNetworkAccess fun provideRecipientCache(): LiveRecipientCache diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 5cc9de31ff..7fe0880d3f 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -23,6 +23,8 @@ import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; import org.signal.network.api.ArchiveApi; +import org.signal.network.api.KeysApiV2; +import org.signal.network.api.MessageApiV2; import org.signal.network.rest.SignalRestClient; import org.signal.network.api.CallingApi; import org.signal.network.api.CdsApi; @@ -34,6 +36,7 @@ import org.signal.network.api.RateLimitChallengeApi; import org.signal.network.api.RemoteConfigApi; import org.signal.network.api.SvrBApi; import org.signal.network.api.UsernameApi; +import org.signal.network.service.MessageService; import org.thoughtcrime.securesms.BuildConfig; import org.thoughtcrime.securesms.components.TypingStatusRepository; import org.thoughtcrime.securesms.components.TypingStatusSender; @@ -102,12 +105,14 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences; import org.thoughtcrime.securesms.video.exo.GiphyMp4Cache; import org.thoughtcrime.securesms.video.exo.SimpleExoPlayerPool; import org.thoughtcrime.securesms.webrtc.audio.AudioManagerCompat; +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; import org.whispersystems.signalservice.api.SignalServiceAccountManager; import org.whispersystems.signalservice.api.SignalServiceDataStore; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; import org.whispersystems.signalservice.api.SignalServiceMessageSender; import org.whispersystems.signalservice.api.account.AccountApi; import org.signal.network.api.AttachmentApi; +import org.whispersystems.signalservice.api.crypto.SignalServiceCipher; import org.whispersystems.signalservice.api.donations.DonationsApi; import org.whispersystems.signalservice.api.groupsv2.ClientZkOperations; import org.whispersystems.signalservice.api.groupsv2.GroupsV2Operations; @@ -115,6 +120,7 @@ import org.whispersystems.signalservice.api.keys.KeysApi; import org.whispersystems.signalservice.api.keys.PreKeyRepository; import org.whispersystems.signalservice.api.message.MessageApi; import org.whispersystems.signalservice.api.profiles.ProfileApi; +import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.registration.RegistrationApi; import org.whispersystems.signalservice.api.services.DonationsService; import org.whispersystems.signalservice.api.services.ProfileService; @@ -200,6 +206,18 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { ); } + @Override + public @NonNull MessageService provideMessageService(@NonNull SignalServiceDataStore protocolStore, + @NonNull MessageApiV2 messageApiV2, + @NonNull KeysApiV2 keysApiV2) { + SignalServiceAddress localAddress = new SignalServiceAddress(SignalStore.account().requireAci(), SignalStore.account().getE164()); + int localDeviceId = SignalStore.account().getDeviceId(); + SignalServiceAccountDataStore aciStore = protocolStore.aci(); + SignalServiceCipher cipher = new SignalServiceCipher(localAddress, localDeviceId, aciStore, ReentrantSessionLock.INSTANCE, null); + + return new MessageService(localAddress, localDeviceId, messageApiV2, keysApiV2, aciStore, ReentrantSessionLock.INSTANCE, cipher, RemoteConfig.maxEnvelopeSizeBytes()); + } + @Override public @NonNull SignalServiceMessageReceiver provideSignalServiceMessageReceiver(@NonNull PushServiceSocket pushServiceSocket) { return new SignalServiceMessageReceiver(pushServiceSocket); diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt index 9d029e580d..cc9fb8b430 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt @@ -21,7 +21,9 @@ import org.signal.network.api.AttachmentApi import org.signal.network.api.CallingApi import org.signal.network.api.CdsApi import org.signal.network.api.CertificateApi +import org.signal.network.api.KeysApiV2 import org.signal.network.api.LinkDeviceApi +import org.signal.network.api.MessageApiV2 import org.signal.network.api.PaymentsApi import org.signal.network.api.ProvisioningApi import org.signal.network.api.RateLimitChallengeApi @@ -29,6 +31,7 @@ import org.signal.network.api.RemoteConfigApi import org.signal.network.api.SvrBApi import org.signal.network.api.UsernameApi import org.signal.network.rest.SignalRestClient +import org.signal.network.service.MessageService import org.thoughtcrime.securesms.crypto.storage.SignalServiceDataStoreImpl import org.thoughtcrime.securesms.groups.GroupsV2Authorization import org.thoughtcrime.securesms.groups.GroupsV2AuthorizationMemoryValueCache @@ -95,6 +98,12 @@ class NetworkDependenciesModule( } val signalServiceMessageSender: SignalServiceMessageSender by _signalServiceMessageSender + val messageApiV2: MessageApiV2 by lazy { MessageApiV2(authWebSocket, unauthWebSocket) } + + val keysApiV2: KeysApiV2 by lazy { KeysApiV2(authWebSocket, unauthWebSocket) } + + val messageService: MessageService by lazy { provider.provideMessageService(protocolStore, messageApiV2, keysApiV2) } + val incomingMessageObserver: IncomingMessageObserver by lazy { provider.provideIncomingMessageObserver(authWebSocket, unauthWebSocket) } diff --git a/app/src/main/java/org/thoughtcrime/securesms/groups/GroupManagerV2.java b/app/src/main/java/org/thoughtcrime/securesms/groups/GroupManagerV2.java index 054cef5ff1..2de85ab3f4 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/groups/GroupManagerV2.java +++ b/app/src/main/java/org/thoughtcrime/securesms/groups/GroupManagerV2.java @@ -1363,7 +1363,7 @@ final class GroupManagerV2 { long threadId = SignalDatabase.threads().getOrCreateValidThreadId(outgoingMessage.getThreadRecipient(), -1, outgoingMessage.getDistributionType()); try { long messageId = SignalDatabase.messages().insertMessageOutbox(outgoingMessage, threadId, false, null).getMessageId(); - SignalDatabase.messages().markAsSent(messageId, true); + SignalDatabase.messages().markAsSent(messageId); SignalDatabase.threads().update(threadId, true, true); } catch (MmsException e) { throw new AssertionError(e); diff --git a/app/src/main/java/org/thoughtcrime/securesms/groups/v2/processing/GroupsV2StateProcessor.kt b/app/src/main/java/org/thoughtcrime/securesms/groups/v2/processing/GroupsV2StateProcessor.kt index 021a44762c..9046a86cac 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/groups/v2/processing/GroupsV2StateProcessor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/groups/v2/processing/GroupsV2StateProcessor.kt @@ -836,7 +836,7 @@ class GroupsV2StateProcessor private constructor( try { val threadId = SignalDatabase.threads.getOrCreateThreadIdFor(groupRecipient) val id = SignalDatabase.messages.insertMessageOutbox(leaveMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(id, true) + SignalDatabase.messages.markAsSent(id) SignalDatabase.drafts.clearDrafts(threadId) SignalDatabase.threads.update(threadId, unarchive = false, allowDeletion = false) } catch (e: MmsException) { @@ -872,7 +872,7 @@ class GroupsV2StateProcessor private constructor( try { val threadId = SignalDatabase.threads.getOrCreateThreadIdFor(groupRecipient) val id = SignalDatabase.messages.insertMessageOutbox(terminateMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(id, true) + SignalDatabase.messages.markAsSent(id) SignalDatabase.threads.update(threadId, unarchive = false, allowDeletion = false) } catch (e: MmsException) { Log.w(TAG, "Failed to insert terminated group message for $groupId", e) @@ -913,7 +913,7 @@ class GroupsV2StateProcessor private constructor( try { val threadId = SignalDatabase.threads.getOrCreateThreadIdFor(groupRecipient) val id = SignalDatabase.messages.insertMessageOutbox(rejectedMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(id, true) + SignalDatabase.messages.markAsSent(id) SignalDatabase.threads.update(threadId, unarchive = false, allowDeletion = false) } catch (e: MmsException) { Log.w(TAG, "Failed to insert rejected join request message for $groupId", e) @@ -985,7 +985,7 @@ class GroupsV2StateProcessor private constructor( val threadId = SignalDatabase.threads.getOrCreateThreadIdFor(recipient) val messageId = SignalDatabase.messages.insertMessageOutbox(outgoingMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) SignalDatabase.threads.update(threadId, unarchive = false, allowDeletion = false) } catch (e: MmsException) { Log.w(TAG, "Failed to insert outgoing update message!", e) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJob.kt index def36ed531..23f54f7c56 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJob.kt @@ -28,6 +28,7 @@ import org.thoughtcrime.securesms.recipients.RecipientUtil import org.thoughtcrime.securesms.transport.RetryLaterException import org.thoughtcrime.securesms.transport.UndeliverableMessageException import org.thoughtcrime.securesms.util.MessageUtil +import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.SignalLocalMetrics import org.whispersystems.signalservice.api.SignalServiceMessageSender.IndividualSendEvents import org.whispersystems.signalservice.api.crypto.ContentHint @@ -67,12 +68,21 @@ class IndividualSendJob private constructor(parameters: Parameters, private val throw AssertionError("This job does not send group messages!") } - return IndividualSendJob(messageId, recipient, hasMedia, isScheduledSend) + return if (RemoteConfig.useIndividualSendJobV2) { + IndividualSendJobV2.create(messageId, recipient, hasMedia, isScheduledSend) + } else { + IndividualSendJob(messageId, recipient, hasMedia, isScheduledSend) + } } @JvmStatic @WorkerThread fun enqueue(context: Context, jobManager: JobManager, messageId: Long, recipient: Recipient, isScheduledSend: Boolean) { + if (RemoteConfig.useIndividualSendJobV2) { + IndividualSendJobV2.enqueue(context, messageId, recipient, isScheduledSend) + return + } + try { val message = SignalDatabase.messages.getOutgoingMessage(messageId) if (message.scheduledDate != -1L) { @@ -155,7 +165,7 @@ class IndividualSendJob private constructor(parameters: Parameters, private val val unidentified = deliver(message, originalEditedMessage) - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) markAttachmentsUploaded(messageId, message) SignalDatabase.messages.markUnidentified(messageId, unidentified) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJobV2.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJobV2.kt new file mode 100644 index 0000000000..b2f5ce5cf6 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/IndividualSendJobV2.kt @@ -0,0 +1,508 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.jobs + +import android.content.Context +import androidx.annotation.WorkerThread +import arrow.core.Either +import arrow.core.getOrElse +import arrow.core.raise.Raise +import arrow.core.raise.either +import okio.utf8Size +import org.signal.core.util.logging.Log +import org.signal.core.util.orNull +import org.signal.network.service.MessageService +import org.thoughtcrime.securesms.BuildConfig +import org.thoughtcrime.securesms.attachments.Attachment +import org.thoughtcrime.securesms.attachments.DatabaseAttachment +import org.thoughtcrime.securesms.crypto.SealedSenderAccessUtil +import org.thoughtcrime.securesms.database.MessageTypes +import org.thoughtcrime.securesms.database.RecipientTable.SealedSenderAccessMode +import org.thoughtcrime.securesms.database.SignalDatabase +import org.thoughtcrime.securesms.database.model.MessageId +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.jobmanager.CoroutineJob +import org.thoughtcrime.securesms.jobmanager.Job +import org.thoughtcrime.securesms.jobmanager.JobTracker +import org.thoughtcrime.securesms.jobmanager.impl.BackoffUtil +import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint +import org.thoughtcrime.securesms.jobmanager.impl.SealedSenderConstraint +import org.thoughtcrime.securesms.jobs.protos.IndividualSendJobV2Data +import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.ratelimit.ProofRequiredExceptionHandler +import org.thoughtcrime.securesms.recipients.Recipient +import org.thoughtcrime.securesms.recipients.RecipientUtil +import org.thoughtcrime.securesms.util.MessageUtil +import org.thoughtcrime.securesms.util.RemoteConfig +import org.thoughtcrime.securesms.util.SignalLocalMetrics +import org.thoughtcrime.securesms.util.isUrgent +import org.thoughtcrime.securesms.util.toDataMessage +import org.whispersystems.signalservice.api.crypto.ContentHint +import org.whispersystems.signalservice.api.crypto.EnvelopeContent +import org.whispersystems.signalservice.api.messages.SendMessageResult +import org.whispersystems.signalservice.api.push.SignalServiceAddress +import org.whispersystems.signalservice.api.push.exceptions.ProofRequiredException +import org.whispersystems.signalservice.internal.push.Content +import org.whispersystems.signalservice.internal.push.DataMessage +import org.whispersystems.signalservice.internal.push.EditMessage +import org.whispersystems.signalservice.internal.push.ProofRequiredResponse +import org.whispersystems.signalservice.internal.push.SyncMessage +import java.util.Optional +import java.util.concurrent.TimeUnit +import kotlin.jvm.optionals.getOrNull + +/** + * Alternate implementation of [IndividualSendJob] that: + * - Extends [Job] directly rather than going through [BaseJob]/[PushSendJob]. + * - Routes the actual send through the new [MessageService] (which encapsulates device resolution, + * prekey fetching, session building, encryption, and sync-transcript delivery). + * + * Used when [RemoteConfig.useIndividualSendJobV2] is true. + * + * Behavior should match [IndividualSendJob] exactly for observable state changes (marking sent, + * UD-mode updates, expiration starts, view-once cleanup, etc.). The primary divergence is the + * network layer. + */ +class IndividualSendJobV2 private constructor(parameters: Parameters, private val messageId: Long) : CoroutineJob(parameters) { + + companion object { + const val KEY: String = "IndividualSendJobV2" + + private val TAG = Log.tag(IndividualSendJobV2::class.java) + + @JvmStatic + fun create(messageId: Long, recipient: Recipient, hasMedia: Boolean, isScheduledSend: Boolean): Job { + check(recipient.hasServiceId) { "No ServiceId!" } + check(!recipient.isGroup) { "This job does not send group messages!" } + return IndividualSendJobV2(messageId, recipient, hasMedia, isScheduledSend) + } + + @JvmStatic + @WorkerThread + fun enqueue(context: Context, messageId: Long, recipient: Recipient, isScheduledSend: Boolean) { + val message = SignalDatabase.messages.getOutgoingMessageOrNull(messageId) + if (message == null) { + Log.w(TAG, "${logPrefix(null, messageId)} Failed to enqueue message.") + SignalDatabase.messages.markAsSentFailed(messageId) + PushSendJob.notifyMediaMessageDeliveryFailed(context, messageId) + return + } + + if (message.scheduledDate != -1L) { + AppDependencies.scheduledMessageManager.scheduleIfNecessary() + return + } + + val attachmentUploadIds: Set = PushSendJob.enqueueCompressingAndUploadAttachmentsChains(AppDependencies.jobManager, message) + val hasMedia = attachmentUploadIds.isNotEmpty() + val addHardDependencies = hasMedia && !isScheduledSend + + AppDependencies.jobManager.add( + create(messageId, recipient, hasMedia, isScheduledSend), + attachmentUploadIds, + if (addHardDependencies) recipient.id.toQueueKey() else null + ) + } + + private fun logPrefix(sentTimestamp: Long? = null, messageId: Long): String = "[${sentTimestamp ?: "?"}][$messageId]" + } + + constructor(messageId: Long, recipient: Recipient, hasMedia: Boolean, isScheduledSend: Boolean) : this( + parameters = Parameters.Builder() + .setQueue(if (isScheduledSend) recipient.id.toScheduledSendQueueKey() else recipient.id.toQueueKey(hasMedia)) + .addConstraint(NetworkConstraint.KEY) + .addConstraint(SealedSenderConstraint.KEY) + .setLifespan(TimeUnit.DAYS.toMillis(1)) + .setMaxAttempts(Parameters.UNLIMITED) + .build(), + messageId = messageId + ) + + override fun serialize(): ByteArray = IndividualSendJobV2Data(messageId = messageId).encode() + + override fun getFactoryKey(): String = KEY + + override fun onAdded() { + SignalDatabase.messages.markAsSending(messageId) + } + + override suspend fun doRun(): Result { + SignalLocalMetrics.IndividualMessageSend.onJobStarted(messageId) + val result = doWork() + SignalLocalMetrics.IndividualMessageSend.onJobFinished(messageId) + return result + } + + suspend fun doWork(): Result { + syncPreKeysIfNecessary().getOrElse { return it } + + if (SignalStore.misc.isClientDeprecated) { + Log.w(TAG, "${logPrefix()} Client is deprecated (build ${BuildConfig.BUILD_TIMESTAMP}); failing message.") + return Result.failure() + } + + if (!Recipient.self().isRegistered) { + Log.w(TAG, "${logPrefix()} Self is not registered; failing.") + return Result.failure() + } + + val message = SignalDatabase.messages.getOutgoingMessageOrNull(messageId) + if (message == null) { + Log.w(TAG, "${logPrefix()} No outgoing message found for id; failing.") + return Result.failure() + } + + val messageRecord = SignalDatabase.messages.getMessageRecordOrNull(messageId) + if (messageRecord == null) { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} No message record found for id; failing.") + return Result.failure() + } + + if (MessageTypes.isSentType(messageRecord.type)) { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Message was already sent. Ignoring.") + return Result.success() + } + + val threadId = messageRecord.threadId + val originalEditedMessage = if (message.messageToEdit > 0) { + SignalDatabase.messages.getMessageRecordOrNull(message.messageToEdit) + } else { + null + } + + if (message.body.utf8Size() > MessageUtil.MAX_INLINE_BODY_SIZE_BYTES) { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Body size exceeds limit of ${MessageUtil.MAX_INLINE_BODY_SIZE_BYTES} bytes; failing.") + return Result.failure() + } + + val recipient = message.threadRecipient.fresh().validated(message.sentTimeMillis).getOrElse { return it } + + val dataMessage = message.toDataMessage().getOrElse { error -> + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Failed to create a data message! Reason: $error") + return Result.failure() + } + + RecipientUtil.shareProfileIfFirstSecureMessage(message.threadRecipient) + + Log.i(TAG, "${logPrefix(message.sentTimeMillis)} Sending message. Recipient: ${message.threadRecipient.id}, Thread: $threadId, Attachments: ${buildAttachmentString(message.attachments)}, Editing: ${originalEditedMessage?.dateSent ?: "N/A"}") + SignalLocalMetrics.IndividualMessageSend.onDeliveryStarted(messageId, message.sentTimeMillis) + + return sendMessage(recipient, dataMessage, originalEditedMessage?.timestamp).fold( + ifRight = { success -> + val content = success.envelopeContent.content.get() + + val syntheticResult = SendMessageResult.success( + SignalServiceAddress(recipient.requireServiceId(), recipient.e164.orNull()), + success.devices, + success.sentUnidentified, + false, + 0L, + Optional.of(content) + ) + + SignalDatabase.messageLog.insertIfPossible( + recipientId = recipient.id, + sentTimestamp = message.sentTimeMillis, + sendMessageResult = syntheticResult, + contentHint = ContentHint.RESENDABLE, + messageId = MessageId(messageId), + urgent = content.isUrgent() + ) + + if (recipient.needsPniSignature) { + SignalDatabase.pendingPniSignatureMessages.insertIfNecessary(recipient.id, message.sentTimeMillis, syntheticResult) + } + + SignalDatabase.messages.markAsSent(messageId, success.sentUnidentified) + PushSendJob.markAttachmentsUploaded(messageId, message) + + SignalDatabase.threads.updateSilently(threadId, false) + + if (recipient.isSelf) { + SignalDatabase.messages.incrementDeliveryReceiptCount(message.sentTimeMillis, recipient.id, System.currentTimeMillis()) + SignalDatabase.messages.incrementReadReceiptCount(message.sentTimeMillis, recipient.id, System.currentTimeMillis()) + SignalDatabase.messages.incrementViewedReceiptCount(message.sentTimeMillis, recipient.id, System.currentTimeMillis()) + } + + val accessMode = recipient.sealedSenderAccessMode + if (success.sentUnidentified && accessMode == SealedSenderAccessMode.UNKNOWN && recipient.profileKey == null) { + SignalDatabase.recipients.setSealedSenderAccessMode(recipient.id, SealedSenderAccessMode.UNRESTRICTED) + } else if (success.sentUnidentified && accessMode == SealedSenderAccessMode.UNKNOWN) { + SignalDatabase.recipients.setSealedSenderAccessMode(recipient.id, SealedSenderAccessMode.ENABLED) + } else if (!success.sentUnidentified && accessMode != SealedSenderAccessMode.DISABLED) { + SignalDatabase.recipients.setSealedSenderAccessMode(recipient.id, SealedSenderAccessMode.DISABLED) + } + + if (originalEditedMessage != null && originalEditedMessage.expireStarted > 0) { + SignalDatabase.messages.markExpireStarted(messageId, originalEditedMessage.expireStarted) + AppDependencies.expiringMessageManager.scheduleDeletion(messageId, true, originalEditedMessage.expireStarted, originalEditedMessage.expiresIn) + } else if (message.expiresIn > 0 && !message.isExpirationUpdate) { + SignalDatabase.messages.markExpireStarted(messageId) + AppDependencies.expiringMessageManager.scheduleDeletion(messageId, true, message.expiresIn) + } + + if (message.isViewOnce) { + SignalDatabase.attachments.deleteAttachmentFilesForViewOnceMessage(messageId) + } + + ConversationShortcutRankingUpdateJob.enqueueForOutgoingIfNecessary(recipient) + Log.i(TAG, "${logPrefix(message.sentTimeMillis)} Sent message.") + Result.success() + }, + ifLeft = { error -> + when (error) { + is MessageService.SendError.IdentityMismatch -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Identity mismatch for ${error.recipient.identifier}", error.cause) + val externalRecipient = Recipient.external(error.recipient.identifier) + if (externalRecipient == null) { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Failed to create a Recipient for the identifier!") + } else { + SignalDatabase.messages.addMismatchedIdentity(messageId, externalRecipient.id, error.cause.untrustedIdentity) + SignalDatabase.messages.markAsSentFailed(messageId) + RetrieveProfileJob.enqueue(externalRecipient.id, true) + } + Result.success() + } + + MessageService.SendError.NotRegistered -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Recipient not registered") + SignalDatabase.messages.markAsSentFailed(messageId) + PushSendJob.notifyMediaMessageDeliveryFailed(context, messageId) + AppDependencies.jobManager.add(DirectoryRefreshJob(false)) + Result.success() + } + + MessageService.SendError.Unauthorized -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Unauthorized send") + Result.failure() + } + + is MessageService.SendError.ChallengeRequired -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Challenge required (options=${error.options})") + val proofResponse = ProofRequiredResponse().apply { + token = error.token + options = error.options + } + val proofException = ProofRequiredException(proofResponse, error.retryAfter?.inWholeSeconds ?: 0L) + val threadRecipient = SignalDatabase.threads.getRecipientForThreadId(threadId) + when (ProofRequiredExceptionHandler.handle(context, proofException, threadRecipient, threadId, messageId)) { + ProofRequiredExceptionHandler.Result.RETRY_NOW -> Result.retry(0L) + ProofRequiredExceptionHandler.Result.RETRY_LATER, + ProofRequiredExceptionHandler.Result.RETHROW -> Result.retry(nextRunAttemptBackoff(runAttempt + 1)) + } + } + + MessageService.SendError.ServerRejected -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Server rejected the send") + Result.failure() + } + + is MessageService.SendError.ContentTooLarge -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Content too large (${error.size} > ${error.maxAllowed} bytes); failing.") + Result.failure() + } + + MessageService.SendError.SessionAttemptsExhausted -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Exhausted device-resolution attempts; retrying") + Result.retry(nextRunAttemptBackoff(runAttempt + 1)) + } + + is MessageService.SendError.PreKeyUnavailable -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Prekey unavailable: ${error.reason}") + Result.retry(nextRunAttemptBackoff(runAttempt + 1)) + } + + is MessageService.SendError.RateLimited -> { + val defaultBackoff = nextRunAttemptBackoff(runAttempt + 1) + val serverBackoff = error.retryAfter?.inWholeMilliseconds ?: 0L + val backoff = maxOf(defaultBackoff, serverBackoff) + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Rate limited, retryAfter=${error.retryAfter}, using backoff=${backoff}ms") + Result.retry(backoff) + } + + is MessageService.SendError.NetworkError -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Network error", error.cause) + Result.retry(nextRunAttemptBackoff(runAttempt + 1)) + } + + is MessageService.SendError.ApplicationError -> when (val cause = error.cause) { + is RuntimeException -> { + Log.e(TAG, "${logPrefix(message.sentTimeMillis)} Encountered a fatal application error. Crash imminent.", cause) + Result.fatalFailure(cause) + } + + else -> { + Log.w(TAG, "${logPrefix(message.sentTimeMillis)} Application error", cause) + Result.retry(nextRunAttemptBackoff(runAttempt + 1)) + } + } + } + } + ) + } + + private suspend fun sendMessage(recipient: Recipient, dataMessage: DataMessage, editMessageTarget: Long?): Either = either { + val primaryResult = sendPrimaryMessage( + recipient = recipient, + dataMessage = dataMessage, + editMessageTarget = editMessageTarget + ).also { + SignalLocalMetrics.IndividualMessageSend.onMessageSent(messageId) + } + + if (SignalStore.account.isMultiDevice) { + sendSyncMessage(recipient, primaryResult).also { + SignalLocalMetrics.IndividualMessageSend.onSyncMessageSent(messageId) + } + } + + primaryResult + } + + private suspend fun Raise.sendPrimaryMessage(recipient: Recipient, dataMessage: DataMessage, editMessageTarget: Long?): MessageService.SendSuccess { + val content: Content = if (editMessageTarget != null) { + Content( + editMessage = EditMessage( + targetSentTimestamp = editMessageTarget, + dataMessage = dataMessage + ) + ) + } else { + val pniSignature = if (recipient.needsPniSignature) { + Log.i(TAG, "${logPrefix(dataMessage.timestamp)} Including PNI signature.") + AppDependencies.signalServiceMessageSender.createPniSignatureMessage() + } else { + null + } + + Content( + dataMessage = dataMessage, + pniSignatureMessage = pniSignature + ) + } + + val envelopeContent = EnvelopeContent.encrypted(content, ContentHint.RESENDABLE, Optional.empty()) + + // If this is a note to self message, don't actually send it. Instead, craft a result of what we *would* send. Then it'll be sent via sync message if appropriate. + if (SignalStore.account.aci == recipient.serviceId.getOrNull()) { + Log.i(TAG, "${logPrefix(dataMessage.timestamp)} Note to self. Skipping primary send.") + return MessageService.SendSuccess(envelopeContent, true, listOf(SignalServiceAddress.DEFAULT_DEVICE_ID)) + } + + return AppDependencies.messageService.sendMessage( + recipient = SignalServiceAddress(recipient.requireServiceId(), recipient.e164.orNull()), + envelopeContent = envelopeContent, + timestamp = dataMessage.timestamp!!, + sealedSenderAccess = SealedSenderAccessUtil.getSealedSenderAccessFor(recipient), + story = false, + isOnline = false, + urgent = content.isUrgent(), + onEncrypted = { SignalLocalMetrics.IndividualMessageSend.onMessageEncrypted(messageId) } + ).bind() + } + + private suspend fun Raise.sendSyncMessage(targetRecipient: Recipient, primaryResult: MessageService.SendSuccess): MessageService.SendSuccess { + val dataMessage = primaryResult.envelopeContent.content.get().dataMessage + val editMessage = primaryResult.envelopeContent.content.get().editMessage + val timestamp = dataMessage?.timestamp ?: editMessage?.dataMessage?.timestamp ?: raise(MessageService.SendError.ApplicationError(IllegalStateException("No timestamp on primary message send!"))) + + val syncContent = Content( + syncMessage = SyncMessage( + sent = SyncMessage.Sent( + destinationServiceId = targetRecipient.serviceId.get().toString(), + timestamp = timestamp, + message = dataMessage, + editMessage = editMessage + ) + ) + ) + val syncEnvelope = EnvelopeContent.encrypted(syncContent, ContentHint.IMPLICIT, Optional.empty()) + + return AppDependencies.messageService.sendMessage( + recipient = SignalServiceAddress(SignalStore.account.requireAci()), + envelopeContent = syncEnvelope, + timestamp = timestamp, + sealedSenderAccess = null, // We don't use sealed sender for sync messages + story = false, + isOnline = false, + urgent = true, + onEncrypted = { SignalLocalMetrics.IndividualMessageSend.onSyncMessageEncrypted(messageId) } + ).bind() + } + + override fun onRetry() { + SignalLocalMetrics.IndividualMessageSend.cancel(messageId) + if (runAttempt > 1) { + AppDependencies.jobManager.add(ServiceOutageDetectionJob()) + } + } + + override fun onFailure() { + SignalLocalMetrics.IndividualMessageSend.cancel(messageId) + SignalDatabase.messages.markAsSentFailed(messageId) + PushSendJob.notifyMediaMessageDeliveryFailed(context, messageId) + } + + private fun nextRunAttemptBackoff(pastAttemptCount: Int): Long { + return BackoffUtil.exponentialBackoff(pastAttemptCount, RemoteConfig.defaultMaxBackoff) + } + + /** + * Syncs prekeys if we haven't done so for a long time. In practice, we shouldn't hit this -- it's a failsafe. + * @return if non-null, this should be used as the overall job result. + */ + private fun syncPreKeysIfNecessary(): Either = either { + val timeSinceAciSignedPreKeyRotation = System.currentTimeMillis() - SignalStore.account.aciPreKeys.lastSignedPreKeyRotationTime + val timeSincePniSignedPreKeyRotation = System.currentTimeMillis() - SignalStore.account.pniPreKeys.lastSignedPreKeyRotationTime + if (timeSinceAciSignedPreKeyRotation > PreKeysSyncJob.MAXIMUM_ALLOWED_SIGNED_PREKEY_AGE || + timeSinceAciSignedPreKeyRotation < 0 || + timeSincePniSignedPreKeyRotation > PreKeysSyncJob.MAXIMUM_ALLOWED_SIGNED_PREKEY_AGE || + timeSincePniSignedPreKeyRotation < 0 + ) { + Log.w(TAG, "${logPrefix()} It's been too long since rotating our signed prekeys. Attempting to rotate now.") + val state = AppDependencies.jobManager.runSynchronously(PreKeysSyncJob.create(), TimeUnit.SECONDS.toMillis(30)) + if (state.isPresent && state.get() == JobTracker.JobState.SUCCESS) { + Log.i(TAG, "${logPrefix()} Successfully refreshed prekeys. Continuing.") + } else { + Log.w(TAG, "${logPrefix()} Failed to refresh prekeys; retrying. State: ${if (state.isEmpty) "" else state.get()}") + raise(Result.retry(nextRunAttemptBackoff(runAttempt + 1))) + } + } + } + + private fun Recipient.validated(sentTime: Long): Either = either { + if (isUnregistered) { + Log.w(TAG, "${logPrefix(sentTime)} Recipient $id not registered; failing.") + raise(Result.failure()) + } + + if (!hasServiceId) { + Log.w(TAG, "${logPrefix(sentTime)} Recipient $id has no serviceId; failing.") + raise(Result.failure()) + } + + this@validated + } + + private fun logPrefix(sentTimestamp: Long? = null): String = logPrefix(sentTimestamp, messageId) + + private fun buildAttachmentString(attachments: List): String { + return attachments.joinToString(", ") { attachment -> + when { + attachment is DatabaseAttachment -> attachment.attachmentId.toString() + attachment.uri != null -> attachment.uri.toString() + else -> attachment.toString() + } + } + } + + class Factory : Job.Factory { + override fun create(parameters: Parameters, serializedData: ByteArray?): IndividualSendJobV2 { + val data = IndividualSendJobV2Data.ADAPTER.decode(serializedData!!) + return IndividualSendJobV2(parameters, data.messageId) + } + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/JobManagerFactories.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/JobManagerFactories.java index 056104b5c0..ef3fba57ed 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/JobManagerFactories.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/JobManagerFactories.java @@ -198,6 +198,7 @@ public final class JobManagerFactories { put(InAppPaymentStripeOneTimeSetupJob.KEY, new InAppPaymentStripeOneTimeSetupJob.Factory()); put(InAppPaymentStripeRecurringSetupJob.KEY, new InAppPaymentStripeRecurringSetupJob.Factory()); put(IndividualSendJob.KEY, new IndividualSendJob.Factory()); + put(IndividualSendJobV2.KEY, new IndividualSendJobV2.Factory()); put(LeaveGroupV2Job.KEY, new LeaveGroupV2Job.Factory()); put(LeaveGroupV2WorkerJob.KEY, new LeaveGroupV2WorkerJob.Factory()); put(LinkedDeviceInactiveCheckJob.KEY, new LinkedDeviceInactiveCheckJob.Factory()); diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java index f70bae3c47..757070c3dc 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushGroupSendJob.java @@ -495,7 +495,7 @@ public final class PushGroupSendJob extends PushSendJob { } if (existingNetworkFailures.isEmpty() && existingIdentityMismatches.isEmpty()) { - database.markAsSent(messageId, true); + database.markAsSent(messageId); markAttachmentsUploaded(messageId, message); diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushSendJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushSendJob.kt index 82f8764d33..8ca106b0bd 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushSendJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushSendJob.kt @@ -73,7 +73,7 @@ abstract class PushSendJob protected constructor(parameters: Parameters) : BaseJ private val TAG = Log.tag(PushSendJob::class.java) @JvmStatic - protected fun enqueueCompressingAndUploadAttachmentsChains(jobManager: JobManager, message: OutgoingMessage): Set { + fun enqueueCompressingAndUploadAttachmentsChains(jobManager: JobManager, message: OutgoingMessage): Set { val attachments: MutableList = mutableListOf() attachments += message.attachments @@ -109,7 +109,7 @@ abstract class PushSendJob protected constructor(parameters: Parameters) : BaseJ } @JvmStatic - protected fun notifyMediaMessageDeliveryFailed(context: Context, messageId: Long) { + fun notifyMediaMessageDeliveryFailed(context: Context, messageId: Long) { val threadId = messages.getThreadIdForMessage(messageId) val recipient = threads.getRecipientForThreadId(threadId) val groupReplyStoryId = messages.getParentStoryIdForGroupReply(messageId) @@ -135,7 +135,7 @@ abstract class PushSendJob protected constructor(parameters: Parameters) : BaseJ } @JvmStatic - protected fun markAttachmentsUploaded(messageId: Long, message: OutgoingMessage) { + fun markAttachmentsUploaded(messageId: Long, message: OutgoingMessage) { val attachments: MutableList = mutableListOf() attachments += message.attachments diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RemoteDeleteSendJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/RemoteDeleteSendJob.java index 69f2d286b6..7b998f4487 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RemoteDeleteSendJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RemoteDeleteSendJob.java @@ -189,7 +189,7 @@ public class RemoteDeleteSendJob extends BaseJob { } if (recipients.isEmpty()) { - db.markAsSent(messageId, true); + db.markAsSent(messageId); } else { Log.w(TAG, "Still need to send to " + recipients.size() + " recipients. Retrying."); throw new RetryLaterException(); diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt index f17745a735..f467e1bccd 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/SyncMessageProcessor.kt @@ -425,7 +425,7 @@ object SyncMessageProcessor { SignalDatabase.messages.markUnidentified(messageId, sent.isUnidentified(toRecipient.serviceId.orNull())) } - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) if (targetMessage.expireStarted > 0) { SignalDatabase.messages.markExpireStarted(messageId, targetMessage.expireStarted) AppDependencies.expiringMessageManager.scheduleDeletion(messageId, true, targetMessage.expireStarted, targetMessage.expireStarted) @@ -498,7 +498,7 @@ object SyncMessageProcessor { SignalDatabase.messages.markUnidentified(messageId, sent.isUnidentified(toRecipient.serviceId.orNull())) } - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) val attachments: List = SignalDatabase.attachments.getAttachmentsForMessage(messageId) @@ -605,7 +605,7 @@ object SyncMessageProcessor { SignalDatabase.messages.markUnidentified(messageId, sent.isUnidentified(recipient.serviceId.orNull())) } - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) val allAttachments = SignalDatabase.attachments.getAttachmentsForMessage(messageId) val attachments: List = allAttachments.filterNot { it.isSticker } @@ -716,14 +716,14 @@ object SyncMessageProcessor { // TODO [expireVersion] After unsupported builds expire, we can remove this branch SignalDatabase.recipients.setExpireMessagesWithoutIncrementingVersion(recipient.id, sent.message!!.expireTimerDuration.inWholeSeconds.toInt()) val messageId: Long = SignalDatabase.messages.insertMessageOutbox(expirationUpdateMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) } else if (sent.message!!.expireTimerVersion!! >= recipient.expireTimerVersion) { SignalDatabase.recipients.setExpireMessages(recipient.id, sent.message!!.expireTimerDuration.inWholeSeconds.toInt(), sent.message!!.expireTimerVersion!!) if (sent.message!!.expireTimerDuration != recipient.expiresInSeconds.seconds) { log(sent.timestamp!!, "Not inserted update message as timer value did not change") val messageId: Long = SignalDatabase.messages.insertMessageOutbox(expirationUpdateMessage, threadId, false, null).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) } } else { warn(sent.timestamp!!, "[SynchronizeExpiration] Ignoring expire timer update with old version. Received: ${sent.message!!.expireTimerVersion}, Current: ${recipient.expireTimerVersion}") @@ -807,7 +807,7 @@ object SyncMessageProcessor { SignalDatabase.messages.markUnidentified(messageId, sent.isUnidentified(recipient.serviceId.orNull())) } - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) if (dataMessage.expireTimerDuration > Duration.ZERO) { SignalDatabase.messages.markExpireStarted(messageId, sent.expirationStartTimestamp ?: 0) @@ -874,7 +874,7 @@ object SyncMessageProcessor { SignalDatabase.messages.markUnidentified(messageId, sent.isUnidentified(syncDestinationRecipient.serviceId.orNull())) } - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) if (dataMessage.expireTimerDuration > Duration.ZERO) { SignalDatabase.messages.markExpireStarted(messageId, sent.expirationStartTimestamp ?: 0) @@ -949,7 +949,7 @@ object SyncMessageProcessor { log(envelopeTimestamp, "Inserted sync message as messageId $messageId") - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) if (expiresInMillis > 0) { SignalDatabase.messages.markExpireStarted(messageId, sent.expirationStartTimestamp ?: 0) @@ -1889,7 +1889,7 @@ object SyncMessageProcessor { log(envelope.clientTimestamp!!, "Inserted sync poll create message as messageId $messageId") - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) if (expiresInMillis > 0) { SignalDatabase.messages.markExpireStarted(messageId, sent.expirationStartTimestamp ?: 0) @@ -1947,7 +1947,7 @@ object SyncMessageProcessor { val receiptStatus = if (recipient.isGroup) GroupReceiptTable.STATUS_UNKNOWN else GroupReceiptTable.STATUS_UNDELIVERED val messageId = SignalDatabase.messages.insertMessageOutbox(outgoingMessage, threadId, false, receiptStatus, null).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) log(envelope.clientTimestamp!!, "Inserted sync poll end message as messageId $messageId") @@ -2014,7 +2014,7 @@ object SyncMessageProcessor { ) val messageId = SignalDatabase.messages.insertMessageOutbox(outgoingMessage, threadId, false, GroupReceiptTable.STATUS_UNKNOWN, null).messageId - SignalDatabase.messages.markAsSent(messageId, true) + SignalDatabase.messages.markAsSent(messageId) log(envelope.clientTimestamp!!, "Inserted sync pin message as messageId $messageId") diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt b/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt index ee5a93c686..352fe44afe 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt @@ -1180,6 +1180,19 @@ object RemoteConfig { hotSwappable = true ) + /** + * When true, individual 1:1 sends are routed through [IndividualSendJobV2], which uses the + * network-module [org.signal.network.service.MessageService] instead of the legacy + * [SignalServiceMessageSender] send path. + */ + @JvmStatic + @get:JvmName("useIndividualSendJobV2") + val useIndividualSendJobV2: Boolean by remoteBoolean( + key = "android.useIndividualSendJobV2", + defaultValue = false, + hotSwappable = true + ) + /** * Also determines how long an unregistered/deleted record should remain in storage service */ diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java b/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java index 2ac80d3d97..62afaab757 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/SignalLocalMetrics.java @@ -131,10 +131,10 @@ public final class SignalLocalMetrics { private static final String SPLIT_DB_INSERT = "db-insert"; private static final String SPLIT_JOB_ENQUEUE = "job-enqueue"; private static final String SPLIT_JOB_PRE_NETWORK = "job-pre-network"; - private static final String SPLIT_ENCRYPT = "encrypt"; - private static final String SPLIT_NETWORK_MAIN = "network-main"; + private static final String SPLIT_MAIN_ENCRYPT = "main-encrypt"; + private static final String SPLIT_MAIN_NETWORK = "main-network"; private static final String SPLIT_SYNC_ENCRYPT = "sync-encrypt"; - private static final String SPLIT_NETWORK_SYNC = "network-sync"; + private static final String SPLIT_SYNC_NETWORK = "sync-network"; private static final String SPLIT_JOB_POST_NETWORK = "job-post-network"; private static final String SPLIT_UI_UPDATE = "ui-update"; @@ -167,11 +167,11 @@ public final class SignalLocalMetrics { } public static void onMessageEncrypted(long messageId) { - split(messageId, SPLIT_ENCRYPT); + split(messageId, SPLIT_MAIN_ENCRYPT); } public static void onMessageSent(long messageId) { - split(messageId, SPLIT_NETWORK_MAIN); + split(messageId, SPLIT_MAIN_NETWORK); } public static void onSyncMessageEncrypted(long messageId) { @@ -179,7 +179,7 @@ public final class SignalLocalMetrics { } public static void onSyncMessageSent(long messageId) { - split(messageId, SPLIT_NETWORK_SYNC); + split(messageId, SPLIT_SYNC_NETWORK); } public static void onJobFinished(long messageId) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/SignalServiceTransformExtensions.kt b/app/src/main/java/org/thoughtcrime/securesms/util/SignalServiceTransformExtensions.kt new file mode 100644 index 0000000000..7c20670be8 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/util/SignalServiceTransformExtensions.kt @@ -0,0 +1,529 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.thoughtcrime.securesms.util + +import arrow.core.Either +import arrow.core.raise.context.bind +import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.ensureNotNull +import okio.ByteString +import okio.ByteString.Companion.toByteString +import org.signal.core.models.ServiceId +import org.signal.core.models.ServiceId.ACI +import org.signal.core.util.Base64 +import org.signal.core.util.Hex +import org.signal.core.util.UuidUtil +import org.signal.core.util.logging.Log +import org.signal.libsignal.zkgroup.InvalidInputException +import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation +import org.thoughtcrime.securesms.attachments.Attachment +import org.thoughtcrime.securesms.contactshare.Contact +import org.thoughtcrime.securesms.crypto.ProfileKeyUtil +import org.thoughtcrime.securesms.database.MessageTable +import org.thoughtcrime.securesms.database.SignalDatabase +import org.thoughtcrime.securesms.database.model.databaseprotos.BodyRangeList +import org.thoughtcrime.securesms.database.model.databaseprotos.GiftBadge +import org.thoughtcrime.securesms.database.model.databaseprotos.PinnedMessage +import org.thoughtcrime.securesms.database.model.databaseprotos.PollTerminate +import org.thoughtcrime.securesms.linkpreview.LinkPreview +import org.thoughtcrime.securesms.mms.OutgoingMessage +import org.thoughtcrime.securesms.mms.QuoteModel +import org.thoughtcrime.securesms.polls.Poll +import org.thoughtcrime.securesms.recipients.Recipient +import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId +import org.whispersystems.signalservice.internal.push.AttachmentPointer +import org.whispersystems.signalservice.internal.push.BodyRange +import org.whispersystems.signalservice.internal.push.CallMessage +import org.whispersystems.signalservice.internal.push.Content +import org.whispersystems.signalservice.internal.push.DataMessage +import org.whispersystems.signalservice.internal.push.Preview +import org.whispersystems.signalservice.internal.push.SyncMessage +import java.io.IOException + +private const val TAG = "DataMessageTransforms" + +/** + * Builds the wire [DataMessage] for this outgoing message. It is technically possible, though rare, that we may not be + * able to successfully construct a model. These are almost certainly data consistency bugs, and we'd rather fail the + * send than send something that doesn't match the user intent. + */ +fun OutgoingMessage.toDataMessage(): Either = either { + val builder = DataMessage.Builder() + + builder.body = body.ifEmpty { null } + builder.timestamp = sentTimeMillis + builder.profileKey = threadRecipient.fresh().selfProfileKeyForOutgoing() + builder.sticker = attachments.toStickerIfPresent().bind() + builder.contact = sharedContacts.map { it.toProto().bind() } + builder.preview = linkPreviews.map { it.toProto().bind() } + builder.giftBadge = giftBadge?.toProto()?.bind() + builder.bodyRanges = bodyRanges?.toProto()?.bind() ?: emptyList() + builder.pollCreate = poll?.toProto() + builder.pollTerminate = messageExtras?.pollTerminate?.toProto() + builder.pinMessage = messageExtras?.pinnedMessage?.toProto()?.bind() + builder.payment = toPaymentProtoIfPresent().bind() + builder.isViewOnce = isViewOnce + builder.flags = if (isExpirationUpdate) DataMessage.Flags.EXPIRATION_TIMER_UPDATE.value else null + builder.expireTimer = (expiresIn / 1000).toInt() + builder.expireTimerVersion = expireTimerVersion + builder.attachments = attachments + .filter { !it.isSticker } + .map { it.toAttachmentPointerProto().bind() } + .capIncrementalMacs(RemoteConfig.maxIncrementalMacsPerEnvelope) + + if (giftBadge != null || isPaymentsNotification) { + builder.body = null + } + + if (parentStoryId != null) { + val storyRecord = ensureNotNull(SignalDatabase.messages.getMessageRecordOrNull(parentStoryId.asMessageId().id)) { + DataMessageError.MissingParentStory + } + val storyAuthor = storyRecord.fromRecipient.requireServiceId() + builder.storyContext = DataMessage.StoryContext( + authorAciBinary = storyAuthor.toByteString(), + sentTimestamp = storyRecord.dateSent + ) + + if (isStoryReaction) { + builder.reaction = DataMessage.Reaction( + emoji = body, + remove = false, + targetAuthorAciBinary = storyAuthor.toByteString(), + targetSentTimestamp = storyRecord.dateSent + ) + builder.body = null + } + } else { + builder.quote = outgoingQuote?.toProto(isMessageEdit)?.bind() + } + + builder.requiredProtocolVersion = builder.getRequiredProtocolVersion(isViewOnce) + + builder.build() +} + +private fun DataMessage.Builder.getRequiredProtocolVersion(isViewOnce: Boolean): Int? { + var version = 0 + + if (isViewOnce) { + version = maxOf(version, DataMessage.ProtocolVersion.VIEW_ONCE_VIDEO.value) + } + + if (reaction != null) { + version = maxOf(version, DataMessage.ProtocolVersion.REACTIONS.value) + } + + if (payment != null) { + version = maxOf(version, DataMessage.ProtocolVersion.PAYMENTS.value) + } + + if (pollCreate != null) { + version = maxOf(version, DataMessage.ProtocolVersion.POLLS.value) + } + + return version.takeIf { it > 0 } +} + +private fun QuoteModel.toProto(isMessageEdit: Boolean): Either = either { + if (isMessageEdit) { + return@either DataMessage.Quote( + id = 0, + authorAciBinary = ACI.UNKNOWN.toByteString(), + text = "", + type = DataMessage.Quote.Type.NORMAL + ) + } + + val quoteAuthor = Recipient.resolved(author) + ensure(quoteAuthor.hasServiceId) { DataMessageError.MissingQuoteAuthorServiceId } + + val mentionBodyRanges: List = mentions.map { mention -> + BodyRange( + start = mention.start, + length = mention.length, + mentionAciBinary = Recipient.resolved(mention.recipientId).requireAci().toByteString() + ) + } + + val combinedBodyRanges: List = mentionBodyRanges + (bodyRanges?.toProto()?.bind() ?: emptyList()) + + val quoteAttachments = attachment + ?.takeUnless { MediaUtil.isViewOnceType(attachment.contentType) } + ?.toQuoteAttachmentProto() + ?.bind() + ?.let { listOf(it) } + + DataMessage.Quote( + id = id, + authorAciBinary = quoteAuthor.requireAci().toByteString(), + text = text, + attachments = quoteAttachments ?: emptyList(), + bodyRanges = combinedBodyRanges, + type = type.dataMessageType.protoType + ) +} + +private fun Attachment.toQuoteAttachmentProto(): Either = either { + DataMessage.Quote.QuotedAttachment( + contentType = quoteTargetContentType ?: MediaUtil.IMAGE_JPEG, + fileName = fileName, + thumbnail = toAttachmentPointerProto().bind() + ) +} + +private fun OutgoingMessage.toPaymentProtoIfPresent(): Either = either { + when { + isPaymentsNotification -> { + val paymentUuid = UuidUtil.parseOrThrow(body) + val payment = ensureNotNull(SignalDatabase.payments.getPayment(paymentUuid)) { DataMessageError.MissingPayment } + val receipt = ensureNotNull(payment.receipt) { DataMessageError.MissingPaymentReceipt } + + DataMessage.Payment( + notification = DataMessage.Payment.Notification( + note = payment.note, + mobileCoin = DataMessage.Payment.Notification.MobileCoin(receipt = receipt.toByteString()) + ) + ) + } + isRequestToActivatePayments -> { + DataMessage.Payment(activation = DataMessage.Payment.Activation(type = DataMessage.Payment.Activation.Type.REQUEST)) + } + isPaymentsActivated -> { + DataMessage.Payment(activation = DataMessage.Payment.Activation(type = DataMessage.Payment.Activation.Type.ACTIVATED)) + } + else -> { + null + } + } +} + +private fun Recipient.selfProfileKeyForOutgoing(): ByteString? { + val resolved = this.resolve() + return if (resolved.isSystemContact || resolved.isProfileSharing) { + ProfileKeyUtil.getSelfProfileKey().serialize().toByteString() + } else { + null + } +} + +private fun Attachment.toAttachmentPointerProto(): Either = either { + if (remoteLocation.isNullOrEmpty() || remoteKey.isNullOrEmpty() || remoteDigest == null) { + raise(DataMessageError.MissingAttachmentRemoteFields) + } + + val remoteIdResolved: SignalServiceAttachmentRemoteId = SignalServiceAttachmentRemoteId.from(remoteLocation) + + val keyBytes: ByteArray = try { + Base64.decode(remoteKey) + } catch (_: IOException) { + raise(DataMessageError.FailedToDecodeAttachmentKey) + } + + val sizeInt: Int = try { + Math.toIntExact(size) + } catch (_: ArithmeticException) { + Log.w(TAG, "Failed to parse attachment size! Skipping attachment.") + raise(DataMessageError.FailedToDecodeAttachmentSize) + } + + var flags = 0 + if (voiceNote) { + flags = flags or AttachmentPointer.Flags.VOICE_MESSAGE.value + } + if (borderless) { + flags = flags or AttachmentPointer.Flags.BORDERLESS.value + } + if (videoGif) { + flags = flags or AttachmentPointer.Flags.GIF.value + } + + val builder = AttachmentPointer.Builder() + .cdnNumber(cdn.cdnNumber) + .contentType(contentType) + .key(keyBytes.toByteString()) + .digest(remoteDigest.toByteString()) + .size(sizeInt) + .uploadTimestamp(uploadTimestamp) + .flags(flags) + + when (remoteIdResolved) { + is SignalServiceAttachmentRemoteId.V2 -> builder.cdnId(remoteIdResolved.cdnId) + is SignalServiceAttachmentRemoteId.V4 -> builder.cdnKey(remoteIdResolved.cdnKey) + is SignalServiceAttachmentRemoteId.S3, + is SignalServiceAttachmentRemoteId.Backup -> Unit + } + + incrementalDigest?.let { builder.incrementalMac(it.toByteString()) } + incrementalMacChunkSize.takeIf { it > 0 }?.let { builder.chunkSize(incrementalMacChunkSize) } + width.takeIf { it > 0 }?.let { builder.width(it) } + height.takeIf { it > 0 }?.let { builder.height(it) } + fileName?.let { builder.fileName(it) } + caption?.let { builder.caption(it) } + blurHash?.let { builder.blurHash(it.hash) } + uuid?.let { builder.clientUuid(UuidUtil.toByteString(it)) } + + builder.build() +} + +private fun List.toStickerIfPresent(): Either = either { + val stickerAttachment = firstOrNull { it.isSticker } ?: return@either null + val locator = ensureNotNull(stickerAttachment.stickerLocator) { DataMessageError.MissingStickerLocator } + + try { + val packId = Hex.fromStringCondensed(locator.packId) + val packKey = Hex.fromStringCondensed(locator.packKey) + val emoji = SignalDatabase.stickers.getSticker(locator.packId, locator.stickerId, false)?.emoji + DataMessage.Sticker( + packId = packId.toByteString(), + packKey = packKey.toByteString(), + stickerId = locator.stickerId, + emoji = emoji, + data_ = stickerAttachment.toAttachmentPointerProto().bind() + ) + } catch (e: IOException) { + Log.w(TAG, "Failed to decode sticker pack fields.", e) + raise(DataMessageError.FailedToDecodeStickerPackFields) + } +} + +private fun GiftBadge.toProto(): Either = either { + try { + val presentation = ReceiptCredentialPresentation(redemptionToken.toByteArray()) + DataMessage.GiftBadge(receiptCredentialPresentation = presentation.serialize().toByteString()) + } catch (e: InvalidInputException) { + Log.w(TAG, "Failed to parse gift badge.", e) + raise(DataMessageError.InvalidGiftBadge) + } +} + +private fun BodyRangeList.toProto(): Either> = either { + if (ranges.isEmpty()) { + return@either emptyList() + } + + ranges.map { range -> + val style = when (range.style) { + BodyRangeList.BodyRange.Style.BOLD -> BodyRange.Style.BOLD + BodyRangeList.BodyRange.Style.ITALIC -> BodyRange.Style.ITALIC + BodyRangeList.BodyRange.Style.SPOILER -> BodyRange.Style.SPOILER + BodyRangeList.BodyRange.Style.STRIKETHROUGH -> BodyRange.Style.STRIKETHROUGH + BodyRangeList.BodyRange.Style.MONOSPACE -> BodyRange.Style.MONOSPACE + null -> raise(DataMessageError.InvalidBodyRange) + } + BodyRange.Builder().start(range.start).length(range.length).style(style).build() + } +} + +private fun Poll.toProto(): DataMessage.PollCreate { + return DataMessage.PollCreate( + question = this.question, + allowMultiple = this.allowMultipleVotes, + options = this.pollOptions + ) +} + +private fun PollTerminate.toProto(): DataMessage.PollTerminate { + return DataMessage.PollTerminate(targetSentTimestamp = this.targetTimestamp) +} + +private fun PinnedMessage.toProto(): Either = either { + val targetAuthor = ensureNotNull(ServiceId.parseOrNull(targetAuthorAci)) { DataMessageError.PinnedMessageInvalidAuthorAci } + val forever = pinDurationInSeconds == MessageTable.PIN_FOREVER + DataMessage.PinMessage( + targetAuthorAciBinary = targetAuthor.toByteString(), + targetSentTimestamp = targetTimestamp, + pinDurationSeconds = if (!forever) pinDurationInSeconds.toInt() else null, + pinDurationForever = if (forever) true else null + ) +} + +private fun LinkPreview.toProto(): Either = either { + Preview( + url = url, + title = title, + description = description, + date = date, + image = thumbnail.orElse(null)?.toAttachmentPointerProto()?.bind() + ) +} + +private fun Contact.toProto(): Either = either { + DataMessage.Contact( + name = DataMessage.Contact.Name( + givenName = name.givenName, + familyName = name.familyName, + prefix = name.prefix, + suffix = name.suffix, + middleName = name.middleName, + nickname = name.nickname + ), + number = phoneNumbers.map { + DataMessage.Contact.Phone(value_ = it.number, type = it.type.toProto(), label = it.label) + }, + email = emails.map { + DataMessage.Contact.Email(value_ = it.email, type = it.type.toProto(), label = it.label) + }, + address = postalAddresses.map { + DataMessage.Contact.PostalAddress( + type = it.type.toProto(), + label = it.label, + street = it.street, + pobox = it.poBox, + neighborhood = it.neighborhood, + city = it.city, + region = it.region, + postcode = it.postalCode, + country = it.country + ) + }, + avatar = avatar?.let { avatar -> + avatar.attachment + ?.toAttachmentPointerProto() + ?.map { DataMessage.Contact.Avatar(avatar = it, isProfile = avatar.isProfile) } + ?.bind() + }, + organization = organization + ) +} + +private fun Contact.Phone.Type.toProto(): DataMessage.Contact.Phone.Type { + return when (this) { + Contact.Phone.Type.HOME -> DataMessage.Contact.Phone.Type.HOME + Contact.Phone.Type.MOBILE -> DataMessage.Contact.Phone.Type.MOBILE + Contact.Phone.Type.WORK -> DataMessage.Contact.Phone.Type.WORK + Contact.Phone.Type.CUSTOM -> DataMessage.Contact.Phone.Type.CUSTOM + } +} + +private fun Contact.Email.Type.toProto(): DataMessage.Contact.Email.Type { + return when (this) { + Contact.Email.Type.HOME -> DataMessage.Contact.Email.Type.HOME + Contact.Email.Type.MOBILE -> DataMessage.Contact.Email.Type.MOBILE + Contact.Email.Type.WORK -> DataMessage.Contact.Email.Type.WORK + Contact.Email.Type.CUSTOM -> DataMessage.Contact.Email.Type.CUSTOM + } +} + +private fun Contact.PostalAddress.Type.toProto(): DataMessage.Contact.PostalAddress.Type { + return when (this) { + Contact.PostalAddress.Type.HOME -> DataMessage.Contact.PostalAddress.Type.HOME + Contact.PostalAddress.Type.WORK -> DataMessage.Contact.PostalAddress.Type.WORK + Contact.PostalAddress.Type.CUSTOM -> DataMessage.Contact.PostalAddress.Type.CUSTOM + } +} + +/** + * Strips `incrementalMac` (and its sibling `chunkSize`) from attachments past the [max]th one + * that carries an incremental MAC, mirroring `SignalServiceMessageSender.capIncrementalMacs`. + * [max] <= 0 disables the cap. + */ +private fun List.capIncrementalMacs(max: Int): List { + if (max <= 0) { + return this + } + + val incrementalCount = count { it.incrementalMac != null } + + if (incrementalCount <= max) { + return this + } + + var kept = 0 + return map { pointer -> + if (pointer.incrementalMac == null) { + pointer + } else if (kept < max) { + kept++ + pointer + } else { + pointer.newBuilder().incrementalMac(null).chunkSize(null).build() + } + } +} + +/** + * Whether or not the content should generate a high-priority push notification for the receiver. + */ +fun Content.isUrgent(): Boolean { + dataMessage?.let { return it.isUrgent() } + editMessage?.let { return it.dataMessage?.isUrgent() ?: false } + syncMessage?.let { return it.isUrgent() } + callMessage?.let { return it.isUrgent() } + + return false +} + +private fun DataMessage.isUrgent(): Boolean { + val flagsValue = this.flags ?: 0 + + if (flagsValue and DataMessage.Flags.EXPIRATION_TIMER_UPDATE.value != 0) { + return false + } + + if (flagsValue and DataMessage.Flags.PROFILE_KEY_UPDATE.value != 0) { + return false + } + + return !this.body.isNullOrEmpty() || + this.attachments.isNotEmpty() || + this.sticker != null || + this.reaction != null || + this.quote != null || + this.contact.isNotEmpty() || + this.giftBadge != null || + this.pollCreate != null || + this.pollTerminate != null || + this.pinMessage != null || + this.delete != null || + this.payment?.notification != null +} + +private fun SyncMessage.isUrgent(): Boolean { + if (this.read.isNotEmpty()) { + return true + } + + this.request?.let { req -> + return when (req.type) { + SyncMessage.Request.Type.CONTACTS, SyncMessage.Request.Type.KEYS -> true + else -> false + } + } + + this.callEvent?.let { event -> + return event.event == SyncMessage.CallEvent.Event.ACCEPTED + } + + return false +} + +private fun CallMessage.isUrgent(): Boolean { + if (offer != null) { + return true + } + + if (opaque?.urgency == CallMessage.Opaque.Urgency.HANDLE_IMMEDIATELY) { + return true + } + + return false +} + +sealed interface DataMessageError { + data object MissingParentStory : DataMessageError + data object MissingQuoteAuthorServiceId : DataMessageError + data object MissingPayment : DataMessageError + data object MissingPaymentReceipt : DataMessageError + data object MissingAttachmentRemoteFields : DataMessageError + data object FailedToDecodeAttachmentKey : DataMessageError + data object FailedToDecodeAttachmentSize : DataMessageError + data object FailedToDecodeStickerPackFields : DataMessageError + data object MissingStickerLocator : DataMessageError + data object PinnedMessageInvalidAuthorAci : DataMessageError + data object InvalidGiftBadge : DataMessageError + data object InvalidBodyRange : DataMessageError +} diff --git a/app/src/main/protowire/JobData.proto b/app/src/main/protowire/JobData.proto index 259acba701..de3ab021e2 100644 --- a/app/src/main/protowire/JobData.proto +++ b/app/src/main/protowire/JobData.proto @@ -282,3 +282,7 @@ message AdminDeleteJobData { repeated uint64 recipientIds = 2; uint32 initialRecipientCount = 3; } + +message IndividualSendJobV2Data { + uint64 messageId = 1; +} diff --git a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt index 1a6566cd87..0a2f9fc198 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt +++ b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt @@ -92,6 +92,14 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } + override fun provideMessageService( + protocolStore: SignalServiceDataStore, + messageApiV2: org.signal.network.api.MessageApiV2, + keysApiV2: org.signal.network.api.KeysApiV2 + ): org.signal.network.service.MessageService { + return mockk(relaxed = true) + } + override fun provideSignalServiceMessageReceiver(pushServiceSocket: PushServiceSocket): SignalServiceMessageReceiver { return mockk(relaxed = true) } diff --git a/core/serialization/build.gradle.kts b/core/serialization/build.gradle.kts index 30b7303c90..50b551ac73 100644 --- a/core/serialization/build.gradle.kts +++ b/core/serialization/build.gradle.kts @@ -27,4 +27,5 @@ dependencies { implementation(libs.kotlinx.serialization.json) implementation(libs.libsignal.client) + api(libs.arrow.core) } diff --git a/core/serialization/src/main/java/org/signal/core/util/serialization/SignalJson.kt b/core/serialization/src/main/java/org/signal/core/util/serialization/SignalJson.kt new file mode 100644 index 0000000000..6ad2913c3d --- /dev/null +++ b/core/serialization/src/main/java/org/signal/core/util/serialization/SignalJson.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.core.util.serialization + +import arrow.core.Either +import arrow.core.raise.either +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.json.Json + +/** + * Helper for working with JSON. + */ +object SignalJson { + + val json = Json { ignoreUnknownKeys = true } + + inline fun encode(input: T): Either = either { + try { + json.encodeToString(input) + } catch (e: SerializationException) { + raise(EncodeError.BadInput(e)) + } + } + + inline fun encode(serializer: SerializationStrategy, input: T): Either = either { + try { + json.encodeToString(serializer, input) + } catch (e: SerializationException) { + raise(EncodeError.BadInput(e)) + } + } + + inline fun decode(input: String): Either = either { + try { + json.decodeFromString(input) + } catch (e: SerializationException) { + raise(DecodeError.BadInput(e)) + } catch (e: IllegalStateException) { + raise(DecodeError.BadClassAssignment(e)) + } + } + + fun decode(deserializer: DeserializationStrategy, input: String): Either = either { + try { + json.decodeFromString(deserializer, input) + } catch (e: SerializationException) { + raise(DecodeError.BadInput(e)) + } catch (e: IllegalStateException) { + raise(DecodeError.BadClassAssignment(e)) + } + } + + sealed class EncodeError(val cause: Exception) { + data class BadInput(val error: SerializationException) : EncodeError(error) + } + + sealed class DecodeError(val cause: Exception) { + data class BadInput(val error: SerializationException) : DecodeError(error) + data class BadClassAssignment(val error: IllegalStateException) : DecodeError(error) + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 66922d3f12..e14f500b0c 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -200,6 +200,7 @@ dnsjava = "dnsjava:dnsjava:3.6.4" nanohttpd-webserver = { module = "org.nanohttpd:nanohttpd-webserver", version.ref = "nanohttpd" } nanohttpd-websocket = { module = "org.nanohttpd:nanohttpd-websocket", version.ref = "nanohttpd" } kotlinx-collections-immutable = "org.jetbrains.kotlinx:kotlinx-collections-immutable:0.4.0" +arrow-core = "io.arrow-kt:arrow-core:2.2.2.1" # Can't use the newest version because it hits some weird NoClassDefFoundException jknack-handlebars = "com.github.jknack:handlebars:4.0.7" diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 1f1dc06fb8..e194769df3 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -15380,6 +15380,102 @@ https://docs.gradle.org/current/userguide/dependency_verification.html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -16845,6 +16941,11 @@ https://docs.gradle.org/current/userguide/dependency_verification.html + + + + + @@ -16873,6 +16974,11 @@ https://docs.gradle.org/current/userguide/dependency_verification.html + + + + + diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 425c3be6b8..602e5ea811 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -903,7 +903,7 @@ public class SignalServiceMessageSender { return sendMessage(address, sealedSenderAccess, System.currentTimeMillis(), envelopeContent, false, null, null, false, false); } - private PniSignatureMessage createPniSignatureMessage() { + public PniSignatureMessage createPniSignatureMessage() { byte[] signature = localPniIdentity.signAlternateIdentity(aciStore.getIdentityKeyPair().getPublicKey()); return new PniSignatureMessage.Builder() diff --git a/lib/network/build.gradle.kts b/lib/network/build.gradle.kts index 679c3649bd..fd7c5687c5 100644 --- a/lib/network/build.gradle.kts +++ b/lib/network/build.gradle.kts @@ -11,6 +11,7 @@ plugins { id("org.jetbrains.kotlin.jvm") id("idea") id("org.jlleitschuh.gradle.ktlint") + alias(libs.plugins.kotlinx.serialization) } java { @@ -40,21 +41,20 @@ tasks.whenTaskAdded { dependencies { api(project(":lib:libsignal-service")) + api(project(":core:network")) + implementation(project(":core:util-jvm")) + implementation(project(":core:models-jvm")) + implementation(project(":core:serialization")) implementation(libs.libsignal.client) - api(libs.square.okhttp3) api(libs.square.okio) api(libs.rxjava3.rxjava) - implementation(libs.rxjava3.rxkotlin) implementation(libs.kotlin.stdlib.jdk8) implementation(libs.kotlinx.coroutines.core) implementation(libs.kotlinx.coroutines.core.jvm) - - api(project(":core:network")) - implementation(project(":core:util-jvm")) - implementation(project(":core:models-jvm")) + implementation(libs.kotlinx.serialization.json) testImplementation(testLibs.junit.junit) testImplementation(testLibs.assertk) diff --git a/lib/network/src/main/java/org/signal/network/api/KeysApiV2.kt b/lib/network/src/main/java/org/signal/network/api/KeysApiV2.kt new file mode 100644 index 0000000000..9ce019069b --- /dev/null +++ b/lib/network/src/main/java/org/signal/network/api/KeysApiV2.kt @@ -0,0 +1,137 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.api + +import kotlinx.serialization.Serializable +import org.signal.core.util.serialization.ByteArrayToBase64Serializer +import org.signal.core.util.serialization.SignalJson +import org.signal.libsignal.net.BadRequestError +import org.signal.libsignal.net.RequestResult +import org.signal.network.websocket.WebSocketRequestMessage +import org.signal.network.websocket.get +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.websocket.SignalWebSocket +import java.io.IOException +import kotlin.time.Duration + +/** + * Prekey endpoints. Uses [RequestResult] and kotlinx-serializable DTOs; no jackson, no libsignal-service response types. + */ +class KeysApiV2( + private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket, + private val unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket +) { + /** + * Fetch prekeys for a specific device. + * + * GET /v2/keys/[identifier]/[deviceId] + * - 200: Success + * - 401: Unauthorized + * - 404: No keys found for address/device + * - 429: Rate limited + */ + suspend fun getPreKey( + identifier: String, + deviceId: Int, + sealedSenderAccess: SealedSenderAccess? + ): RequestResult { + return getPreKeysBySpecifier(identifier, deviceId.toString(), sealedSenderAccess) + } + + /** + * Fetch prekeys for all of the recipient's devices. (Server returns a bundle per device.) + * + * Wildcard device specifier: `GET /v2/keys/{identifier}/{asterisk}` + */ + suspend fun getPreKeysForAllDevices( + identifier: String, + sealedSenderAccess: SealedSenderAccess? + ): RequestResult { + return getPreKeysBySpecifier(identifier, "*", sealedSenderAccess) + } + + private suspend fun getPreKeysBySpecifier( + identifier: String, + deviceSpecifier: String, + sealedSenderAccess: SealedSenderAccess? + ): RequestResult { + val request = WebSocketRequestMessage.get("/v2/keys/$identifier/$deviceSpecifier") + + return try { + val response = if (sealedSenderAccess != null) { + unauthWebSocket.requestSuspend(request, sealedSenderAccess) + } else { + authWebSocket.requestSuspend(request) + } + + when (response.status) { + 200 -> SignalJson.decode(PreKeyResponse.serializer(), response.body).fold( + ifLeft = { RequestResult.ApplicationError(it.cause) }, + ifRight = { RequestResult.Success(it) } + ) + 401 -> RequestResult.NonSuccess(GetPreKeysError.Unauthorized) + 404 -> RequestResult.NonSuccess(GetPreKeysError.NotFound) + 429 -> RequestResult.NonSuccess(GetPreKeysError.RateLimited(response.retryAfter())) + else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${response.status}")) + } + } catch (e: IOException) { + RequestResult.RetryableNetworkError(e) + } catch (e: Throwable) { + RequestResult.ApplicationError(e) + } + } + + /** + * Full prekey bundle for a recipient, including the shared identity key and one entry per device. + * Wire format for key/signature fields is base64; [ByteArrayToBase64Serializer] handles the conversion. + */ + @Serializable + class PreKeyResponse( + @Serializable(with = ByteArrayToBase64Serializer::class) + val identityKey: ByteArray, + val devices: List = emptyList() + ) + + @Serializable + data class PreKeyResponseItem( + val deviceId: Int, + val registrationId: Int, + val signedPreKey: SignedPreKey? = null, + val preKey: PreKey? = null, + val pqPreKey: KyberPreKey? = null + ) + + @Serializable + class PreKey( + val keyId: Long, + @Serializable(with = ByteArrayToBase64Serializer::class) + val publicKey: ByteArray + ) + + @Serializable + class SignedPreKey( + val keyId: Long, + @Serializable(with = ByteArrayToBase64Serializer::class) + val publicKey: ByteArray, + @Serializable(with = ByteArrayToBase64Serializer::class) + val signature: ByteArray + ) + + @Serializable + class KyberPreKey( + val keyId: Long, + @Serializable(with = ByteArrayToBase64Serializer::class) + val publicKey: ByteArray, + @Serializable(with = ByteArrayToBase64Serializer::class) + val signature: ByteArray + ) + + sealed interface GetPreKeysError : BadRequestError { + data object Unauthorized : GetPreKeysError + data object NotFound : GetPreKeysError + data class RateLimited(val retryAfter: Duration?) : GetPreKeysError + } +} diff --git a/lib/network/src/main/java/org/signal/network/api/MessageApiV2.kt b/lib/network/src/main/java/org/signal/network/api/MessageApiV2.kt new file mode 100644 index 0000000000..87c89ffaec --- /dev/null +++ b/lib/network/src/main/java/org/signal/network/api/MessageApiV2.kt @@ -0,0 +1,160 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.api + +import arrow.core.getOrElse +import kotlinx.serialization.Serializable +import kotlinx.serialization.Transient +import org.signal.core.util.serialization.SignalJson +import org.signal.libsignal.net.BadRequestError +import org.signal.libsignal.net.RequestResult +import org.signal.network.websocket.WebSocketRequestMessage +import org.signal.network.websocket.put +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.websocket.SignalWebSocket +import java.io.IOException +import kotlin.time.Duration + +/** + * Collection of message-related endpoints. + */ +class MessageApiV2( + private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket, + private val unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket +) { + /** + * Sends a message to a single recipient. Uses the unauthenticated websocket if [sealedSenderAccess] is provided, + * and the authenticated websocket otherwise. + * + * PUT /v1/messages/[destination]?story=[story] + * - 200: Success + * - 401: Authorization or [sealedSenderAccess] is missing or incorrect + * - 404: Recipient is not a registered Signal user + * - 409: Mismatched devices for the recipient + * - 410: Stale devices for some recipient devices + * - 428: Sender must complete a challenge before proceeding + * - 508: Server rejected the message + */ + suspend fun sendMessage( + destination: String, + messageList: SendMessageRequest, + sealedSenderAccess: SealedSenderAccess?, + story: Boolean + ): RequestResult { + val requestBody = SignalJson.encode(SendMessageRequest.serializer(), messageList).getOrElse { return RequestResult.ApplicationError(it.cause) } + val request = WebSocketRequestMessage.put("/v1/messages/$destination?story=$story", requestBody) + + return try { + val response = if (sealedSenderAccess == null) { + authWebSocket.requestSuspend(request) + } else { + unauthWebSocket.requestSuspend(request, sealedSenderAccess) + } + + when (response.status) { + 200 -> { + SignalJson + .decode(SendMessageResponse.serializer(), response.body) + .map { it.copy(sentUnidentified = response.isUnidentified) } + .fold( + ifLeft = { RequestResult.ApplicationError(it.cause) }, + ifRight = { RequestResult.Success(it) } + ) + } + 401 -> { + RequestResult.NonSuccess(SendMessageError.Unauthorized) + } + 404 -> { + RequestResult.NonSuccess(SendMessageError.NotRegistered) + } + 409 -> { + SignalJson + .decode(MismatchedDevices.serializer(), response.body) + .fold( + ifLeft = { RequestResult.ApplicationError(it.cause) }, + ifRight = { RequestResult.NonSuccess(SendMessageError.MismatchedDevicesError(it)) } + ) + } + 410 -> { + SignalJson + .decode(StaleDevices.serializer(), response.body) + .fold( + ifLeft = { RequestResult.ApplicationError(it.cause) }, + ifRight = { RequestResult.NonSuccess(SendMessageError.StaleDevicesError(it)) } + ) + } + 428 -> { + SignalJson + .decode(ProofRequiredResponseBody.serializer(), response.body) + .fold( + ifLeft = { RequestResult.ApplicationError(it.cause) }, + ifRight = { RequestResult.NonSuccess(SendMessageError.ChallengeRequired(it.token, it.options, response.retryAfter())) } + ) + } + 429 -> RequestResult.NonSuccess(SendMessageError.RateLimited(response.retryAfter())) + 508 -> RequestResult.NonSuccess(SendMessageError.ServerRejected) + else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${response.status}")) + } + } catch (e: IOException) { + RequestResult.RetryableNetworkError(e) + } catch (e: Throwable) { + RequestResult.ApplicationError(e) + } + } + + @Serializable + data class SendMessageRequest( + val messages: List, + val timestamp: Long, + val online: Boolean = false, + val urgent: Boolean = true + ) + + @Serializable + data class Message( + val type: Int, + val destinationDeviceId: Int, + val destinationRegistrationId: Int, + val content: String + ) + + @Serializable + data class SendMessageResponse( + val needsSync: Boolean = false, + @Transient val sentUnidentified: Boolean = false + ) + + @Serializable + data class MismatchedDevices( + val missingDevices: List = emptyList(), + val extraDevices: List = emptyList() + ) + + @Serializable + data class StaleDevices( + val staleDevices: List = emptyList() + ) + + /** + * Body of a 428 response. [token] is the proof-required challenge token; [options] is the + * list of supported challenge mechanisms (e.g. "captcha", "pushChallenge"). + */ + @Serializable + private data class ProofRequiredResponseBody( + val token: String, + val options: List = emptyList() + ) + + sealed class SendMessageError : BadRequestError { + data object Unauthorized : SendMessageError() + data object NotRegistered : SendMessageError() + data class MismatchedDevicesError(val devices: MismatchedDevices) : SendMessageError() + data class StaleDevicesError(val devices: StaleDevices) : SendMessageError() + data class ChallengeRequired(val token: String, val options: List, val retryAfter: Duration?) : SendMessageError() + data class RateLimited(val retryAfter: Duration?) : SendMessageError() + data object ServerRejected : SendMessageError() + } +} diff --git a/lib/network/src/main/java/org/signal/network/api/WebsocketResponseExtensions.kt b/lib/network/src/main/java/org/signal/network/api/WebsocketResponseExtensions.kt new file mode 100644 index 0000000000..bc32dbec82 --- /dev/null +++ b/lib/network/src/main/java/org/signal/network/api/WebsocketResponseExtensions.kt @@ -0,0 +1,19 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.api + +import org.signal.network.websocket.WebsocketResponse +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +/** + * Parses the `Retry-After` header as a whole number of seconds. Returns null if the header is + * absent or can't be parsed (e.g. HTTP-date form, which the server does not currently use). + */ +internal fun WebsocketResponse.retryAfter(): Duration? { + val raw = getHeader("retry-after") ?: return null + return raw.toLongOrNull()?.seconds +} diff --git a/lib/network/src/main/java/org/signal/network/service/MessageService.kt b/lib/network/src/main/java/org/signal/network/service/MessageService.kt new file mode 100644 index 0000000000..f836cb10d8 --- /dev/null +++ b/lib/network/src/main/java/org/signal/network/service/MessageService.kt @@ -0,0 +1,320 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.service + +import arrow.core.Either +import arrow.core.raise.Raise +import arrow.core.raise.either +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import org.jetbrains.annotations.VisibleForTesting +import org.signal.core.util.logging.Log +import org.signal.libsignal.net.RequestResult +import org.signal.libsignal.protocol.IdentityKey +import org.signal.libsignal.protocol.InvalidKeyException +import org.signal.libsignal.protocol.SessionBuilder +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.UntrustedIdentityException +import org.signal.libsignal.protocol.ecc.ECPublicKey +import org.signal.libsignal.protocol.kem.KEMPublicKey +import org.signal.libsignal.protocol.state.PreKeyBundle +import org.signal.network.api.KeysApiV2 +import org.signal.network.api.MessageApiV2 +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.SignalSessionLock +import org.whispersystems.signalservice.api.crypto.EnvelopeContent +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.crypto.SignalServiceCipher +import org.whispersystems.signalservice.api.crypto.SignalSessionBuilder +import org.whispersystems.signalservice.api.push.SignalServiceAddress +import org.whispersystems.signalservice.internal.push.OutgoingPushMessage +import java.io.IOException +import kotlin.time.Duration + +/** + * Sends an [EnvelopeContent] to a single recipient, driving the full one-to-one flow: + * encrypt-per-device, send, recover mismatched / stale devices by fetching prekeys and rebuilding sessions. + * + * All server interaction is delegated to [MessageApiV2] and [KeysApiV2]. Encryption is delegated to + * [cipher]. Session state is read from (and archived via) [protocolStore] under [sessionLock]. + * + * Internal helpers return [Either] of [SendError] so orchestration is driven entirely by return + * values rather than exceptions. Libsignal's checked exceptions (from `cipher.encrypt` and session + * building) are caught at the single point they can be raised and `raise`d into the matching + * [SendError] variant. + * + * Sync transcripts are the caller's responsibility — issue a second [sendMessage] to the local address + * with a SyncMessage.Sent payload after a successful primary send. + */ +open class MessageService( + private val localAddress: SignalServiceAddress, + private val localDeviceId: Int, + private val messageApi: MessageApiV2, + private val keysApi: KeysApiV2, + private val protocolStore: SignalServiceAccountDataStore, + private val sessionLock: SignalSessionLock, + private val cipher: SignalServiceCipher, + private val maxContentSizeBytes: Long = 0L +) { + + companion object { + private val TAG = Log.tag(MessageService::class) + + private const val MAX_DEVICE_RECOVERY_ATTEMPTS = 3 + } + + private val localProtocolAddress: SignalProtocolAddress = SignalProtocolAddress(localAddress.identifier, localDeviceId) + + /** + * Sends [envelopeContent] to [recipient]. Handles things like establishing sessions with newly-discovered linked devices. + */ + suspend fun sendMessage( + recipient: SignalServiceAddress, + envelopeContent: EnvelopeContent, + timestamp: Long, + sealedSenderAccess: SealedSenderAccess?, + story: Boolean, + isOnline: Boolean, + urgent: Boolean = true, + onEncrypted: (() -> Unit)? = null + ): Either = withContext(Dispatchers.IO) { + either { + val contentSize = envelopeContent.size().toLong() + if (maxContentSizeBytes > 0 && contentSize > maxContentSizeBytes) { + Log.w(TAG, "Content size $contentSize exceeds limit of $maxContentSizeBytes bytes; aborting send.") + raise(SendError.ContentTooLarge(size = contentSize, maxAllowed = maxContentSizeBytes)) + } + + var encryptedReported = false + + // Certain errors self-resolve by mutating external state, like creating new sessions. + // Trying several times in a loop lets us re-read that external state and use it in the next attempt. + for (attempt in 0 until MAX_DEVICE_RECOVERY_ATTEMPTS) { + val encrypted = encryptForAllDevices(recipient, envelopeContent, sealedSenderAccess) + + if (!encryptedReported) { + onEncrypted?.invoke() + encryptedReported = true + } + + val request = MessageApiV2.SendMessageRequest( + messages = encrypted.map { it.toWireMessage() }, + timestamp = timestamp, + online = isOnline, + urgent = urgent + ) + + when (val result = messageApi.sendMessage(recipient.identifier, request, sealedSenderAccess, story)) { + is RequestResult.Success -> { + val response = result.result + val devices = encrypted.map { it.destinationDeviceId } + return@either SendSuccess(envelopeContent = envelopeContent, sentUnidentified = response.sentUnidentified, devices = devices) + } + is RequestResult.NonSuccess -> when (val err = result.error) { + is MessageApiV2.SendMessageError.MismatchedDevicesError -> { + handleMismatched(recipient, err.devices, sealedSenderAccess) + } + is MessageApiV2.SendMessageError.StaleDevicesError -> { + for (deviceId in err.devices.staleDevices) { + protocolStore.archiveSession(SignalProtocolAddress(recipient.identifier, deviceId)) + } + } + MessageApiV2.SendMessageError.Unauthorized -> raise(SendError.Unauthorized) + MessageApiV2.SendMessageError.NotRegistered -> raise(SendError.NotRegistered) + is MessageApiV2.SendMessageError.ChallengeRequired -> raise(SendError.ChallengeRequired(err.token, err.options, err.retryAfter)) + MessageApiV2.SendMessageError.ServerRejected -> raise(SendError.ServerRejected) + is MessageApiV2.SendMessageError.RateLimited -> raise(SendError.RateLimited(err.retryAfter)) + } + is RequestResult.RetryableNetworkError -> raise(SendError.NetworkError(result.networkError)) + is RequestResult.ApplicationError -> raise(SendError.ApplicationError(result.cause)) + } + } + + Log.w(TAG, "Exhausted device-recovery attempts for ${recipient.identifier}") + raise(SendError.SessionAttemptsExhausted) + } + } + + private fun Raise.encryptForAllDevices( + recipient: SignalServiceAddress, + envelopeContent: EnvelopeContent, + sealedSenderAccess: SealedSenderAccess? + ): List { + return targetDeviceIds(recipient).map { deviceId -> + val address = SignalProtocolAddress(recipient.identifier, deviceId) + encryptContent(recipient, address, envelopeContent, sealedSenderAccess) + } + } + + private fun Raise.encryptContent( + recipient: SignalServiceAddress, + address: SignalProtocolAddress, + envelopeContent: EnvelopeContent, + sealedSenderAccess: SealedSenderAccess? + ): OutgoingPushMessage = try { + cipher.encrypt(address, sealedSenderAccess, envelopeContent) + } catch (e: UntrustedIdentityException) { + raise(SendError.IdentityMismatch(recipient, e)) + } catch (e: InvalidKeyException) { + raise(SendError.ApplicationError(e)) + } + + private fun targetDeviceIds(recipient: SignalServiceAddress): List { + val subDevices: MutableSet = (protocolStore.getSubDeviceSessions(recipient.identifier) + SignalServiceAddress.DEFAULT_DEVICE_ID).toMutableSet() + + // When sending to self, skip our own device. + if (recipient.matches(localAddress)) { + subDevices -= localDeviceId + } + + return subDevices + .filter { it == SignalServiceAddress.DEFAULT_DEVICE_ID || protocolStore.containsSession(SignalProtocolAddress(recipient.identifier, it)) } + .toList() + } + + /** + * Initialize a session with the target address, which requires fetching a prekey bundle. + */ + @VisibleForTesting + internal open suspend fun Raise.initializeSession( + recipient: SignalServiceAddress, + address: SignalProtocolAddress, + sealedSenderAccess: SealedSenderAccess? + ) { + val response = when (val result = keysApi.getPreKey(address.serviceId.toServiceIdString(), address.deviceId, sealedSenderAccess)) { + is RequestResult.Success -> result.result + is RequestResult.NonSuccess -> { + when (val e = result.error) { + KeysApiV2.GetPreKeysError.Unauthorized -> raise(SendError.Unauthorized) + KeysApiV2.GetPreKeysError.NotFound -> raise(SendError.PreKeyUnavailable("No prekeys found for $address")) + is KeysApiV2.GetPreKeysError.RateLimited -> raise(SendError.RateLimited(e.retryAfter)) + } + } + is RequestResult.RetryableNetworkError -> raise(SendError.NetworkError(result.networkError)) + is RequestResult.ApplicationError -> raise(SendError.ApplicationError(result.cause)) + } + + val item = response.devices.firstOrNull { it.deviceId == address.deviceId } + ?: raise(SendError.PreKeyUnavailable("No prekey for $address")) + + val bundle = buildPreKeyBundle(response.identityKey, item, address) + + try { + SignalSessionBuilder(sessionLock, SessionBuilder(protocolStore, address, localProtocolAddress)).process(bundle) + } catch (e: UntrustedIdentityException) { + raise(SendError.IdentityMismatch(recipient, e)) + } catch (e: InvalidKeyException) { + raise(SendError.ApplicationError(e)) + } + } + + private suspend fun Raise.handleMismatched( + recipient: SignalServiceAddress, + mismatched: MessageApiV2.MismatchedDevices, + sealedSenderAccess: SealedSenderAccess? + ) { + for (extra in mismatched.extraDevices) { + protocolStore.archiveSession(SignalProtocolAddress(recipient.identifier, extra)) + } + + for (missing in mismatched.missingDevices) { + val address = SignalProtocolAddress(recipient.identifier, missing) + initializeSession(recipient, address, sealedSenderAccess) + } + } + + private fun OutgoingPushMessage.toWireMessage(): MessageApiV2.Message = MessageApiV2.Message( + type = type, + destinationDeviceId = destinationDeviceId, + destinationRegistrationId = destinationRegistrationId, + content = content + ) + + private fun Raise.buildPreKeyBundle( + identityKey: ByteArray, + item: KeysApiV2.PreKeyResponseItem, + address: SignalProtocolAddress + ): PreKeyBundle { + val signedPreKey = item.signedPreKey ?: raise(SendError.PreKeyUnavailable("No signed prekey for $address")) + val kyberPreKey = item.pqPreKey ?: raise(SendError.PreKeyUnavailable("No kyber prekey for $address")) + + return try { + PreKeyBundle( + item.registrationId, + item.deviceId, + item.preKey?.keyId?.toInt() ?: PreKeyBundle.NULL_PRE_KEY_ID, + item.preKey?.let { ECPublicKey(it.publicKey) }, + signedPreKey.keyId.toInt(), + ECPublicKey(signedPreKey.publicKey), + signedPreKey.signature, + IdentityKey(identityKey), + kyberPreKey.keyId.toInt(), + KEMPublicKey(kyberPreKey.publicKey, 0, kyberPreKey.publicKey.size), + kyberPreKey.signature + ) + } catch (e: InvalidKeyException) { + raise(SendError.ApplicationError(e)) + } + } + + /** + * Send completed successfully. + * + * [devices] is the set of recipient devices the encrypted payload was delivered to. Callers persisting + * a [org.thoughtcrime.securesms.database.MessageSendLogTables] entry (or a pending PNI signature record) + * need this to know which sessions the recipient may later reference in a retry receipt. + */ + data class SendSuccess( + val envelopeContent: EnvelopeContent, + val sentUnidentified: Boolean, + val devices: List + ) + + sealed interface SendError { + /** You discovered a safety number change during sending. */ + data class IdentityMismatch(val recipient: SignalServiceAddress, val cause: UntrustedIdentityException) : SendError + + /** The recipient is no longer registered. */ + data object NotRegistered : SendError + + /** Invalid credentials. You are likely no longer registered. */ + data object Unauthorized : SendError + + /** + * The server wants you to complete a push challenge/captcha before continuing. + * [token] is the challenge token; [options] enumerates the supported challenge mechanisms + * (e.g. "captcha", "pushChallenge"). [retryAfter] is the Retry-After hint, if provided. + */ + data class ChallengeRequired(val token: String, val options: List, val retryAfter: Duration?) : SendError + + /** The server has fully rejected your request. This usually only happens during times of turmoil. Fail and require user action to resend. */ + data object ServerRejected : SendError + + /** + * The encoded content exceeded the configured size cap. Permanent failure for this message — + * retrying with the same content won't help. + */ + data class ContentTooLarge(val size: Long, val maxAllowed: Long) : SendError + + /** + * Each send attempt may result in us having to establish sessions with linked devices and such. This indicates that we hit our max attempt count while + * trying to handle these situations. It should be safe to retry with normal backoff. + */ + data object SessionAttemptsExhausted : SendError + + /** We needed to establish a session, but the server was missing either a signed or kyber prekey for the user. */ + data class PreKeyUnavailable(val reason: String) : SendError + + /** You're rate-limited. Use the [retryAfter] for your backoff. */ + data class RateLimited(val retryAfter: Duration?) : SendError + + /** A generic, retryable network error. */ + data class NetworkError(val cause: IOException) : SendError + + /** An unexpected error. You should likely crash. */ + data class ApplicationError(val cause: Throwable) : SendError + } +} diff --git a/lib/network/src/test/java/org/signal/network/api/MessageApiV2Test.kt b/lib/network/src/test/java/org/signal/network/api/MessageApiV2Test.kt new file mode 100644 index 0000000000..38278c46f6 --- /dev/null +++ b/lib/network/src/test/java/org/signal/network/api/MessageApiV2Test.kt @@ -0,0 +1,196 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.api + +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isInstanceOf +import assertk.assertions.isSameInstanceAs +import io.mockk.coEvery +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.test.runTest +import org.junit.Test +import org.signal.libsignal.net.RequestResult +import org.signal.network.websocket.WebSocketRequestMessage +import org.signal.network.websocket.WebsocketResponse +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.websocket.SignalWebSocket +import java.io.IOException +import kotlin.time.Duration.Companion.seconds + +class MessageApiV2Test { + + private val authSocket: SignalWebSocket.AuthenticatedWebSocket = mockk() + private val unauthSocket: SignalWebSocket.UnauthenticatedWebSocket = mockk() + private val api = MessageApiV2(authSocket, unauthSocket) + + private val request = MessageApiV2.SendMessageRequest( + messages = listOf(MessageApiV2.Message(type = 1, destinationDeviceId = 1, destinationRegistrationId = 42, content = "abc")), + timestamp = 1_700_000_000L + ) + + @Test + fun `200 parses SendMessageResponse and flags sentUnidentified from response`() = runTest { + stubAuth(status = 200, body = """{"needsSync": true}""", unidentified = true) + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + assertThat(result).isInstanceOf(RequestResult.Success::class) + val success = result as RequestResult.Success + assertThat(success.result.needsSync).isEqualTo(true) + assertThat(success.result.sentUnidentified).isEqualTo(true) + } + + @Test + fun `401 maps to Unauthorized`() = runTest { + stubAuth(status = 401) + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + assertNonSuccess(result, MessageApiV2.SendMessageError.Unauthorized) + } + + @Test + fun `404 maps to NotRegistered`() = runTest { + stubAuth(status = 404) + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + assertNonSuccess(result, MessageApiV2.SendMessageError.NotRegistered) + } + + @Test + fun `409 parses MismatchedDevices body`() = runTest { + stubAuth(status = 409, body = """{"missingDevices": [2, 3], "extraDevices": [5]}""") + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val nonSuccess = result as RequestResult.NonSuccess + val err = nonSuccess.error as MessageApiV2.SendMessageError.MismatchedDevicesError + assertThat(err.devices.missingDevices).isEqualTo(listOf(2, 3)) + assertThat(err.devices.extraDevices).isEqualTo(listOf(5)) + } + + @Test + fun `410 parses StaleDevices body`() = runTest { + stubAuth(status = 410, body = """{"staleDevices": [2]}""") + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val nonSuccess = result as RequestResult.NonSuccess + val err = nonSuccess.error as MessageApiV2.SendMessageError.StaleDevicesError + assertThat(err.devices.staleDevices).isEqualTo(listOf(2)) + } + + @Test + fun `428 parses ProofRequired body and Retry-After header`() = runTest { + val response: WebsocketResponse = mockk { + every { status } returns 428 + every { body } returns """{"token": "abc123", "options": ["captcha", "pushChallenge"]}""" + every { isUnidentified } returns false + every { getHeader("retry-after") } returns "120" + } + coEvery { authSocket.requestSuspend(any()) } returns response + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val err = (result as RequestResult.NonSuccess).error as MessageApiV2.SendMessageError.ChallengeRequired + assertThat(err.token).isEqualTo("abc123") + assertThat(err.options).isEqualTo(listOf("captcha", "pushChallenge")) + assertThat(err.retryAfter).isEqualTo(120.seconds) + } + + @Test + fun `429 with retry-after header maps to RateLimited with Duration`() = runTest { + val response: WebsocketResponse = mockk { + every { status } returns 429 + every { body } returns "{}" + every { isUnidentified } returns false + every { getHeader("retry-after") } returns "42" + } + coEvery { authSocket.requestSuspend(any()) } returns response + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val err = (result as RequestResult.NonSuccess).error as MessageApiV2.SendMessageError.RateLimited + assertThat(err.retryAfter).isEqualTo(42.seconds) + } + + @Test + fun `429 without retry-after header maps to RateLimited with null Duration`() = runTest { + val response: WebsocketResponse = mockk { + every { status } returns 429 + every { body } returns "{}" + every { isUnidentified } returns false + every { getHeader("retry-after") } returns null + } + coEvery { authSocket.requestSuspend(any()) } returns response + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val err = (result as RequestResult.NonSuccess).error as MessageApiV2.SendMessageError.RateLimited + assertThat(err.retryAfter).isEqualTo(null) + } + + @Test + fun `508 maps to ServerRejected`() = runTest { + stubAuth(status = 508) + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + assertNonSuccess(result, MessageApiV2.SendMessageError.ServerRejected) + } + + @Test + fun `unexpected status maps to ApplicationError`() = runTest { + stubAuth(status = 418) + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + assertThat(result).isInstanceOf(RequestResult.ApplicationError::class) + } + + @Test + fun `IOException from socket becomes RetryableNetworkError`() = runTest { + val ioError = IOException("socket closed") + coEvery { authSocket.requestSuspend(any()) } throws ioError + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = null, story = false) + + val retry = result as RequestResult.RetryableNetworkError + assertThat(retry.networkError).isSameInstanceAs(ioError) + } + + @Test + fun `sealedSenderAccess routes to unauthenticated socket`() = runTest { + val sealed: SealedSenderAccess = mockk() + val response: WebsocketResponse = mockk { + every { status } returns 200 + every { body } returns """{"needsSync": false}""" + every { isUnidentified } returns true + } + coEvery { unauthSocket.requestSuspend(any(), sealed) } returns response + + val result = api.sendMessage("destination-id", request, sealedSenderAccess = sealed, story = false) + + assertThat(result).isInstanceOf(RequestResult.Success::class) + } + + private fun stubAuth(status: Int, body: String = "{}", unidentified: Boolean = false) { + val response: WebsocketResponse = mockk { + every { this@mockk.status } returns status + every { this@mockk.body } returns body + every { isUnidentified } returns unidentified + } + coEvery { authSocket.requestSuspend(any()) } returns response + } + + private fun assertNonSuccess(result: RequestResult<*, *>, expected: MessageApiV2.SendMessageError) { + val nonSuccess = result as RequestResult.NonSuccess + assertThat(nonSuccess.error).isEqualTo(expected) + } +} diff --git a/lib/network/src/test/java/org/signal/network/service/MessageServiceTest.kt b/lib/network/src/test/java/org/signal/network/service/MessageServiceTest.kt new file mode 100644 index 0000000000..d550a5e0ec --- /dev/null +++ b/lib/network/src/test/java/org/signal/network/service/MessageServiceTest.kt @@ -0,0 +1,318 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.network.service + +import arrow.core.Either +import arrow.core.raise.Raise +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isInstanceOf +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import io.mockk.spyk +import io.mockk.verify +import kotlinx.coroutines.test.runTest +import org.junit.Test +import org.signal.core.models.ServiceId +import org.signal.libsignal.net.RequestResult +import org.signal.libsignal.protocol.SignalProtocolAddress +import org.signal.libsignal.protocol.UntrustedIdentityException +import org.signal.network.api.KeysApiV2 +import org.signal.network.api.MessageApiV2 +import org.whispersystems.signalservice.api.SignalServiceAccountDataStore +import org.whispersystems.signalservice.api.SignalSessionLock +import org.whispersystems.signalservice.api.crypto.EnvelopeContent +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.crypto.SignalServiceCipher +import org.whispersystems.signalservice.api.push.SignalServiceAddress +import org.whispersystems.signalservice.internal.push.OutgoingPushMessage +import java.io.IOException +import java.util.UUID +import kotlin.time.Duration.Companion.seconds + +class MessageServiceTest { + + private val messageApi: MessageApiV2 = mockk() + private val keysApi: KeysApiV2 = mockk() + private val protocolStore: SignalServiceAccountDataStore = mockk(relaxUnitFun = true) + private val sessionLock: SignalSessionLock = mockk() + private val cipher: SignalServiceCipher = mockk() + + private val localAci = ServiceId.ACI.from(UUID.fromString("aaaaaaaa-0000-0000-0000-000000000001")) + private val localAddress = SignalServiceAddress(localAci) + + private val recipientAci = ServiceId.ACI.from(UUID.fromString("bbbbbbbb-0000-0000-0000-000000000002")) + private val recipient = SignalServiceAddress(recipientAci) + + private val timestamp = 1_700_000_000L + private val envelopeContent: EnvelopeContent = mockk { + every { size() } returns 0 + } + + @Test + fun `happy path with existing session returns Success`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(SignalProtocolAddress(recipient.identifier, 1)) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returns + RequestResult.Success(MessageApiV2.SendMessageResponse(sentUnidentified = true)) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + val success = (result as Either.Right).value + assertThat(success.sentUnidentified).isEqualTo(true) + } + + @Test + fun `isOnline true is forwarded to the send request`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(SignalProtocolAddress(recipient.identifier, 1)) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + coEvery { messageApi.sendMessage(any(), any(), any(), any()) } returns + RequestResult.Success(MessageApiV2.SendMessageResponse()) + + service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = true) + + coVerify { + messageApi.sendMessage( + recipient.identifier, + match { it.online }, + null, + false + ) + } + } + + @Test + fun `sub-device without session is excluded from target devices`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns listOf(2, 3) + every { protocolStore.containsSession(SignalProtocolAddress(recipient.identifier, 2)) } returns true + every { protocolStore.containsSession(SignalProtocolAddress(recipient.identifier, 3)) } returns false + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returns + RequestResult.Success(MessageApiV2.SendMessageResponse()) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isInstanceOf(Either.Right::class) + verify { cipher.encrypt(SignalProtocolAddress(recipient.identifier, 1), any(), any()) } + verify { cipher.encrypt(SignalProtocolAddress(recipient.identifier, 2), any(), any()) } + verify(exactly = 0) { cipher.encrypt(SignalProtocolAddress(recipient.identifier, 3), any(), any()) } + } + + @Test + fun `409 MismatchedDevices archives extras, fetches missing prekeys, and retries`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + val mismatched = MessageApiV2.MismatchedDevices(missingDevices = listOf(2), extraDevices = listOf(5)) + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returnsMany listOf( + RequestResult.NonSuccess(MessageApiV2.SendMessageError.MismatchedDevicesError(mismatched)), + RequestResult.Success(MessageApiV2.SendMessageResponse()) + ) + coEvery { keysApi.getPreKey(recipient.identifier, 2, null) } returns + RequestResult.Success(KeysApiV2.PreKeyResponse(identityKey = ByteArray(0), devices = emptyList())) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isInstanceOf(Either.Right::class) + verify { protocolStore.archiveSession(SignalProtocolAddress(recipient.identifier, 5)) } + coVerify { keysApi.getPreKey(recipient.identifier, 2, null) } + } + + @Test + fun `410 StaleDevices archives stales and retries`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + val stale = MessageApiV2.StaleDevices(staleDevices = listOf(3)) + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returnsMany listOf( + RequestResult.NonSuccess(MessageApiV2.SendMessageError.StaleDevicesError(stale)), + RequestResult.Success(MessageApiV2.SendMessageResponse()) + ) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isInstanceOf(Either.Right::class) + verify { protocolStore.archiveSession(SignalProtocolAddress(recipient.identifier, 3)) } + } + + @Test + fun `repeated device conflicts exhaust retries`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + val stale = MessageApiV2.StaleDevices(staleDevices = listOf(4)) + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.StaleDevicesError(stale)) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isEqualTo(Either.Left(MessageService.SendError.SessionAttemptsExhausted)) + } + + @Test + fun `401 maps to Unauthorized`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + coEvery { messageApi.sendMessage(any(), any(), any(), any()) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.Unauthorized) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isEqualTo(Either.Left(MessageService.SendError.Unauthorized)) + } + + @Test + fun `404 maps to NotRegistered`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + coEvery { messageApi.sendMessage(any(), any(), any(), any()) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.NotRegistered) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isEqualTo(Either.Left(MessageService.SendError.NotRegistered)) + } + + @Test + fun `send 429 propagates retry-after duration via SendResult RateLimited`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + coEvery { messageApi.sendMessage(any(), any(), any(), any()) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.RateLimited(retryAfter = 30.seconds)) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isEqualTo(Either.Left(MessageService.SendError.RateLimited(retryAfter = 30.seconds))) + } + + @Test + fun `prekey 429 during mismatched-device recovery propagates retry-after as RateLimited`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + val mismatched = MessageApiV2.MismatchedDevices(missingDevices = listOf(2), extraDevices = emptyList()) + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.MismatchedDevicesError(mismatched)) + coEvery { keysApi.getPreKey(recipient.identifier, 2, null) } returns + RequestResult.NonSuccess(KeysApiV2.GetPreKeysError.RateLimited(retryAfter = 60.seconds)) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + assertThat(result).isEqualTo(Either.Left(MessageService.SendError.RateLimited(retryAfter = 60.seconds))) + } + + @Test + fun `IOException from send maps to NetworkError`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + val ioError = IOException("down") + coEvery { messageApi.sendMessage(any(), any(), any(), any()) } returns RequestResult.RetryableNetworkError(ioError) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + val network = (result as Either.Left).value as MessageService.SendError.NetworkError + assertThat(network.cause).isEqualTo(ioError) + } + + @Test + fun `UntrustedIdentityException during encryption maps to IdentityMismatch`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + val untrusted = UntrustedIdentityException(recipient.identifier) + every { cipher.encrypt(any(), any(), any()) } throws untrusted + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + val mismatch = (result as Either.Left).value as MessageService.SendError.IdentityMismatch + assertThat(mismatch.cause).isEqualTo(untrusted) + } + + @Test + fun `prekey fetch 404 during mismatched-device recovery propagates as PreKeyUnavailable`() = runTest { + val service = newService() + every { protocolStore.getSubDeviceSessions(recipient.identifier) } returns emptyList() + every { protocolStore.containsSession(any()) } returns true + every { cipher.encrypt(any(), any(), any()) } returns OutgoingPushMessage(1, 1, 100, "payload") + + val mismatched = MessageApiV2.MismatchedDevices(missingDevices = listOf(2), extraDevices = emptyList()) + coEvery { messageApi.sendMessage(recipient.identifier, any(), null, false) } returns + RequestResult.NonSuccess(MessageApiV2.SendMessageError.MismatchedDevicesError(mismatched)) + coEvery { keysApi.getPreKey(recipient.identifier, 2, null) } returns + RequestResult.NonSuccess(KeysApiV2.GetPreKeysError.NotFound) + + val result = service.sendMessage(recipient, envelopeContent, timestamp, sealedSenderAccess = null, story = false, isOnline = false) + + val left = (result as Either.Left).value + assertThat(left).isInstanceOf(MessageService.SendError.PreKeyUnavailable::class) + } + + /** + * Spy with `initializeSession` stubbed so tests don't exercise real crypto / native session building. + * The stub still invokes [KeysApiV2.getPreKey] and forwards non-success [RequestResult]s as the real + * implementation would; happy path is a no-op. + */ + private fun newService(): MessageService { + val spy: MessageService = spyk( + MessageService( + localAddress = localAddress, + localDeviceId = 1, + messageApi = messageApi, + keysApi = keysApi, + protocolStore = protocolStore, + sessionLock = sessionLock, + cipher = cipher + ) + ) + coEvery { + with(spy) { + any>().initializeSession(any(), any(), any()) + } + } coAnswers { + val raiseArg = arg>(0) + val addressArg = arg(2) + val sealedArg = arg(3) + when (val r = keysApi.getPreKey(addressArg.name, addressArg.deviceId, sealedArg)) { + is RequestResult.Success -> Unit + is RequestResult.NonSuccess -> raiseArg.raise( + when (val e = r.error) { + KeysApiV2.GetPreKeysError.Unauthorized -> MessageService.SendError.Unauthorized + KeysApiV2.GetPreKeysError.NotFound -> MessageService.SendError.PreKeyUnavailable("No prekeys found for $addressArg") + is KeysApiV2.GetPreKeysError.RateLimited -> MessageService.SendError.RateLimited(e.retryAfter) + } + ) + is RequestResult.RetryableNetworkError -> raiseArg.raise(MessageService.SendError.NetworkError(r.networkError)) + is RequestResult.ApplicationError -> raiseArg.raise(MessageService.SendError.ApplicationError(r.cause)) + } + } + return spy + } +}