Add initial storage interfaces for kyber prekeys.

This commit is contained in:
Greyson Parrelli
2023-05-17 11:58:45 -04:00
parent c76002663f
commit e2c2ace0e3
12 changed files with 352 additions and 12 deletions

View File

@@ -0,0 +1,49 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.crypto.storage
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.KyberPreKeyStore
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.push.ServiceId
import kotlin.jvm.Throws
/**
* An implementation of the [KyberPreKeyStore] that stores entries in [org.thoughtcrime.securesms.database.KyberPreKeyTable].
*/
class SignalKyberPreKeyStore(private val selfServiceId: ServiceId) : KyberPreKeyStore {
@Throws(InvalidKeyIdException::class)
override fun loadKyberPreKey(kyberPreKeyId: Int): KyberPreKeyRecord {
ReentrantSessionLock.INSTANCE.acquire().use {
return SignalDatabase.kyberPreKeys.get(selfServiceId, kyberPreKeyId)?.record ?: throw InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId")
}
}
override fun loadKyberPreKeys(): List<KyberPreKeyRecord> {
ReentrantSessionLock.INSTANCE.acquire().use {
return SignalDatabase.kyberPreKeys.getAll(selfServiceId).map { it.record }
}
}
override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord) {
error("This method is only used in tests")
}
override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean {
ReentrantSessionLock.INSTANCE.acquire().use {
return SignalDatabase.kyberPreKeys.contains(selfServiceId, kyberPreKeyId)
}
}
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) {
ReentrantSessionLock.INSTANCE.acquire().use {
SignalDatabase.kyberPreKeys.deleteIfNotLastResort(selfServiceId, kyberPreKeyId)
}
}
}

View File

