Improve the storage controller for regV5.

This commit is contained in:
Greyson Parrelli
2026-03-16 09:07:04 -04:00
committed by Michelle Tang
parent 6877b9163b
commit d2c8b6e14c
5 changed files with 391 additions and 235 deletions

View File

@@ -10,191 +10,159 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
import org.signal.core.util.Base64
import org.signal.core.models.ServiceId.ACI
import org.signal.core.models.ServiceId.PNI
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.ecc.ECKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyPair
import org.signal.libsignal.protocol.kem.KEMKeyType
import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.signal.registration.KeyMaterial
import org.signal.registration.NetworkController
import org.signal.registration.NewRegistrationData
import org.signal.registration.PreExistingRegistrationData
import org.signal.registration.StorageController
import org.signal.registration.proto.ProvisioningData
import org.signal.registration.proto.RegistrationData
import org.signal.registration.sample.storage.RegistrationDatabase
import org.signal.registration.sample.storage.RegistrationPreferences
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec
import java.io.File
/**
* Implementation of [StorageController] that persists registration data using
* SharedPreferences for simple key-value data and SQLite for prekeys.
*/
class DemoStorageController(context: Context) : StorageController {
class DemoStorageController(private val context: Context) : StorageController {
companion object {
private const val MAX_SVR_CREDENTIALS = 10
private const val TEMP_PROTO_FILENAME = "registration_data.pb"
}
private val db = RegistrationDatabase(context)
override suspend fun generateAndStoreKeyMaterial(
existingAccountEntropyPool: AccountEntropyPool?,
existingAciIdentityKeyPair: IdentityKeyPair?,
existingPniIdentityKeyPair: IdentityKeyPair?
): KeyMaterial = withContext(Dispatchers.IO) {
val accountEntropyPool = existingAccountEntropyPool ?: AccountEntropyPool.generate()
val aciIdentityKeyPair = existingAciIdentityKeyPair ?: IdentityKeyPair.generate()
val pniIdentityKeyPair = existingPniIdentityKeyPair ?: IdentityKeyPair.generate()
val aciSignedPreKeyId = generatePreKeyId()
val pniSignedPreKeyId = generatePreKeyId()
val aciKyberPreKeyId = generatePreKeyId()
val pniKyberPreKeyId = generatePreKeyId()
val timestamp = System.currentTimeMillis()
val aciSignedPreKey = generateSignedPreKey(aciSignedPreKeyId, timestamp, aciIdentityKeyPair)
val pniSignedPreKey = generateSignedPreKey(pniSignedPreKeyId, timestamp, pniIdentityKeyPair)
val aciLastResortKyberPreKey = generateKyberPreKey(aciKyberPreKeyId, timestamp, aciIdentityKeyPair)
val pniLastResortKyberPreKey = generateKyberPreKey(pniKyberPreKeyId, timestamp, pniIdentityKeyPair)
val aciRegistrationId = generateRegistrationId()
val pniRegistrationId = generateRegistrationId()
val profileKey = generateProfileKey()
val unidentifiedAccessKey = deriveUnidentifiedAccessKey(profileKey)
val password = generatePassword()
val keyMaterial = KeyMaterial(
aciIdentityKeyPair = aciIdentityKeyPair,
aciSignedPreKey = aciSignedPreKey,
aciLastResortKyberPreKey = aciLastResortKyberPreKey,
pniIdentityKeyPair = pniIdentityKeyPair,
pniSignedPreKey = pniSignedPreKey,
pniLastResortKyberPreKey = pniLastResortKyberPreKey,
aciRegistrationId = aciRegistrationId,
pniRegistrationId = pniRegistrationId,
unidentifiedAccessKey = unidentifiedAccessKey,
servicePassword = password,
accountEntropyPool = accountEntropyPool
)
storeKeyMaterial(keyMaterial, profileKey)
keyMaterial
}
override suspend fun saveNewRegistrationData(newRegistrationData: NewRegistrationData) = withContext(Dispatchers.IO) {
RegistrationPreferences.saveRegistrationData(newRegistrationData)
}
override suspend fun getPreExistingRegistrationData(): PreExistingRegistrationData? = withContext(Dispatchers.IO) {
RegistrationPreferences.getPreExistingRegistrationData()
}
override suspend fun clearAllData() = withContext(Dispatchers.IO) {
File(context.filesDir, TEMP_PROTO_FILENAME).takeIf { it.exists() }?.delete()
RegistrationPreferences.clearAll()
RegistrationPreferences.clearRestoredSvr2Credentials()
db.clearAllPreKeys()
}
override suspend fun saveValidatedPinAndTemporaryMasterKey(pin: String, isAlphanumeric: Boolean, masterKey: MasterKey, registrationLockEnabled: Boolean) = withContext(Dispatchers.IO) {
RegistrationPreferences.pin = pin
RegistrationPreferences.pinAlphanumeric = isAlphanumeric
RegistrationPreferences.temporaryMasterKey = masterKey
RegistrationPreferences.registrationLockEnabled = registrationLockEnabled
override suspend fun readInProgressRegistrationData(): RegistrationData = withContext(Dispatchers.IO) {
val file = File(context.filesDir, TEMP_PROTO_FILENAME)
if (file.exists()) {
RegistrationData.ADAPTER.decode(file.readBytes())
} else {
RegistrationData()
}
}
override suspend fun getRestoredSvrCredentials(): List<NetworkController.SvrCredentials> = withContext(Dispatchers.IO) {
RegistrationPreferences.restoredSvr2Credentials
override suspend fun updateInProgressRegistrationData(updater: RegistrationData.Builder.() -> Unit) = withContext(Dispatchers.IO) {
val current = readInProgressRegistrationData()
val updated = current.newBuilder().apply(updater).build()
writeRegistrationData(updated)
}
override suspend fun appendSvrCredentials(credentials: List<NetworkController.SvrCredentials>) = withContext(Dispatchers.IO) {
val existing = RegistrationPreferences.restoredSvr2Credentials
val combined = (existing + credentials).distinctBy { it.username }.takeLast(MAX_SVR_CREDENTIALS)
RegistrationPreferences.restoredSvr2Credentials = combined
override suspend fun commitRegistrationData() = withContext(Dispatchers.IO) {
val file = File(context.filesDir, TEMP_PROTO_FILENAME)
val data = RegistrationData.ADAPTER.decode(file.readBytes())
// Key material
if (data.aciIdentityKeyPair.size > 0) {
RegistrationPreferences.aciIdentityKeyPair = IdentityKeyPair(data.aciIdentityKeyPair.toByteArray())
}
if (data.pniIdentityKeyPair.size > 0) {
RegistrationPreferences.pniIdentityKeyPair = IdentityKeyPair(data.pniIdentityKeyPair.toByteArray())
}
if (data.aciRegistrationId != 0) {
RegistrationPreferences.aciRegistrationId = data.aciRegistrationId
}
if (data.pniRegistrationId != 0) {
RegistrationPreferences.pniRegistrationId = data.pniRegistrationId
}
if (data.servicePassword.isNotEmpty()) {
RegistrationPreferences.servicePassword = data.servicePassword
}
if (data.accountEntropyPool.isNotEmpty()) {
RegistrationPreferences.aep = AccountEntropyPool(data.accountEntropyPool)
}
// Pre-keys
if (data.aciSignedPreKey.size > 0) {
db.signedPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_ACI, SignedPreKeyRecord(data.aciSignedPreKey.toByteArray()))
}
if (data.pniSignedPreKey.size > 0) {
db.signedPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_PNI, SignedPreKeyRecord(data.pniSignedPreKey.toByteArray()))
}
if (data.aciLastResortKyberPreKey.size > 0) {
db.kyberPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_ACI, KyberPreKeyRecord(data.aciLastResortKyberPreKey.toByteArray()))
}
if (data.pniLastResortKyberPreKey.size > 0) {
db.kyberPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_PNI, KyberPreKeyRecord(data.pniLastResortKyberPreKey.toByteArray()))
}
// Account identity
if (data.e164.isNotEmpty() && data.aci.isNotEmpty() && data.pni.isNotEmpty() && data.servicePassword.isNotEmpty() && data.accountEntropyPool.isNotEmpty()) {
RegistrationPreferences.saveRegistrationData(
NewRegistrationData(
e164 = data.e164,
aci = ACI.parseOrThrow(data.aci),
pni = PNI.parseOrThrow(data.pni),
servicePassword = data.servicePassword,
aep = AccountEntropyPool(data.accountEntropyPool)
)
)
}
// PIN data
if (data.pin.isNotEmpty()) {
RegistrationPreferences.pin = data.pin
RegistrationPreferences.pinAlphanumeric = data.pinIsAlphanumeric
}
if (data.temporaryMasterKey.size > 0) {
RegistrationPreferences.temporaryMasterKey = MasterKey(data.temporaryMasterKey.toByteArray())
}
RegistrationPreferences.registrationLockEnabled = data.registrationLockEnabled
// SVR credentials
if (data.svrCredentials.isNotEmpty()) {
RegistrationPreferences.restoredSvr2Credentials = data.svrCredentials.map {
NetworkController.SvrCredentials(username = it.username, password = it.password)
}
}
// Provisioning data
data.provisioningData?.let { prov ->
RegistrationPreferences.saveProvisioningData(
NetworkController.ProvisioningMessage(
accountEntropyPool = data.accountEntropyPool,
e164 = data.e164,
pin = data.pin.ifEmpty { null },
aciIdentityKeyPair = IdentityKeyPair(data.aciIdentityKeyPair.toByteArray()),
pniIdentityKeyPair = IdentityKeyPair(data.pniIdentityKeyPair.toByteArray()),
platform = when (prov.platform) {
ProvisioningData.Platform.ANDROID -> NetworkController.ProvisioningMessage.Platform.ANDROID
ProvisioningData.Platform.IOS -> NetworkController.ProvisioningMessage.Platform.IOS
else -> NetworkController.ProvisioningMessage.Platform.ANDROID
},
tier = when (prov.tier) {
ProvisioningData.Tier.FREE -> NetworkController.ProvisioningMessage.Tier.FREE
ProvisioningData.Tier.PAID -> NetworkController.ProvisioningMessage.Tier.PAID
else -> null
},
backupTimestampMs = prov.backupTimestampMs,
backupSizeBytes = prov.backupSizeBytes,
restoreMethodToken = prov.restoreMethodToken,
backupVersion = prov.backupVersion
)
)
}
Unit
}
override suspend fun saveNewlyCreatedPin(pin: String, isAlphanumeric: Boolean) {
RegistrationPreferences.pin = pin
RegistrationPreferences.pinAlphanumeric = isAlphanumeric
}
override suspend fun saveProvisioningData(provisioningMessage: NetworkController.ProvisioningMessage) = withContext(Dispatchers.IO) {
RegistrationPreferences.saveProvisioningData(provisioningMessage)
}
private fun storeKeyMaterial(keyMaterial: KeyMaterial, profileKey: ProfileKey) {
// Clear existing data
RegistrationPreferences.clearKeyMaterial()
db.clearAllPreKeys()
// Store in SharedPreferences
RegistrationPreferences.aciIdentityKeyPair = keyMaterial.aciIdentityKeyPair
RegistrationPreferences.pniIdentityKeyPair = keyMaterial.pniIdentityKeyPair
RegistrationPreferences.aciRegistrationId = keyMaterial.aciRegistrationId
RegistrationPreferences.pniRegistrationId = keyMaterial.pniRegistrationId
RegistrationPreferences.profileKey = profileKey
// Store prekeys in database
db.signedPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_ACI, keyMaterial.aciSignedPreKey)
db.signedPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_PNI, keyMaterial.pniSignedPreKey)
db.kyberPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_ACI, keyMaterial.aciLastResortKyberPreKey)
db.kyberPreKeys.insert(RegistrationDatabase.ACCOUNT_TYPE_PNI, keyMaterial.pniLastResortKyberPreKey)
}
private fun generateSignedPreKey(id: Int, timestamp: Long, identityKeyPair: IdentityKeyPair): SignedPreKeyRecord {
val keyPair = ECKeyPair.generate()
val signature = identityKeyPair.privateKey.calculateSignature(keyPair.publicKey.serialize())
return SignedPreKeyRecord(id, timestamp, keyPair, signature)
}
private fun generateKyberPreKey(id: Int, timestamp: Long, identityKeyPair: IdentityKeyPair): KyberPreKeyRecord {
val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
val signature = identityKeyPair.privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
return KyberPreKeyRecord(id, timestamp, kemKeyPair, signature)
}
private fun generatePreKeyId(): Int {
return SecureRandom().nextInt(Int.MAX_VALUE - 1) + 1
}
private fun generateRegistrationId(): Int {
return SecureRandom().nextInt(16380) + 1
}
private fun generateProfileKey(): ProfileKey {
val keyBytes = ByteArray(32)
SecureRandom().nextBytes(keyBytes)
return ProfileKey(keyBytes)
}
/**
* Generates a password for basic auth during registration.
* 18 random bytes, base64 encoded with padding.
*/
private fun generatePassword(): String {
val passwordBytes = ByteArray(18)
SecureRandom().nextBytes(passwordBytes)
return Base64.encodeWithPadding(passwordBytes)
}
/**
* Derives the unidentified access key from a profile key.
* This mirrors the logic in UnidentifiedAccess.deriveAccessKeyFrom().
*/
private fun deriveUnidentifiedAccessKey(profileKey: ProfileKey): ByteArray {
val nonce = ByteArray(12)
val input = ByteArray(16)
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
cipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(profileKey.serialize(), "AES"), GCMParameterSpec(128, nonce))
val ciphertext = cipher.doFinal(input)
return ciphertext.copyOf(16)
private suspend fun writeRegistrationData(data: RegistrationData) = withContext(Dispatchers.IO) {
val file = File(context.filesDir, TEMP_PROTO_FILENAME)
file.writeBytes(RegistrationData.ADAPTER.encode(data))
}
}