Perform message decryptions in batches.

This commit is contained in:
Greyson Parrelli
2023-03-09 17:05:00 -05:00
parent 04baa7925f
commit 894095414a
17 changed files with 772 additions and 69 deletions

View File

@@ -14,6 +14,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messages.MessageContentProcessor.ExceptionMetadata
import org.thoughtcrime.securesms.messages.MessageContentProcessor.MessageState
import org.thoughtcrime.securesms.messages.MessageDecryptor
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.transport.RetryLaterException
@@ -77,7 +78,9 @@ class PushDecryptMessageJob private constructor(
throw RetryLaterException()
}
val result = MessageDecryptor.decrypt(context, envelope.proto, envelope.serverDeliveredTimestamp)
val bufferedProtocolStore = BufferedProtocolStore.create()
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope.proto, envelope.serverDeliveredTimestamp)
bufferedProtocolStore.flushToDisk()
when (result) {
is MessageDecryptor.Result.Success -> {

View File

@@ -11,6 +11,7 @@ import org.thoughtcrime.securesms.database.model.MessageId;
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.jobmanager.Data;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
import org.thoughtcrime.securesms.recipients.Recipient;
@@ -51,6 +52,7 @@ public class SendDeliveryReceiptJob extends BaseJob {
public SendDeliveryReceiptJob(@NonNull RecipientId recipientId, long messageSentTimestamp, @NonNull MessageId messageId) {
this(new Job.Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.addConstraint(DecryptionsDrainedConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.setQueue(recipientId.toQueueKey())

View File

@@ -16,6 +16,7 @@ import org.thoughtcrime.securesms.dependencies.ApplicationDependencies;
import org.thoughtcrime.securesms.jobmanager.Data;
import org.thoughtcrime.securesms.jobmanager.Job;
import org.thoughtcrime.securesms.jobmanager.JobManager;
import org.thoughtcrime.securesms.jobmanager.impl.DecryptionsDrainedConstraint;
import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.net.NotPushRegisteredException;
import org.thoughtcrime.securesms.recipients.Recipient;
@@ -65,6 +66,7 @@ public class SendReadReceiptJob extends BaseJob {
public SendReadReceiptJob(long threadId, @NonNull RecipientId recipientId, List<Long> messageSentTimestamps, List<MessageId> messageIds) {
this(new Job.Parameters.Builder()
.addConstraint(NetworkConstraint.KEY)
.addConstraint(DecryptionsDrainedConstraint.KEY)
.setLifespan(TimeUnit.DAYS.toMillis(1))
.setMaxAttempts(Parameters.UNLIMITED)
.setQueue(recipientId.toQueueKey())

View File

@@ -72,7 +72,7 @@ class PersistentLogger(
}
private fun write(level: String, tag: String?, message: String?, t: Throwable?, keepLonger: Boolean) {
logEntries.add(LogRequest(level, tag ?: "null", message, Date(), getThreadString(), t, keepLonger))
logEntries.add(LogRequest(level, tag ?: "null", message, System.currentTimeMillis(), getThreadString(), t, keepLonger))
}
private fun getThreadString(): String {
@@ -95,7 +95,7 @@ class PersistentLogger(
val level: String,
val tag: String,
val message: String?,
val date: Date,
val createTime: Long,
val threadString: String,
val throwable: Throwable?,
val keepLonger: Boolean
@@ -121,11 +121,13 @@ class PersistentLogger(
fun requestToEntries(request: LogRequest): List<LogEntry> {
val out = mutableListOf<LogEntry>()
val createDate = Date(request.createTime)
out.add(
LogEntry(
createdAt = request.date.time,
createdAt = request.createTime,
keepLonger = request.keepLonger,
body = formatBody(request.threadString, request.date, request.level, request.tag, request.message)
body = formatBody(request.threadString, createDate, request.level, request.tag, request.message)
)
)
@@ -138,9 +140,9 @@ class PersistentLogger(
val entries = lines.map { line ->
LogEntry(
createdAt = request.date.time,
createdAt = request.createTime,
keepLonger = request.keepLonger,
body = formatBody(request.threadString, request.date, request.level, request.tag, line)
body = formatBody(request.threadString, createDate, request.level, request.tag, line)
)
}

View File

@@ -14,7 +14,9 @@ import androidx.core.app.NotificationCompat
import org.signal.core.util.ThreadUtil
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import org.signal.core.util.withinTransaction
import org.thoughtcrime.securesms.R
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock
import org.thoughtcrime.securesms.database.MessageTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
@@ -28,6 +30,7 @@ import org.thoughtcrime.securesms.jobs.PushDecryptMessageJob
import org.thoughtcrime.securesms.jobs.PushProcessMessageJob
import org.thoughtcrime.securesms.jobs.UnableToStartException
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.AppForegroundObserver
@@ -200,7 +203,7 @@ class IncomingMessageObserver(private val context: Application) {
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: [${keepAliveTokens.entries}], Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisible, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: ${keepAliveTokens.entries}, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Decrypt Queue Empty: $decryptQueueEmpty")
return conclusion
}
}
@@ -249,19 +252,29 @@ class IncomingMessageObserver(private val context: Application) {
}
@VisibleForTesting
fun processEnvelope(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
when (envelope.type.number) {
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> processReceipt(envelope)
fun processEnvelope(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable>? {
return when (envelope.type.number) {
SignalServiceProtos.Envelope.Type.RECEIPT_VALUE -> {
processReceipt(envelope)
null
}
SignalServiceProtos.Envelope.Type.PREKEY_BUNDLE_VALUE,
SignalServiceProtos.Envelope.Type.CIPHERTEXT_VALUE,
SignalServiceProtos.Envelope.Type.UNIDENTIFIED_SENDER_VALUE,
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> processMessage(envelope, serverDeliveredTimestamp)
else -> Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
SignalServiceProtos.Envelope.Type.PLAINTEXT_CONTENT_VALUE -> {
processMessage(bufferedProtocolStore, envelope, serverDeliveredTimestamp)
}
else -> {
Log.w(TAG, "Received envelope of unknown type: " + envelope.type)
null
}
}
}
private fun processMessage(envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long) {
val result = MessageDecryptor.decrypt(context, envelope, serverDeliveredTimestamp)
private fun processMessage(bufferedProtocolStore: BufferedProtocolStore, envelope: SignalServiceProtos.Envelope, serverDeliveredTimestamp: Long): List<Runnable> {
val result = MessageDecryptor.decrypt(context, bufferedProtocolStore, envelope, serverDeliveredTimestamp)
when (result) {
is MessageDecryptor.Result.Success -> {
@@ -297,7 +310,7 @@ class IncomingMessageObserver(private val context: Application) {
}
}
result.followUpOperations.forEach { it.run() }
return result.followUpOperations
}
private fun processReceipt(envelope: SignalServiceProtos.Envelope) {
@@ -386,13 +399,31 @@ class IncomingMessageObserver(private val context: Application) {
signalWebSocket.connect()
try {
val bufferedStore = BufferedProtocolStore.create()
while (isConnectionNecessary()) {
try {
Log.d(TAG, "Reading message...")
val hasMore = signalWebSocket.readMessage(WEBSOCKET_READ_TIMEOUT) { envelope, serverDeliveredTimestamp ->
Log.i(TAG, "Retrieved envelope! " + envelope.timestamp)
processEnvelope(envelope, serverDeliveredTimestamp)
val hasMore = signalWebSocket.readMessageBatch(WEBSOCKET_READ_TIMEOUT, 30) { batch ->
Log.i(TAG, "Retrieved ${batch.size} envelopes!")
val startTime = System.currentTimeMillis()
ReentrantSessionLock.INSTANCE.acquire().use {
SignalDatabase.rawDatabase.withinTransaction {
val followUpOperations = batch
.mapNotNull { processEnvelope(bufferedStore, it.envelope, it.serverDeliveredTimestamp) }
.flatten()
bufferedStore.flushToDisk()
followUpOperations.forEach { it.run() }
}
}
val duration = System.currentTimeMillis() - startTime
Log.d(TAG, "Decrypted ${batch.size} envelopes in $duration ms (~${duration / batch.size} ms per message)")
true
}

View File

@@ -38,13 +38,13 @@ import org.thoughtcrime.securesms.jobs.PreKeysSyncJob
import org.thoughtcrime.securesms.jobs.SendRetryReceiptJob
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.logsubmit.SubmitDebugLogActivity
import org.thoughtcrime.securesms.messages.protocol.BufferedProtocolStore
import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.FeatureFlags
import org.whispersystems.signalservice.api.InvalidMessageStructureException
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.crypto.ContentHint
import org.whispersystems.signalservice.api.crypto.EnvelopeMetadata
import org.whispersystems.signalservice.api.crypto.SignalServiceCipher
@@ -74,7 +74,12 @@ object MessageDecryptor {
* To keep that property, there may be [Result.followUpOperations] you have to perform after your transaction is committed.
* These can vary from enqueueing jobs to inserting items into the [org.thoughtcrime.securesms.database.PendingRetryReceiptCache].
*/
fun decrypt(context: Context, envelope: Envelope, serverDeliveredTimestamp: Long): Result {
fun decrypt(
context: Context,
bufferedProtocolStore: BufferedProtocolStore,
envelope: Envelope,
serverDeliveredTimestamp: Long
): Result {
val selfAci: ServiceId = SignalStore.account().requireAci()
val selfPni: ServiceId = SignalStore.account().requirePni()
@@ -106,9 +111,9 @@ object MessageDecryptor {
}
}
val protocolStore: SignalServiceAccountDataStore = ApplicationDependencies.getProtocolStore().get(destination)
val bufferedStore = bufferedProtocolStore.get(destination)
val localAddress = SignalServiceAddress(selfAci, SignalStore.account().e164)
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, protocolStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
val cipher = SignalServiceCipher(localAddress, SignalStore.account().deviceId, bufferedStore, ReentrantSessionLock.INSTANCE, UnidentifiedAccessUtil.getCertificateValidator())
return try {
val cipherResult: SignalServiceCipherResult? = cipher.decrypt(envelope, serverDeliveredTimestamp)

View File

@@ -0,0 +1,80 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory identity key store that is intended to be used temporarily while decrypting messages.
*/
class BufferedIdentityKeyStore(
private val selfServiceId: ServiceId,
private val selfIdentityKeyPair: IdentityKeyPair,
private val selfRegistrationId: Int
) : IdentityKeyStore {
private val store: MutableMap<SignalProtocolAddress, IdentityKey> = HashMap()
/** All of the keys that have been created or updated during operation. */
private val updatedKeys: MutableMap<SignalProtocolAddress, IdentityKey> = mutableMapOf()
override fun getIdentityKeyPair(): IdentityKeyPair {
return selfIdentityKeyPair
}
override fun getLocalRegistrationId(): Int {
return selfRegistrationId
}
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
val existing: IdentityKey? = getIdentity(address)
store[address] = identityKey
return if (identityKey != existing) {
updatedKeys[address] = identityKey
true
} else {
false
}
}
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
if (address.name == selfServiceId.toString()) {
return identityKey == selfIdentityKeyPair.publicKey
}
return when (direction) {
IdentityKeyStore.Direction.RECEIVING -> true
IdentityKeyStore.Direction.SENDING -> error("Should not happen during the intended usage pattern of this class")
else -> error("Unknown direction: $direction")
}
}
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
val cached = store[address]
return if (cached != null) {
cached
} else {
val fromDatabase = SignalDatabase.identities.getIdentityStoreRecord(address.name)
if (fromDatabase != null) {
store[address] = fromDatabase.identityKey
}
fromDatabase?.identityKey
}
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((address, identityKey) in updatedKeys) {
persistentStore.saveIdentity(address, identityKey)
}
updatedKeys.clear()
}
}

View File

@@ -0,0 +1,49 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.PreKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory one-time prekey store that is intended to be used temporarily while decrypting messages.
*/
class BufferedOneTimePreKeyStore(private val selfServiceId: ServiceId) : PreKeyStore {
/** Our in-memory cache of one-time prekeys. */
private val store: MutableMap<Int, PreKeyRecord> = HashMap()
/** The one-time prekeys that have been marked as removed */
private val removed: MutableList<Int> = mutableListOf()
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadPreKey(id: Int): PreKeyRecord {
return store.computeIfAbsent(id) {
SignalDatabase.oneTimePreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
}
}
override fun storePreKey(id: Int, record: PreKeyRecord) {
error("Should not happen during the intended usage pattern of this class")
}
override fun containsPreKey(id: Int): Boolean {
loadPreKey(id)
return store.containsKey(id)
}
override fun removePreKey(id: Int) {
store.remove(id)
removed += id
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removed) {
persistentStore.removePreKey(id)
}
removed.clear()
}
}

View File

@@ -0,0 +1,46 @@
package org.thoughtcrime.securesms.messages.protocol
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* The entry point for creating and retrieving buffered protocol stores.
* These stores will read from disk, but never write, instead buffering the results in memory.
* You can then call [flushToDisk] in order to write the buffered results to disk.
*
* This allows you to efficiently do batches of work and avoid unnecessary intermediate writes.
*/
class BufferedProtocolStore private constructor(
private val aciStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>,
private val pniStore: Pair<ServiceId, BufferedSignalServiceAccountDataStore>
) {
fun get(serviceId: ServiceId): BufferedSignalServiceAccountDataStore {
return when (serviceId) {
aciStore.first -> aciStore.second
pniStore.first -> pniStore.second
else -> error("No store matching serviceId $serviceId")
}
}
/**
* Writes any buffered data to disk. You can continue to use the same buffered store afterwards.
*/
fun flushToDisk() {
aciStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().aci())
pniStore.second.flushToDisk(ApplicationDependencies.getProtocolStore().pni())
}
companion object {
fun create(): BufferedProtocolStore {
val aci = SignalStore.account().requireAci()
val pni = SignalStore.account().requirePni()
return BufferedProtocolStore(
aciStore = aci to BufferedSignalServiceAccountDataStore(aci),
pniStore = pni to BufferedSignalServiceAccountDataStore(pni)
)
}
}
}

View File

@@ -0,0 +1,75 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore
import org.whispersystems.signalservice.api.push.DistributionId
import java.util.UUID
/**
* An in-memory sender key store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSenderKeyStore : SignalServiceSenderKeyStore {
private val store: MutableMap<StoreKey, SenderKeyRecord> = HashMap()
/** All of the keys that have been created or updated during operation. */
private val updatedKeys: MutableMap<StoreKey, SenderKeyRecord> = mutableMapOf()
/** All of the distributionId's whose sharing has been cleared during operation. */
private val clearSharedWith: MutableSet<SignalProtocolAddress> = mutableSetOf()
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
val key = StoreKey(sender, distributionId)
store[key] = record
updatedKeys[key] = record
}
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
val cached: SenderKeyRecord? = store[StoreKey(sender, distributionId)]
return if (cached != null) {
cached
} else {
val fromDatabase: SenderKeyRecord? = SignalDatabase.senderKeys.load(sender, distributionId.toDistributionId())
if (fromDatabase != null) {
store[StoreKey(sender, distributionId)] = fromDatabase
}
return fromDatabase
}
}
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
clearSharedWith.addAll(addresses)
}
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
error("Should not happen during the intended usage pattern of this class")
}
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((key, record) in updatedKeys) {
persistentStore.storeSenderKey(key.address, key.distributionId, record)
}
persistentStore.clearSenderKeySharedWith(clearSharedWith)
updatedKeys.clear()
clearSharedWith.clear()
}
private fun UUID.toDistributionId() = DistributionId.from(this)
data class StoreKey(
val address: SignalProtocolAddress,
val distributionId: UUID
)
}

View File

@@ -0,0 +1,115 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.NoSessionException
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.message.CiphertextMessage
import org.signal.libsignal.protocol.state.SessionRecord
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalServiceSessionStore
import org.whispersystems.signalservice.api.push.ServiceId
import kotlin.jvm.Throws
/**
* An in-memory session store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSessionStore(private val selfServiceId: ServiceId) : SignalServiceSessionStore {
private val store: MutableMap<SignalProtocolAddress, SessionRecord> = HashMap()
/** All of the sessions that have been created or updated during operation. */
private val updatedSessions: MutableMap<SignalProtocolAddress, SessionRecord> = mutableMapOf()
/** All of the sessions that have deleted during operation. */
private val deletedSessions: MutableSet<SignalProtocolAddress> = mutableSetOf()
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
val session: SessionRecord = store[address]
?: SignalDatabase.sessions.load(selfServiceId, address)
?: SessionRecord()
store[address] = session
return session
}
@Throws(NoSessionException::class)
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
val found: MutableList<SessionRecord> = mutableListOf()
val needsDatabaseLookup: MutableList<SignalProtocolAddress> = mutableListOf()
for (address in addresses) {
val cached: SessionRecord? = store[address]
if (cached != null) {
found += cached
} else {
needsDatabaseLookup += address
}
}
if (needsDatabaseLookup.isNotEmpty()) {
found += SignalDatabase.sessions.load(selfServiceId, needsDatabaseLookup).filterNotNull()
}
if (found.size != addresses.size) {
throw NoSessionException("Failed to find one or more sessions.")
}
return found
}
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
store[address] = record
updatedSessions[address] = record
}
override fun containsSession(address: SignalProtocolAddress): Boolean {
return if (store.containsKey(address)) {
true
} else {
val fromDatabase: SessionRecord? = SignalDatabase.sessions.load(selfServiceId, address)
if (fromDatabase != null) {
store[address] = fromDatabase
return fromDatabase.hasSenderChain() && fromDatabase.sessionVersion == CiphertextMessage.CURRENT_VERSION
} else {
false
}
}
}
override fun deleteSession(address: SignalProtocolAddress) {
store.remove(address)
deletedSessions += address
}
override fun getSubDeviceSessions(name: String): MutableList<Int> {
error("Should not happen during the intended usage pattern of this class")
}
override fun deleteAllSessions(name: String) {
error("Should not happen during the intended usage pattern of this class")
}
override fun archiveSession(address: SignalProtocolAddress?) {
error("Should not happen during the intended usage pattern of this class")
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for ((address, record) in updatedSessions) {
persistentStore.storeSession(address, record)
}
for (address in deletedSessions) {
persistentStore.deleteSession(address)
}
updatedSessions.clear()
deletedSessions.clear()
}
}

View File

@@ -0,0 +1,157 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.ServiceId
import java.util.UUID
/**
* The wrapper around all of the Buffered protocol stores. Designed to perform operations in memory,
* then [flushToDisk] at set intervals.
*/
class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalServiceAccountDataStore {
private val identityStore: BufferedIdentityKeyStore = if (selfServiceId == SignalStore.account().pni) {
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().pniIdentityKey, SignalStore.account().pniRegistrationId)
} else {
BufferedIdentityKeyStore(selfServiceId, SignalStore.account().aciIdentityKey, SignalStore.account().registrationId)
}
private val oneTimePreKeyStore: BufferedOneTimePreKeyStore = BufferedOneTimePreKeyStore(selfServiceId)
private val signedPreKeyStore: BufferedSignedPreKeyStore = BufferedSignedPreKeyStore(selfServiceId)
private val sessionStore: BufferedSessionStore = BufferedSessionStore(selfServiceId)
private val senderKeyStore: BufferedSenderKeyStore = BufferedSenderKeyStore()
override fun getIdentityKeyPair(): IdentityKeyPair {
return identityStore.identityKeyPair
}
override fun getLocalRegistrationId(): Int {
return identityStore.localRegistrationId
}
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
return identityStore.saveIdentity(address, identityKey)
}
override fun isTrustedIdentity(address: SignalProtocolAddress, identityKey: IdentityKey, direction: IdentityKeyStore.Direction): Boolean {
return identityStore.isTrustedIdentity(address, identityKey, direction)
}
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
return identityStore.getIdentity(address)
}
override fun loadPreKey(preKeyId: Int): PreKeyRecord {
return oneTimePreKeyStore.loadPreKey(preKeyId)
}
override fun storePreKey(preKeyId: Int, record: PreKeyRecord) {
return oneTimePreKeyStore.storePreKey(preKeyId, record)
}
override fun containsPreKey(preKeyId: Int): Boolean {
return oneTimePreKeyStore.containsPreKey(preKeyId)
}
override fun removePreKey(preKeyId: Int) {
oneTimePreKeyStore.removePreKey(preKeyId)
}
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
return sessionStore.loadSession(address)
}
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>): List<SessionRecord> {
return sessionStore.loadExistingSessions(addresses)
}
override fun getSubDeviceSessions(name: String): MutableList<Int> {
return sessionStore.getSubDeviceSessions(name)
}
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
sessionStore.storeSession(address, record)
}
override fun containsSession(address: SignalProtocolAddress): Boolean {
return sessionStore.containsSession(address)
}
override fun deleteSession(address: SignalProtocolAddress) {
return sessionStore.deleteSession(address)
}
override fun deleteAllSessions(name: String) {
sessionStore.deleteAllSessions(name)
}
override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord {
return signedPreKeyStore.loadSignedPreKey(signedPreKeyId)
}
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
return signedPreKeyStore.loadSignedPreKeys()
}
override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord) {
signedPreKeyStore.storeSignedPreKey(signedPreKeyId, record)
}
override fun containsSignedPreKey(signedPreKeyId: Int): Boolean {
return signedPreKeyStore.containsSignedPreKey(signedPreKeyId)
}
override fun removeSignedPreKey(signedPreKeyId: Int) {
signedPreKeyStore.removeSignedPreKey(signedPreKeyId)
}
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
senderKeyStore.storeSenderKey(sender, distributionId, record)
}
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord? {
return senderKeyStore.loadSenderKey(sender, distributionId)
}
override fun archiveSession(address: SignalProtocolAddress?) {
sessionStore.archiveSession(address)
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
return sessionStore.getAllAddressesWithActiveSessions(addressNames)
}
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> {
return senderKeyStore.getSenderKeySharedWith(distributionId)
}
override fun markSenderKeySharedWith(distributionId: DistributionId, addresses: MutableCollection<SignalProtocolAddress>) {
senderKeyStore.markSenderKeySharedWith(distributionId, addresses)
}
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>) {
senderKeyStore.clearSenderKeySharedWith(addresses)
}
override fun isMultiDevice(): Boolean {
error("Should not happen during the intended usage pattern of this class")
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
identityStore.flushToDisk(persistentStore)
oneTimePreKeyStore.flushToDisk(persistentStore)
signedPreKeyStore.flushToDisk(persistentStore)
sessionStore.flushToDisk(persistentStore)
senderKeyStore.flushToDisk(persistentStore)
}
}

View File

@@ -0,0 +1,64 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyStore
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory signed prekey store that is intended to be used temporarily while decrypting messages.
*/
class BufferedSignedPreKeyStore(private val selfServiceId: ServiceId) : SignedPreKeyStore {
/** Our in-memory cache of signed prekeys. */
private val store: MutableMap<Int, SignedPreKeyRecord> = HashMap()
/** The signed prekeys that have been marked as removed */
private val removed: MutableList<Int> = mutableListOf()
/** Whether or not we've done a loadAll operation. Let's us avoid doing it twice. */
private var hasLoadedAll: Boolean = false
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadSignedPreKey(id: Int): SignedPreKeyRecord {
return store.computeIfAbsent(id) {
SignalDatabase.signedPreKeys.get(selfServiceId, id) ?: throw InvalidKeyIdException("Missing one-time prekey with ID: $id")
}
}
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
return if (hasLoadedAll) {
store.values.toList()
} else {
val records = SignalDatabase.signedPreKeys.getAll(selfServiceId)
records.forEach { store[it.id] = it }
hasLoadedAll = true
records
}
}
override fun storeSignedPreKey(id: Int, record: SignedPreKeyRecord) {
error("Should not happen during the intended usage pattern of this class")
}
override fun containsSignedPreKey(id: Int): Boolean {
loadSignedPreKey(id)
return store.containsKey(id)
}
override fun removeSignedPreKey(id: Int) {
store.remove(id)
removed += id
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removed) {
persistentStore.removeSignedPreKey(id)
}
removed.clear()
}
}