@@ -10,6 +10,7 @@ import org.signal.libsignal.protocol.InvalidKeyIdException;
import org.signal.libsignal.protocol.NoSessionException;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
@@ -31,15 +32,18 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
private final SignalIdentityKeyStore identityKeyStore;
private final TextSecureSessionStore sessionStore;
private final SignalSenderKeyStore senderKeyStore;
private final SignalKyberPreKeyStore kyberPreKeyStore;
public SignalServiceAccountDataStoreImpl(@NonNull Context context,
@NonNull TextSecurePreKeyStore preKeyStore,
@NonNull SignalKyberPreKeyStore kyberPreKeyStore,
@NonNull SignalIdentityKeyStore identityKeyStore,
@NonNull TextSecureSessionStore sessionStore,
@NonNull SignalSenderKeyStore senderKeyStore)
{
this.context = context;
this.preKeyStore = preKeyStore;
this.kyberPreKeyStore = kyberPreKeyStore;
this.signedPreKeyStore = preKeyStore;
this.identityKeyStore = identityKeyStore;
this.sessionStore = sessionStore;
@@ -167,6 +171,31 @@ public class SignalServiceAccountDataStoreImpl implements SignalServiceAccountDa
signedPreKeyStore.removeSignedPreKey(signedPreKeyId);
}
@Override
public KyberPreKeyRecord loadKyberPreKey(int kyberPreKeyId) throws InvalidKeyIdException {
return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId);
}
@Override
public List<KyberPreKeyRecord> loadKyberPreKeys() {
return kyberPreKeyStore.loadKyberPreKeys();
}
@Override
public void storeKyberPreKey(int kyberPreKeyId, KyberPreKeyRecord record) {
kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record);
}
@Override
public boolean containsKyberPreKey(int kyberPreKeyId) {
return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId);
}
@Override
public void markKyberPreKeyUsed(int kyberPreKeyId) {
kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId);
}
@Override
public void storeSenderKey(SignalProtocolAddress sender, UUID distributionId, SenderKeyRecord record) {
senderKeyStore.storeSenderKey(sender, distributionId, record);

View File

@@ -0,0 +1,111 @@
package org.thoughtcrime.securesms.database
import android.content.Context
import org.signal.core.util.delete
import org.signal.core.util.exists
import org.signal.core.util.insertInto
import org.signal.core.util.readToList
import org.signal.core.util.readToSingleObject
import org.signal.core.util.requireBoolean
import org.signal.core.util.requireNonNullBlob
import org.signal.core.util.select
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.whispersystems.signalservice.api.push.ServiceId
/**
* A table for storing data related to [org.thoughtcrime.securesms.crypto.storage.SignalKyberPreKeyStore].
*/
class KyberPreKeyTable(context: Context, databaseHelper: SignalDatabase) : DatabaseTable(context, databaseHelper) {
companion object {
const val TABLE_NAME = "kyber_prekey"
const val ID = "_id"
const val ACCOUNT_ID = "account_id"
const val KEY_ID = "key_id"
const val TIMESTAMP = "timestamp"
const val LAST_RESORT = "last_resort"
const val SERIALIZED = "serialized"
const val CREATE_TABLE = """
CREATE TABLE $TABLE_NAME (
$ID INTEGER PRIMARY KEY,
$ACCOUNT_ID TEXT NOT NULL,
$KEY_ID INTEGER UNIQUE NOT NULL,
$TIMESTAMP INTEGER NOT NULL,
$LAST_RESORT INTEGER NOT NULL,
$SERIALIZED BLOB NOT NULL,
UNIQUE($ACCOUNT_ID, $KEY_ID)
)
"""
private const val INDEX_ACCOUNT_KEY = "kyber_account_id_key_id"
val CREATE_INDEXES = arrayOf(
"CREATE INDEX IF NOT EXISTS $INDEX_ACCOUNT_KEY ON $TABLE_NAME ($ACCOUNT_ID, $KEY_ID, $LAST_RESORT, $SERIALIZED)"
)
}
fun get(serviceId: ServiceId, keyId: Int): KyberPreKey? {
return readableDatabase
.select(LAST_RESORT, SERIALIZED)
.from("$TABLE_NAME INDEXED BY $INDEX_ACCOUNT_KEY")
.where("$ACCOUNT_ID = ? AND $KEY_ID = ?", serviceId, keyId)
.run()
.readToSingleObject { cursor ->
KyberPreKey(
record = KyberPreKeyRecord(cursor.requireNonNullBlob(SERIALIZED)),
lastResort = cursor.requireBoolean(LAST_RESORT)
)
}
}
fun getAll(serviceId: ServiceId): List<KyberPreKey> {
return readableDatabase
.select(LAST_RESORT, SERIALIZED)
.from("$TABLE_NAME INDEXED BY $INDEX_ACCOUNT_KEY")
.where("$ACCOUNT_ID = ?", serviceId)
.run()
.readToList { cursor ->
KyberPreKey(
record = KyberPreKeyRecord(cursor.requireNonNullBlob(SERIALIZED)),
lastResort = cursor.requireBoolean(LAST_RESORT)
)
}
}
fun contains(serviceId: ServiceId, keyId: Int): Boolean {
return readableDatabase
.exists("$TABLE_NAME INDEXED BY $INDEX_ACCOUNT_KEY")
.where("$ACCOUNT_ID = ? AND $KEY_ID = ?", serviceId, keyId)
.run()
}
fun insert(serviceId: ServiceId, keyId: Int, record: KyberPreKeyRecord) {
writableDatabase
.insertInto(TABLE_NAME)
.values(
ACCOUNT_ID to serviceId.toString(),
KEY_ID to keyId,
TIMESTAMP to record.timestamp,
SERIALIZED to record.serialize()
)
.run(SQLiteDatabase.CONFLICT_REPLACE)
}
fun deleteIfNotLastResort(serviceId: ServiceId, keyId: Int) {
writableDatabase
.delete("$TABLE_NAME INDEXED BY $INDEX_ACCOUNT_KEY")
.where("$ACCOUNT_ID = ? AND $KEY_ID = ? AND $LAST_RESORT = ?", serviceId, keyId, 0)
.run()
}
fun delete(serviceId: ServiceId, keyId: Int) {
writableDatabase
.delete("$TABLE_NAME INDEXED BY $INDEX_ACCOUNT_KEY")
.where("$ACCOUNT_ID = ? AND $KEY_ID = ?", serviceId, keyId)
.run()
}
data class KyberPreKey(
val record: KyberPreKeyRecord,
val lastResort: Boolean
)
}

View File

@@ -74,6 +74,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val remoteMegaphoneTable: RemoteMegaphoneTable = RemoteMegaphoneTable(context, this)
val pendingPniSignatureMessageTable: PendingPniSignatureMessageTable = PendingPniSignatureMessageTable(context, this)
val callTable: CallTable = CallTable(context, this)
val kyberPreKeyTable: KyberPreKeyTable = KyberPreKeyTable(context, this)
override fun onOpen(db: net.zetetic.database.sqlcipher.SQLiteDatabase) {
db.setForeignKeyConstraintsEnabled(true)
@@ -110,6 +111,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
db.execSQL(PendingPniSignatureMessageTable.CREATE_TABLE)
db.execSQL(CallLinkTable.CREATE_TABLE)
db.execSQL(CallTable.CREATE_TABLE)
db.execSQL(KyberPreKeyTable.CREATE_TABLE)
executeStatements(db, SearchTable.CREATE_TABLE)
executeStatements(db, RemappedRecordTables.CREATE_TABLE)
executeStatements(db, MessageSendLogTables.CREATE_TABLE)
@@ -135,6 +137,7 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
executeStatements(db, PendingPniSignatureMessageTable.CREATE_INDEXES)
executeStatements(db, CallTable.CREATE_INDEXES)
executeStatements(db, ReactionTable.CREATE_INDEXES)
executeStatements(db, KyberPreKeyTable.CREATE_INDEXES)
executeStatements(db, SearchTable.CREATE_TRIGGERS)
executeStatements(db, MessageSendLogTables.CREATE_TRIGGERS)
@@ -403,6 +406,11 @@ open class SignalDatabase(private val context: Application, databaseSecret: Data
val identities: IdentityTable
get() = instance!!.identityTable
@get:JvmStatic
@get:JvmName("kyberPreKeys")
val kyberPreKeys: KyberPreKeyTable
get() = instance!!.kyberPreKeyTable
@get:JvmStatic
@get:JvmName("media")
val media: MediaTable

View File

@@ -49,6 +49,7 @@ import org.thoughtcrime.securesms.database.helpers.migration.V190_UniqueMessageM
import org.thoughtcrime.securesms.database.helpers.migration.V191_UniqueMessageMigrationV2
import org.thoughtcrime.securesms.database.helpers.migration.V192_CallLinkTableNullableRootKeys
import org.thoughtcrime.securesms.database.helpers.migration.V193_BackCallLinksWithRecipient
import org.thoughtcrime.securesms.database.helpers.migration.V194_KyberPreKeyMigration
/**
* Contains all of the database migrations for [SignalDatabase]. Broken into a separate file for cleanliness.
@@ -57,7 +58,7 @@ object SignalDatabaseMigrations {
val TAG: String = Log.tag(SignalDatabaseMigrations.javaClass)
const val DATABASE_VERSION = 193
const val DATABASE_VERSION = 194
@JvmStatic
fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
@@ -240,6 +241,10 @@ object SignalDatabaseMigrations {
if (oldVersion < 193) {
V193_BackCallLinksWithRecipient.migrate(context, db, oldVersion, newVersion)
}
if (oldVersion < 194) {
V194_KyberPreKeyMigration.migrate(context, db, oldVersion, newVersion)
}
}
@JvmStatic

View File

@@ -0,0 +1,32 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.database.helpers.migration
import android.app.Application
import net.zetetic.database.sqlcipher.SQLiteDatabase
/**
* Introduces [org.thoughtcrime.securesms.database.KyberPreKeyTable].
*/
object V194_KyberPreKeyMigration : SignalDatabaseMigration {
override fun migrate(context: Application, db: SQLiteDatabase, oldVersion: Int, newVersion: Int) {
db.execSQL(
"""
CREATE TABLE kyber_prekey (
_id INTEGER PRIMARY KEY,
account_id TEXT NOT NULL,
key_id INTEGER UNIQUE NOT NULL,
timestamp INTEGER NOT NULL,
last_resort INTEGER NOT NULL,
serialized BLOB NOT NULL,
UNIQUE(account_id, key_id)
)
"""
)
db.execSQL("CREATE INDEX IF NOT EXISTS kyber_account_id_key_id ON kyber_prekey (account_id, key_id, last_resort, serialized)")
}
}

View File

@@ -21,6 +21,7 @@ import org.thoughtcrime.securesms.components.TypingStatusSender;
import org.thoughtcrime.securesms.crypto.ReentrantSessionLock;
import org.thoughtcrime.securesms.crypto.storage.SignalBaseIdentityKeyStore;
import org.thoughtcrime.securesms.crypto.storage.SignalIdentityKeyStore;
import org.thoughtcrime.securesms.crypto.storage.SignalKyberPreKeyStore;
import org.thoughtcrime.securesms.crypto.storage.SignalSenderKeyStore;
import org.thoughtcrime.securesms.crypto.storage.SignalServiceAccountDataStoreImpl;
import org.thoughtcrime.securesms.crypto.storage.SignalServiceDataStoreImpl;
@@ -333,12 +334,14 @@ public class ApplicationDependencyProvider implements ApplicationDependencies.Pr
SignalServiceAccountDataStoreImpl aciStore = new SignalServiceAccountDataStoreImpl(context,
new TextSecurePreKeyStore(localAci),
new SignalKyberPreKeyStore(localAci),
new SignalIdentityKeyStore(baseIdentityStore, () -> SignalStore.account().getAciIdentityKey()),
new TextSecureSessionStore(localAci),
new SignalSenderKeyStore(context));
SignalServiceAccountDataStoreImpl pniStore = new SignalServiceAccountDataStoreImpl(context,
new TextSecurePreKeyStore(localPni),
new SignalKyberPreKeyStore(localPni),
new SignalIdentityKeyStore(baseIdentityStore, () -> SignalStore.account().getPniIdentityKey()),
new TextSecureSessionStore(localPni),
new SignalSenderKeyStore(context));

View File

@@ -0,0 +1,75 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.thoughtcrime.securesms.messages.protocol
import org.signal.libsignal.protocol.InvalidKeyIdException
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.KyberPreKeyStore
import org.thoughtcrime.securesms.database.KyberPreKeyTable.KyberPreKey
import org.thoughtcrime.securesms.database.SignalDatabase
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore
import org.whispersystems.signalservice.api.push.ServiceId
/**
* An in-memory kyber prekey store that is intended to be used temporarily while decrypting messages.
*/
class BufferedKyberPreKeyStore(private val selfServiceId: ServiceId) : KyberPreKeyStore {
/** Our in-memory cache of kyber prekeys. */
val store: MutableMap<Int, KyberPreKey> = mutableMapOf()
/** Whether or not we've done a loadAll operation. Let's us avoid doing it twice. */
private var hasLoadedAll: Boolean = false
/** The kyber prekeys that have been marked as removed (if they're not last resort). */
private val removedIfNotLastResort: MutableList<Int> = mutableListOf()
@kotlin.jvm.Throws(InvalidKeyIdException::class)
override fun loadKyberPreKey(kyberPreKeyId: Int): KyberPreKeyRecord {
return store.computeIfAbsent(kyberPreKeyId) {
SignalDatabase.kyberPreKeys.get(selfServiceId, kyberPreKeyId) ?: throw InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId")
}.record
}
override fun loadKyberPreKeys(): List<KyberPreKeyRecord> {
return if (hasLoadedAll) {
store.values.map { it.record }
} else {
val models = SignalDatabase.kyberPreKeys.getAll(selfServiceId)
models.forEach { store[it.record.id] = it }
hasLoadedAll = true
models.map { it.record }
}
}
override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord) {
error("This method is only used in tests")
}
override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean {
loadKyberPreKey(kyberPreKeyId)
return store.containsKey(kyberPreKeyId)
}
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) {
loadKyberPreKey(kyberPreKeyId)
store[kyberPreKeyId]?.let {
if (!it.lastResort) {
store.remove(kyberPreKeyId)
}
}
removedIfNotLastResort += kyberPreKeyId
}
fun flushToDisk(persistentStore: SignalServiceAccountDataStore) {
for (id in removedIfNotLastResort) {
persistentStore.markKyberPreKeyUsed(id)
}
}
}

View File

@@ -5,6 +5,7 @@ import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord
import org.signal.libsignal.protocol.state.IdentityKeyStore
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.PreKeyRecord
import org.signal.libsignal.protocol.state.SessionRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
@@ -28,6 +29,7 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
private val oneTimePreKeyStore: BufferedOneTimePreKeyStore = BufferedOneTimePreKeyStore(selfServiceId)
private val signedPreKeyStore: BufferedSignedPreKeyStore = BufferedSignedPreKeyStore(selfServiceId)
private val kyberPreKeyStore: BufferedKyberPreKeyStore = BufferedKyberPreKeyStore(selfServiceId)
private val sessionStore: BufferedSessionStore = BufferedSessionStore(selfServiceId)
private val senderKeyStore: BufferedSenderKeyStore = BufferedSenderKeyStore()
@@ -115,6 +117,26 @@ class BufferedSignalServiceAccountDataStore(selfServiceId: ServiceId) : SignalSe
signedPreKeyStore.removeSignedPreKey(signedPreKeyId)
}
override fun loadKyberPreKey(kyberPreKeyId: Int): KyberPreKeyRecord {
return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId)
}
override fun loadKyberPreKeys(): List<KyberPreKeyRecord> {
return kyberPreKeyStore.loadKyberPreKeys()
}
override fun storeKyberPreKey(kyberPreKeyId: Int, record: KyberPreKeyRecord) {
kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record)
}
override fun containsKyberPreKey(kyberPreKeyId: Int): Boolean {
return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId)
}
override fun markKyberPreKeyUsed(kyberPreKeyId: Int) {
return kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId)
}
override fun storeSenderKey(sender: SignalProtocolAddress, distributionId: UUID, record: SenderKeyRecord) {
senderKeyStore.storeSenderKey(sender, distributionId, record)
}