Add microbenchmarks for message decryption.

This commit is contained in:
Greyson Parrelli
2023-04-06 10:04:37 -04:00
parent 0156e74f5a
commit 6d4906dfa8
16 changed files with 611 additions and 13 deletions

View File

@@ -0,0 +1,89 @@
package org.signal.microbenchmark
import android.util.Log
import androidx.benchmark.junit4.BenchmarkRule
import androidx.benchmark.junit4.measureRepeated
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.signal.libsignal.protocol.logging.SignalProtocolLogger
import org.signal.libsignal.protocol.logging.SignalProtocolLoggerProvider
import org.signal.util.SignalClient
/**
* Benchmarks for decrypting messages.
*
* Note that in order to isolate all costs to just the process of decryption itself,
* all operations are performed in in-memory stores.
*/
@RunWith(AndroidJUnit4::class)
class ProtocolBenchmarks {
@get:Rule
val benchmarkRule = BenchmarkRule()
@Before
fun setup() {
SignalProtocolLoggerProvider.setProvider { priority, tag, message ->
when (priority) {
SignalProtocolLogger.VERBOSE -> Log.v(tag, message)
SignalProtocolLogger.DEBUG -> Log.d(tag, message)
SignalProtocolLogger.INFO -> Log.i(tag, message)
SignalProtocolLogger.WARN -> Log.w(tag, message)
SignalProtocolLogger.ERROR -> Log.w(tag, message)
SignalProtocolLogger.ASSERT -> Log.e(tag, message)
}
}
}
@Test
fun decrypt_unsealedSender() {
val (alice, bob) = buildAndInitializeClients()
benchmarkRule.measureRepeated {
val envelope = runWithTimingDisabled {
alice.encryptUnsealedSender(bob)
}
bob.decryptMessage(envelope)
// Respond so that the session ratchets
runWithTimingDisabled {
alice.decryptMessage(bob.encryptUnsealedSender(alice))
}
}
}
@Test
fun decrypt_sealedSender() {
val (alice, bob) = buildAndInitializeClients()
benchmarkRule.measureRepeated {
val envelope = runWithTimingDisabled {
alice.encryptSealedSender(bob)
}
bob.decryptMessage(envelope)
// Respond so that the session ratchets
runWithTimingDisabled {
alice.decryptMessage(bob.encryptSealedSender(alice))
}
}
}
private fun buildAndInitializeClients(): Pair<SignalClient, SignalClient> {
val alice = SignalClient()
val bob = SignalClient()
// Do initial prekey dance
alice.initializeSession(bob)
bob.initializeSession(alice)
alice.decryptMessage(bob.encryptUnsealedSender(alice))
bob.decryptMessage(alice.encryptUnsealedSender(bob))
return alice to bob
}
}

View File

