Add message processing performance test.

This commit is contained in:
Cody Henthorne
2023-03-03 11:42:30 -05:00
committed by Greyson Parrelli
parent f719dcca6d
commit c0aff46e31
18 changed files with 704 additions and 1646 deletions

View File

@@ -0,0 +1,44 @@
package org.thoughtcrime.securesms.testing
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.crypto.ProfileKeyUtil
import org.thoughtcrime.securesms.dependencies.ApplicationDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.recipients.Recipient
import org.thoughtcrime.securesms.testing.FakeClientHelpers.toSignalServiceEnvelope
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
/**
* Welcome to Alice's Client.
*
* Alice represent the Android instrumentation test user. Unlike [BobClient] much less is needed here
* as it can make use of the standard Signal Android App infrastructure.
*/
class AliceClient(val serviceId: ServiceId, val e164: String, val trustRoot: ECKeyPair) {
private val aliceSenderCertificate = FakeClientHelpers.createCertificateFor(
trustRoot = trustRoot,
uuid = serviceId.uuid(),
e164 = e164,
deviceId = 1,
identityKey = SignalStore.account().aciIdentityKey.publicKey.publicKey,
expires = 31337
)
fun process(envelope: SignalServiceEnvelope) {
ApplicationDependencies.getIncomingMessageProcessor().acquire().use { processor -> processor.processEnvelope(envelope) }
}
fun encrypt(now: Long, destination: Recipient): SignalServiceEnvelope {
return ApplicationDependencies.getSignalServiceMessageSender().getEncryptedMessage(
SignalServiceAddress(destination.requireServiceId(), destination.requireE164()),
FakeClientHelpers.getTargetUnidentifiedAccess(ProfileKeyUtil.getSelfProfileKey(), ProfileKey(destination.profileKey), aliceSenderCertificate),
1,
FakeClientHelpers.encryptedTextMessage(now),
false
).toSignalServiceEnvelope(now, destination.requireServiceId())
}
}

View File

