Clean up old one-time prekeys.

This commit is contained in:
Greyson Parrelli
2023-08-11 12:38:03 -04:00
committed by Cody Henthorne
parent 389b439e9a
commit d6adfea9b1
18 changed files with 572 additions and 11 deletions

View File

@@ -0,0 +1,176 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database
import junit.framework.TestCase.assertEquals
import junit.framework.TestCase.assertNotNull
import junit.framework.TestCase.assertNull
import org.junit.Test
import org.signal.core.util.readToSingleObject
import org.signal.core.util.requireLongOrNull
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.libsignal.protocol.ecc.Curve
import org.signal.libsignal.protocol.kem.KEMKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyType
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import java.util.UUID
class KyberPreKeyTableTest {
private val aci: ACI = ACI.from(UUID.randomUUID())
private val pni: PNI = PNI.from(UUID.randomUUID())
@Test
fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
insertTestRecord(aci, id = 1)
insertTestRecord(aci, id = 2)
insertTestRecord(aci, id = 3, staleTime = 42)
insertTestRecord(pni, id = 4)
val now = System.currentTimeMillis()
SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(aci, now)
assertEquals(now, getStaleTime(aci, 1))
assertEquals(now, getStaleTime(aci, 2))
assertEquals(42L, getStaleTime(aci, 3))
assertEquals(0L, getStaleTime(pni, 4))
}
@Test
fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(aci, id = 4, staleTime = 15)
insertTestRecord(aci, id = 5, staleTime = 0)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
assertNull(getStaleTime(aci, 1))
assertNull(getStaleTime(aci, 2))
assertNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
insertTestRecord(aci, id = 1, staleTime = 0)
insertTestRecord(aci, id = 2, staleTime = 0)
insertTestRecord(aci, id = 3, staleTime = 0)
insertTestRecord(aci, id = 4, staleTime = 0)
insertTestRecord(aci, id = 5, staleTime = 0)
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 1)
assertNotNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_respectMinCount() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(aci, id = 4, staleTime = 10)
insertTestRecord(aci, id = 5, staleTime = 10)
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
assertNull(getStaleTime(aci, 1))
assertNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_respectAccount() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(pni, id = 4, staleTime = 10)
insertTestRecord(pni, id = 5, staleTime = 10)
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
assertNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(pni, 4))
assertNotNull(getStaleTime(pni, 5))
}
@Test
fun deleteAllStaleBefore_ignoreLastResortForMinCount() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(aci, id = 4, staleTime = 10)
insertTestRecord(aci, id = 5, staleTime = 10, lastResort = true)
SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
assertNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_neverDeleteLastResort() {
insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true)
insertTestRecord(aci, id = 2, staleTime = 10, lastResort = true)
insertTestRecord(aci, id = 3, staleTime = 10, lastResort = true)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
assertNotNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
}
private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) {
val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
SignalDatabase.kyberPreKeys.insert(
serviceId = account,
keyId = id,
record = KyberPreKeyRecord(
id,
System.currentTimeMillis(),
kemKeyPair,
Curve.generateKeyPair().privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
),
lastResort = lastResort
)
val count = SignalDatabase.rawDatabase
.update(KyberPreKeyTable.TABLE_NAME)
.values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
.run()
assertEquals(1, count)
}
private fun getStaleTime(account: ServiceId, id: Int): Long? {
return SignalDatabase.rawDatabase
.select(KyberPreKeyTable.STALE_TIMESTAMP)
.from(KyberPreKeyTable.TABLE_NAME)
.where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
.run()
.readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) }
}
}

View File