@@ -0,0 +1,157 @@
package org.signal.util
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.message.CiphertextMessage
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.DistributionId
import java.util.UUID
/**
* An in-memory datastore specifically designed for tests.
*/
class InMemorySignalServiceAccountDataStore : SignalServiceAccountDataStore {
private val identityKey: IdentityKeyPair = IdentityKeyPair.generate()
private val identities: MutableMap<SignalProtocolAddress, IdentityKey> = mutableMapOf()
private val oneTimePreKeys: MutableMap<Int, PreKeyRecord> = mutableMapOf()
private val signedPreKeys: MutableMap<Int, SignedPreKeyRecord> = mutableMapOf()
private var sessions: MutableMap<SignalProtocolAddress, SessionRecord> = mutableMapOf()
private val senderKeys: MutableMap<SenderKeyLocator, SenderKeyRecord> = mutableMapOf()
override fun getIdentityKeyPair(): IdentityKeyPair {
return identityKey
}
override fun getLocalRegistrationId(): Int {
return 1
}
override fun saveIdentity(address: SignalProtocolAddress, identityKey: IdentityKey): Boolean {
val hadPrevious = identities.containsKey(address)
identities[address] = identityKey
return hadPrevious
}
override fun isTrustedIdentity(address: SignalProtocolAddress?, identityKey: IdentityKey?, direction: IdentityKeyStore.Direction?): Boolean {
return true
}
override fun getIdentity(address: SignalProtocolAddress): IdentityKey? {
return identities[address]
}
override fun loadPreKey(preKeyId: Int): PreKeyRecord {
return oneTimePreKeys[preKeyId]!!
}
override fun storePreKey(preKeyId: Int, record: PreKeyRecord) {
oneTimePreKeys[preKeyId] = record
}
override fun containsPreKey(preKeyId: Int): Boolean {
return oneTimePreKeys.containsKey(preKeyId)
}
override fun removePreKey(preKeyId: Int) {
oneTimePreKeys.remove(preKeyId)
}
override fun loadSession(address: SignalProtocolAddress): SessionRecord {
return sessions.getOrPut(address) { SessionRecord() }
}
override fun loadExistingSessions(addresses: List<SignalProtocolAddress>): List<SessionRecord> {
return addresses.map { sessions[it]!! }
}
override fun getSubDeviceSessions(name: String): List<Int> {
return sessions
.filter { it.key.name == name && it.key.deviceId != 1 && it.value.isValid() }
.map { it.key.deviceId }
}
override fun storeSession(address: SignalProtocolAddress, record: SessionRecord) {
sessions[address] = record
}
override fun containsSession(address: SignalProtocolAddress): Boolean {
return sessions[address]?.isValid() ?: false
}
override fun deleteSession(address: SignalProtocolAddress) {
sessions -= address
}
override fun deleteAllSessions(name: String) {
sessions = sessions.filter { it.key.name == name }.toMutableMap()
}
override fun loadSignedPreKey(signedPreKeyId: Int): SignedPreKeyRecord {
return signedPreKeys[signedPreKeyId]!!
}
override fun loadSignedPreKeys(): List<SignedPreKeyRecord> {
return signedPreKeys.values.toList()
}
override fun storeSignedPreKey(signedPreKeyId: Int, record: SignedPreKeyRecord) {
signedPreKeys[signedPreKeyId] = record
}
override fun containsSignedPreKey(signedPreKeyId: Int): Boolean {
return signedPreKeys.containsKey(signedPreKeyId)
}
override fun removeSignedPreKey(signedPreKeyId: Int) {
signedPreKeys -= signedPreKeyId
}
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
senderKeys[SenderKeyLocator(sender, distributionId)] = record
}
override fun loadSenderKey(sender: SignalProtocolAddress, distributionId: UUID): SenderKeyRecord {
return senderKeys[SenderKeyLocator(sender, distributionId)]!!
}
override fun archiveSession(address: SignalProtocolAddress) {
sessions[address]!!.archiveCurrentState()
}
override fun getAllAddressesWithActiveSessions(addressNames: MutableList<String>): Set<SignalProtocolAddress> {
return sessions
.filter { it.key.name in addressNames }
.filter { it.value.isValid() }
.map { it.key }
.toSet()
}
override fun getSenderKeySharedWith(distributionId: DistributionId): Set<SignalProtocolAddress> {
error("Not used")
}
override fun markSenderKeySharedWith(distributionId: DistributionId, addresses: Collection<SignalProtocolAddress>) {
// Not used
}
override fun clearSenderKeySharedWith(addresses: Collection<SignalProtocolAddress>) {
// Not used
}
override fun isMultiDevice(): Boolean {
return false
}
private fun SessionRecord.isValid(): Boolean {
return this.hasSenderChain() && this.sessionVersion == CiphertextMessage.CURRENT_VERSION
}
private data class SenderKeyLocator(val address: SignalProtocolAddress, val distributionId: UUID)
}

View File