@@ -0,0 +1,167 @@
package org.thoughtcrime.securesms.testing
import org.signal.core.util.readToSingleInt
import org.signal.core.util.select
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SessionBuilder
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.PreKeyBundle
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.libsignal.protocol.util.KeyHelper
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.crypto.ProfileKeyUtil
import org.thoughtcrime.securesms.crypto.UnidentifiedAccessUtil
import org.thoughtcrime.securesms.database.OneTimePreKeyTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.database.SignedPreKeyTable
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.testing.FakeClientHelpers.toSignalServiceEnvelope
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalSessionLock
import org.whispersystems.signalservice.api.crypto.SignalServiceCipher
import org.whispersystems.signalservice.api.crypto.SignalSessionBuilder
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope
import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import java.util.Optional
import java.util.UUID
import java.util.concurrent.locks.ReentrantLock
/**
* Welcome to Bob's Client.
*
* Bob is a "fake" client that can start a session with the Android instrumentation test user (Alice).
*
* Bob can create a new session using a prekey bundle created from Alice's prekeys, send a message, decrypt
* a return message from Alice, and that'll start a standard Signal session with normal keys/ratcheting.
*/
class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair: IdentityKeyPair, val trustRoot: ECKeyPair, val profileKey: ProfileKey) {
private val serviceAddress = SignalServiceAddress(serviceId, e164)
private val registrationId = KeyHelper.generateRegistrationId(false)
private val aciStore = BobSignalServiceAccountDataStore(registrationId, identityKeyPair)
private val senderCertificate = FakeClientHelpers.createCertificateFor(trustRoot, serviceId.uuid(), e164, 1, identityKeyPair.publicKey.publicKey, 31337)
private val sessionLock = object : SignalSessionLock {
private val lock = ReentrantLock()
override fun acquire(): SignalSessionLock.Lock {
lock.lock()
return SignalSessionLock.Lock { lock.unlock() }
}
}
/** Inspired by SignalServiceMessageSender#getEncryptedMessage */
fun encrypt(now: Long): SignalServiceEnvelope {
val envelopeContent = FakeClientHelpers.encryptedTextMessage(now)
val cipher = SignalServiceCipher(serviceAddress, 1, aciStore, sessionLock, null)
if (!aciStore.containsSession(getAliceProtocolAddress())) {
val sessionBuilder = SignalSessionBuilder(sessionLock, SessionBuilder(aciStore, getAliceProtocolAddress()))
sessionBuilder.process(getAlicePreKeyBundle())
}
return cipher.encrypt(getAliceProtocolAddress(), getAliceUnidentifiedAccess(), envelopeContent)
.toSignalServiceEnvelope(envelopeContent.content.get().dataMessage.timestamp, getAliceServiceId())
}
fun decrypt(envelope: SignalServiceEnvelope) {
val cipher = SignalServiceCipher(serviceAddress, 1, aciStore, sessionLock, UnidentifiedAccessUtil.getCertificateValidator())
cipher.decrypt(envelope)
}
private fun getAliceServiceId(): ServiceId {
return SignalStore.account().requireAci()
}
private fun getAlicePreKeyBundle(): PreKeyBundle {
val selfPreKeyId = SignalDatabase.rawDatabase
.select(OneTimePreKeyTable.KEY_ID)
.from(OneTimePreKeyTable.TABLE_NAME)
.where("${OneTimePreKeyTable.ACCOUNT_ID} = ?", getAliceServiceId().toString())
.run()
.readToSingleInt(-1)
val selfPreKeyRecord = SignalDatabase.oneTimePreKeys.get(getAliceServiceId(), selfPreKeyId)!!
val selfSignedPreKeyId = SignalDatabase.rawDatabase
.select(SignedPreKeyTable.KEY_ID)
.from(SignedPreKeyTable.TABLE_NAME)
.where("${SignedPreKeyTable.ACCOUNT_ID} = ?", getAliceServiceId().toString())
.run()
.readToSingleInt(-1)
val selfSignedPreKeyRecord = SignalDatabase.signedPreKeys.get(getAliceServiceId(), selfSignedPreKeyId)!!
return PreKeyBundle(
SignalStore.account().registrationId,
1,
selfPreKeyId,
selfPreKeyRecord.keyPair.publicKey,
selfSignedPreKeyId,
selfSignedPreKeyRecord.keyPair.publicKey,
selfSignedPreKeyRecord.signature,
getAlicePublicKey()
)
}
private fun getAliceProtocolAddress(): SignalProtocolAddress {
return SignalProtocolAddress(SignalStore.account().requireAci().toString(), 1)
}
private fun getAlicePublicKey(): IdentityKey {
return SignalStore.account().aciIdentityKey.publicKey
}
private fun getAliceProfileKey(): ProfileKey {
return ProfileKeyUtil.getSelfProfileKey()
}
private fun getAliceUnidentifiedAccess(): Optional<UnidentifiedAccess> {
return FakeClientHelpers.getTargetUnidentifiedAccess(profileKey, getAliceProfileKey(), senderCertificate)
}
private class BobSignalServiceAccountDataStore(private val registrationId: Int, private val identityKeyPair: IdentityKeyPair) : SignalServiceAccountDataStore {
private var aliceSessionRecord: SessionRecord? = null
override fun getIdentityKeyPair(): IdentityKeyPair = identityKeyPair
override fun getLocalRegistrationId(): Int = registrationId
override fun isTrustedIdentity(address: SignalProtocolAddress?, identityKey: IdentityKey?, direction: IdentityKeyStore.Direction?): Boolean = true
override fun loadSession(address: SignalProtocolAddress?): SessionRecord = aliceSessionRecord ?: SessionRecord()
override fun saveIdentity(address: SignalProtocolAddress?, identityKey: IdentityKey?): Boolean = false
override fun storeSession(address: SignalProtocolAddress?, record: SessionRecord?) { aliceSessionRecord = record }
override fun getSubDeviceSessions(name: String?): List<Int> = emptyList()
override fun containsSession(address: SignalProtocolAddress?): Boolean = aliceSessionRecord != null
override fun getIdentity(address: SignalProtocolAddress?): IdentityKey = SignalStore.account().aciIdentityKey.publicKey
override fun loadPreKey(preKeyId: Int): PreKeyRecord = throw UnsupportedOperationException()
override fun storePreKey(preKeyId: Int, record: PreKeyRecord?) = throw UnsupportedOperationException()
override fun containsPreKey(preKeyId: Int): Boolean = throw UnsupportedOperationException()
override fun removePreKey(preKeyId: Int) = throw UnsupportedOperationException()
override fun loadExistingSessions(addresses: MutableList<SignalProtocolAddress>?): MutableList<SessionRecord> = throw UnsupportedOperationException()
override fun deleteSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
override fun deleteAllSessions(name: String?) = throw UnsupportedOperationException()
override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord = throw UnsupportedOperationException()
override fun loadSignedPreKeys(): MutableList<SignedPreKeyRecord> = throw UnsupportedOperationException()
override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord?) = throw UnsupportedOperationException()
override fun containsSignedPreKey(signedPreKeyId: Int): Boolean = throw UnsupportedOperationException()
override fun removeSignedPreKey(signedPreKeyId: Int) = throw UnsupportedOperationException()
override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException()
override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException()
override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
override fun getSenderKeySharedWith(distributionId: DistributionId?): MutableSet<SignalProtocolAddress> = throw UnsupportedOperationException()
override fun markSenderKeySharedWith(distributionId: DistributionId?, addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
override fun isMultiDevice(): Boolean = throw UnsupportedOperationException()
}
}