@@ -0,0 +1,137 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database
import junit.framework.TestCase.assertEquals
import junit.framework.TestCase.assertNotNull
import junit.framework.TestCase.assertNull
import org.junit.Test
import org.signal.core.util.readToSingleObject
import org.signal.core.util.requireLongOrNull
import org.signal.core.util.select
import org.signal.core.util.update
import org.signal.libsignal.protocol.ecc.Curve
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import java.util.UUID
class OneTimePreKeyTableTest {
private val aci: ACI = ACI.from(UUID.randomUUID())
private val pni: PNI = PNI.from(UUID.randomUUID())
@Test
fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
insertTestRecord(aci, id = 1)
insertTestRecord(aci, id = 2)
insertTestRecord(aci, id = 3, staleTime = 42)
insertTestRecord(pni, id = 4)
val now = System.currentTimeMillis()
SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now)
assertEquals(now, getStaleTime(aci, 1))
assertEquals(now, getStaleTime(aci, 2))
assertEquals(42L, getStaleTime(aci, 3))
assertEquals(0L, getStaleTime(pni, 4))
}
@Test
fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(aci, id = 4, staleTime = 15)
insertTestRecord(aci, id = 5, staleTime = 0)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
assertNull(getStaleTime(aci, 1))
assertNull(getStaleTime(aci, 2))
assertNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
insertTestRecord(aci, id = 1, staleTime = 0)
insertTestRecord(aci, id = 2, staleTime = 0)
insertTestRecord(aci, id = 3, staleTime = 0)
insertTestRecord(aci, id = 4, staleTime = 0)
insertTestRecord(aci, id = 5, staleTime = 0)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0)
assertNotNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_respectMinCount() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(aci, id = 4, staleTime = 10)
insertTestRecord(aci, id = 5, staleTime = 10)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
assertNull(getStaleTime(aci, 1))
assertNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(aci, 4))
assertNotNull(getStaleTime(aci, 5))
}
@Test
fun deleteAllStaleBefore_respectAccount() {
insertTestRecord(aci, id = 1, staleTime = 10)
insertTestRecord(aci, id = 2, staleTime = 10)
insertTestRecord(aci, id = 3, staleTime = 10)
insertTestRecord(pni, id = 4, staleTime = 10)
insertTestRecord(pni, id = 5, staleTime = 10)
SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
assertNull(getStaleTime(aci, 1))
assertNotNull(getStaleTime(aci, 2))
assertNotNull(getStaleTime(aci, 3))
assertNotNull(getStaleTime(pni, 4))
assertNotNull(getStaleTime(pni, 5))
}
private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) {
SignalDatabase.oneTimePreKeys.insert(
serviceId = account,
keyId = id,
record = PreKeyRecord(id, Curve.generateKeyPair())
)
val count = SignalDatabase.rawDatabase
.update(OneTimePreKeyTable.TABLE_NAME)
.values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime)
.where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
.run()
assertEquals(1, count)
}
private fun getStaleTime(account: ServiceId, id: Int): Long? {
return SignalDatabase.rawDatabase
.select(OneTimePreKeyTable.STALE_TIMESTAMP)
.from(OneTimePreKeyTable.TABLE_NAME)
.where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
.run()
.readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) }
}
}

View File

@@ -32,10 +32,10 @@ import org.whispersystems.signalservice.api.push.DistributionId
import org.whispersystems.signalservice.api.push.ServiceId
import org.whispersystems.signalservice.api.push.SignalServiceAddress
import org.whispersystems.signalservice.internal.push.SignalServiceProtos
import java.lang.UnsupportedOperationException
import java.util.Optional
import java.util.UUID
import java.util.concurrent.locks.ReentrantLock
import kotlin.UnsupportedOperationException
/**
* Welcome to Bob's Client.
@@ -144,7 +144,6 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
override fun getSubDeviceSessions(name: String?): List<Int> = emptyList()
override fun containsSession(address: SignalProtocolAddress?): Boolean = aliceSessionRecord != null
override fun getIdentity(address: SignalProtocolAddress?): IdentityKey = SignalStore.account().aciIdentityKey.publicKey
override fun loadPreKey(preKeyId: Int): PreKeyRecord = throw UnsupportedOperationException()
override fun storePreKey(preKeyId: Int, record: PreKeyRecord?) = throw UnsupportedOperationException()
override fun containsPreKey(preKeyId: Int): Boolean = throw UnsupportedOperationException()
@@ -162,6 +161,8 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord?) = throw UnsupportedOperationException()
override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean = throw UnsupportedOperationException()
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) = throw UnsupportedOperationException()
override fun deleteAllStaleOneTimeEcPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException()
override fun markAllOneTimeEcPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException()
override fun storeSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?, record: SenderKeyRecord?) = throw UnsupportedOperationException()
override fun loadSenderKey(sender: SignalProtocolAddress?, distributionId: UUID?): SenderKeyRecord = throw UnsupportedOperationException()
override fun archiveSession(address: SignalProtocolAddress?) = throw UnsupportedOperationException()
@@ -171,8 +172,9 @@ class BobClient(val serviceId: ServiceId, val e164: String, val identityKeyPair:
override fun clearSenderKeySharedWith(addresses: MutableCollection<SignalProtocolAddress>?) = throw UnsupportedOperationException()
override fun storeLastResortKyberPreKey(kyberPreKeyId: Int, kyberPreKeyRecord: KyberPreKeyRecord) = throw UnsupportedOperationException()
override fun removeKyberPreKey(kyberPreKeyId: Int) = throw UnsupportedOperationException()
override fun markAllOneTimeKyberPreKeysStaleIfNecessary(staleTime: Long) = throw UnsupportedOperationException()
override fun deleteAllStaleOneTimeKyberPreKeys(threshold: Long, minCount: Int) = throw UnsupportedOperationException()
override fun loadLastResortKyberPreKeys(): List<KyberPreKeyRecord> = throw UnsupportedOperationException()
override fun isMultiDevice(): Boolean = throw UnsupportedOperationException()
}
}