@@ -0,0 +1,175 @@
package org.signal.util
import com.google.protobuf.ByteString
import org.signal.libsignal.internal.Native
import org.signal.libsignal.internal.NativeHandleGuard
import org.signal.libsignal.metadata.certificate.CertificateValidator
import org.signal.libsignal.metadata.certificate.SenderCertificate
import org.signal.libsignal.metadata.certificate.ServerCertificate
import org.signal.libsignal.protocol.SessionBuilder
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.ecc.Curve
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.protocol.state.PreKeyBundle
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.SignalSessionLock
import org.whispersystems.signalservice.api.crypto.ContentHint
import org.whispersystems.signalservice.api.crypto.EnvelopeContent
import org.whispersystems.signalservice.api.crypto.SignalServiceCipher
import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.internal.push.OutgoingPushMessage
import org.whispersystems.signalservice.internal.push.SignalServiceProtos
import org.whispersystems.signalservice.internal.util.Util
import org.whispersystems.util.Base64
import java.util.Optional
import java.util.UUID
import java.util.concurrent.locks.ReentrantLock
import kotlin.random.Random
/**
* An in-memory signal client that can encrypt and decrypt messages.
*
* Has a single prekey bundle that can be used to initialize a session with another client.
*/
class SignalClient {
companion object {
private val trustRoot: ECKeyPair = Curve.generateKeyPair()
}
private val serviceId: ServiceId = ServiceId.from(UUID.randomUUID())
private val store: SignalServiceAccountDataStore = InMemorySignalServiceAccountDataStore()
private val preKeyBundle: PreKeyBundle = let {
val preKeyRecord = PreKeyRecord(1, Curve.generateKeyPair())
val signedPreKeyPair = Curve.generateKeyPair()
val signedPreKeySignature = Curve.calculateSignature(store.identityKeyPair.privateKey, signedPreKeyPair.publicKey.serialize())
store.storePreKey(1, preKeyRecord)
store.storeSignedPreKey(1, SignedPreKeyRecord(1, System.currentTimeMillis(), signedPreKeyPair, signedPreKeySignature))
PreKeyBundle(1, 1, 1, preKeyRecord.keyPair.publicKey, 1, signedPreKeyPair.publicKey, signedPreKeySignature, store.identityKeyPair.publicKey)
}
private val unidentifiedAccessKey: ByteArray = Util.getSecretBytes(32)
private val senderCertificate: SenderCertificate = createCertificateFor(
trustRoot = trustRoot,
uuid = serviceId.uuid(),
e164 = "+${Random.nextLong(1111111111L, 9999999999L)}",
deviceId = 1,
identityKey = store.identityKeyPair.publicKey.publicKey,
expires = Long.MAX_VALUE
)
private val cipher = SignalServiceCipher(SignalServiceAddress(serviceId), 1, store, TestSessionLock(), CertificateValidator(trustRoot.publicKey))
/**
* Sets up sessions using the [to] client's [preKeyBundle]. Note that you can only initialize a client once
* since we currently only make a single prekey bundle.
*/
fun initializeSession(to: SignalClient) {
val address = SignalProtocolAddress(to.serviceId.toString(), 1)
SessionBuilder(store, address).process(to.preKeyBundle)
}
fun encryptUnsealedSender(to: SignalClient): SignalServiceProtos.Envelope {
val sentTimestamp = System.currentTimeMillis()
val message = SignalServiceProtos.DataMessage.newBuilder()
.setBody("Test Message")
.setTimestamp(sentTimestamp)
.build()
val content = SignalServiceProtos.Content.newBuilder()
.setDataMessage(message)
.build()
val outgoingPushMessage: OutgoingPushMessage = cipher.encrypt(
SignalProtocolAddress(to.serviceId.toString(), 1),
Optional.empty(),
EnvelopeContent.encrypted(content, ContentHint.RESENDABLE, Optional.empty())
)
val encryptedContent: ByteArray = Base64.decode(outgoingPushMessage.content)
return SignalServiceProtos.Envelope.newBuilder()
.setSourceUuid(serviceId.toString())
.setSourceDevice(1)
.setDestinationUuid(to.serviceId.toString())
.setTimestamp(sentTimestamp)
.setServerTimestamp(sentTimestamp)
.setServerGuid(UUID.randomUUID().toString())
.setType(SignalServiceProtos.Envelope.Type.valueOf(outgoingPushMessage.type))
.setUrgent(true)
.setContent(ByteString.copyFrom(encryptedContent))
.build()
}
fun encryptSealedSender(to: SignalClient): SignalServiceProtos.Envelope {
val sentTimestamp = System.currentTimeMillis()
val message = SignalServiceProtos.DataMessage.newBuilder()
.setBody("Test Message")
.setTimestamp(sentTimestamp)
.build()
val content = SignalServiceProtos.Content.newBuilder()
.setDataMessage(message)
.build()
val outgoingPushMessage: OutgoingPushMessage = cipher.encrypt(
SignalProtocolAddress(to.serviceId.toString(), 1),
Optional.of(UnidentifiedAccess(to.unidentifiedAccessKey, senderCertificate.serialized)),
EnvelopeContent.encrypted(content, ContentHint.RESENDABLE, Optional.empty())
)
val encryptedContent: ByteArray = Base64.decode(outgoingPushMessage.content)
return SignalServiceProtos.Envelope.newBuilder()
.setSourceUuid(serviceId.toString())
.setSourceDevice(1)
.setDestinationUuid(to.serviceId.toString())
.setTimestamp(sentTimestamp)
.setServerTimestamp(sentTimestamp)
.setServerGuid(UUID.randomUUID().toString())
.setType(SignalServiceProtos.Envelope.Type.valueOf(outgoingPushMessage.type))
.setUrgent(true)
.setContent(ByteString.copyFrom(encryptedContent))
.build()
}
fun decryptMessage(envelope: SignalServiceProtos.Envelope) {
cipher.decrypt(envelope, System.currentTimeMillis())
}
}
private fun createCertificateFor(trustRoot: ECKeyPair, uuid: UUID, e164: String, deviceId: Int, identityKey: ECPublicKey, expires: Long): SenderCertificate {
val serverKey: ECKeyPair = Curve.generateKeyPair()
NativeHandleGuard(serverKey.publicKey).use { serverPublicGuard ->
NativeHandleGuard(trustRoot.privateKey).use { trustRootPrivateGuard ->
val serverCertificate = ServerCertificate(Native.ServerCertificate_New(1, serverPublicGuard.nativeHandle(), trustRootPrivateGuard.nativeHandle()))
NativeHandleGuard(identityKey).use { identityGuard ->
NativeHandleGuard(serverCertificate).use { serverCertificateGuard ->
NativeHandleGuard(serverKey.privateKey).use { serverPrivateGuard ->
return SenderCertificate(Native.SenderCertificate_New(uuid.toString(), e164, deviceId, identityGuard.nativeHandle(), expires, serverCertificateGuard.nativeHandle(), serverPrivateGuard.nativeHandle()))
}
}
}
}
}
}
private class TestSessionLock : SignalSessionLock {
val lock = ReentrantLock()
override fun acquire(): SignalSessionLock.Lock {
lock.lock()
return SignalSessionLock.Lock { lock.unlock() }
}
}