View File

@@ -0,0 +1,81 @@
package org.thoughtcrime.securesms.testing
import org.signal.libsignal.internal.Native
import org.signal.libsignal.internal.NativeHandleGuard
import org.signal.libsignal.metadata.certificate.CertificateValidator
import org.signal.libsignal.metadata.certificate.SenderCertificate
import org.signal.libsignal.metadata.certificate.ServerCertificate
import org.signal.libsignal.protocol.ecc.Curve
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.whispersystems.signalservice.api.crypto.ContentHint
import org.whispersystems.signalservice.api.crypto.EnvelopeContent
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.internal.push.OutgoingPushMessage
import org.whispersystems.signalservice.internal.push.SignalServiceProtos
import org.whispersystems.util.Base64
import java.util.Optional
import java.util.UUID
object FakeClientHelpers {
val noOpCertificateValidator = object : CertificateValidator(null) {
override fun validate(certificate: SenderCertificate, validationTime: Long) = Unit
}
fun createCertificateFor(trustRoot: ECKeyPair, uuid: UUID, e164: String, deviceId: Int, identityKey: ECPublicKey, expires: Long): SenderCertificate {
val serverKey: ECKeyPair = Curve.generateKeyPair()
NativeHandleGuard(serverKey.publicKey).use { serverPublicGuard ->
NativeHandleGuard(trustRoot.privateKey).use { trustRootPrivateGuard ->
val serverCertificate = ServerCertificate(Native.ServerCertificate_New(1, serverPublicGuard.nativeHandle(), trustRootPrivateGuard.nativeHandle()))
NativeHandleGuard(identityKey).use { identityGuard ->
NativeHandleGuard(serverCertificate).use { serverCertificateGuard ->
NativeHandleGuard(serverKey.privateKey).use { serverPrivateGuard ->
return SenderCertificate(Native.SenderCertificate_New(uuid.toString(), e164, deviceId, identityGuard.nativeHandle(), expires, serverCertificateGuard.nativeHandle(), serverPrivateGuard.nativeHandle()))
}
}
}
}
}
}
fun getTargetUnidentifiedAccess(myProfileKey: ProfileKey, theirProfileKey: ProfileKey, senderCertificate: SenderCertificate): Optional<UnidentifiedAccess> {
val selfUnidentifiedAccessKey = UnidentifiedAccess.deriveAccessKeyFrom(myProfileKey)
val themUnidentifiedAccessKey = UnidentifiedAccess.deriveAccessKeyFrom(theirProfileKey)
return UnidentifiedAccessPair(UnidentifiedAccess(selfUnidentifiedAccessKey, senderCertificate.serialized), UnidentifiedAccess(themUnidentifiedAccessKey, senderCertificate.serialized)).targetUnidentifiedAccess
}
fun encryptedTextMessage(now: Long, message: String = "Test body message"): EnvelopeContent {
val content = SignalServiceProtos.Content.newBuilder().apply {
setDataMessage(
SignalServiceProtos.DataMessage.newBuilder().apply {
body = message
timestamp = now
}
)
}
return EnvelopeContent.encrypted(content.build(), ContentHint.RESENDABLE, Optional.empty())
}
fun OutgoingPushMessage.toSignalServiceEnvelope(timestamp: Long, destination: ServiceId): SignalServiceEnvelope {
return SignalServiceEnvelope(
this.type,
Optional.empty(),
1,
timestamp,
Base64.decode(this.content),
timestamp + 1,
timestamp + 2,
UUID.randomUUID().toString(),
destination.toString(),
true,
false,
null
)
}
}

