From 5c418a4260389cef7302cef769d1c7e5660b626a Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Wed, 11 Feb 2026 20:38:35 -0500 Subject: [PATCH] Add RRP support to regV5. --- .../signal/core/util/logging/NoopLogger.kt | 2 +- .../registration/src/main/AndroidManifest.xml | 2 + .../registration/sample/MainActivity.kt | 4 + .../sample/RegistrationApplication.kt | 41 +- .../absbackup/RegistrationBackupAgent.kt | 111 ++++ .../sample/debug/DebugNetworkController.kt | 27 +- ...Controller.kt => DemoNetworkController.kt} | 85 ++- ...Controller.kt => DemoStorageController.kt} | 30 +- .../pinsettings/PinSettingsViewModel.kt | 2 +- .../sample/storage/RegistrationPreferences.kt | 30 +- .../util/PreviewRegistrationDependencies.kt | 145 ----- .../signal/registration/NetworkController.kt | 64 +- .../registration/RegistrationActivity.kt | 1 + .../registration/RegistrationDependencies.kt | 10 +- .../registration/RegistrationFlowEvent.kt | 19 +- .../registration/RegistrationFlowState.kt | 17 +- .../registration/RegistrationNavigation.kt | 22 + .../registration/RegistrationRepository.kt | 70 ++- .../registration/RegistrationViewModel.kt | 4 +- .../signal/registration/StorageController.kt | 30 +- .../phonenumber/PhoneNumberEntryState.kt | 4 +- .../phonenumber/PhoneNumberEntryViewModel.kt | 58 +- .../PinEntryForRegistrationLockViewModel.kt | 2 +- .../pinentry/PinEntryForSmsBypassViewModel.kt | 226 +++++++ .../PinEntryForSvrRestoreViewModel.kt | 8 +- .../screens/pinentry/PinEntryScreen.kt | 3 +- .../screens/pinentry/PinEntryState.kt | 4 +- .../VerificationCodeViewModel.kt | 4 +- .../signal/registration/util/SensitiveLog.kt | 45 ++ .../PhoneNumberEntryViewModelTest.kt | 557 ++++++++++++++++++ ...inEntryForRegistrationLockViewModelTest.kt | 18 +- .../PinEntryForSmsBypassViewModelTest.kt | 397 +++++++++++++ .../PinEntryForSvrRestoreViewModelTest.kt | 11 +- 33 files changed, 1821 insertions(+), 232 deletions(-) create mode 100644 demo/registration/src/main/java/org/signal/registration/sample/absbackup/RegistrationBackupAgent.kt rename demo/registration/src/main/java/org/signal/registration/sample/dependencies/{RealNetworkController.kt => DemoNetworkController.kt} (89%) rename demo/registration/src/main/java/org/signal/registration/sample/dependencies/{RealStorageController.kt => DemoStorageController.kt} (84%) delete mode 100644 demo/registration/src/main/java/org/signal/registration/sample/util/PreviewRegistrationDependencies.kt create mode 100644 feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt create mode 100644 feature/registration/src/main/java/org/signal/registration/util/SensitiveLog.kt create mode 100644 feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModelTest.kt diff --git a/core/util-jvm/src/main/java/org/signal/core/util/logging/NoopLogger.kt b/core/util-jvm/src/main/java/org/signal/core/util/logging/NoopLogger.kt index bd7abd950b..a21422d3d7 100644 --- a/core/util-jvm/src/main/java/org/signal/core/util/logging/NoopLogger.kt +++ b/core/util-jvm/src/main/java/org/signal/core/util/logging/NoopLogger.kt @@ -8,7 +8,7 @@ package org.signal.core.util.logging /** * A logger that does nothing. */ -internal class NoopLogger : Log.Logger() { +class NoopLogger : Log.Logger() { override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit diff --git a/demo/registration/src/main/AndroidManifest.xml b/demo/registration/src/main/AndroidManifest.xml index 341fa50078..45a07e2054 100644 --- a/demo/registration/src/main/AndroidManifest.xml +++ b/demo/registration/src/main/AndroidManifest.xml @@ -5,7 +5,9 @@ , modifier: Modifier = Modifier ) { + val context = LocalContext.current + val registrationRepository = remember { RegistrationRepository( + context = context.applicationContext, networkController = registrationDependencies.networkController, storageController = registrationDependencies.storageController ) diff --git a/demo/registration/src/main/java/org/signal/registration/sample/RegistrationApplication.kt b/demo/registration/src/main/java/org/signal/registration/sample/RegistrationApplication.kt index d2a2b93892..33d8a3ea5c 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/RegistrationApplication.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/RegistrationApplication.kt @@ -15,8 +15,8 @@ import org.signal.core.util.logging.AndroidLogger import org.signal.core.util.logging.Log import org.signal.registration.RegistrationDependencies import org.signal.registration.sample.debug.DebugNetworkController -import org.signal.registration.sample.dependencies.RealNetworkController -import org.signal.registration.sample.dependencies.RealStorageController +import org.signal.registration.sample.dependencies.DemoNetworkController +import org.signal.registration.sample.dependencies.DemoStorageController import org.signal.registration.sample.storage.RegistrationPreferences import org.whispersystems.signalservice.api.push.TrustStore import org.whispersystems.signalservice.api.util.CredentialsProvider @@ -34,7 +34,7 @@ class RegistrationApplication : Application() { companion object { // Staging SVR2 mrEnclave value - private const val SVR2_MRENCLAVE = "a75542d82da9f6914a1e31f8a7407053b99cc99a0e7291d8fbd394253e19b036" + private const val SVR2_MRENCLAVE = "97f151f6ed078edbbfd72fa9cae694dcc08353f1f5e8d9ccd79a971b10ffc535" } override fun onCreate() { @@ -47,14 +47,15 @@ class RegistrationApplication : Application() { val trustStore = SampleTrustStore() val configuration = createServiceConfiguration(trustStore) val pushServiceSocket = createPushServiceSocket(configuration) - val realNetworkController = RealNetworkController(this, pushServiceSocket, configuration, SVR2_MRENCLAVE) - val networkController = DebugNetworkController(realNetworkController) - val storageController = RealStorageController(this) + val demoNetworkController = DemoNetworkController(this, pushServiceSocket, configuration, SVR2_MRENCLAVE) + val networkController = DebugNetworkController(demoNetworkController) + val storageController = DemoStorageController(this) RegistrationDependencies.provide( RegistrationDependencies( networkController = networkController, - storageController = storageController + storageController = storageController, + sensitiveLogger = LogLogger ) ) @@ -117,4 +118,30 @@ class RegistrationApplication : Application() { override fun getDeviceId(): Int = 1 override fun getPassword(): String? = null } + + private object LogLogger : Log.Logger() { + override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + Log.v(tag, message, t, keepLonger) + } + + override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + Log.d(tag, message, t, keepLonger) + } + + override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + Log.i(tag, message, t, keepLonger) + } + + override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + Log.w(tag, message, t, keepLonger) + } + + override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + Log.e(tag, message, t, keepLonger) + } + + override fun flush() { + Log.blockUntilAllWritesFinished() + } + } } diff --git a/demo/registration/src/main/java/org/signal/registration/sample/absbackup/RegistrationBackupAgent.kt b/demo/registration/src/main/java/org/signal/registration/sample/absbackup/RegistrationBackupAgent.kt new file mode 100644 index 0000000000..7efd616188 --- /dev/null +++ b/demo/registration/src/main/java/org/signal/registration/sample/absbackup/RegistrationBackupAgent.kt @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration.sample.absbackup + +import android.app.backup.BackupAgent +import android.app.backup.BackupDataInput +import android.app.backup.BackupDataOutput +import android.app.backup.FullBackupDataOutput +import android.os.ParcelFileDescriptor +import org.signal.core.util.logging.Log +import org.signal.registration.NetworkController +import org.signal.registration.sample.storage.RegistrationPreferences +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.io.DataInputStream +import java.io.DataOutputStream +import java.io.File +import java.io.FileInputStream +import java.io.FileOutputStream +import java.io.IOException + +/** + * Uses the [Android Backup Service](https://developer.android.com/guide/topics/data/keyvaluebackup) to back up SVR2 credentials. + * These credentials can be combined with a PIN to prove ownership of a phone number in order to complete the registration process. + */ +class RegistrationBackupAgent : BackupAgent() { + + override fun onBackup(oldState: ParcelFileDescriptor?, data: BackupDataOutput, newState: ParcelFileDescriptor) { + Log.i(TAG, "Performing backup to Android Backup Service.") + val contentsHash = cumulativeHashCode() + if (oldState == null || !hashMatches(oldState, contentsHash)) { + val backupData = getDataForBackup() + data.writeEntityHeader(BACKUP_KEY, backupData.size) + data.writeEntityData(backupData, backupData.size) + } + + DataOutputStream(FileOutputStream(newState.fileDescriptor)).use { it.writeInt(contentsHash) } + Log.i(TAG, "Backup finished.") + } + + override fun onRestore(dataInput: BackupDataInput, appVersionCode: Int, newState: ParcelFileDescriptor) { + Log.i(TAG, "Restoring from Android Backup Service.") + while (dataInput.readNextHeader()) { + if (dataInput.key == BACKUP_KEY) { + val buffer = ByteArray(dataInput.dataSize) + dataInput.readEntityData(buffer, 0, dataInput.dataSize) + restoreData(buffer) + } + } + DataOutputStream(FileOutputStream(newState.fileDescriptor)).use { it.writeInt(cumulativeHashCode()) } + Log.i(TAG, "Android Backup Service restore complete.") + } + + private fun cumulativeHashCode(): Int { + return getDataForBackup().decodeToString().hashCode() + } + + private fun hashMatches(oldState: ParcelFileDescriptor, expected: Int): Boolean { + return try { + val hash = DataInputStream(FileInputStream(oldState.fileDescriptor)).use { it.readInt() } + hash == expected + } catch (e: IOException) { + false + } + } + + private fun getDataForBackup(): ByteArray { + val credentials = RegistrationPreferences.restoredSvr2Credentials + val byteArrayOutputStream = ByteArrayOutputStream() + DataOutputStream(byteArrayOutputStream).use { output -> + output.writeInt(credentials.size) + credentials.forEach { credential -> + output.writeUTF(credential.username) + output.writeUTF(credential.password) + } + } + return byteArrayOutputStream.toByteArray() + } + + private fun restoreData(data: ByteArray) { + // Only restore if we don't already have credentials + if (RegistrationPreferences.restoredSvr2Credentials.isNotEmpty()) { + return + } + + try { + val byteArrayInputStream = ByteArrayInputStream(data) + val credentials = mutableListOf() + DataInputStream(byteArrayInputStream).use { input -> + val count = input.readInt() + repeat(count) { + val username = input.readUTF() + val password = input.readUTF() + credentials.add(NetworkController.SvrCredentials(username = username, password = password)) + } + } + RegistrationPreferences.restoredSvr2Credentials = credentials + Log.i(TAG, "Successfully restored ${credentials.size} SVR2 credentials from backup service.") + } catch (e: IOException) { + Log.w(TAG, "Cannot restore SVR2 credentials from backup service.", e) + } + } + + companion object { + private val TAG = Log.tag(RegistrationBackupAgent::class) + private const val BACKUP_KEY = "Svr2Credentials" + } +} \ No newline at end of file diff --git a/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt b/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt index a33936c975..d61b30327c 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt @@ -27,6 +27,8 @@ import org.signal.registration.NetworkController.SubmitVerificationCodeError import org.signal.registration.NetworkController.SvrCredentials import org.signal.registration.NetworkController.UpdateSessionError import org.signal.registration.NetworkController.VerificationCodeTransport +import org.signal.registration.NetworkController.CheckSvrCredentialsError +import org.signal.registration.NetworkController.CheckSvrCredentialsResponse import java.util.Locale /** @@ -137,27 +139,32 @@ class DebugNetworkController( } override suspend fun restoreMasterKeyFromSvr( - svr2Credentials: SvrCredentials, + svrCredentials: SvrCredentials, pin: String ): RegistrationNetworkResult { NetworkDebugState.getOverride>("restoreMasterKeyFromSvr")?.let { Log.d(TAG, "[restoreMasterKeyFromSvr] Returning debug override") return it } - return delegate.restoreMasterKeyFromSvr(svr2Credentials, pin) + return delegate.restoreMasterKeyFromSvr(svrCredentials, pin) } override suspend fun setPinAndMasterKeyOnSvr( pin: String, masterKey: MasterKey - ): RegistrationNetworkResult { - NetworkDebugState.getOverride>("setPinAndMasterKeyOnSvr")?.let { + ): RegistrationNetworkResult { + NetworkDebugState.getOverride>("setPinAndMasterKeyOnSvr")?.let { Log.d(TAG, "[setPinAndMasterKeyOnSvr] Returning debug override") return it } return delegate.setPinAndMasterKeyOnSvr(pin, masterKey) } + override suspend fun enqueueSvrGuessResetJob() { + // No override support for simple value methods + delegate.enqueueSvrGuessResetJob() + } + override suspend fun enableRegistrationLock(): RegistrationNetworkResult { NetworkDebugState.getOverride>("enableRegistrationLock")?.let { Log.d(TAG, "[enableRegistrationLock] Returning debug override") @@ -189,4 +196,16 @@ class DebugNetworkController( } return delegate.getSvrCredentials() } + + override suspend fun checkSvrCredentials( + e164: String, + credentials: List + ): RegistrationNetworkResult { + NetworkDebugState.getOverride>("checkSvrCredentials")?.let { + Log.d(TAG, "[checkSvrCredentials] Returning debug override") + return it + } + + return delegate.checkSvrCredentials(e164, credentials) + } } diff --git a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealNetworkController.kt b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt similarity index 89% rename from demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealNetworkController.kt rename to demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt index cf80256169..61d2674b15 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealNetworkController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt @@ -14,6 +14,7 @@ import okhttp3.Response import org.signal.core.models.MasterKey import org.signal.core.util.logging.Log import org.signal.libsignal.net.Network +import org.signal.libsignal.protocol.util.Hex import org.signal.registration.NetworkController import org.signal.registration.NetworkController.AccountAttributes import org.signal.registration.NetworkController.CreateSessionError @@ -23,8 +24,9 @@ import org.signal.registration.NetworkController.RegisterAccountError import org.signal.registration.NetworkController.RegisterAccountResponse import org.signal.registration.NetworkController.RegistrationLockResponse import org.signal.registration.NetworkController.RegistrationNetworkResult -import org.signal.registration.NetworkController.RegistrationNetworkResult.* import org.signal.registration.NetworkController.RequestVerificationCodeError +import org.signal.registration.NetworkController.CheckSvrCredentialsRequest +import org.signal.registration.NetworkController.CheckSvrCredentialsResponse import org.signal.registration.NetworkController.SessionMetadata import org.signal.registration.NetworkController.SubmitVerificationCodeError import org.signal.registration.NetworkController.ThirdPartyServiceErrorResponse @@ -52,7 +54,7 @@ import kotlin.time.Duration.Companion.seconds import org.whispersystems.signalservice.api.account.AccountAttributes as ServiceAccountAttributes import org.whispersystems.signalservice.api.account.PreKeyCollection as ServicePreKeyCollection -class RealNetworkController( +class DemoNetworkController( private val context: android.content.Context, private val pushServiceSocket: PushServiceSocket, private val serviceConfiguration: SignalServiceConfiguration, @@ -60,7 +62,7 @@ class RealNetworkController( ) : NetworkController { companion object { - private val TAG = Log.tag(RealNetworkController::class) + private val TAG = Log.tag(DemoNetworkController::class) } private val json = Json { ignoreUnknownKeys = true } @@ -367,11 +369,11 @@ class RealNetworkController( } override suspend fun restoreMasterKeyFromSvr( - svr2Credentials: NetworkController.SvrCredentials, + svrCredentials: NetworkController.SvrCredentials, pin: String ): RegistrationNetworkResult = withContext(Dispatchers.IO) { try { - val authCredentials = AuthCredentials.create(svr2Credentials.username, svr2Credentials.password) + val authCredentials = AuthCredentials.create(svrCredentials.username, svrCredentials.password) // Create a stub websocket that will never be used for pre-registration restore val stubWebSocketFactory = WebSocketFactory { throw UnsupportedOperationException("WebSocket not available during pre-registration") } @@ -388,7 +390,7 @@ class RealNetworkController( when (val response = svr2.restoreDataPreRegistration(authCredentials, null, pin)) { is RestoreResponse.Success -> { - Log.i(TAG, "[restoreMasterKeyFromSvr] Successfully restored master key from SVR2") + Log.i(TAG, "[restoreMasterKeyFromSvr] Successfully restored master key from SVR2. Value: ${Hex.toStringCondensed(response.masterKey.serialize())}") RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(response.masterKey)) } is RestoreResponse.PinMismatch -> { @@ -424,7 +426,7 @@ class RealNetworkController( override suspend fun setPinAndMasterKeyOnSvr( pin: String, masterKey: MasterKey - ): RegistrationNetworkResult = withContext(Dispatchers.IO) { + ): RegistrationNetworkResult = withContext(Dispatchers.IO) { try { val aci = RegistrationPreferences.aci val pni = RegistrationPreferences.pni @@ -468,31 +470,31 @@ class RealNetworkController( when (response) { is BackupResponse.Success -> { - Log.i(TAG, "[backupMasterKeyToSvr] Successfully backed up master key to SVR2") - Success(Unit) + Log.i(TAG, "[backupMasterKeyToSvr] Successfully backed up master key to SVR2. Value: ${Hex.toStringCondensed(masterKey.serialize())}") + RegistrationNetworkResult.Success(NetworkController.SvrCredentials(response.authorization.username(), response.authorization.password())) } is BackupResponse.ApplicationError -> { Log.w(TAG, "[backupMasterKeyToSvr] Application error", response.exception) - ApplicationError(response.exception) + RegistrationNetworkResult.ApplicationError(response.exception) } is BackupResponse.NetworkError -> { Log.w(TAG, "[backupMasterKeyToSvr] Network error", response.exception) - NetworkError(response.exception) + RegistrationNetworkResult.NetworkError(response.exception) } is BackupResponse.EnclaveNotFound -> { Log.w(TAG, "[backupMasterKeyToSvr] Enclave not found") - Failure(NetworkController.BackupMasterKeyError.EnclaveNotFound) + RegistrationNetworkResult.Failure(NetworkController.BackupMasterKeyError.EnclaveNotFound) } is BackupResponse.ExposeFailure -> { Log.w(TAG, "[backupMasterKeyToSvr] Expose failure -- per spec, treat as success.") - Success(Unit) + RegistrationNetworkResult.Success(null) } is BackupResponse.ServerRejected -> { Log.w(TAG, "[backupMasterKeyToSvr] Server rejected") - NetworkError(IOException("Server rejected backup request")) + RegistrationNetworkResult.NetworkError(IOException("Server rejected backup request")) } is BackupResponse.RateLimited -> { - NetworkError(IOException("Rate limited")) + RegistrationNetworkResult.NetworkError(IOException("Rate limited")) } } } catch (e: IOException) { @@ -504,6 +506,16 @@ class RealNetworkController( } } + override suspend fun enqueueSvrGuessResetJob() { + val pin = checkNotNull(RegistrationPreferences.pin) { "Pin is not set!" } + val masterKey = checkNotNull(RegistrationPreferences.masterKey) { "Master key is not set!" } + + val result = setPinAndMasterKeyOnSvr(pin, masterKey) + if (result !is RegistrationNetworkResult.Success) { + Log.w(TAG, "Failed to set pin and master key on SVR! A real app would retry. Result: $result") + } + } + override suspend fun enableRegistrationLock(): RegistrationNetworkResult = withContext(Dispatchers.IO) { val aci = RegistrationPreferences.aci val password = RegistrationPreferences.servicePassword @@ -692,6 +704,49 @@ class RealNetworkController( } } + override suspend fun checkSvrCredentials( + e164: String, + credentials: List + ): RegistrationNetworkResult = withContext(Dispatchers.IO) { + try { + val baseUrl = serviceConfiguration.signalServiceUrls[0].url + + val requestBody = json.encodeToString( + CheckSvrCredentialsRequest.serializer(), + CheckSvrCredentialsRequest.createForCredentials(number = e164, credentials) + ).toRequestBody("application/json".toMediaType()) + + val request = okhttp3.Request.Builder() + .url("$baseUrl/v2/svr/auth/check") + .post(requestBody) + .build() + + okHttpClient.newCall(request).execute().use { response -> + when (response.code) { + 200 -> { + val result = json.decodeFromString(response.body.string()) + RegistrationNetworkResult.Success(result) + } + 400, 422 -> { + RegistrationNetworkResult.Failure(NetworkController.CheckSvrCredentialsError.InvalidRequest(response.body.string())) + } + 401 -> { + RegistrationNetworkResult.Failure(NetworkController.CheckSvrCredentialsError.Unauthorized) + } + else -> { + RegistrationNetworkResult.ApplicationError(IllegalStateException("Unexpected response code: ${response.code}, body: ${response.body?.string()}")) + } + } + } + } catch (e: IOException) { + Log.w(TAG, "[checkSvrCredentials] IOException", e) + RegistrationNetworkResult.NetworkError(e) + } catch (e: Exception) { + Log.w(TAG, "[checkSvrCredentials] Exception", e) + RegistrationNetworkResult.ApplicationError(e) + } + } + private fun AccountAttributes.toServiceAccountAttributes(): ServiceAccountAttributes { return ServiceAccountAttributes( signalingKey, diff --git a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealStorageController.kt b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt similarity index 84% rename from demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealStorageController.kt rename to demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt index 429e2a7fed..aa3434ba84 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/RealStorageController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt @@ -19,6 +19,7 @@ import org.signal.libsignal.protocol.state.KyberPreKeyRecord import org.signal.libsignal.protocol.state.SignedPreKeyRecord import org.signal.libsignal.zkgroup.profiles.ProfileKey import org.signal.registration.KeyMaterial +import org.signal.registration.NetworkController import org.signal.registration.NewRegistrationData import org.signal.registration.PreExistingRegistrationData import org.signal.registration.StorageController @@ -33,14 +34,22 @@ import javax.crypto.spec.SecretKeySpec * Implementation of [StorageController] that persists registration data using * SharedPreferences for simple key-value data and SQLite for prekeys. */ -class RealStorageController(context: Context) : StorageController { +class DemoStorageController(context: Context) : StorageController { + + companion object { + private const val MAX_SVR_CREDENTIALS = 10 + } private val db = RegistrationDatabase(context) - override suspend fun generateAndStoreKeyMaterial(): KeyMaterial = withContext(Dispatchers.IO) { - val accountEntropyPool = AccountEntropyPool.generate() - val aciIdentityKeyPair = IdentityKeyPair.generate() - val pniIdentityKeyPair = IdentityKeyPair.generate() + override suspend fun generateAndStoreKeyMaterial( + existingAccountEntropyPool: AccountEntropyPool?, + existingAciIdentityKeyPair: IdentityKeyPair?, + existingPniIdentityKeyPair: IdentityKeyPair? + ): KeyMaterial = withContext(Dispatchers.IO) { + val accountEntropyPool = existingAccountEntropyPool ?: AccountEntropyPool.generate() + val aciIdentityKeyPair = existingAciIdentityKeyPair ?: IdentityKeyPair.generate() + val pniIdentityKeyPair = existingPniIdentityKeyPair ?: IdentityKeyPair.generate() val aciSignedPreKeyId = generatePreKeyId() val pniSignedPreKeyId = generatePreKeyId() @@ -89,6 +98,7 @@ class RealStorageController(context: Context) : StorageController { override suspend fun clearAllData() = withContext(Dispatchers.IO) { RegistrationPreferences.clearAll() + RegistrationPreferences.clearRestoredSvr2Credentials() db.clearAllPreKeys() } @@ -99,6 +109,16 @@ class RealStorageController(context: Context) : StorageController { RegistrationPreferences.registrationLockEnabled = registrationLockEnabled } + override suspend fun getRestoredSvrCredentials(): List = withContext(Dispatchers.IO) { + RegistrationPreferences.restoredSvr2Credentials + } + + override suspend fun appendSvrCredentials(credentials: List) = withContext(Dispatchers.IO) { + val existing = RegistrationPreferences.restoredSvr2Credentials + val combined = (existing + credentials).distinctBy { it.username }.takeLast(MAX_SVR_CREDENTIALS) + RegistrationPreferences.restoredSvr2Credentials = combined + } + override suspend fun saveNewlyCreatedPin(pin: String, isAlphanumeric: Boolean) { RegistrationPreferences.pin = pin RegistrationPreferences.pinAlphanumeric = isAlphanumeric diff --git a/demo/registration/src/main/java/org/signal/registration/sample/screens/pinsettings/PinSettingsViewModel.kt b/demo/registration/src/main/java/org/signal/registration/sample/screens/pinsettings/PinSettingsViewModel.kt index b4c3b4bfbb..b4d933db08 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/screens/pinsettings/PinSettingsViewModel.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/screens/pinsettings/PinSettingsViewModel.kt @@ -178,7 +178,7 @@ class PinSettingsViewModel( ), name = null, pniRegistrationId = RegistrationPreferences.pniRegistrationId, - recoveryPassword = RegistrationPreferences.masterKey?.deriveRegistrationRecoveryPassword() + recoveryPassword = null ) when (val result = networkController.setAccountAttributes(attributes)) { diff --git a/demo/registration/src/main/java/org/signal/registration/sample/storage/RegistrationPreferences.kt b/demo/registration/src/main/java/org/signal/registration/sample/storage/RegistrationPreferences.kt index 5e621a4fee..f831ccc8f5 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/storage/RegistrationPreferences.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/storage/RegistrationPreferences.kt @@ -16,6 +16,7 @@ import org.signal.core.models.ServiceId.PNI import org.signal.core.util.Base64 import org.signal.libsignal.protocol.IdentityKeyPair import org.signal.libsignal.zkgroup.profiles.ProfileKey +import org.signal.registration.NetworkController import org.signal.registration.NewRegistrationData import org.signal.registration.PreExistingRegistrationData @@ -44,6 +45,7 @@ object RegistrationPreferences { private const val KEY_PIN = "has_pin" private const val KEY_PIN_ALPHANUMERIC = "pin_alphanumeric" private const val KEY_PINS_OPTED_OUT = "pins_opted_out" + private const val KEY_SVR2_CREDENTIALS = "svr2_credentials" fun init(context: Application) { this.context = context @@ -119,6 +121,23 @@ object RegistrationPreferences { get() = prefs.getBoolean(KEY_PINS_OPTED_OUT, false) set(value) = prefs.edit { putBoolean(KEY_PINS_OPTED_OUT, value) } + var restoredSvr2Credentials: List + get() = prefs.getStringSet(KEY_SVR2_CREDENTIALS, emptySet())?.mapNotNull { parseCredential(it) } ?: emptyList() + set(value) = prefs.edit { putStringSet(KEY_SVR2_CREDENTIALS, value.map { serializeCredential(it) }.toSet()) } + + private fun parseCredential(serialized: String): NetworkController.SvrCredentials? { + val parts = serialized.split(":", limit = 2) + return if (parts.size == 2) { + NetworkController.SvrCredentials(username = parts[0], password = parts[1]) + } else { + null + } + } + + private fun serializeCredential(credential: NetworkController.SvrCredentials): String { + return "${credential.username}:${credential.password}" + } + fun saveRegistrationData(data: NewRegistrationData) { prefs.edit { putString(KEY_E164, data.e164) @@ -135,13 +154,18 @@ object RegistrationPreferences { val pni = pni ?: return null val servicePassword = servicePassword ?: return null val aep = aep ?: return null + val aciIdentityKeyPair = aciIdentityKeyPair ?: return null + val pniIdentityKeyPair = pniIdentityKeyPair ?: return null return PreExistingRegistrationData( e164 = e164, aci = aci, pni = pni, servicePassword = servicePassword, - aep = aep + aep = aep, + registrationLockEnabled = registrationLockEnabled, + aciIdentityKeyPair = aciIdentityKeyPair, + pniIdentityKeyPair = pniIdentityKeyPair ) } @@ -158,4 +182,8 @@ object RegistrationPreferences { fun clearAll() { prefs.edit { clear() } } + + fun clearRestoredSvr2Credentials() { + prefs.edit { remove(KEY_SVR2_CREDENTIALS) } + } } diff --git a/demo/registration/src/main/java/org/signal/registration/sample/util/PreviewRegistrationDependencies.kt b/demo/registration/src/main/java/org/signal/registration/sample/util/PreviewRegistrationDependencies.kt deleted file mode 100644 index e723e2a91d..0000000000 --- a/demo/registration/src/main/java/org/signal/registration/sample/util/PreviewRegistrationDependencies.kt +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.signal.registration.sample.util - -import org.signal.core.models.MasterKey -import org.signal.registration.KeyMaterial -import org.signal.registration.NetworkController -import org.signal.registration.NewRegistrationData -import org.signal.registration.PreExistingRegistrationData -import org.signal.registration.RegistrationDependencies -import org.signal.registration.StorageController -import java.util.Locale - -object PreviewRegistrationDependencies { - fun get(): RegistrationDependencies { - return RegistrationDependencies( - networkController = PreviewNewtorkController(), - storageController = PreviewStorageController() - ) - } -} - -private class PreviewNewtorkController : NetworkController { - override suspend fun createSession( - e164: String, - fcmToken: String?, - mcc: String?, - mnc: String? - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun getSession(sessionId: String): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun updateSession( - sessionId: String?, - pushChallengeToken: String?, - captchaToken: String? - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun requestVerificationCode( - sessionId: String, - locale: Locale?, - androidSmsRetrieverSupported: Boolean, - transport: NetworkController.VerificationCodeTransport - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun submitVerificationCode( - sessionId: String, - verificationCode: String - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun registerAccount( - e164: String, - password: String, - sessionId: String?, - recoveryPassword: String?, - attributes: NetworkController.AccountAttributes, - aciPreKeys: NetworkController.PreKeyCollection, - pniPreKeys: NetworkController.PreKeyCollection, - fcmToken: String?, - skipDeviceTransfer: Boolean - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun getFcmToken(): String? { - throw NotImplementedError() - } - - override suspend fun awaitPushChallengeToken(): String? { - throw NotImplementedError() - } - - override fun getCaptchaUrl(): String { - throw NotImplementedError() - } - - override suspend fun restoreMasterKeyFromSvr( - svr2Credentials: NetworkController.SvrCredentials, - pin: String - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun setPinAndMasterKeyOnSvr( - pin: String, - masterKey: MasterKey - ): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun enableRegistrationLock(): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun disableRegistrationLock(): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun getSvrCredentials(): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } - - override suspend fun setAccountAttributes(attributes: NetworkController.AccountAttributes): NetworkController.RegistrationNetworkResult { - throw NotImplementedError() - } -} - -private class PreviewStorageController : StorageController { - override suspend fun generateAndStoreKeyMaterial(): KeyMaterial { - throw NotImplementedError() - } - - override suspend fun saveNewRegistrationData(newRegistrationData: NewRegistrationData) { - throw NotImplementedError() - } - - override suspend fun getPreExistingRegistrationData(): PreExistingRegistrationData? { - throw NotImplementedError() - } - - override suspend fun saveValidatedPinAndTemporaryMasterKey(pin: String, isAlphanumeric: Boolean, masterKey: MasterKey, registrationLockEnabled: Boolean) { - throw NotImplementedError() - } - - override suspend fun saveNewlyCreatedPin(pin: String, isAlphanumeric: Boolean) { - throw NotImplementedError() - } - - override suspend fun clearAllData() { - throw NotImplementedError() - } -} diff --git a/feature/registration/src/main/java/org/signal/registration/NetworkController.kt b/feature/registration/src/main/java/org/signal/registration/NetworkController.kt index 4fbe1c55df..8c3db71f19 100644 --- a/feature/registration/src/main/java/org/signal/registration/NetworkController.kt +++ b/feature/registration/src/main/java/org/signal/registration/NetworkController.kt @@ -108,12 +108,12 @@ interface NetworkController { * This is called when the user encounters a registration lock and needs to prove * they know their PIN to proceed with registration. * - * @param svr2Credentials The SVR2 credentials provided by the server during the registration lock response. + * @param svrCredentials The SVR2 credentials provided by the server during the registration lock response. * @param pin The user-entered PIN. * @return The restored master key on success, or an appropriate error. */ suspend fun restoreMasterKeyFromSvr( - svr2Credentials: SvrCredentials, + svrCredentials: SvrCredentials, pin: String ): RegistrationNetworkResult @@ -127,7 +127,14 @@ interface NetworkController { suspend fun setPinAndMasterKeyOnSvr( pin: String, masterKey: MasterKey - ): RegistrationNetworkResult + ): RegistrationNetworkResult + + /** + * Requests that the currently-set PIN and [MasterKey] are backed up to SVR. + * It should always be the case that when this is called, you should have a stored PIN and [MasterKey]. + * If you do not, you should probably crash. + */ + suspend fun enqueueSvrGuessResetJob() /** * Enables registration lock on the account using the registration lock token @@ -153,6 +160,15 @@ interface NetworkController { */ suspend fun getSvrCredentials(): RegistrationNetworkResult + /** + * Checks if the SVR2 credentials are valid for the given phone number. + * + * `POST /v2/svr/auth/check` + * + * @return A response containing a mapping of which credentials are matches. + */ + suspend fun checkSvrCredentials(e164: String, credentials: List): RegistrationNetworkResult + /** * Updates account attributes on the server. * @@ -282,6 +298,11 @@ interface NetworkController { data object NoServiceCredentialsAvailable : GetSvrCredentialsError() } + sealed class CheckSvrCredentialsError() { + data object Unauthorized : CheckSvrCredentialsError() + data class InvalidRequest(val message: String) : CheckSvrCredentialsError() + } + data class MasterKeyResponse( val masterKey: MasterKey ) @@ -373,6 +394,43 @@ interface NetworkController { val password: String ) : Parcelable + @Serializable + data class CheckSvrCredentialsResponse( + val matches: Map + ) { + /** + * The first valid credential, if any. + * + * The response is structured like this: + * { + * matches: { + * : "match|no-match|invalid" + * } + * } + * + * So we find the first map entry with "match". The token is "username:password", so we split it apart. + * Important: The password can have ":" in it, so we need to make sure to just split on the first ":". + */ + val validCredential: SvrCredentials? by lazy { + matches.entries.firstOrNull { it.value == "match" }?.key?.split(":", limit = 2)?.let { SvrCredentials(it[0], it[1]) } + } + } + + @Serializable + data class CheckSvrCredentialsRequest( + val number: String, + val tokens: List + ) { + companion object { + fun createForCredentials(number: String, credentials: List): CheckSvrCredentialsRequest { + return CheckSvrCredentialsRequest( + number = number, + tokens = credentials.map { "${it.username}:${it.password}" } + ) + } + } + } + @Serializable data class ThirdPartyServiceErrorResponse( val reason: String, diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt index da9dc16daa..d6382a5221 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt @@ -23,6 +23,7 @@ class RegistrationActivity : ComponentActivity() { private val repository: RegistrationRepository by lazy { RegistrationRepository( + context = this.application, networkController = RegistrationDependencies.get().networkController, storageController = RegistrationDependencies.get().storageController ) diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt index aa2ac3cd2c..d8d45ec8d9 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt @@ -5,18 +5,26 @@ package org.signal.registration +import org.signal.core.util.logging.Log +import org.signal.registration.util.SensitiveLog + /** * Injection point for dependencies needed by this module. + * + * @param sensitiveLogger A logger for logging sensitive material. The intention is this would only be used in the demo app for testing + debugging, while + * the actual app would just pass null. */ class RegistrationDependencies( val networkController: NetworkController, - val storageController: StorageController + val storageController: StorageController, + val sensitiveLogger: Log.Logger? ) { companion object { lateinit var dependencies: RegistrationDependencies fun provide(registrationDependencies: RegistrationDependencies) { dependencies = registrationDependencies + SensitiveLog.init(dependencies.sensitiveLogger) } fun get(): RegistrationDependencies = dependencies diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowEvent.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowEvent.kt index 03ba03624f..4110230970 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowEvent.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowEvent.kt @@ -9,12 +9,27 @@ import org.signal.core.models.AccountEntropyPool import org.signal.core.models.MasterKey sealed interface RegistrationFlowEvent { + /** Navigate to a specific screen. */ data class NavigateToScreen(val route: RegistrationRoute) : RegistrationFlowEvent + + /** Navigate back one screen. */ data object NavigateBack : RegistrationFlowEvent + + /** We've encountered some irrecoverable state where the best course of action is to completely reset registration. */ data object ResetState : RegistrationFlowEvent + + /** An update has been made to the ongoing registration session. */ data class SessionUpdated(val session: NetworkController.SessionMetadata) : RegistrationFlowEvent + + /** The e164 associated with this registration attempt has been updated. */ data class E164Chosen(val e164: String) : RegistrationFlowEvent + + /** The user has successfully registered. */ data class Registered(val accountEntropyPool: AccountEntropyPool) : RegistrationFlowEvent - data class MasterKeyRestoredViaRegistrationLock(val masterKey: MasterKey) : RegistrationFlowEvent - data class MasterKeyRestoredViaPostRegisterPinEntry(val masterKey: MasterKey) : RegistrationFlowEvent + + /** The master key has been restored from SVR. */ + data class MasterKeyRestoredFromSvr(val masterKey: MasterKey) : RegistrationFlowEvent + + /** We've discovered that RRP-based registration is not possible for this account. */ + data object RecoveryPasswordInvalid : RegistrationFlowEvent } diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt index b89a9b4d8e..ea06a0d971 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt @@ -17,12 +17,25 @@ import org.signal.registration.util.MasterKeyParceler @TypeParceler @TypeParceler data class RegistrationFlowState( + /** The navigation stack. Controls what screen we're on and what the backstack looks like. */ val backStack: List = listOf(RegistrationRoute.Welcome), + + /** The metadata for the currently-active registration session. */ val sessionMetadata: NetworkController.SessionMetadata? = null, + + /** The e164 associated with the [sessionMetadata]. */ val sessionE164: String? = null, + + /** The AEP we generated as part of this registration. */ val accountEntropyPool: AccountEntropyPool? = null, + + /** The master key we restored from SVR. Needed for initial storage service restore, but afterwards we'll generate a new one. */ val temporaryMasterKey: MasterKey? = null, - val registrationLockProof: String? = null, - val preExistingRegistrationData: PreExistingRegistrationData? = null + + /** If set, indicates that this is a re-registration. It contains a bundle of data related to that previous registration. */ + val preExistingRegistrationData: PreExistingRegistrationData? = null, + + /** If true, do not attempt any flows where we generate RRP's. Create a session instead. */ + val doNotAttemptRecoveryPassword: Boolean = false ) : Parcelable diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt index 3b50eb4ac3..741a31e6fb 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt @@ -45,6 +45,7 @@ import org.signal.registration.screens.phonenumber.PhoneNumberScreen import org.signal.registration.screens.pincreation.PinCreationScreen import org.signal.registration.screens.pincreation.PinCreationViewModel import org.signal.registration.screens.pinentry.PinEntryForRegistrationLockViewModel +import org.signal.registration.screens.pinentry.PinEntryForSmsBypassViewModel import org.signal.registration.screens.pinentry.PinEntryForSvrRestoreViewModel import org.signal.registration.screens.pinentry.PinEntryScreen import org.signal.registration.screens.restore.RestoreViaQrScreen @@ -90,6 +91,9 @@ sealed interface RegistrationRoute : NavKey, Parcelable { val svrCredentials: NetworkController.SvrCredentials ) : RegistrationRoute + @Serializable + data class PinEntryForSmsBypass(val svrCredentials: NetworkController.SvrCredentials) : RegistrationRoute + @Serializable data class AccountLocked(val timeRemainingMs: Long) : RegistrationRoute @@ -367,6 +371,24 @@ private fun EntryProviderScope.navigationEntries( ) } + // -- SMS Bypass PIN Entry Screen + entry { key -> + val viewModel: PinEntryForSmsBypassViewModel = viewModel( + factory = PinEntryForSmsBypassViewModel.Factory( + repository = registrationRepository, + parentState = registrationViewModel.state, + parentEventEmitter = registrationViewModel::onEvent, + svrCredentials = key.svrCredentials + ) + ) + val state by viewModel.state.collectAsStateWithLifecycle() + + PinEntryScreen( + state = state, + onEvent = { viewModel.onEvent(it) } + ) + } + // -- Account Locked Screen entry { key -> val daysRemaining = (key.timeRemainingMs / (1000 * 60 * 60 * 24)).toInt() diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt index 0e236ef68f..59128088d9 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt @@ -5,11 +5,14 @@ package org.signal.registration +import android.app.backup.BackupManager +import android.content.Context import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import org.signal.core.models.MasterKey import org.signal.core.models.ServiceId.ACI import org.signal.core.models.ServiceId.PNI +import org.signal.core.util.logging.Log import org.signal.registration.NetworkController.AccountAttributes import org.signal.registration.NetworkController.CreateSessionError import org.signal.registration.NetworkController.MasterKeyResponse @@ -22,9 +25,14 @@ import org.signal.registration.NetworkController.RestoreMasterKeyError import org.signal.registration.NetworkController.SessionMetadata import org.signal.registration.NetworkController.SvrCredentials import org.signal.registration.NetworkController.UpdateSessionError +import org.signal.registration.util.SensitiveLog import java.util.Locale -class RegistrationRepository(val networkController: NetworkController, val storageController: StorageController) { +class RegistrationRepository(val context: Context, val networkController: NetworkController, val storageController: StorageController) { + + companion object { + private val TAG = Log.tag(RegistrationRepository::class) + } suspend fun createSession(e164: String): RegistrationNetworkResult = withContext(Dispatchers.IO) { val fcmToken = networkController.getFcmToken() @@ -88,26 +96,47 @@ class RegistrationRepository(val networkController: NetworkController, val stora } suspend fun getSvrCredentials(): RegistrationNetworkResult = withContext(Dispatchers.IO) { - networkController.getSvrCredentials() + networkController.getSvrCredentials().also { + if (it is RegistrationNetworkResult.Success) { + storageController.appendSvrCredentials(listOf(it.data)) + BackupManager(context).dataChanged() + } + } + } + + suspend fun getRestoredSvrCredentials(): List = withContext(Dispatchers.IO) { + storageController.getRestoredSvrCredentials() + } + + suspend fun checkSvrCredentials(e164: String, credentials: List): RegistrationNetworkResult = withContext(Dispatchers.IO) { + networkController.checkSvrCredentials(e164, credentials) } suspend fun restoreMasterKeyFromSvr( - svr2Credentials: SvrCredentials, + svrCredentials: SvrCredentials, pin: String, isAlphanumeric: Boolean, forRegistrationLock: Boolean ): RegistrationNetworkResult = withContext(Dispatchers.IO) { networkController.restoreMasterKeyFromSvr( - svr2Credentials = svr2Credentials, + svrCredentials = svrCredentials, pin = pin ).also { if (it is RegistrationNetworkResult.Success) { // TODO consider whether we should save this now, or whether we should keep in app state and then hand it back to the library user at the end of the flow storageController.saveValidatedPinAndTemporaryMasterKey(pin, isAlphanumeric, it.data.masterKey, forRegistrationLock) + storageController.appendSvrCredentials(listOf(svrCredentials)) } } } + /** + * See [NetworkController.enqueueSvrGuessResetJob] + */ + suspend fun enqueueSvrResetGuessCountJob() { + networkController.enqueueSvrGuessResetJob() + } + /** * Registers a new account using a recovery password derived from the user's [MasterKey]. * @@ -119,7 +148,7 @@ class RegistrationRepository(val networkController: NetworkController, val stora * * @param e164 The phone number in E.164 format (used for basic auth) * @param recoveryPassword The recovery password, derived from the user's [MasterKey], which allows us to forgo session creation. - * @param registrationLock The registration lock token derived from the master key (if unlocking a reglocked account) + * @param registrationLock The registration lock token derived from the master key, if unlocking a reglocked account. Must be null if the account is not reglocked. * @param skipDeviceTransfer Whether to skip device transfer flow * @return The registration result containing account information or an error */ @@ -127,9 +156,10 @@ class RegistrationRepository(val networkController: NetworkController, val stora e164: String, recoveryPassword: String, registrationLock: String? = null, - skipDeviceTransfer: Boolean = true + skipDeviceTransfer: Boolean = true, + preExistingRegistrationData: PreExistingRegistrationData? = null ): RegistrationNetworkResult, RegisterAccountError> = withContext(Dispatchers.IO) { - registerAccount(e164, sessionId = null, recoveryPassword, registrationLock, skipDeviceTransfer) + registerAccount(e164, sessionId = null, recoveryPassword, registrationLock, skipDeviceTransfer, preExistingRegistrationData) } /** @@ -168,8 +198,9 @@ class RegistrationRepository(val networkController: NetworkController, val stora * @param e164 The phone number in E.164 format (used for basic auth) * @param sessionId The verified session ID from phone number verification. Must provide if you're not using [recoveryPassword]. * @param recoveryPassword The recovery password, derived from the user's [MasterKey], which allows us to forgo session creation. Must provide if you're not using [sessionId]. - * @param registrationLock The registration lock token derived from the master key (if unlocking a reglocked account) + * @param registrationLock The registration lock token derived from the master key (if unlocking a reglocked account). Important: if you provide this, the user will be registered with reglock enabled. * @param skipDeviceTransfer Whether to skip device transfer flow + * @param preExistingRegistrationData If present, we will use the pre-existing key material from this pre-existing registration rather than generating new key material. * @return The registration result containing account information or an error */ private suspend fun registerAccount( @@ -177,14 +208,26 @@ class RegistrationRepository(val networkController: NetworkController, val stora sessionId: String?, recoveryPassword: String?, registrationLock: String? = null, - skipDeviceTransfer: Boolean = true + skipDeviceTransfer: Boolean = true, + preExistingRegistrationData: PreExistingRegistrationData? = null ): RegistrationNetworkResult, RegisterAccountError> = withContext(Dispatchers.IO) { check(sessionId != null || recoveryPassword != null) { "Either sessionId or recoveryPassword must be provided" } check(sessionId == null || recoveryPassword == null) { "Either sessionId or recoveryPassword must be provided, but not both" } - val keyMaterial = storageController.generateAndStoreKeyMaterial() + Log.i(TAG, "[registerAccount] Starting registration for $e164. sessionId: ${sessionId != null}, recoveryPassword: ${recoveryPassword != null}, registrationLock: ${registrationLock != null}, skipDeviceTransfer: $skipDeviceTransfer, preExistingRegistrationData: ${preExistingRegistrationData != null}") + + val keyMaterial = storageController.generateAndStoreKeyMaterial( + existingAccountEntropyPool = preExistingRegistrationData?.aep, + existingAciIdentityKeyPair = preExistingRegistrationData?.aciIdentityKeyPair, + existingPniIdentityKeyPair = preExistingRegistrationData?.pniIdentityKeyPair + ) val fcmToken = networkController.getFcmToken() + val newMasterKey = keyMaterial.accountEntropyPool.deriveMasterKey() + val newRecoveryPassword = newMasterKey.deriveRegistrationRecoveryPassword() + + SensitiveLog.d(TAG, "[registerAccount] Using master key [${org.signal.libsignal.protocol.util.Hex.toStringCondensed(newMasterKey.serialize())}] and RRP [$newRecoveryPassword]") + val accountAttributes = AccountAttributes( signalingKey = null, registrationId = keyMaterial.aciRegistrationId, @@ -203,7 +246,7 @@ class RegistrationRepository(val networkController: NetworkController, val stora ), name = null, pniRegistrationId = keyMaterial.pniRegistrationId, - recoveryPassword = keyMaterial.accountEntropyPool.deriveMasterKey().deriveRegistrationRecoveryPassword() + recoveryPassword = newRecoveryPassword ) val aciPreKeys = PreKeyCollection( @@ -249,11 +292,14 @@ class RegistrationRepository(val networkController: NetworkController, val stora pin: String, isAlphanumeric: Boolean, masterKey: MasterKey - ): RegistrationNetworkResult = withContext(Dispatchers.IO) { + ): RegistrationNetworkResult = withContext(Dispatchers.IO) { val result = networkController.setPinAndMasterKeyOnSvr(pin, masterKey) if (result is RegistrationNetworkResult.Success) { storageController.saveNewlyCreatedPin(pin, isAlphanumeric) + result.data?.let { credential -> + storageController.appendSvrCredentials(listOf(credential)) + } } result diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt index b6872a265e..099246d426 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt @@ -54,10 +54,10 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save is RegistrationFlowEvent.SessionUpdated -> state.copy(sessionMetadata = event.session) is RegistrationFlowEvent.E164Chosen -> state.copy(sessionE164 = event.e164) is RegistrationFlowEvent.Registered -> state.copy(accountEntropyPool = event.accountEntropyPool) - is RegistrationFlowEvent.MasterKeyRestoredViaRegistrationLock -> state.copy(temporaryMasterKey = event.masterKey, registrationLockProof = event.masterKey.deriveRegistrationLock()) - is RegistrationFlowEvent.MasterKeyRestoredViaPostRegisterPinEntry -> state.copy(temporaryMasterKey = event.masterKey) + is RegistrationFlowEvent.MasterKeyRestoredFromSvr -> state.copy(temporaryMasterKey = event.masterKey) is RegistrationFlowEvent.NavigateToScreen -> applyNavigationToScreenEvent(state, event) is RegistrationFlowEvent.NavigateBack -> state.copy(backStack = state.backStack.dropLast(1)) + is RegistrationFlowEvent.RecoveryPasswordInvalid -> state.copy(doNotAttemptRecoveryPassword = true) } } diff --git a/feature/registration/src/main/java/org/signal/registration/StorageController.kt b/feature/registration/src/main/java/org/signal/registration/StorageController.kt index 089edbb0eb..7ba35b101b 100644 --- a/feature/registration/src/main/java/org/signal/registration/StorageController.kt +++ b/feature/registration/src/main/java/org/signal/registration/StorageController.kt @@ -28,9 +28,19 @@ interface StorageController { * Generates all key material required for account registration and stores it persistently. * This includes ACI identity key, PNI identity key, and their respective pre-keys. * + * If optional parameters are provided (e.g. from a pre-existing registration), those values + * will be re-used instead of generating new ones. + * + * @param existingAccountEntropyPool If non-null, re-use this AEP instead of generating a new one. + * @param existingAciIdentityKeyPair If non-null, re-use this ACI identity key pair instead of generating a new one. + * @param existingPniIdentityKeyPair If non-null, re-use this PNI identity key pair instead of generating a new one. * @return [KeyMaterial] containing all generated cryptographic material needed for registration. */ - suspend fun generateAndStoreKeyMaterial(): KeyMaterial + suspend fun generateAndStoreKeyMaterial( + existingAccountEntropyPool: AccountEntropyPool? = null, + existingAciIdentityKeyPair: IdentityKeyPair? = null, + existingPniIdentityKeyPair: IdentityKeyPair? = null + ): KeyMaterial /** * Called after a successful registration to store new registration data. @@ -44,6 +54,18 @@ interface StorageController { */ suspend fun getPreExistingRegistrationData(): PreExistingRegistrationData? + /** + * Retrieves any SVR2 credentials that may have been restored via the OS-level backup/restore service. May be empty. + */ + suspend fun getRestoredSvrCredentials(): List + + // TODO [regV5] Can this just take a single item? + /** + * Appends known-working SVR credentials to the local store of credentials. + * Implementations should limit the number of stored credentials to some reasonable maximum. + */ + suspend fun appendSvrCredentials(credentials: List) + /** * Saves a validated PIN, temporary master key, and registration lock status. * @@ -114,10 +136,14 @@ data class NewRegistrationData( @TypeParceler @TypeParceler @TypeParceler +@TypeParceler data class PreExistingRegistrationData( val e164: String, val aci: ACI, val pni: PNI, val servicePassword: String, - val aep: AccountEntropyPool + val aep: AccountEntropyPool, + val registrationLockEnabled: Boolean, + val aciIdentityKeyPair: IdentityKeyPair, + val pniIdentityKeyPair: IdentityKeyPair ) : Parcelable diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt index 219665508d..d19c16f45b 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt @@ -5,6 +5,7 @@ package org.signal.registration.screens.phonenumber +import org.signal.registration.NetworkController import org.signal.registration.NetworkController.SessionMetadata import org.signal.registration.PreExistingRegistrationData import kotlin.time.Duration @@ -18,7 +19,8 @@ data class PhoneNumberEntryState( val sessionMetadata: SessionMetadata? = null, val showFullScreenSpinner: Boolean = false, val oneTimeEvent: OneTimeEvent? = null, - val preExistingRegistrationData: PreExistingRegistrationData? = null + val preExistingRegistrationData: PreExistingRegistrationData? = null, + val restoredSvrCredentials: List = emptyList() ) { sealed interface OneTimeEvent { data object NetworkError : OneTimeEvent diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt index ba551d173d..c00d476280 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt @@ -51,6 +51,14 @@ class PhoneNumberEntryViewModel( .onEach { Log.d(TAG, "[State] $it") } .stateIn(viewModelScope, SharingStarted.Eagerly, PhoneNumberEntryState()) + init { + viewModelScope.launch { + _state.value = state.value.copy( + restoredSvrCredentials = repository.getRestoredSvrCredentials() + ) + } + } + fun onEvent(event: PhoneNumberEntryScreenEvents) { viewModelScope.launch { val stateEmitter: (PhoneNumberEntryState) -> Unit = { state -> @@ -92,7 +100,8 @@ class PhoneNumberEntryViewModel( return state.copy( sessionE164 = parentState.sessionE164, sessionMetadata = parentState.sessionMetadata, - preExistingRegistrationData = parentState.preExistingRegistrationData + preExistingRegistrationData = parentState.preExistingRegistrationData, + restoredSvrCredentials = state.restoredSvrCredentials.takeUnless { parentState.doNotAttemptRecoveryPassword } ?: emptyList() ) } @@ -140,10 +149,11 @@ class PhoneNumberEntryViewModel( if (state.preExistingRegistrationData?.e164 == e164) { val masterKey = state.preExistingRegistrationData.aep.deriveMasterKey() val recoveryPassword = masterKey.deriveRegistrationRecoveryPassword() - val registrationLock = masterKey.deriveRegistrationLock() + val registrationLock = masterKey.deriveRegistrationLock().takeIf { state.preExistingRegistrationData.registrationLockEnabled } - when (val registerResult = repository.registerAccountWithRecoveryPassword(e164, recoveryPassword, registrationLock, skipDeviceTransfer = true)) { + when (val registerResult = repository.registerAccountWithRecoveryPassword(e164, recoveryPassword, registrationLock, skipDeviceTransfer = true, state.preExistingRegistrationData)) { is NetworkController.RegistrationNetworkResult.Success -> { + Log.i(TAG, "[Register] Successfully re-registered using RRP from pre-existing data.") val (response, keyMaterial) = registerResult.data parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool)) @@ -153,6 +163,7 @@ class PhoneNumberEntryViewModel( } else { parentEventEmitter.navigateTo(RegistrationRoute.PinCreate) } + return state } is NetworkController.RegistrationNetworkResult.Failure -> { when (registerResult.error) { @@ -167,7 +178,7 @@ class PhoneNumberEntryViewModel( return state } is NetworkController.RegisterAccountError.RegistrationLock -> { - Log.w(TAG, "[Register] Reglocked.") + Log.w(TAG, "[Register] Reglocked. This implies that the user still had reglock enabled despite the pre-existing data not thinking it was.") parentEventEmitter.navigateTo( RegistrationRoute.PinEntryForRegistrationLock( timeRemaining = registerResult.error.data.timeRemaining, @@ -177,17 +188,17 @@ class PhoneNumberEntryViewModel( return state } is NetworkController.RegisterAccountError.RateLimited -> { - Log.w(TAG, "[Register] Rate limited.") + Log.w(TAG, "[Register] Rate limited (retryAfter: ${registerResult.error.retryAfter}).") return state.copy(oneTimeEvent = OneTimeEvent.RateLimited(registerResult.error.retryAfter)) } is NetworkController.RegisterAccountError.InvalidRequest -> { Log.w(TAG, "[Register] Invalid request when registering account with RRP. Ditching pre-existing data and continuing with session creation. Message: ${registerResult.error.message}") - // TODO should we clear it in the parent state as well? + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) state = state.copy(preExistingRegistrationData = null) } is NetworkController.RegisterAccountError.RegistrationRecoveryPasswordIncorrect -> { Log.w(TAG, "[Register] Registration recovery password incorrect. Ditching pre-existing data and continuing with session creation. Message: ${registerResult.error.message}") - // TODO should we clear it in the parent state as well? + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) state = state.copy(preExistingRegistrationData = null) } } @@ -203,6 +214,39 @@ class PhoneNumberEntryViewModel( } } + // Detect if we have valid SVR credentials for the current number. If so, we can go right to the PIN entry screen. + // If they successfully restore the master key at that screen, we can use that to build the RRP and register without SMS. + if (state.restoredSvrCredentials.isNotEmpty()) { + when (val result = repository.checkSvrCredentials(e164, state.restoredSvrCredentials)) { + is NetworkController.RegistrationNetworkResult.Success -> { + Log.i(TAG, "[CheckSVRCredentials] Successfully validated credentials for ${e164}.") + val credential = result.data.validCredential + if (credential != null) { + parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164)) + parentEventEmitter.navigateTo(RegistrationRoute.PinEntryForSmsBypass(credential)) + return state + } + } + is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[CheckSVRCredentials] Network error. Ignoring error and continuing without RRP.", result.exception) + } + is NetworkController.RegistrationNetworkResult.ApplicationError -> { + Log.w(TAG, "[CheckSVRCredentials] Application error. Ignoring error and continuing without RRP.", result.exception) + } + is NetworkController.RegistrationNetworkResult.Failure -> { + when (result.error) { + is NetworkController.CheckSvrCredentialsError.InvalidRequest -> { + Log.w(TAG, "[CheckSVRCredentials] Invalid request. Ignoring error and continuing without RRP. Message: ${result.error.message}") + } + + NetworkController.CheckSvrCredentialsError.Unauthorized -> { + Log.w(TAG, "[CheckSVRCredentials] Unauthorized. Ignoring error and continuing without RRP.") + } + } + } + } + } + // Detect if someone backed into this screen and entered a different number if (state.sessionE164 != null && state.sessionE164 != e164) { state = state.copy(sessionMetadata = null) diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt index 29a78d7161..7df112e0c8 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt @@ -113,7 +113,7 @@ class PinEntryForRegistrationLockViewModel( } } - parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredViaRegistrationLock(masterKey)) + parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredFromSvr(masterKey)) val registrationLockToken = masterKey.deriveRegistrationLock() diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt new file mode 100644 index 0000000000..6589b109ed --- /dev/null +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt @@ -0,0 +1,226 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration.screens.pinentry + +import androidx.annotation.VisibleForTesting +import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.viewModelScope +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import org.signal.core.models.MasterKey +import org.signal.core.util.logging.Log +import org.signal.libsignal.protocol.util.Hex +import org.signal.registration.NetworkController +import org.signal.registration.RegistrationFlowEvent +import org.signal.registration.RegistrationFlowState +import org.signal.registration.RegistrationRepository +import org.signal.registration.RegistrationRoute +import org.signal.registration.screens.util.navigateBack +import org.signal.registration.screens.util.navigateTo +import org.signal.registration.util.SensitiveLog + +/** + * ViewModel for the SMS-bypass PIN entry screen. + * + * This screen is shown when we have a known-valid SVR credential for the entered phone number, + * allowing the user to restore their master key and bypass SMS verification. + */ +class PinEntryForSmsBypassViewModel( + private val repository: RegistrationRepository, + private val parentState: StateFlow, + private val parentEventEmitter: (RegistrationFlowEvent) -> Unit, + private val svrCredentials: NetworkController.SvrCredentials +) : ViewModel() { + + companion object { + private val TAG = Log.tag(PinEntryForSmsBypassViewModel::class) + } + + private val _state = MutableStateFlow( + PinEntryState( + mode = PinEntryState.Mode.SmsBypass + ) + ) + + val state: StateFlow = _state + .combine(parentState) { state, parentState -> applyParentState(state, parentState) } + .onEach { Log.d(TAG, "[State] $it") } + .stateIn(viewModelScope, SharingStarted.Eagerly, PinEntryState(showNeedHelp = true)) + + fun onEvent(event: PinEntryScreenEvents) { + viewModelScope.launch { + val stateEmitter: (PinEntryState) -> Unit = { _state.value = it } + applyEvent(state.value, event, stateEmitter, parentEventEmitter) + } + } + + @VisibleForTesting + suspend fun applyEvent( + state: PinEntryState, + event: PinEntryScreenEvents, + stateEmitter: (PinEntryState) -> Unit, + parentEventEmitter: (RegistrationFlowEvent) -> Unit + ) { + when (event) { + is PinEntryScreenEvents.PinEntered -> { + var localState = state.copy(loading = true) + stateEmitter(localState) + localState = applyPinEntered(localState, event, parentEventEmitter) + stateEmitter(localState.copy(loading = false)) + } + is PinEntryScreenEvents.Skip -> { + handleSkip() + } + is PinEntryScreenEvents.ToggleKeyboard, + is PinEntryScreenEvents.NeedHelp -> { + stateEmitter(PinEntryScreenEventHandler.applyEvent(state, event)) + } + } + } + + fun applyParentState(state: PinEntryState, parentState: RegistrationFlowState): PinEntryState { + return state.copy(e164 = parentState.sessionE164) + } + + private suspend fun applyPinEntered( + state: PinEntryState, + event: PinEntryScreenEvents.PinEntered, + parentEventEmitter: (RegistrationFlowEvent) -> Unit + ): PinEntryState { + Log.d(TAG, "[PinEntered] Attempting to restore master key from SVR...") + + if (state.e164 == null) { + Log.w(TAG, "[PinEntered] No e164 available! Shouldn't be in this state. Resetting.") + parentEventEmitter(RegistrationFlowEvent.ResetState) + return state + } + + return when (val result = repository.restoreMasterKeyFromSvr(svrCredentials, event.pin, state.isAlphanumericKeyboard, forRegistrationLock = false)) { + is NetworkController.RegistrationNetworkResult.Success -> { + Log.i(TAG, "[PinEntered] Successfully restored master key from SVR.") + parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredFromSvr(result.data.masterKey)) + attemptToRegister(state, state.e164, result.data.masterKey, provideRegistrationLock = false, parentEventEmitter) + } + is NetworkController.RegistrationNetworkResult.Failure -> { + when (result.error) { + is NetworkController.RestoreMasterKeyError.WrongPin -> { + Log.w(TAG, "[PinEntered] Wrong PIN. Tries remaining: ${result.error.triesRemaining}") + state.copy(triesRemaining = result.error.triesRemaining) + } + is NetworkController.RestoreMasterKeyError.NoDataFound -> { + Log.w(TAG, "[PinEntered] No SVR data found for sms-bypass credential. Marking RRP as invalid and navigating back.") + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) + parentEventEmitter.navigateBack() + state + } + } + } + is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[PinEntered] Network error when restoring master key (sms-bypass).", result.exception) + state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.NetworkError) + } + is NetworkController.RegistrationNetworkResult.ApplicationError -> { + Log.w(TAG, "[PinEntered] Application error when restoring master key (sms-bypass).", result.exception) + state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.UnknownError) + } + } + } + + private fun handleSkip() { + // TODO: Decide desired behavior (likely return to phone number entry). + Log.d(TAG, "[Skip] Not yet implemented.") + } + + private suspend fun attemptToRegister( + state: PinEntryState, + e164: String, + masterKey: MasterKey, + provideRegistrationLock: Boolean, + parentEventEmitter: (RegistrationFlowEvent) -> Unit + ): PinEntryState { + val recoveryPassword = masterKey.deriveRegistrationRecoveryPassword() + val registrationLock = masterKey.deriveRegistrationLock().takeIf { provideRegistrationLock } + + SensitiveLog.d(TAG, "Attempting registration using master key [${Hex.toStringCondensed(masterKey.serialize())}] and RRP [$recoveryPassword]") + + return when (val result = repository.registerAccountWithRecoveryPassword(e164, recoveryPassword, registrationLock, skipDeviceTransfer = true)) { + is NetworkController.RegistrationNetworkResult.Success -> { + parentEventEmitter.navigateTo(RegistrationRoute.FullyComplete) + repository.enqueueSvrResetGuessCountJob() + state + } + is NetworkController.RegistrationNetworkResult.NetworkError -> { + state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.NetworkError) + } + is NetworkController.RegistrationNetworkResult.ApplicationError -> { + state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.UnknownError) + } + is NetworkController.RegistrationNetworkResult.Failure -> { + when (result.error) { + NetworkController.RegisterAccountError.DeviceTransferPossible -> { + Log.w(TAG, "[Register] Got told a device transfer is possible. We should never get into this state. Resetting.") + parentEventEmitter(RegistrationFlowEvent.ResetState) + state + } + is NetworkController.RegisterAccountError.InvalidRequest -> { + Log.w(TAG, "[Register] Invalid request when registering account with RRP. Marking RRP as invalid and navigating back. Message: ${result.error.message}") + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) + parentEventEmitter.navigateBack() + state + } + is NetworkController.RegisterAccountError.RateLimited -> { + Log.w(TAG, "[Register] Rate limited (retryAfter: ${result.error.retryAfter}).") + state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.RateLimited(result.error.retryAfter)) + } + is NetworkController.RegisterAccountError.RegistrationLock -> { + if (provideRegistrationLock) { + Log.w(TAG, "[Register] Hit reglock error when supplying RRP with reglock. This shouldn't happen and implies that the RRP is likely invalid. Marking RRP as invalid and navigating back.") + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) + parentEventEmitter.navigateBack() + state + } else { + Log.w(TAG, "[Register] Hit reglock error when supplying RRP without reglock. Attempting again with reglock.") + attemptToRegister(state, e164, masterKey, provideRegistrationLock = true, parentEventEmitter) + } + } + is NetworkController.RegisterAccountError.RegistrationRecoveryPasswordIncorrect -> { + Log.w(TAG, "[Register] Told that RRP is incorrect. Marking RRP as invalid and navigating back.") + parentEventEmitter(RegistrationFlowEvent.RecoveryPasswordInvalid) + parentEventEmitter.navigateBack() + state + } + is NetworkController.RegisterAccountError.SessionNotFoundOrNotVerified -> { + Log.w(TAG, "[Register] Got told our session wasn't found when trying to use RRP. We should never get into this state. Resetting.") + parentEventEmitter(RegistrationFlowEvent.ResetState) + state + } + } + } + } + } + + class Factory( + private val repository: RegistrationRepository, + private val parentState: StateFlow, + private val parentEventEmitter: (RegistrationFlowEvent) -> Unit, + private val svrCredentials: NetworkController.SvrCredentials + ) : ViewModelProvider.Factory { + override fun create(modelClass: Class): T { + return PinEntryForSmsBypassViewModel( + repository = repository, + parentState = parentState, + parentEventEmitter = parentEventEmitter, + svrCredentials = svrCredentials + ) as T + } + } +} diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt index 88c7940ace..c03d7eb796 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt @@ -118,7 +118,8 @@ class PinEntryForSvrRestoreViewModel( return when (val result = repository.restoreMasterKeyFromSvr(svrCredentials, event.pin, state.isAlphanumericKeyboard, forRegistrationLock = false)) { is NetworkController.RegistrationNetworkResult.Success -> { Log.i(TAG, "[PinEntered] Successfully restored master key from SVR.") - parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredViaPostRegisterPinEntry(result.data.masterKey)) + repository.enqueueSvrResetGuessCountJob() + parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredFromSvr(result.data.masterKey)) parentEventEmitter.navigateTo(RegistrationRoute.FullyComplete) state } @@ -129,8 +130,9 @@ class PinEntryForSvrRestoreViewModel( state.copy(triesRemaining = result.error.triesRemaining) } is NetworkController.RestoreMasterKeyError.NoDataFound -> { - Log.w(TAG, "[PinEntered] No SVR data found. Proceeding without restore.") - state.copy(oneTimeEvent = PinEntryState.OneTimeEvent.SvrDataMissing) + Log.w(TAG, "[PinEntered] No SVR data found. Need to create a PIN instead.") + parentEventEmitter.navigateTo(RegistrationRoute.PinCreate) + state } } } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryScreen.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryScreen.kt index 1c0487c819..bf4ec3d67f 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryScreen.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryScreen.kt @@ -73,7 +73,8 @@ fun PinEntryScreen( val titleString = remember { return@remember when (state.mode) { PinEntryState.Mode.RegistrationLock -> "Registration Lock" - PinEntryState.Mode.SvrRestore -> "Enter your PIN" + PinEntryState.Mode.SvrRestore, + PinEntryState.Mode.SmsBypass -> "Enter your PIN" } } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryState.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryState.kt index 901279ec34..5a7d9ab027 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryState.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryState.kt @@ -13,10 +13,12 @@ data class PinEntryState( val loading: Boolean = false, val triesRemaining: Int? = null, val mode: Mode = Mode.SvrRestore, - val oneTimeEvent: OneTimeEvent? = null + val oneTimeEvent: OneTimeEvent? = null, + val e164: String? = null ) { enum class Mode { RegistrationLock, + SmsBypass, SvrRestore } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt index 0498e54820..c0abb01363 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt @@ -108,7 +108,7 @@ class VerificationCodeViewModel( } } is NetworkController.SubmitVerificationCodeError.RateLimited -> { - Log.w(TAG, "[SubmitCode] Rate limited.") + Log.w(TAG, "[SubmitCode] Rate limited (retryAfter: ${result.error.retryAfter}).") return state.copy(oneTimeEvent = OneTimeEvent.RateLimited(result.error.retryAfter)) } } @@ -167,7 +167,7 @@ class VerificationCodeViewModel( state } is NetworkController.RegisterAccountError.RateLimited -> { - Log.w(TAG, "[Register] Rate limited.") + Log.w(TAG, "[Register] Rate limited (retryAfter: ${registerResult.error.retryAfter}).") state.copy(oneTimeEvent = OneTimeEvent.RateLimited(registerResult.error.retryAfter)) } is NetworkController.RegisterAccountError.InvalidRequest -> { diff --git a/feature/registration/src/main/java/org/signal/registration/util/SensitiveLog.kt b/feature/registration/src/main/java/org/signal/registration/util/SensitiveLog.kt new file mode 100644 index 0000000000..714370b598 --- /dev/null +++ b/feature/registration/src/main/java/org/signal/registration/util/SensitiveLog.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration.util + +import org.signal.core.util.logging.Log +import org.signal.core.util.logging.NoopLogger + +/** + * A logger that can be used to log sensitive information for debugging purposes. + * The actual application will use a NoopLogger, while the demo app will provide actual logging capabilities to ease debugging. + */ +object SensitiveLog : Log.Logger() { + private var logger: Log.Logger = NoopLogger() + + fun init(logger: Log.Logger?) { + this.logger = logger ?: NoopLogger() + } + + override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + this.logger.v(tag, "[SENSITIVE] $message", t, keepLonger) + } + + override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + this.logger.d(tag, "[SENSITIVE] $message", t, keepLonger) + } + + override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + this.logger.i(tag, "[SENSITIVE] $message", t, keepLonger) + } + + override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + this.logger.w(tag, "[SENSITIVE] $message", t, keepLonger) + } + + override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) { + this.logger.e(tag, "[SENSITIVE] $message", t, keepLonger) + } + + override fun flush() { + this.logger.flush() + } +} \ No newline at end of file diff --git a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt index 0aaa9ead06..bce05255d9 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt @@ -7,6 +7,7 @@ package org.signal.registration.screens.phonenumber import assertk.assertThat import assertk.assertions.hasSize +import assertk.assertions.isEmpty import assertk.assertions.isEqualTo import assertk.assertions.isFalse import assertk.assertions.isInstanceOf @@ -15,16 +16,20 @@ import assertk.assertions.isNull import assertk.assertions.isTrue import assertk.assertions.prop import io.mockk.coEvery +import io.mockk.coVerify import io.mockk.mockk import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.test.runTest import org.junit.Before import org.junit.Test +import org.signal.registration.KeyMaterial import org.signal.registration.NetworkController +import org.signal.registration.PreExistingRegistrationData import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import java.io.IOException import kotlin.time.Duration.Companion.seconds class PhoneNumberEntryViewModelTest { @@ -822,6 +827,542 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.NetworkError) } + // ==================== applyParentState Tests ==================== + + @Test + fun `applyParentState copies preExistingRegistrationData from parent`() { + val preExistingData = mockk(relaxed = true) + val state = PhoneNumberEntryState() + val parentFlowState = RegistrationFlowState(preExistingRegistrationData = preExistingData) + + val result = viewModel.applyParentState(state, parentFlowState) + + assertThat(result.preExistingRegistrationData).isEqualTo(preExistingData) + } + + @Test + fun `applyParentState clears restoredSvrCredentials when doNotAttemptRecoveryPassword is true`() { + val credentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val state = PhoneNumberEntryState(restoredSvrCredentials = credentials) + val parentFlowState = RegistrationFlowState(doNotAttemptRecoveryPassword = true) + + val result = viewModel.applyParentState(state, parentFlowState) + + assertThat(result.restoredSvrCredentials).isEmpty() + } + + @Test + fun `applyParentState keeps restoredSvrCredentials when doNotAttemptRecoveryPassword is false`() { + val credentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val state = PhoneNumberEntryState(restoredSvrCredentials = credentials) + val parentFlowState = RegistrationFlowState(doNotAttemptRecoveryPassword = false) + + val result = viewModel.applyParentState(state, parentFlowState) + + assertThat(result.restoredSvrCredentials).isEqualTo(credentials) + } + + // ==================== Pre-existing Registration Data (RRP) Tests ==================== + + @Test + fun `PhoneNumberSubmitted with matching preExistingRegistrationData registers with RRP and navigates to PinEntryForSvrRestore`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + val keyMaterial = mockk(relaxed = true) + val registerResponse = createRegisterAccountResponse(storageCapable = true) + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(registerResponse to keyMaterial) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.first()).isInstanceOf() + assertThat(emittedEvents[1]) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with matching preExistingRegistrationData navigates to PinCreate when not storage capable`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + val keyMaterial = mockk(relaxed = true) + val registerResponse = createRegisterAccountResponse(storageCapable = false) + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(registerResponse to keyMaterial) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.first()).isInstanceOf() + assertThat(emittedEvents[1]) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and SessionNotFoundOrNotVerified emits ResetState`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.SessionNotFoundOrNotVerified("Not found") + ) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents).hasSize(1) + assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and DeviceTransferPossible emits ResetState`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.DeviceTransferPossible + ) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents).hasSize(1) + assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and RegistrationLock navigates to PinEntryForRegistrationLock`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + val svrCredentials = NetworkController.SvrCredentials(username = "user", password = "pass") + val registrationLockData = NetworkController.RegistrationLockResponse( + timeRemaining = 60000L, + svr2Credentials = svrCredentials + ) + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RegistrationLock(registrationLockData) + ) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents).hasSize(1) + assertThat(emittedEvents.first()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and RateLimited returns RateLimited event`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RateLimited(30.seconds) + ) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedStates.last().oneTimeEvent).isNotNull() + .isInstanceOf() + .prop(PhoneNumberEntryState.OneTimeEvent.RateLimited::retryAfter) + .isEqualTo(30.seconds) + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and InvalidRequest emits RecoveryPasswordInvalid and falls through to session creation`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.InvalidRequest("Bad request") + ) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + // Should emit RecoveryPasswordInvalid and then continue to session creation + assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + // Should ultimately navigate to verification code entry after falling through + assertThat(emittedStates.last().preExistingRegistrationData).isNull() + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and RegistrationRecoveryPasswordIncorrect emits RecoveryPasswordInvalid and falls through`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RegistrationRecoveryPasswordIncorrect("Wrong password") + ) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + assertThat(emittedStates.last().preExistingRegistrationData).isNull() + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and NetworkError returns NetworkError event`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.NetworkError(IOException("Network error")) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.NetworkError) + } + + @Test + fun `PhoneNumberSubmitted with preExistingRegistrationData and ApplicationError returns UnknownError event`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15551234567" + coEvery { registrationLockEnabled } returns false + } + + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.UnknownError) + } + + @Test + fun `PhoneNumberSubmitted with non-matching preExistingRegistrationData skips RRP and creates session`() = runTest { + val preExistingData = mockk(relaxed = true) { + coEvery { e164 } returns "+15559999999" + coEvery { registrationLockEnabled } returns false + } + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + preExistingRegistrationData = preExistingData + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + // Should skip RRP and go to session creation flow + coVerify(exactly = 0) { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any(), any()) } + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + // ==================== SVR Credential Checking Tests ==================== + + @Test + fun `PhoneNumberSubmitted with valid SVR credentials navigates to PinEntryForSmsBypass`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val validCredential = NetworkController.SvrCredentials(username = "user", password = "pass") + val checkResponse = NetworkController.CheckSvrCredentialsResponse( + matches = mapOf("user:pass" to "match") + ) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(checkResponse) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents).hasSize(2) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with no matching SVR credentials falls through to session creation`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val checkResponse = NetworkController.CheckSvrCredentialsResponse( + matches = mapOf("user:pass" to "no-match") + ) + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(checkResponse) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + // Should fall through to session creation + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with SVR credentials network error falls through to session creation`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.NetworkError(IOException("Network error")) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + // Should ignore error and fall through + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with SVR credentials application error falls through to session creation`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with SVR credentials invalid request falls through to session creation`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.CheckSvrCredentialsError.InvalidRequest("Bad request") + ) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with SVR credentials unauthorized falls through to session creation`() = runTest { + val svrCredentials = listOf( + NetworkController.SvrCredentials(username = "user", password = "pass") + ) + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.checkSvrCredentials(any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.CheckSvrCredentialsError.Unauthorized + ) + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = svrCredentials + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PhoneNumberSubmitted with empty restoredSvrCredentials skips SVR check`() = runTest { + val sessionMetadata = createSessionMetadata(requestedInformation = emptyList()) + + coEvery { mockRepository.createSession(any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + coEvery { mockRepository.requestVerificationCode(any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(sessionMetadata) + + val initialState = PhoneNumberEntryState( + countryCode = "1", + nationalNumber = "5551234567", + restoredSvrCredentials = emptyList() + ) + + viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) + + coVerify(exactly = 0) { mockRepository.checkSvrCredentials(any(), any()) } + assertThat(emittedEvents.last()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + // ==================== Helper Functions ==================== private fun createSessionMetadata( @@ -837,4 +1378,20 @@ class PhoneNumberEntryViewModelTest { requestedInformation = requestedInformation, verified = verified ) + + private fun createRegisterAccountResponse( + aci: String = "test-aci", + pni: String = "test-pni", + e164: String = "+15551234567", + storageCapable: Boolean = true + ) = NetworkController.RegisterAccountResponse( + aci = aci, + pni = pni, + e164 = e164, + usernameHash = null, + usernameLinkHandle = null, + storageCapable = storageCapable, + entitlements = null, + reregistration = false + ) } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModelTest.kt index cc0cb0b42a..fe1af6c4d9 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModelTest.kt @@ -80,7 +80,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(3) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedParentEvents[1]).isInstanceOf() assertThat(emittedParentEvents[2]) .isInstanceOf() @@ -166,7 +166,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(2) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.ResetState) } @@ -186,7 +186,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(2) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.ResetState) } @@ -208,7 +208,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(2) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.ResetState) } @@ -228,7 +228,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(1) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() .prop(PinEntryState.OneTimeEvent.RateLimited::retryAfter) @@ -250,7 +250,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(1) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.UnknownError) } @@ -269,7 +269,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(1) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.UnknownError) } @@ -286,7 +286,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(1) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.NetworkError) } @@ -303,7 +303,7 @@ class PinEntryForRegistrationLockViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(1) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.UnknownError) } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModelTest.kt new file mode 100644 index 0000000000..aaf5fd6a49 --- /dev/null +++ b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModelTest.kt @@ -0,0 +1,397 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration.screens.pinentry + +import assertk.assertThat +import assertk.assertions.hasSize +import assertk.assertions.isEqualTo +import assertk.assertions.isInstanceOf +import assertk.assertions.isNotNull +import assertk.assertions.prop +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.test.runTest +import org.junit.Before +import org.junit.Test +import org.signal.core.models.MasterKey +import org.signal.registration.NetworkController +import org.signal.registration.RegistrationFlowEvent +import org.signal.registration.RegistrationFlowState +import org.signal.registration.RegistrationRepository +import org.signal.registration.RegistrationRoute +import java.io.IOException +import kotlin.time.Duration.Companion.seconds + +class PinEntryForSmsBypassViewModelTest { + + private lateinit var viewModel: PinEntryForSmsBypassViewModel + private lateinit var mockRepository: RegistrationRepository + private lateinit var parentState: MutableStateFlow + private lateinit var emittedParentEvents: MutableList + private lateinit var parentEventEmitter: (RegistrationFlowEvent) -> Unit + private lateinit var emittedStates: MutableList + private lateinit var stateEmitter: (PinEntryState) -> Unit + + private val testSvrCredentials = NetworkController.SvrCredentials( + username = "test-username", + password = "test-password" + ) + + @Before + fun setup() { + mockRepository = mockk(relaxed = true) + parentState = MutableStateFlow( + RegistrationFlowState( + sessionE164 = "+15551234567" + ) + ) + emittedParentEvents = mutableListOf() + parentEventEmitter = { event -> emittedParentEvents.add(event) } + emittedStates = mutableListOf() + stateEmitter = { state -> emittedStates.add(state) } + viewModel = PinEntryForSmsBypassViewModel( + repository = mockRepository, + parentState = parentState, + parentEventEmitter = parentEventEmitter, + svrCredentials = testSvrCredentials + ) + } + + // ==================== PinEntered - Restore Master Key Tests ==================== + + @Test + fun `PinEntered with correct PIN restores master key and registers successfully`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(mockk(relaxed = true)) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(2) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PinEntered with correct PIN enqueues SVR guess reset job after successful registration`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(mockk(relaxed = true)) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + coVerify { mockRepository.enqueueSvrResetGuessCountJob() } + } + + @Test + fun `PinEntered with wrong PIN returns state with tries remaining`() = runTest { + val triesRemaining = 3 + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RestoreMasterKeyError.WrongPin(triesRemaining) + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("wrong-pin"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(0) + assertThat(emittedStates.last().triesRemaining).isEqualTo(triesRemaining) + } + + @Test + fun `PinEntered with no SVR data emits RecoveryPasswordInvalid and navigates back`() = runTest { + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RestoreMasterKeyError.NoDataFound + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(2) + assertThat(emittedParentEvents[0]).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.NavigateBack) + } + + @Test + fun `PinEntered with network error restoring master key returns NetworkError event`() = runTest { + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.NetworkError(IOException("Network error")) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(0) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.NetworkError) + } + + @Test + fun `PinEntered with application error restoring master key returns UnknownError event`() = runTest { + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(0) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.UnknownError) + } + + @Test + fun `PinEntered with missing e164 emits ResetState`() = runTest { + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = null) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(1) + assertThat(emittedParentEvents[0]).isEqualTo(RegistrationFlowEvent.ResetState) + } + + // ==================== Registration Error Tests ==================== + + @Test + fun `PinEntered with registration network error returns NetworkError event`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.NetworkError(IOException("Network error")) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(1) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.NetworkError) + } + + @Test + fun `PinEntered with registration application error returns UnknownError event`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(1) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.UnknownError) + } + + @Test + fun `PinEntered with DeviceTransferPossible during registration emits ResetState`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.DeviceTransferPossible + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(2) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.ResetState) + } + + @Test + fun `PinEntered with InvalidRequest during registration emits RecoveryPasswordInvalid and navigates back`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.InvalidRequest("Bad request") + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(3) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + assertThat(emittedParentEvents[2]).isEqualTo(RegistrationFlowEvent.NavigateBack) + } + + @Test + fun `PinEntered with RateLimited during registration returns RateLimited event`() = runTest { + val masterKey = mockk(relaxed = true) + val retryAfter = 30.seconds + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RateLimited(retryAfter) + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(1) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedStates.last().oneTimeEvent).isNotNull() + .isInstanceOf() + .prop(PinEntryState.OneTimeEvent.RateLimited::retryAfter) + .isEqualTo(retryAfter) + } + + @Test + fun `PinEntered with RegistrationLock without provideRegistrationLock retries with reglock`() = runTest { + val masterKey = mockk(relaxed = true) + val registrationLockData = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + // First call (without reglock) returns RegistrationLock error, second call (with reglock) succeeds + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), registrationLock = null, any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RegistrationLock(registrationLockData) + ) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), registrationLock = any(), any()) } returns + NetworkController.RegistrationNetworkResult.Success(mockk(relaxed = true)) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(2) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() + } + + @Test + fun `PinEntered with RegistrationLock when already providing reglock emits RecoveryPasswordInvalid and navigates back`() = runTest { + val masterKey = mockk(relaxed = true) + val registrationLockData = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + // Both calls (with and without reglock) return RegistrationLock error + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RegistrationLock(registrationLockData) + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + // First retry (without reglock) -> reglock error -> retry with reglock -> reglock error again -> RecoveryPasswordInvalid + NavigateBack + assertThat(emittedParentEvents).hasSize(3) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + assertThat(emittedParentEvents[2]).isEqualTo(RegistrationFlowEvent.NavigateBack) + } + + @Test + fun `PinEntered with RegistrationRecoveryPasswordIncorrect emits RecoveryPasswordInvalid and navigates back`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.RegistrationRecoveryPasswordIncorrect("Wrong password") + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(3) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.RecoveryPasswordInvalid) + assertThat(emittedParentEvents[2]).isEqualTo(RegistrationFlowEvent.NavigateBack) + } + + @Test + fun `PinEntered with SessionNotFoundOrNotVerified during registration emits ResetState`() = runTest { + val masterKey = mockk(relaxed = true) + val initialState = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + + coEvery { mockRepository.restoreMasterKeyFromSvr(any(), any(), any(), forRegistrationLock = false) } returns + NetworkController.RegistrationNetworkResult.Success(NetworkController.MasterKeyResponse(masterKey)) + coEvery { mockRepository.registerAccountWithRecoveryPassword(any(), any(), any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.RegisterAccountError.SessionNotFoundOrNotVerified("Not found") + ) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) + + assertThat(emittedParentEvents).hasSize(2) + assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[1]).isEqualTo(RegistrationFlowEvent.ResetState) + } + + // ==================== applyParentState Tests ==================== + + @Test + fun `applyParentState copies e164 from parent state`() { + val state = PinEntryState(mode = PinEntryState.Mode.SmsBypass) + val parentFlowState = RegistrationFlowState(sessionE164 = "+15559876543") + + val result = viewModel.applyParentState(state, parentFlowState) + + assertThat(result.e164).isEqualTo("+15559876543") + } + + @Test + fun `applyParentState with null e164 in parent state sets null e164`() { + val state = PinEntryState(mode = PinEntryState.Mode.SmsBypass, e164 = "+15551234567") + val parentFlowState = RegistrationFlowState(sessionE164 = null) + + val result = viewModel.applyParentState(state, parentFlowState) + + assertThat(result.e164).isEqualTo(null) + } + + // ==================== ToggleKeyboard Tests ==================== + + @Test + fun `ToggleKeyboard toggles isAlphanumericKeyboard from false to true`() = runTest { + val initialState = PinEntryState(isAlphanumericKeyboard = false) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.ToggleKeyboard, stateEmitter, parentEventEmitter) + + assertThat(emittedStates.last().isAlphanumericKeyboard).isEqualTo(true) + } + + @Test + fun `ToggleKeyboard toggles isAlphanumericKeyboard from true to false`() = runTest { + val initialState = PinEntryState(isAlphanumericKeyboard = true) + + viewModel.applyEvent(initialState, PinEntryScreenEvents.ToggleKeyboard, stateEmitter, parentEventEmitter) + + assertThat(emittedStates.last().isAlphanumericKeyboard).isEqualTo(false) + } +} diff --git a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt index 81a0742dff..8f20d95c20 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt @@ -72,7 +72,7 @@ class PinEntryForSvrRestoreViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) assertThat(emittedParentEvents).hasSize(2) - assertThat(emittedParentEvents[0]).isInstanceOf() + assertThat(emittedParentEvents[0]).isInstanceOf() assertThat(emittedParentEvents[1]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) @@ -162,7 +162,7 @@ class PinEntryForSvrRestoreViewModelTest { } @Test - fun `PinEntered with no SVR data returns SvrDataMissing event`() = runTest { + fun `PinEntered with no SVR data navigates to PinCreate`() = runTest { val svrCredentials = NetworkController.SvrCredentials( username = "test-username", password = "test-password" @@ -178,8 +178,11 @@ class PinEntryForSvrRestoreViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.PinEntered("123456"), stateEmitter, parentEventEmitter) - assertThat(emittedParentEvents).hasSize(0) - assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PinEntryState.OneTimeEvent.SvrDataMissing) + assertThat(emittedParentEvents).hasSize(1) + assertThat(emittedParentEvents.first()) + .isInstanceOf() + .prop(RegistrationFlowEvent.NavigateToScreen::route) + .isInstanceOf() } @Test