Reject last-use kyber key sets that we've seen before.

This commit is contained in:
Alex Hart
2025-10-01 16:08:01 -03:00
committed by Michelle Tang
parent 5324290fab
commit 1b9695cb98
12 changed files with 362 additions and 66 deletions

View File

@@ -6,6 +6,8 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.ReusedBaseKeyException
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.thoughtcrime.securesms.database.KyberPreKeyTable.KyberPreKey
import org.thoughtcrime.securesms.database.SignalDatabase
@@ -25,7 +27,13 @@ class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalSer
private var hasLoadedAll: Boolean = false
/** The kyber prekeys that have been marked as removed (if they're not last resort). */
private val removedIfNotLastResort: MutableSet<Int> = mutableSetOf()
private val removedIfNotLastResort: MutableSet<Triple<Int, Int, ECPublicKey>> = mutableSetOf()
/** Tuples of last-resort key data we've already seen. */
private val lastResortKeyTuples: MutableSet<Triple<Int, Int, ECPublicKey>> = mutableSetOf()
/** A separate list of tuples to flush so we don't try to flush the same one multiple times */
private val unFlushedLastResortKeyTuples: MutableSet<Triple<Int, Int, ECPublicKey>> = mutableSetOf()
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadKyberPreKey(kyberPreKeyId: Int): KyberPreKeyRecord {
@@ -63,16 +71,21 @@ class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalSer
return store.containsKey(kyberPreKeyId)
}
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) {
override fun markKyberPreKeyUsed(kyberPreKeyId: Int, signedPreKeyId: Int, publicKey: ECPublicKey) {
loadKyberPreKey(kyberPreKeyId)
store[kyberPreKeyId]?.let {
if (!it.lastResort) {
store.remove(kyberPreKeyId)
removedIfNotLastResort += Triple(kyberPreKeyId, signedPreKeyId, publicKey)
} else {
if (!lastResortKeyTuples.add(Triple(kyberPreKeyId, signedPreKeyId, publicKey))) {
throw ReusedBaseKeyException()
}
unFlushedLastResortKeyTuples += Triple(kyberPreKeyId, signedPreKeyId, publicKey)
}
}
removedIfNotLastResort += kyberPreKeyId
}
override fun removeKyberPreKey(kyberPreKeyId: Int) {
@@ -88,8 +101,13 @@ class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : SignalSer
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removedIfNotLastResort) {
persistentStore.markKyberPreKeyUsed(id)
val tuples = removedIfNotLastResort + unFlushedLastResortKeyTuples
unFlushedLastResortKeyTuples.clear()
for ((key, signedKey, publicKey) in tuples) {
persistentStore.markKyberPreKeyUsed(key, signedKey, publicKey)
}
unFlushedLastResortKeyTuples.clear()
}
}

View File

@@ -1,8 +1,10 @@
package org.thoughtcrime.securesms.messages.protocol
import org.signal.core.util.withinTransaction
import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.IdentityKeyStore.IdentityChange
@@ -10,6 +12,7 @@ import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.thoughtcrime.securesms.database.SignalDatabase
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.DistributionId
@@ -138,8 +141,8 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId)
}
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) {
return kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId)
override fun markKyberPreKeyUsed(kyberPreKeyId: Int, signedPreKeyId: Int, publicKey: ECPublicKey) {
return kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId, signedPreKeyId, publicKey)
}
override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) {
@@ -199,11 +202,13 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
identityStore.flushToDisk(persistentStore)
oneTimePreKeyStore.flushToDisk(persistentStore)
kyberPreKeyStore.flushToDisk(persistentStore)
signedPreKeyStore.flushToDisk(persistentStore)
sessionStore.flushToDisk(persistentStore)
senderKeyStore.flushToDisk(persistentStore)
SignalDatabase.writableDatabase.withinTransaction {
identityStore.flushToDisk(persistentStore)
oneTimePreKeyStore.flushToDisk(persistentStore)
kyberPreKeyStore.flushToDisk(persistentStore)
signedPreKeyStore.flushToDisk(persistentStore)
sessionStore.flushToDisk(persistentStore)
senderKeyStore.flushToDisk(persistentStore)
}
}
}