View File

@@ -0,0 +1,86 @@
package org.thoughtcrime.securesms.testing
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import java.util.concurrent.CountDownLatch
typealias LogPredicate = (Entry) -> Boolean
/**
* Logging implementation that holds logs in memory as they are added to be retrieve at a later time by a test.
* Can also be used for multithreaded synchronization and waiting until certain logs are emitted before continuing
* a test.
*/
class InMemoryLogger : Log.Logger() {
private val executor = SignalExecutors.newCachedSingleThreadExecutor("inmemory-logger")
private val predicates = mutableListOf<LogPredicate>()
private val logEntries = mutableListOf<Entry>()
override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = add(Verbose(tag, message, t, System.currentTimeMillis()))
override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = add(Debug(tag, message, t, System.currentTimeMillis()))
override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = add(Info(tag, message, t, System.currentTimeMillis()))
override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = add(Warn(tag, message, t, System.currentTimeMillis()))
override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = add(Error(tag, message, t, System.currentTimeMillis()))
override fun flush() {
val latch = CountDownLatch(1)
executor.execute { latch.countDown() }
latch.await()
}
private fun add(entry: Entry) {
executor.execute {
logEntries += entry
val iterator = predicates.iterator()
while (iterator.hasNext()) {
val predicate = iterator.next()
if (predicate(entry)) {
iterator.remove()
}
}
}
}
/** Blocks until a snapshot of all log entries can be taken in a thread-safe way. */
fun entries(): List<Entry> {
val latch = CountDownLatch(1)
var entries: List<Entry> = emptyList()
executor.execute {
entries = logEntries.toList()
latch.countDown()
}
latch.await()
return entries
}
/** Returns a countdown latch that'll fire at a future point when an [Entry] is received that matches the predicate. */
fun getLockForUntil(predicate: LogPredicate): CountDownLatch {
val latch = CountDownLatch(1)
executor.execute {
predicates += { entry ->
if (predicate(entry)) {
latch.countDown()
true
} else {
false
}
}
}
return latch
}
}
sealed interface Entry {
val tag: String
val message: String?
val throwable: Throwable?
val timestamp: Long
}
data class Verbose(override val tag: String, override val message: String?, override val throwable: Throwable?, override val timestamp: Long) : Entry
data class Debug(override val tag: String, override val message: String?, override val throwable: Throwable?, override val timestamp: Long) : Entry
data class Info(override val tag: String, override val message: String?, override val throwable: Throwable?, override val timestamp: Long) : Entry
data class Warn(override val tag: String, override val message: String?, override val throwable: Throwable?, override val timestamp: Long) : Entry
data class Error(override val tag: String, override val message: String?, override val throwable: Throwable?, override val timestamp: Long) : Entry

