From 66f0470960596850ab4d65d076105aef68be0b3d Mon Sep 17 00:00:00 2001 From: Cody Henthorne Date: Mon, 23 Feb 2026 11:37:11 -0500 Subject: [PATCH] Improve incoming group message processing. --- .../securesms/database/GroupTable.kt | 12 +- .../securesms/database/MessageTable.kt | 14 +- .../securesms/database/ThreadTable.kt | 5 + .../securesms/jobs/PushProcessMessageJob.kt | 18 +-- .../securesms/messages/BatchCache.kt | 127 ++++++++++++++++++ .../messages/DataMessageProcessor.kt | 40 ++++-- .../messages/IncomingMessageObserver.kt | 15 ++- .../messages/MessageContentProcessor.kt | 26 +++- .../GroupMessageProcessingBenchmarks.kt | 14 +- 9 files changed, 222 insertions(+), 49 deletions(-) create mode 100644 app/src/main/java/org/thoughtcrime/securesms/messages/BatchCache.kt diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/GroupTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/GroupTable.kt index 8547a9ed3b..e67cae4c5f 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/GroupTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/GroupTable.kt @@ -313,18 +313,12 @@ class GroupTable(context: Context?, databaseHelper: SignalDatabase?) : * @return local db group revision or -1 if not present. */ fun getGroupV2Revision(groupId: GroupId.V2): Int { - readableDatabase - .select() + return readableDatabase + .select(V2_REVISION) .from(TABLE_NAME) .where("$GROUP_ID = ?", groupId.toString()) .run() - .use { cursor -> - return if (cursor.moveToNext()) { - cursor.getInt(cursor.getColumnIndexOrThrow(V2_REVISION)) - } else { - -1 - } - } + .readToSingleInt(-1) } fun isUnknownGroup(groupId: GroupId): Boolean { diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt index e7bfd99a8b..1ccc1b731b 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/MessageTable.kt @@ -2800,7 +2800,8 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat retrieved: IncomingMessage, candidateThreadId: Long = -1, editedMessage: MmsMessageRecord? = null, - notifyObservers: Boolean = true + notifyObservers: Boolean = true, + skipThreadUpdate: Boolean = false ): Optional { val type = retrieved.toMessageType() @@ -2901,7 +2902,7 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat messageRanges = retrieved.messageRanges, contentValues = contentValues, insertListener = null, - updateThread = updateThread, + updateThread = updateThread && !skipThreadUpdate, unarchive = true, poll = retrieved.poll, pollTerminate = retrieved.messageExtras?.pollTerminate, @@ -2971,7 +2972,8 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat threadId = threadId, threadWasNewlyCreated = threadIdResult.newlyCreated, insertedAttachments = insertedAttachments, - quoteAttachmentId = quoteAttachments.firstOrNull()?.let { insertedAttachments?.get(it) } + quoteAttachmentId = quoteAttachments.firstOrNull()?.let { insertedAttachments?.get(it) }, + needsThreadUpdate = updateThread && skipThreadUpdate ) ) } @@ -3576,8 +3578,7 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat val contentValuesThreadId = contentValues.getAsLong(THREAD_ID) if (updateThread) { - threads.setLastScrolled(contentValuesThreadId, 0) - threads.update(threadId, unarchive) + threads.updateForMessageInsert(threadId, unarchive) } if (pinnedMessage != null && pinnedMessage.pinDurationInSeconds != PIN_FOREVER) { @@ -6093,7 +6094,8 @@ open class MessageTable(context: Context?, databaseHelper: SignalDatabase) : Dat val threadId: Long, val threadWasNewlyCreated: Boolean, val insertedAttachments: Map? = null, - val quoteAttachmentId: AttachmentId? = null + val quoteAttachmentId: AttachmentId? = null, + val needsThreadUpdate: Boolean = false ) data class MessageReceiptUpdate( diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt index 6e343f5e66..edfec744f9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/ThreadTable.kt @@ -1686,6 +1686,11 @@ class ThreadTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTa .run() } + fun updateForMessageInsert(threadId: Long, unarchive: Boolean) { + setLastScrolled(threadId, 0) + update(threadId, unarchive) + } + fun update(threadId: Long, unarchive: Boolean, syncThreadDelete: Boolean = true): Boolean { return update( threadId = threadId, diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushProcessMessageJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushProcessMessageJob.kt index 8314b89b4c..14ddd48126 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/PushProcessMessageJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/PushProcessMessageJob.kt @@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.groups.GroupChangeBusyException import org.thoughtcrime.securesms.jobmanager.Job import org.thoughtcrime.securesms.jobmanager.impl.ChangeNumberConstraint import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint +import org.thoughtcrime.securesms.messages.BatchCache import org.thoughtcrime.securesms.messages.MessageContentProcessor import org.thoughtcrime.securesms.messages.MessageDecryptor import org.thoughtcrime.securesms.messages.SignalServiceProtoUtil.groupId @@ -116,14 +117,14 @@ class PushProcessMessageJob private constructor( return QUEUE_PREFIX + recipientId.toQueueKey() } - fun processOrDefer(messageProcessor: MessageContentProcessor, result: MessageDecryptor.Result.Success, localReceiveMetric: SignalLocalMetrics.MessageReceive): PushProcessMessageJob? { + fun processOrDefer(messageProcessor: MessageContentProcessor, result: MessageDecryptor.Result.Success, localReceiveMetric: SignalLocalMetrics.MessageReceive, batchCache: BatchCache): PushProcessMessageJob? { val groupContext = GroupUtil.getGroupContextIfPresent(result.content) val groupId = groupContext?.groupId var requireNetwork = false val queueName: String = if (groupId != null) { if (groupId.isV2) { - val localRevision = groups.getGroupV2Revision(groupId.requireV2()) + val localRevision = batchCache.groupRevisionCache.getOrPut(groupId) { groups.getGroupV2Revision(groupId.requireV2()) } if (groupContext.revision!! > localRevision) { Log.i(TAG, "Adding network constraint to group-related job.") @@ -140,7 +141,7 @@ class PushProcessMessageJob private constructor( getQueueName(RecipientId.from(result.metadata.sourceServiceId)) } - return if (requireNetwork || !isQueueEmpty(queueName = queueName, isGroup = groupId != null)) { + return if (requireNetwork || !isQueueEmpty(queueName = queueName, cache = if (groupId != null) batchCache.groupQueueEmptyCache else empty1to1QueueCache)) { val builder = Parameters.Builder() .setMaxAttempts(Parameters.UNLIMITED) .addConstraint(ChangeNumberConstraint.KEY) @@ -148,10 +149,11 @@ class PushProcessMessageJob private constructor( if (requireNetwork) { builder.addConstraint(NetworkConstraint.KEY).setLifespan(TimeUnit.DAYS.toMillis(30)) } + batchCache.groupQueueEmptyCache.remove(queueName) PushProcessMessageJob(builder.build(), result.envelope.newBuilder().content(null).build(), result.content, result.metadata, result.serverDeliveredTimestamp) } else { try { - messageProcessor.process(result.envelope, result.content, result.metadata, result.serverDeliveredTimestamp, localMetric = localReceiveMetric) + messageProcessor.process(result.envelope, result.content, result.metadata, result.serverDeliveredTimestamp, localMetric = localReceiveMetric, batchCache = batchCache) } catch (e: Exception) { Log.e(TAG, "Failed to process message with timestamp ${result.envelope.timestamp}. Dropping.", e) } @@ -159,13 +161,13 @@ class PushProcessMessageJob private constructor( } } - private fun isQueueEmpty(queueName: String, isGroup: Boolean): Boolean { - if (!isGroup && empty1to1QueueCache.contains(queueName)) { + private fun isQueueEmpty(queueName: String, cache: HashSet): Boolean { + if (cache.contains(queueName)) { return true } val queueEmpty = AppDependencies.jobManager.isQueueEmpty(queueName) - if (!isGroup && queueEmpty) { - empty1to1QueueCache.add(queueName) + if (queueEmpty) { + cache.add(queueName) } return queueEmpty } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/BatchCache.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/BatchCache.kt new file mode 100644 index 0000000000..26b6a996d1 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/BatchCache.kt @@ -0,0 +1,127 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.messages + +import org.signal.libsignal.zkgroup.groups.GroupMasterKey +import org.signal.libsignal.zkgroup.groups.GroupSecretParams +import org.thoughtcrime.securesms.database.SignalDatabase +import org.thoughtcrime.securesms.database.model.GroupRecord +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.groups.GroupId +import org.thoughtcrime.securesms.jobmanager.Job +import org.thoughtcrime.securesms.messages.SignalServiceProtoUtil.groupMasterKey +import org.thoughtcrime.securesms.messages.SignalServiceProtoUtil.hasGroupContext +import org.whispersystems.signalservice.internal.push.DataMessage +import java.util.Optional + +/** + * A caching system for batch processing of incoming messages. + * + * The primary things that enables the cache to safely store various group state: + * 1. [IncomingMessageObserver] holds a group processing lock during a batch process preventing group state from changing. + * Helps enable [groupRevisionCache] and [groupRecordCache]. + * + * 2. Some group state doesn't change as it's derived from the [GroupMasterKey]. Enables [groupSecretParamsAndIdCache]. + */ +abstract class BatchCache { + companion object { + const val BATCH_SIZE = 30 + } + + abstract val batchThreadUpdates: Boolean + + val groupQueueEmptyCache = HashSet(BATCH_SIZE) + val groupRevisionCache = HashMap(BATCH_SIZE) + val groupRecordCache = HashMap>(BATCH_SIZE) + + protected val groupSecretParamsAndIdCache = HashMap>(BATCH_SIZE) + + fun getGroupInfo(message: DataMessage): Pair { + return if (message.hasGroupContext) { + groupSecretParamsAndIdCache.getOrPut(message.groupV2!!.groupMasterKey) { + val params = GroupSecretParams.deriveFromMasterKey(message.groupV2!!.groupMasterKey) + params to GroupId.v2(params.publicParams.groupIdentifier) + } + } else { + null to null + } + } + + open fun flushAndClear() { + groupQueueEmptyCache.clear() + groupRevisionCache.clear() + groupRecordCache.clear() + groupSecretParamsAndIdCache.clear() + } + + protected fun flushJob(job: Job) { + AppDependencies.jobManager.add(job) + } + + protected fun flushIncomingMessageInsertThreadUpdate(threadId: Long) { + SignalDatabase.threads.updateForMessageInsert(threadId, unarchive = true) + } + + abstract fun addJob(job: Job) + abstract fun addIncomingMessageInsertThreadUpdate(threadId: Long) +} + +/** + * This is intended to be used when processing messages outside of [IncomingMessageObserver] where + * no batching is possible, mostly when the [org.thoughtcrime.securesms.jobs.PushProcessMessageJob] runs. + */ +class OneTimeBatchCache : BatchCache() { + override val batchThreadUpdates: Boolean = false + + override fun addJob(job: Job) { + flushJob(job) + } + + override fun addIncomingMessageInsertThreadUpdate(threadId: Long) { + flushIncomingMessageInsertThreadUpdate(threadId) + } +} + +/** + * This is intended to be used in [IncomingMessageObserver] to batch jobs (e.g., [org.thoughtcrime.securesms.jobs.SendDeliveryReceiptJob]) + * and dedupe and batch calls to [SignalDatabase.threads.updateForMessageInsert]. + * + * Why Jobs? There's a lot of locking and database management when adding a job. Delaying that work from the processing loop + * and doing it all at once reduces the number of times we need to do either, reducing overall contention. + * + * Why thread updates? Thread updating has always been the longest thing to do in message processing. Deduping allows + * us to only call it once per thread in a batch instead of X times a message for that thread is in the batch. + */ +class ReusedBatchCache : BatchCache() { + override val batchThreadUpdates: Boolean = true + + private val batchedJobs = ArrayList(BATCH_SIZE) + private val threadUpdates = HashSet(BATCH_SIZE) + + override fun addJob(job: Job) { + batchedJobs += job + } + + override fun addIncomingMessageInsertThreadUpdate(threadId: Long) { + threadUpdates += threadId + } + + override fun flushAndClear() { + super.flushAndClear() + + if (batchedJobs.isNotEmpty()) { + AppDependencies.jobManager.addAll(batchedJobs) + } + batchedJobs.clear() + + if (threadUpdates.isNotEmpty()) { + SignalDatabase.runInTransaction { + threadUpdates.forEach { flushIncomingMessageInsertThreadUpdate(it) } + } + } + threadUpdates.clear() + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/DataMessageProcessor.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/DataMessageProcessor.kt index d6d953189c..4d002e4997 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/DataMessageProcessor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/DataMessageProcessor.kt @@ -9,12 +9,10 @@ import org.signal.core.models.ServiceId.ACI import org.signal.core.util.Base64 import org.signal.core.util.Hex import org.signal.core.util.UuidUtil -import org.signal.core.util.concurrent.SignalExecutors import org.signal.core.util.isNotEmpty import org.signal.core.util.logging.Log import org.signal.core.util.orNull import org.signal.core.util.toOptional -import org.signal.libsignal.zkgroup.groups.GroupSecretParams import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.LocalStickerAttachment @@ -103,6 +101,7 @@ import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.MessageConstraintsUtil import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.SignalLocalMetrics +import org.thoughtcrime.securesms.util.SignalTrace import org.thoughtcrime.securesms.util.TextSecurePreferences import org.thoughtcrime.securesms.util.hasGiftBadge import org.thoughtcrime.securesms.util.isStory @@ -137,14 +136,15 @@ object DataMessageProcessor { metadata: EnvelopeMetadata, receivedTime: Long, earlyMessageCacheEntry: EarlyMessageCacheEntry?, - localMetrics: SignalLocalMetrics.MessageReceive? + localMetrics: SignalLocalMetrics.MessageReceive?, + batchCache: BatchCache ) { val message: DataMessage = content.dataMessage!! - val groupSecretParams = if (message.hasGroupContext) GroupSecretParams.deriveFromMasterKey(message.groupV2!!.groupMasterKey) else null - val groupId: GroupId.V2? = if (groupSecretParams != null) GroupId.v2(groupSecretParams.publicParams.groupIdentifier) else null + val (groupSecretParams, groupId) = batchCache.getGroupInfo(message) var groupProcessResult: MessageContentProcessor.Gv2PreProcessResult? = null if (groupId != null) { + SignalTrace.beginSection("DataMessageProcessor#gv2PreProcessing") groupProcessResult = MessageContentProcessor.handleGv2PreProcessing( context = context, timestamp = envelope.timestamp!!, @@ -154,8 +154,10 @@ object DataMessageProcessor { groupV2 = message.groupV2!!, senderRecipient = senderRecipient, groupSecretParams = groupSecretParams, - serverGuid = UuidUtil.getStringUUID(envelope.serverGuid, envelope.serverGuidBinary) + serverGuid = UuidUtil.getStringUUID(envelope.serverGuid, envelope.serverGuidBinary), + batchCache = batchCache ) + SignalTrace.endSection() if (groupProcessResult == MessageContentProcessor.Gv2PreProcessResult.IGNORE) { return @@ -165,6 +167,7 @@ object DataMessageProcessor { var insertResult: InsertResult? = null var messageId: MessageId? = null + SignalTrace.beginSection("DataMessageProcessor#messageInsert") when { message.isInvalid -> handleInvalidMessage(context, senderRecipient.id, groupId, envelope.timestamp!!) message.isEndSession -> insertResult = handleEndSessionMessage(context, senderRecipient.id, envelope, metadata) @@ -177,8 +180,8 @@ object DataMessageProcessor { message.payment != null -> insertResult = handlePayment(context, envelope, metadata, message, senderRecipient.id, receivedTime) message.storyContext != null -> insertResult = handleStoryReply(context, envelope, metadata, message, senderRecipient, groupId, receivedTime) message.giftBadge != null -> insertResult = handleGiftMessage(context, envelope, metadata, message, senderRecipient, threadRecipient.id, receivedTime) - message.isMediaMessage -> insertResult = handleMediaMessage(context, envelope, metadata, message, senderRecipient, threadRecipient, groupId, receivedTime, localMetrics) - message.body != null -> insertResult = handleTextMessage(context, envelope, metadata, message, senderRecipient, threadRecipient, groupId, receivedTime, localMetrics) + message.isMediaMessage -> insertResult = handleMediaMessage(context, envelope, metadata, message, senderRecipient, threadRecipient, groupId, receivedTime, localMetrics, batchCache) + message.body != null -> insertResult = handleTextMessage(context, envelope, metadata, message, senderRecipient, threadRecipient, groupId, receivedTime, localMetrics, batchCache) message.groupCallUpdate != null -> handleGroupCallUpdateMessage(envelope, message, senderRecipient.id, groupId) message.pollCreate != null -> insertResult = handlePollCreate(context, envelope, metadata, message, senderRecipient, threadRecipient, groupId, receivedTime) message.pollTerminate != null -> insertResult = handlePollTerminate(context, envelope, metadata, message, senderRecipient, earlyMessageCacheEntry, threadRecipient, groupId, receivedTime) @@ -187,7 +190,9 @@ object DataMessageProcessor { message.unpinMessage != null -> messageId = handleUnpinMessage(envelope, message, senderRecipient, threadRecipient, earlyMessageCacheEntry) message.adminDelete != null -> messageId = handleAdminRemoteDelete(envelope, message, senderRecipient, threadRecipient, earlyMessageCacheEntry) } + SignalTrace.endSection() + SignalTrace.beginSection("DataMessageProcessor#postProcess") messageId = messageId ?: insertResult?.messageId?.let { MessageId(it) } if (messageId != null) { log(envelope.timestamp!!, "Inserted as messageId $messageId") @@ -212,7 +217,7 @@ object DataMessageProcessor { } if (metadata.sealedSender && messageId != null) { - SignalExecutors.BOUNDED.execute { AppDependencies.jobManager.add(SendDeliveryReceiptJob(senderRecipient.id, message.timestamp!!, messageId)) } + batchCache.addJob(SendDeliveryReceiptJob(senderRecipient.id, message.timestamp!!, messageId)) } else if (!metadata.sealedSender) { if (RecipientUtil.shouldHaveProfileKey(threadRecipient)) { Log.w(MessageContentProcessor.TAG, "Received an unsealed sender message from " + senderRecipient.id + ", but they should already have our profile key. Correcting.") @@ -251,6 +256,7 @@ object DataMessageProcessor { localMetrics?.onPostProcessComplete() localMetrics?.complete(groupId != null) + SignalTrace.endSection() } private fun handleProfileKey( @@ -906,7 +912,8 @@ object DataMessageProcessor { threadRecipient: Recipient, groupId: GroupId.V2?, receivedTime: Long, - localMetrics: SignalLocalMetrics.MessageReceive? + localMetrics: SignalLocalMetrics.MessageReceive?, + batchCache: BatchCache ): InsertResult? { log(envelope.timestamp!!, "Media message.") @@ -946,9 +953,12 @@ object DataMessageProcessor { messageRanges = messageRanges ) - insertResult = SignalDatabase.messages.insertMessageInbox(mediaMessage, -1).orNull() + insertResult = SignalDatabase.messages.insertMessageInbox(retrieved = mediaMessage, candidateThreadId = -1, skipThreadUpdate = batchCache.batchThreadUpdates).orNull() if (insertResult != null) { SignalDatabase.messages.setTransactionSuccessful() + if (insertResult.needsThreadUpdate) { + batchCache.addIncomingMessageInsertThreadUpdate(insertResult.threadId) + } } } catch (e: MmsException) { throw StorageFailedException(e, metadata.sourceServiceId.toString(), metadata.sourceDeviceId) @@ -998,7 +1008,8 @@ object DataMessageProcessor { threadRecipient: Recipient, groupId: GroupId.V2?, receivedTime: Long, - localMetrics: SignalLocalMetrics.MessageReceive? + localMetrics: SignalLocalMetrics.MessageReceive?, + batchCache: BatchCache ): InsertResult? { log(envelope.timestamp!!, "Text message.") @@ -1021,10 +1032,13 @@ object DataMessageProcessor { serverGuid = UuidUtil.getStringUUID(envelope.serverGuid, envelope.serverGuidBinary) ) - val insertResult: InsertResult? = SignalDatabase.messages.insertMessageInbox(textMessage).orNull() + val insertResult: InsertResult? = SignalDatabase.messages.insertMessageInbox(textMessage, skipThreadUpdate = batchCache.batchThreadUpdates).orNull() localMetrics?.onInsertedTextMessage() return if (insertResult != null) { + if (insertResult.needsThreadUpdate) { + batchCache.addIncomingMessageInsertThreadUpdate(insertResult.threadId) + } AppDependencies.messageNotifier.updateNotification(context, ConversationId.forConversation(insertResult.threadId)) insertResult } else { diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index 2ce767800a..2644ed4e39 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -276,7 +276,7 @@ class IncomingMessageObserver( } @VisibleForTesting - fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long): List? { + fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List? { return when (envelope.type) { Envelope.Type.SERVER_DELIVERY_RECEIPT -> { SignalTrace.beginSection("IncomingMessageObserver#processReceipt") @@ -290,7 +290,7 @@ class IncomingMessageObserver( Envelope.Type.UNIDENTIFIED_SENDER, Envelope.Type.PLAINTEXT_CONTENT -> { SignalTrace.beginSection("IncomingMessageObserver#processMessage") - val followUps = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp) + val followUps = processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp, batchCache) SignalTrace.endSection() followUps } @@ -302,7 +302,7 @@ class IncomingMessageObserver( } } - private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long): List { + private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: Envelope, serverDeliveredTimestamp: Long, batchCache: BatchCache): List { val localReceiveMetric = SignalLocalMetrics.MessageReceive.start() SignalTrace.beginSection("IncomingMessageObserver#decryptMessage") val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp) @@ -312,7 +312,7 @@ class IncomingMessageObserver( SignalLocalMetrics.MessageLatency.onMessageReceived(envelope.serverTimestamp!!, serverDeliveredTimestamp, envelope.urgent!!) when (result) { is MessageDecryptor.Result.Success -> { - val job = PushProcessMessageJob.processOrDefer(messageContentProcessor, result, localReceiveMetric) + val job = PushProcessMessageJob.processOrDefer(messageContentProcessor, result, localReceiveMetric, batchCache) if (job != null) { return result.followUpOperations + FollowUpOperation { job.asChain() } } @@ -374,6 +374,7 @@ class IncomingMessageObserver( private var sleepTimer: SleepTimer private val canProcessMessages: Boolean + private val batchCache = ReusedBatchCache() init { Log.i(TAG, "Initializing! (${this.hashCode()})") @@ -433,11 +434,13 @@ class IncomingMessageObserver( GroupsV2ProcessingLock.acquireGroupProcessingLock().use { ReentrantSessionLock.INSTANCE.acquire().use { batch.forEach { response -> + SignalTrace.beginSection("IncomingMessageObserver#perMessageTransaction") val followUpOperations = SignalDatabase.runInTransaction { db -> - val followUps: List? = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp) + val followUps: List? = processEnvelope(bufferedStore, response.envelope, response.serverDeliveredTimestamp, batchCache) bufferedStore.flushToDisk() followUps } + SignalTrace.endSection() if (followUpOperations?.isNotEmpty() == true) { Log.d(TAG, "Running ${followUpOperations.size} follow-up operations...") @@ -447,6 +450,8 @@ class IncomingMessageObserver( authWebSocket.sendAck(response) } + + batchCache.flushAndClear() } } val duration = System.currentTimeMillis() - startTime diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/MessageContentProcessor.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/MessageContentProcessor.kt index d2d9b767f2..04822f46a2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/MessageContentProcessor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/MessageContentProcessor.kt @@ -233,9 +233,10 @@ open class MessageContentProcessor(private val context: Context) { groupV2: GroupContextV2, senderRecipient: Recipient, groupSecretParams: GroupSecretParams? = null, - serverGuid: String? = null + serverGuid: String? = null, + batchCache: BatchCache? = null ): Gv2PreProcessResult { - val preUpdateGroupRecord = SignalDatabase.groups.getGroup(groupId) + val preUpdateGroupRecord = batchCache?.groupRecordCache[groupId] ?: SignalDatabase.groups.getGroup(groupId) val groupUpdateResult = updateGv2GroupFromServerOrP2PChange(context, timestamp, groupV2, preUpdateGroupRecord, groupSecretParams, serverGuid) if (groupUpdateResult == null) { log(timestamp, "Ignoring GV2 message for group we are not currently in $groupId") @@ -247,6 +248,7 @@ open class MessageContentProcessor(private val context: Context) { } else { SignalDatabase.groups.getGroup(groupId) } + batchCache?.groupRecordCache?.put(groupId, groupRecord) if (groupRecord.isPresent && !groupRecord.get().members.contains(senderRecipient.id)) { log(timestamp, "Ignoring GV2 message from member not in group $groupId. Sender: ${formatSender(senderRecipient.id, metadata.sourceServiceId, metadata.sourceDeviceId)}") @@ -326,11 +328,19 @@ open class MessageContentProcessor(private val context: Context) { * store or enqueue early content jobs if we detect this as being early, to avoid recursive scenarios. */ @JvmOverloads - open fun process(envelope: Envelope, content: Content, metadata: EnvelopeMetadata, serverDeliveredTimestamp: Long, processingEarlyContent: Boolean = false, localMetric: SignalLocalMetrics.MessageReceive? = null) { + open fun process( + envelope: Envelope, + content: Content, + metadata: EnvelopeMetadata, + serverDeliveredTimestamp: Long, + processingEarlyContent: Boolean = false, + localMetric: SignalLocalMetrics.MessageReceive? = null, + batchCache: BatchCache = OneTimeBatchCache() + ) { val senderRecipient = Recipient.externalPush(SignalServiceAddress(metadata.sourceServiceId, metadata.sourceE164)) SignalTrace.beginSection("MessageContentProcessor#handleMessage") - handleMessage(senderRecipient, envelope, content, metadata, serverDeliveredTimestamp, processingEarlyContent, localMetric) + handleMessage(senderRecipient, envelope, content, metadata, serverDeliveredTimestamp, processingEarlyContent, localMetric, batchCache) SignalTrace.endSection() val earlyCacheEntries: List? = AppDependencies @@ -341,7 +351,7 @@ open class MessageContentProcessor(private val context: Context) { if (!processingEarlyContent && earlyCacheEntries != null) { log(envelope.timestamp!!, "Found " + earlyCacheEntries.size + " dependent item(s) that were retrieved earlier. Processing.") for (entry in earlyCacheEntries) { - handleMessage(senderRecipient, entry.envelope, entry.content, entry.metadata, entry.serverDeliveredTimestamp, processingEarlyContent = true, localMetric = null) + handleMessage(senderRecipient, entry.envelope, entry.content, entry.metadata, entry.serverDeliveredTimestamp, processingEarlyContent = true, localMetric = null, batchCache) } } } @@ -421,7 +431,8 @@ open class MessageContentProcessor(private val context: Context) { metadata: EnvelopeMetadata, serverDeliveredTimestamp: Long, processingEarlyContent: Boolean, - localMetric: SignalLocalMetrics.MessageReceive? + localMetric: SignalLocalMetrics.MessageReceive?, + batchCache: BatchCache ) { val threadRecipient = getMessageDestination(content, senderRecipient) @@ -446,7 +457,8 @@ open class MessageContentProcessor(private val context: Context) { metadata, receivedTime, if (processingEarlyContent) null else EarlyMessageCacheEntry(envelope, content, metadata, serverDeliveredTimestamp), - localMetric + localMetric, + batchCache ) } diff --git a/benchmark/src/main/java/org/thoughtcrime/benchmark/GroupMessageProcessingBenchmarks.kt b/benchmark/src/main/java/org/thoughtcrime/benchmark/GroupMessageProcessingBenchmarks.kt index 930d11f601..7554f760c7 100644 --- a/benchmark/src/main/java/org/thoughtcrime/benchmark/GroupMessageProcessingBenchmarks.kt +++ b/benchmark/src/main/java/org/thoughtcrime/benchmark/GroupMessageProcessingBenchmarks.kt @@ -47,7 +47,19 @@ class GroupMessageProcessingBenchmarks { mode = Mode.Average ), TraceSectionMetric( - sectionName = "MessageContentProcessor#handleMessage", + sectionName = "IncomingMessageObserver#perMessageTransaction", + mode = Mode.Average + ), + TraceSectionMetric( + sectionName = "DataMessageProcessor#gv2PreProcessing", + mode = Mode.Average + ), + TraceSectionMetric( + sectionName = "DataMessageProcessor#messageInsert", + mode = Mode.Average + ), + TraceSectionMetric( + sectionName = "DataMessageProcessor#postProcess", mode = Mode.Average ), TraceSectionMetric(