Reject last-use kyber key sets that we've seen before.

This commit is contained in:
Alex Hart
2025-10-01 16:08:01 -03:00
committed by Michelle Tang
parent 5324290fab
commit 1b9695cb98
12 changed files with 362 additions and 66 deletions

View File

@@ -9,15 +9,10 @@ import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull
import org.junit.Test
import org.signal.core.util.readToSingleObject
import org.signal.core.util.requireLongOrNull
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyType
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.whispersystems.signalservice.api.push.ServiceId
import org.signal.libsignal.protocol.ReusedBaseKeyException
import org.thoughtcrime.securesms.util.KyberPreKeysTestUtil.generateECPublicKey
import org.thoughtcrime.securesms.util.KyberPreKeysTestUtil.getStaleTime
import org.thoughtcrime.securesms.util.KyberPreKeysTestUtil.insertTestRecord
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import java.util.UUID
@@ -142,42 +137,43 @@ class KyberPreKeyTableTest {
assertNotNull(getStaleTime(aci, 3))
}
private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) {
val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
SignalDatabase.kyberPreKeys.insert(
serviceId = account,
keyId = id,
record = KyberPreKeyRecord(
id,
System.currentTimeMillis(),
kemKeyPair,
ECKeyPair.generate().privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
),
lastResort = lastResort
@Test(expected = ReusedBaseKeyException::class)
fun handleMarkKyberPreKeyUsed_doesNotAllowDuplicateLastResortKeyEntries() {
insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true)
val publicKey = generateECPublicKey()
SignalDatabase.kyberPreKeys.handleMarkKyberPreKeyUsed(
serviceId = aci,
kyberPreKeyId = 1,
signedPreKeyId = 1,
baseKey = publicKey
)
val count = SignalDatabase.rawDatabase
.update(KyberPreKeyTable.TABLE_NAME)
.values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account.toAccountId())
.run()
assertEquals(1, count)
SignalDatabase.kyberPreKeys.handleMarkKyberPreKeyUsed(
serviceId = aci,
kyberPreKeyId = 1,
signedPreKeyId = 1,
baseKey = publicKey
)
}
private fun getStaleTime(account: ServiceId, id: Int): Long? {
return SignalDatabase.rawDatabase
.select(KyberPreKeyTable.STALE_TIMESTAMP)
.from(KyberPreKeyTable.TABLE_NAME)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account.toAccountId())
.run()
.readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) }
}
@Test
fun handleMarkKyberPreKeyUsed_allowDuplicateNonLastResortKeyEntries() {
insertTestRecord(aci, id = 1, staleTime = 10, lastResort = false)
val publicKey = generateECPublicKey()
private fun ServiceId.toAccountId(): String {
return when (this) {
is ACI -> this.toString()
is PNI -> KyberPreKeyTable.PNI_ACCOUNT_ID
}
SignalDatabase.kyberPreKeys.handleMarkKyberPreKeyUsed(
serviceId = aci,
kyberPreKeyId = 1,
signedPreKeyId = 1,
baseKey = publicKey
)
SignalDatabase.kyberPreKeys.handleMarkKyberPreKeyUsed(
serviceId = aci,
kyberPreKeyId = 1,
signedPreKeyId = 1,
baseKey = publicKey
)
}
}

View File

@@ -0,0 +1,79 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.messages.protocol
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.signal.libsignal.protocol.ReusedBaseKeyException
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.testing.SignalDatabaseRule
import org.thoughtcrime.securesms.util.KyberPreKeysTestUtil
import org.whispersystems.signalservice.api.push.ServiceId
class BufferedKyberPreKeyStoreTest {
@get:Rule
val harness = SignalDatabaseRule()
private lateinit var aci: ServiceId
private lateinit var testSubject: BufferedKyberPreKeyStore
private lateinit var dataStore: BufferedSignalServiceAccountDataStore
@Before
fun setUp() {
SignalStore.account.generateAciIdentityKeyIfNecessary()
aci = harness.localAci
testSubject = BufferedKyberPreKeyStore(aci)
dataStore = BufferedSignalServiceAccountDataStore(aci)
}
@Test
fun givenALastResortKey_whenIMarkKyberPreKeyUsed_thenIExpectNoIssues() {
KyberPreKeysTestUtil.insertTestRecord(aci, 1, lastResort = true)
val publicKey = KyberPreKeysTestUtil.generateECPublicKey()
testSubject.markKyberPreKeyUsed(
kyberPreKeyId = 1,
signedPreKeyId = 2,
publicKey = publicKey
)
}
@Test(expected = ReusedBaseKeyException::class)
fun givenALastResortKey_whenIMarkKyberPreKeyUsedTwice_thenIExpectException() {
KyberPreKeysTestUtil.insertTestRecord(aci, 1, lastResort = true)
val publicKey = KyberPreKeysTestUtil.generateECPublicKey()
testSubject.markKyberPreKeyUsed(
kyberPreKeyId = 1,
signedPreKeyId = 2,
publicKey = publicKey
)
testSubject.markKyberPreKeyUsed(
kyberPreKeyId = 1,
signedPreKeyId = 2,
publicKey = publicKey
)
}
@Test
fun givenAMarkedLastResortKey_whenIFlushTwice_thenIExpectNoIssues() {
KyberPreKeysTestUtil.insertTestRecord(aci, 1, lastResort = true)
val publicKey = KyberPreKeysTestUtil.generateECPublicKey()
testSubject.markKyberPreKeyUsed(
kyberPreKeyId = 1,
signedPreKeyId = 2,
publicKey = publicKey
)
testSubject.flushToDisk(dataStore)
testSubject.flushToDisk(dataStore)
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.util
import org.junit.Assert.assertEquals
import org.signal.core.util.readToSingleObject
import org.signal.core.util.requireLongOrNull
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.protocol.kem.KEMKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyType
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.thoughtcrime.securesms.database.KyberPreKeyTable
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import java.security.SecureRandom
object KyberPreKeysTestUtil {
fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) {
val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
SignalDatabase.kyberPreKeys.insert(
serviceId = account,
keyId = id,
record = KyberPreKeyRecord(
id,
System.currentTimeMillis(),
kemKeyPair,
ECKeyPair.generate().privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
),
lastResort = lastResort
)
val count = SignalDatabase.rawDatabase
.update(KyberPreKeyTable.TABLE_NAME)
.values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account.toAccountId())
.run()
assertEquals(1, count)
}
fun getStaleTime(account: ServiceId, id: Int): Long? {
return SignalDatabase.rawDatabase
.select(KyberPreKeyTable.STALE_TIMESTAMP)
.from(KyberPreKeyTable.TABLE_NAME)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account.toAccountId())
.run()
.readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) }
}
fun generateECPublicKey(): ECPublicKey {
val byteArray = ByteArray(ECPublicKey.KEY_SIZE - 1)
SecureRandom().nextBytes(byteArray)
return ECPublicKey.fromPublicKeyBytes(byteArray)
}
private fun ServiceId.toAccountId(): String {
return when (this) {
is ACI -> this.toString()
is PNI -> KyberPreKeyTable.PNI_ACCOUNT_ID
}
}
}