View File

@@ -11,7 +11,9 @@ import androidx.test.platform.app.InstrumentationRegistry
import okhttp3.mockwebserver.MockResponse
import org.junit.rules.ExternalResource
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.thoughtcrime.securesms.SignalInstrumentationApplicationContext
import org.thoughtcrime.securesms.crypto.IdentityKeyUtil
import org.thoughtcrime.securesms.crypto.MasterSecretUtil
import org.thoughtcrime.securesms.crypto.ProfileKeyUtil
@@ -54,11 +56,18 @@ class SignalActivityRule(private val othersCount: Int = 4) : ExternalResource()
private set
lateinit var others: List<RecipientId>
private set
lateinit var othersKeys: List<IdentityKeyPair>
val inMemoryLogger: InMemoryLogger
get() = (application as SignalInstrumentationApplicationContext).inMemoryLogger
override fun before() {
context = InstrumentationRegistry.getInstrumentation().targetContext
self = setupSelf()
others = setupOthers()
val setupOthers = setupOthers()
others = setupOthers.first
othersKeys = setupOthers.second
InstrumentationApplicationDependencyProvider.clearHandlers()
}
@@ -99,8 +108,9 @@ class SignalActivityRule(private val othersCount: Int = 4) : ExternalResource()
return Recipient.self()
}
private fun setupOthers(): List<RecipientId> {
private fun setupOthers(): Pair<List<RecipientId>, List<IdentityKeyPair>> {
val others = mutableListOf<RecipientId>()
val othersKeys = mutableListOf<IdentityKeyPair>()
if (othersCount !in 0 until 1000) {
throw IllegalArgumentException("$othersCount must be between 0 and 1000")
@@ -114,11 +124,13 @@ class SignalActivityRule(private val othersCount: Int = 4) : ExternalResource()
SignalDatabase.recipients.setCapabilities(recipientId, SignalServiceProfile.Capabilities(true, true, true, true, true, true, true, true, true))
SignalDatabase.recipients.setProfileSharing(recipientId, true)
SignalDatabase.recipients.markRegistered(recipientId, aci)
ApplicationDependencies.getProtocolStore().aci().saveIdentity(SignalProtocolAddress(aci.toString(), 0), IdentityKeyUtil.generateIdentityKeyPair().publicKey)
val otherIdentity = IdentityKeyUtil.generateIdentityKeyPair()
ApplicationDependencies.getProtocolStore().aci().saveIdentity(SignalProtocolAddress(aci.toString(), 0), otherIdentity.publicKey)
others += recipientId
othersKeys += otherIdentity
}
return others
return others to othersKeys
}
inline fun <reified T : Activity> launchActivity(initIntent: Intent.() -> Unit = {}): ActivityScenario<T> {

View File

@@ -7,6 +7,9 @@ import org.hamcrest.Matchers.not
import org.hamcrest.Matchers.notNullValue
import org.hamcrest.Matchers.nullValue
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import kotlin.time.Duration
/**
* Run the given [runnable] on a new thread and wait for it to finish.
@@ -44,3 +47,9 @@ infix fun <T : Any> T.assertIsNot(expected: T) {
infix fun <E, T : Collection<E>> T.assertIsSize(expected: Int) {
assertThat(this, hasSize(expected))
}
fun CountDownLatch.awaitFor(duration: Duration) {
if (!await(duration.inWholeMilliseconds, TimeUnit.MILLISECONDS)) {
throw TimeoutException("Latch await took longer than ${duration.inWholeMilliseconds}ms")
}
}