From f09bf5b14c1bf0c4e843c696d5173f99ac3ea9cd Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Wed, 18 Mar 2026 16:53:50 -0400 Subject: [PATCH] Make regV5 resumable if the app closes. --- .../sample/debug/NetworkDebugOverlay.kt | 2 +- .../sample/screens/main/MainScreen.kt | 78 +++++- .../sample/screens/main/MainScreenState.kt | 11 +- .../screens/main/MainScreenViewModel.kt | 24 ++ .../signal/registration/PersistedFlowState.kt | 60 +++++ .../registration/RegistrationFlowState.kt | 5 +- .../registration/RegistrationNavigation.kt | 14 +- .../registration/RegistrationRepository.kt | 77 ++++++ .../registration/RegistrationViewModel.kt | 81 +++++- .../phonenumber/PhoneNumberEntryViewModel.kt | 38 ++- .../src/main/protowire/Registration.proto | 3 + .../registration/PersistedFlowStateTest.kt | 225 ++++++++++++++++ .../RegistrationNavigationTest.kt | 19 ++ .../RegistrationViewModelRestoreTest.kt | 251 ++++++++++++++++++ .../PhoneNumberEntryViewModelTest.kt | 48 ++-- 15 files changed, 895 insertions(+), 41 deletions(-) create mode 100644 feature/registration/src/main/java/org/signal/registration/PersistedFlowState.kt create mode 100644 feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt create mode 100644 feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt diff --git a/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt b/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt index a9549b9ac0..c4aa1c5d80 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt @@ -69,7 +69,7 @@ fun NetworkDebugOverlay( onClick = { showDialog = true }, dragOffset = dragOffset, onDrag = { delta -> dragOffset += delta }, - modifier = Modifier.align(Alignment.CenterEnd) + modifier = Modifier.align(Alignment.TopEnd) ) if (showDialog) { diff --git a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreen.kt b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreen.kt index cec539d041..de32f414d7 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreen.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreen.kt @@ -14,7 +14,9 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.AlertDialog import androidx.compose.material3.Button import androidx.compose.material3.ButtonDefaults @@ -74,6 +76,7 @@ fun MainScreen( Column( modifier = modifier .fillMaxSize() + .verticalScroll(rememberScrollState()) .padding(24.dp), horizontalAlignment = Alignment.CenterHorizontally, verticalArrangement = Arrangement.Center @@ -95,6 +98,30 @@ fun MainScreen( Spacer(modifier = Modifier.height(32.dp)) } + if (state.pendingFlowState != null) { + PendingFlowStateCard(state.pendingFlowState) + Spacer(modifier = Modifier.height(16.dp)) + + Button( + onClick = { onEvent(MainScreenEvents.LaunchRegistration) }, + modifier = Modifier.fillMaxWidth() + ) { + Text("Resume Registration") + } + + TextButton( + onClick = { showClearDataDialog = true }, + modifier = Modifier.fillMaxWidth(), + colors = ButtonDefaults.textButtonColors( + contentColor = MaterialTheme.colorScheme.error + ) + ) { + Text("Clear Pending Data") + } + + Spacer(modifier = Modifier.height(16.dp)) + } + if (state.existingRegistrationState != null) { if (state.registrationExpired) { Row( @@ -150,7 +177,7 @@ fun MainScreen( ) { Text("Clear All Data") } - } else { + } else if (state.pendingFlowState == null) { Button( onClick = { onEvent(MainScreenEvents.LaunchRegistration) }, modifier = Modifier.fillMaxWidth() @@ -213,6 +240,55 @@ private fun RegistrationField(label: String, value: String) { } } +@Composable +private fun PendingFlowStateCard(pending: MainScreenState.PendingFlowState) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.tertiaryContainer + ) + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + Text( + text = "In-Progress Registration", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onTertiaryContainer + ) + + HorizontalDivider(modifier = Modifier.padding(vertical = 8.dp)) + + RegistrationField(label = "Current Screen", value = pending.currentScreen) + RegistrationField(label = "Backstack Depth", value = pending.backstackSize.toString()) + if (pending.e164 != null) { + RegistrationField(label = "Phone Number", value = pending.e164) + } + RegistrationField(label = "Has Session", value = if (pending.hasSession) "Yes" else "No") + RegistrationField(label = "Has AEP", value = if (pending.hasAccountEntropyPool) "Yes" else "No") + } + } +} + +@Preview(showBackground = true) +@Composable +private fun MainScreenWithPendingFlowStatePreview() { + Previews.Preview { + MainScreen( + state = MainScreenState( + pendingFlowState = MainScreenState.PendingFlowState( + e164 = "+15551234567", + backstackSize = 4, + currentScreen = "VerificationCodeEntry", + hasSession = true, + hasAccountEntropyPool = false + ) + ), + onEvent = {} + ) + } +} + @Preview(showBackground = true) @Composable private fun MainScreenPreview() { diff --git a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenState.kt b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenState.kt index 5668cec604..a7be11bbe7 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenState.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenState.kt @@ -7,8 +7,17 @@ package org.signal.registration.sample.screens.main data class MainScreenState( val existingRegistrationState: ExistingRegistrationState? = null, - val registrationExpired: Boolean = false + val registrationExpired: Boolean = false, + val pendingFlowState: PendingFlowState? = null ) { + data class PendingFlowState( + val e164: String?, + val backstackSize: Int, + val currentScreen: String, + val hasSession: Boolean, + val hasAccountEntropyPool: Boolean + ) + data class ExistingRegistrationState( val phoneNumber: String, val aci: String, diff --git a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenViewModel.kt b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenViewModel.kt index 8003bfb198..9f91460af1 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenViewModel.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/screens/main/MainScreenViewModel.kt @@ -12,9 +12,11 @@ import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.launch +import kotlinx.serialization.json.Json import org.signal.core.util.Base64 import org.signal.core.util.logging.Log import org.signal.registration.NetworkController +import org.signal.registration.PersistedFlowState import org.signal.registration.StorageController import org.signal.registration.sample.storage.RegistrationPreferences @@ -75,6 +77,7 @@ class MainScreenViewModel( } else { null }, + pendingFlowState = loadPendingFlowState(), registrationExpired = false ) @@ -84,6 +87,27 @@ class MainScreenViewModel( } } + private suspend fun loadPendingFlowState(): MainScreenState.PendingFlowState? { + return try { + val data = storageController.readInProgressRegistrationData() + if (data.flowStateJson.isEmpty()) return null + + val json = Json { ignoreUnknownKeys = true } + val persisted = json.decodeFromString(PersistedFlowState.serializer(), data.flowStateJson) + + MainScreenState.PendingFlowState( + e164 = persisted.sessionE164, + backstackSize = persisted.backStack.size, + currentScreen = persisted.backStack.lastOrNull()?.let { it::class.simpleName } ?: "Unknown", + hasSession = persisted.sessionMetadata != null, + hasAccountEntropyPool = data.accountEntropyPool.isNotEmpty() + ) + } catch (e: Exception) { + Log.w(TAG, "Failed to load pending flow state", e) + null + } + } + private suspend fun checkRegistrationStatus() { when (val result = networkController.getSvrCredentials()) { is NetworkController.RegistrationNetworkResult.Success -> { diff --git a/feature/registration/src/main/java/org/signal/registration/PersistedFlowState.kt b/feature/registration/src/main/java/org/signal/registration/PersistedFlowState.kt new file mode 100644 index 0000000000..1753fd93b2 --- /dev/null +++ b/feature/registration/src/main/java/org/signal/registration/PersistedFlowState.kt @@ -0,0 +1,60 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration + +import kotlinx.serialization.Serializable +import org.signal.core.models.AccountEntropyPool +import org.signal.core.models.MasterKey + +/** + * A serializable snapshot of [RegistrationFlowState] fields that need to survive app kills. + * + * Fields like [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey] + * are reconstructed from dedicated proto fields, not from this JSON snapshot. + * [RegistrationFlowState.preExistingRegistrationData] is loaded from permanent storage. + */ +@Serializable +data class PersistedFlowState( + val backStack: List, + val sessionMetadata: NetworkController.SessionMetadata?, + val sessionE164: String?, + val doNotAttemptRecoveryPassword: Boolean +) + +/** + * Extracts the persistable fields from a [RegistrationFlowState]. + */ +fun RegistrationFlowState.toPersistedFlowState(): PersistedFlowState { + return PersistedFlowState( + backStack = backStack, + sessionMetadata = sessionMetadata, + sessionE164 = sessionE164, + doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword + ) +} + +/** + * Reconstructs a full [RegistrationFlowState] from persisted data and separately-stored fields. + * + * @param accountEntropyPool Restored from the proto's dedicated `accountEntropyPool` field. + * @param temporaryMasterKey Restored from the proto's dedicated `temporaryMasterKey` field. + * @param preExistingRegistrationData Loaded from permanent storage via [StorageController.getPreExistingRegistrationData]. + */ +fun PersistedFlowState.toRegistrationFlowState( + accountEntropyPool: AccountEntropyPool?, + temporaryMasterKey: MasterKey?, + preExistingRegistrationData: PreExistingRegistrationData? +): RegistrationFlowState { + return RegistrationFlowState( + backStack = backStack, + sessionMetadata = sessionMetadata, + sessionE164 = sessionE164, + accountEntropyPool = accountEntropyPool, + temporaryMasterKey = temporaryMasterKey, + preExistingRegistrationData = preExistingRegistrationData, + doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword + ) +} diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt index 4f9f6ddec7..2186c4675e 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationFlowState.kt @@ -37,5 +37,8 @@ data class RegistrationFlowState( val preExistingRegistrationData: PreExistingRegistrationData? = null, /** If true, do not attempt any flows where we generate RRP's. Create a session instead. */ - val doNotAttemptRecoveryPassword: Boolean = false + val doNotAttemptRecoveryPassword: Boolean = false, + + /** If true, the ViewModel is still deciding whether to restore a previous flow or start fresh. */ + val isRestoringNavigationState: Boolean = true ) : Parcelable, DebugLoggable diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt index 07e850f529..c0b0d9a256 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt @@ -8,9 +8,13 @@ package org.signal.registration import android.os.Parcelable +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.material3.CircularProgressIndicator import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.lifecycle.viewmodel.compose.viewModel @@ -62,6 +66,7 @@ import org.signal.registration.screens.welcome.WelcomeScreenEvents * Navigation routes for the registration flow. * Using @Serializable and NavKey for type-safe navigation with Navigation 3. */ +@Serializable @Parcelize sealed interface RegistrationRoute : NavKey, Parcelable { @Serializable @@ -77,7 +82,7 @@ sealed interface RegistrationRoute : NavKey, Parcelable { data object CountryCodePicker : RegistrationRoute @Serializable - data class VerificationCodeEntry(val session: NetworkController.SessionMetadata, val e164: String) : RegistrationRoute + data object VerificationCodeEntry : RegistrationRoute @Serializable data class Captcha(val session: NetworkController.SessionMetadata) : RegistrationRoute @@ -150,6 +155,13 @@ fun RegistrationNavHost( val registrationState by viewModel.state.collectAsStateWithLifecycle() val permissions: MultiplePermissionsState = permissionsState ?: rememberMultiplePermissionsState(viewModel.getRequiredPermissions()) + if (registrationState.isRestoringNavigationState) { + Box(modifier = modifier.fillMaxSize(), contentAlignment = Alignment.Center) { + CircularProgressIndicator() + } + return + } + val entryProvider = entryProvider { navigationEntries( registrationRepository = registrationRepository, diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt index 6057b1f4aa..6372625f35 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt @@ -10,6 +10,7 @@ import android.content.Context import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.withContext +import kotlinx.serialization.json.Json import okio.ByteString.Companion.toByteString import org.signal.core.models.AccountEntropyPool import org.signal.core.models.MasterKey @@ -410,6 +411,81 @@ class RegistrationRepository(val context: Context, val networkController: Networ return storageController.getPreExistingRegistrationData() } + /** + * Persists the current flow state as JSON in the in-progress registration data proto. + */ + suspend fun saveFlowState(state: RegistrationFlowState) = withContext(Dispatchers.IO) { + try { + val json = flowStateJson.encodeToString(PersistedFlowState.serializer(), state.toPersistedFlowState()) + storageController.updateInProgressRegistrationData { + flowStateJson = json + } + } catch (e: Exception) { + Log.w(TAG, "Failed to save flow state", e) + } + } + + /** + * Restores the flow state from disk. Returns null if no state is saved or deserialization fails. + * Reconstructs [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey] + * from their dedicated proto fields, and loads [RegistrationFlowState.preExistingRegistrationData] + * from permanent storage. + */ + suspend fun restoreFlowState(): RegistrationFlowState? = withContext(Dispatchers.IO) { + try { + val data = storageController.readInProgressRegistrationData() + if (data.flowStateJson.isEmpty()) return@withContext null + + val persisted = flowStateJson.decodeFromString(PersistedFlowState.serializer(), data.flowStateJson) + + val aep = data.accountEntropyPool.takeIf { it.isNotEmpty() }?.let { AccountEntropyPool(it) } + val masterKey = data.temporaryMasterKey.takeIf { it.size > 0 }?.let { MasterKey(it.toByteArray()) } + val preExisting = storageController.getPreExistingRegistrationData() + + persisted.toRegistrationFlowState( + accountEntropyPool = aep, + temporaryMasterKey = masterKey, + preExistingRegistrationData = preExisting + ) + } catch (e: Exception) { + Log.w(TAG, "Failed to restore flow state", e) + null + } + } + + /** + * Clears any persisted flow state JSON from the in-progress registration data. + */ + suspend fun clearFlowState() = withContext(Dispatchers.IO) { + try { + storageController.updateInProgressRegistrationData { + flowStateJson = "" + } + } catch (e: Exception) { + Log.w(TAG, "Failed to clear flow state", e) + } + } + + /** + * Validates a registration session by fetching its current status from the server. + * Returns fresh [SessionMetadata] on success, or null if the session is expired/invalid. + */ + suspend fun validateSession(sessionId: String): SessionMetadata? = withContext(Dispatchers.IO) { + when (val result = networkController.getSession(sessionId)) { + is RegistrationNetworkResult.Success -> result.data + else -> null + } + } + + /** + * Checks whether the in-progress registration data indicates a completed registration + * (i.e. both ACI and PNI have been saved). + */ + suspend fun isRegistered(): Boolean = withContext(Dispatchers.IO) { + val data = storageController.readInProgressRegistrationData() + data.aci.isNotEmpty() && data.pni.isNotEmpty() + } + private fun generateKeyMaterial( existingAccountEntropyPool: AccountEntropyPool? = null, existingAciIdentityKeyPair: IdentityKeyPair? = null, @@ -488,5 +564,6 @@ class RegistrationRepository(val context: Context, val networkController: Networ companion object { private val TAG = Log.tag(RegistrationRepository::class) + private val flowStateJson = Json { ignoreUnknownKeys = true } } } diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt index 5d3498bc3a..6b284ba29e 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationViewModel.kt @@ -13,6 +13,7 @@ import androidx.lifecycle.ViewModelProvider import androidx.lifecycle.createSavedStateHandle import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewmodel.CreationExtras +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow @@ -37,9 +38,17 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save val resultBus = ResultEventBus() init { + _state.value = _state.value.copy(isRestoringNavigationState = true) viewModelScope.launch { - repository.getPreExistingRegistrationData()?.let { - _state.value = _state.value.copy(preExistingRegistrationData = it) + val restored = repository.restoreFlowState() + if (restored != null) { + Log.i(TAG, "[init] Restored flow state from disk. Backstack size: ${restored.backStack.size}, hasSession: ${restored.sessionMetadata != null}") + _state.value = validateRestoredState(restored).copy(isRestoringNavigationState = false) + } else { + _state.value = _state.value.copy( + preExistingRegistrationData = repository.getPreExistingRegistrationData(), + isRestoringNavigationState = false + ) } } } @@ -47,11 +56,15 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save fun onEvent(event: RegistrationFlowEvent) { Log.d(TAG, "[Event] $event") _state.value = applyEvent(_state.value, event) + + viewModelScope.launch(Dispatchers.IO) { + persistFlowState(event) + } } fun applyEvent(state: RegistrationFlowState, event: RegistrationFlowEvent): RegistrationFlowState { return when (event) { - is RegistrationFlowEvent.ResetState -> RegistrationFlowState() + is RegistrationFlowEvent.ResetState -> RegistrationFlowState(isRestoringNavigationState = false) is RegistrationFlowEvent.SessionUpdated -> state.copy(sessionMetadata = event.session) is RegistrationFlowEvent.E164Chosen -> state.copy(sessionE164 = event.e164) is RegistrationFlowEvent.Registered -> state.copy(accountEntropyPool = event.accountEntropyPool) @@ -63,14 +76,43 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save } private fun applyNavigationToScreenEvent(inputState: RegistrationFlowState, event: RegistrationFlowEvent.NavigateToScreen): RegistrationFlowState { - val state = inputState.copy(backStack = inputState.backStack + event.route) + return inputState.copy(backStack = inputState.backStack + event.route) + } - return when (event.route) { - is RegistrationRoute.VerificationCodeEntry -> { - state.copy(sessionMetadata = event.route.session, sessionE164 = event.route.e164) - } - else -> state + /** + * Validates a restored flow state by checking if the session is still valid. + * + * - If the session is still valid, updates session metadata with fresh data. + * - If the session is expired and the user is already registered, nulls out the session + * (post-registration screens like PinCreate don't need a session). + * - If the session is expired and the user is NOT registered, resets the backstack to + * PhoneNumberEntry with the phone number pre-filled so the user can re-submit. + */ + private suspend fun validateRestoredState(state: RegistrationFlowState): RegistrationFlowState { + val sessionMetadata = state.sessionMetadata ?: return state + + val freshSession = repository.validateSession(sessionMetadata.id) + if (freshSession != null) { + Log.i(TAG, "[validateRestoredState] Session still valid.") + return state.copy(sessionMetadata = freshSession) } + + Log.i(TAG, "[validateRestoredState] Session expired/invalid.") + + if (repository.isRegistered()) { + Log.i(TAG, "[validateRestoredState] User is registered, proceeding without session.") + return state.copy(sessionMetadata = null) + } + + Log.i(TAG, "[validateRestoredState] User is NOT registered, resetting to PhoneNumberEntry.") + return state.copy( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry + ), + sessionMetadata = null + ) } /** @@ -101,6 +143,27 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save } } + private suspend fun persistFlowState(event: RegistrationFlowEvent) { + when (event) { + is RegistrationFlowEvent.ResetState -> repository.clearFlowState() + is RegistrationFlowEvent.NavigateToScreen -> { + if (event.route is RegistrationRoute.FullyComplete) { + repository.clearFlowState() + } else { + repository.saveFlowState(_state.value) + } + } + is RegistrationFlowEvent.NavigateBack, + is RegistrationFlowEvent.SessionUpdated, + is RegistrationFlowEvent.E164Chosen, + is RegistrationFlowEvent.RecoveryPasswordInvalid -> repository.saveFlowState(_state.value) + + // No need to persist anything new, fields accounted for in proto already + is RegistrationFlowEvent.Registered, + is RegistrationFlowEvent.MasterKeyRestoredFromSvr -> { } + } + } + class Factory(private val repository: RegistrationRepository) : ViewModelProvider.Factory { override fun create(modelClass: KClass, extras: CreationExtras): T { return RegistrationViewModel(repository, extras.createSavedStateHandle()) as T diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt index a54ba0435a..6cdd90944e 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt @@ -276,14 +276,17 @@ class PhoneNumberEntryViewModel( is NetworkController.RegistrationNetworkResult.Failure -> { return when (response.error) { is NetworkController.CreateSessionError.InvalidRequest -> { + Log.w(TAG, "[CreateSession] Invalid request when creating session. Message: ${response.error.message}") state.copy(oneTimeEvent = OneTimeEvent.UnknownError) } is NetworkController.CreateSessionError.RateLimited -> { + Log.w(TAG, "[CreateSession] Rate limited (retryAfter: ${response.error.retryAfter}).") state.copy(oneTimeEvent = OneTimeEvent.RateLimited(response.error.retryAfter)) } } } is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[CreateSession] Network error.", response.exception) return state.copy(oneTimeEvent = OneTimeEvent.NetworkError) } is NetworkController.RegistrationNetworkResult.ApplicationError -> { @@ -304,23 +307,26 @@ class PhoneNumberEntryViewModel( Log.d(TAG, "Received push challenge token, submitting...") val updateResult = repository.submitPushChallengeToken(sessionMetadata.id, pushChallengeToken) sessionMetadata = when (updateResult) { - is NetworkController.RegistrationNetworkResult.Success -> updateResult.data + is NetworkController.RegistrationNetworkResult.Success -> { + Log.d(TAG, "[SubmitPushChallengeToken] Successfully submitted push challenge token.") + updateResult.data + } is NetworkController.RegistrationNetworkResult.Failure -> { - Log.w(TAG, "Failed to submit push challenge token: ${updateResult.error}") + Log.w(TAG, "[SubmitPushChallengeToken] Failed to submit push challenge token: ${updateResult.error}") sessionMetadata } is NetworkController.RegistrationNetworkResult.NetworkError -> { - Log.w(TAG, "Network error submitting push challenge token", updateResult.exception) + Log.w(TAG, "[SubmitPushChallengeToken] Network error submitting push challenge token", updateResult.exception) sessionMetadata } is NetworkController.RegistrationNetworkResult.ApplicationError -> { - Log.w(TAG, "Application error submitting push challenge token", updateResult.exception) + Log.w(TAG, "[SubmitPushChallengeToken] Application error submitting push challenge token", updateResult.exception) sessionMetadata } } state = state.copy(sessionMetadata = sessionMetadata) } else { - Log.d(TAG, "Push challenge token not received within timeout") + Log.d(TAG, "[SubmitPushChallengeToken] Push challenge token not received within timeout") } } @@ -337,41 +343,49 @@ class PhoneNumberEntryViewModel( sessionMetadata = when (verificationCodeResponse) { is NetworkController.RegistrationNetworkResult.Success -> { + Log.d(TAG, "[RequestVerificationCode] Successfully requested verification code.") verificationCodeResponse.data } is NetworkController.RegistrationNetworkResult.Failure -> { return when (verificationCodeResponse.error) { is NetworkController.RequestVerificationCodeError.InvalidRequest -> { + Log.w(TAG, "[RequestVerificationCode] Invalid request when requesting verification code. Message: ${verificationCodeResponse.error.message}") state.copy(oneTimeEvent = OneTimeEvent.UnknownError) } is NetworkController.RequestVerificationCodeError.RateLimited -> { + Log.w(TAG, "[RequestVerificationCode] Rate limited (retryAfter: ${verificationCodeResponse.error.retryAfter}).") state.copy(oneTimeEvent = OneTimeEvent.RateLimited(verificationCodeResponse.error.retryAfter)) } is NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport -> { + Log.w(TAG, "[RequestVerificationCode] Could not fulfill with requested transport.") state.copy(oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) } is NetworkController.RequestVerificationCodeError.InvalidSessionId -> { + Log.w(TAG, "[RequestVerificationCode] Invalid session ID when requesting verification code.") parentEventEmitter(RegistrationFlowEvent.ResetState) state } is NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified -> { - Log.w(TAG, "When requesting verification code, missing request information or already verified.") + Log.w(TAG, "[RequestVerificationCode] Missing request information or already verified.") state.copy(oneTimeEvent = OneTimeEvent.NetworkError) } is NetworkController.RequestVerificationCodeError.SessionNotFound -> { + Log.w(TAG, "[RequestVerificationCode] Session not found when requesting verification code.") parentEventEmitter(RegistrationFlowEvent.ResetState) state } is NetworkController.RequestVerificationCodeError.ThirdPartyServiceError -> { + Log.w(TAG, "[RequestVerificationCode] Third party service error.") state.copy(oneTimeEvent = OneTimeEvent.ThirdPartyError) } } } is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[RequestVerificationCode] Network error.", verificationCodeResponse.exception) return state.copy(oneTimeEvent = OneTimeEvent.NetworkError) } is NetworkController.RegistrationNetworkResult.ApplicationError -> { - Log.w(TAG, "Unknown error when creating session.", verificationCodeResponse.exception) + Log.w(TAG, "[RequestVerificationCode] Unknown error when creating session.", verificationCodeResponse.exception) return state.copy(oneTimeEvent = OneTimeEvent.UnknownError) } } @@ -383,7 +397,9 @@ class PhoneNumberEntryViewModel( return state } - parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164)) + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata)) + parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164)) + parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry) return state } @@ -470,9 +486,9 @@ class PhoneNumberEntryViewModel( } } - val e164 = "+${inputState.countryCode}${inputState.nationalNumber}" - - parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164)) + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata)) + parentEventEmitter(RegistrationFlowEvent.E164Chosen("+${inputState.countryCode}${inputState.nationalNumber}")) + parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry) return state } diff --git a/feature/registration/src/main/protowire/Registration.proto b/feature/registration/src/main/protowire/Registration.proto index 7bc96390da..a040327cb3 100644 --- a/feature/registration/src/main/protowire/Registration.proto +++ b/feature/registration/src/main/protowire/Registration.proto @@ -39,6 +39,9 @@ message RegistrationData { // Provisioning data (from saveProvisioningData) ProvisioningData provisioningData = 20; + + // JSON-serialized flow state snapshot (from saveFlowState/restoreFlowState) + string flowStateJson = 21; } message SvrCredential { diff --git a/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt b/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt new file mode 100644 index 0000000000..c9ab0270f7 --- /dev/null +++ b/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt @@ -0,0 +1,225 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration + +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isNull +import kotlinx.serialization.json.Json +import org.junit.Test +import org.signal.core.models.AccountEntropyPool +import org.signal.core.models.MasterKey + +class PersistedFlowStateTest { + + private val json = Json { ignoreUnknownKeys = true } + + @Test + fun `round-trip serialization with simple backstack`() { + val state = PersistedFlowState( + backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PhoneNumberEntry), + sessionMetadata = null, + sessionE164 = null, + doNotAttemptRecoveryPassword = false + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `round-trip serialization with nested Permissions route`() { + val state = PersistedFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry + ), + sessionMetadata = null, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = false + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `round-trip serialization with VerificationCodeEntry`() { + val session = NetworkController.SessionMetadata( + id = "session-123", + nextSms = 1000L, + nextCall = 2000L, + nextVerificationAttempt = 3000L, + allowedToRequestCode = true, + requestedInformation = listOf("pushChallenge"), + verified = false + ) + + val state = PersistedFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry, + RegistrationRoute.VerificationCodeEntry + ), + sessionMetadata = session, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = false + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `round-trip serialization with post-registration routes`() { + val state = PersistedFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.PinCreate, + RegistrationRoute.ArchiveRestoreSelection + ), + sessionMetadata = null, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = true + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `round-trip serialization with PinEntryForRegistrationLock`() { + val creds = NetworkController.SvrCredentials(username = "user", password = "pass") + val state = PersistedFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.PinEntryForRegistrationLock(timeRemaining = 86400000L, svrCredentials = creds) + ), + sessionMetadata = null, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = false + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `round-trip serialization with Captcha route`() { + val session = NetworkController.SessionMetadata( + id = "session-456", + nextSms = null, + nextCall = null, + nextVerificationAttempt = null, + allowedToRequestCode = false, + requestedInformation = listOf("captcha"), + verified = false + ) + + val state = PersistedFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.PhoneNumberEntry, + RegistrationRoute.Captcha(session = session) + ), + sessionMetadata = session, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = false + ) + + val encoded = json.encodeToString(PersistedFlowState.serializer(), state) + val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded) + + assertThat(decoded).isEqualTo(state) + } + + @Test + fun `deserialization ignores unknown keys for forward compatibility`() { + val validJson = """{"backStack":[{"type":"org.signal.registration.RegistrationRoute.Welcome"}],"sessionMetadata":null,"sessionE164":null,"doNotAttemptRecoveryPassword":false,"unknownField":"value"}""" + val decoded = json.decodeFromString(PersistedFlowState.serializer(), validJson) + + assertThat(decoded.backStack).isEqualTo(listOf(RegistrationRoute.Welcome)) + assertThat(decoded.sessionMetadata).isNull() + } + + @Test + fun `toPersistedFlowState captures correct fields`() { + val session = NetworkController.SessionMetadata( + id = "session-789", + nextSms = null, + nextCall = null, + nextVerificationAttempt = null, + allowedToRequestCode = true, + requestedInformation = emptyList(), + verified = true + ) + + val flowState = RegistrationFlowState( + backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate), + sessionMetadata = session, + sessionE164 = "+15551234567", + accountEntropyPool = AccountEntropyPool.generate(), + temporaryMasterKey = MasterKey(ByteArray(32)), + doNotAttemptRecoveryPassword = true + ) + + val persisted = flowState.toPersistedFlowState() + + assertThat(persisted.backStack).isEqualTo(flowState.backStack) + assertThat(persisted.sessionMetadata).isEqualTo(session) + assertThat(persisted.sessionE164).isEqualTo("+15551234567") + assertThat(persisted.doNotAttemptRecoveryPassword).isEqualTo(true) + } + + @Test + fun `toRegistrationFlowState reconstructs all fields`() { + val session = NetworkController.SessionMetadata( + id = "session-101", + nextSms = null, + nextCall = null, + nextVerificationAttempt = null, + allowedToRequestCode = true, + requestedInformation = emptyList(), + verified = true + ) + + val persisted = PersistedFlowState( + backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate), + sessionMetadata = session, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = true + ) + + val aep = AccountEntropyPool.generate() + val masterKey = MasterKey(ByteArray(32)) + + val flowState = persisted.toRegistrationFlowState( + accountEntropyPool = aep, + temporaryMasterKey = masterKey, + preExistingRegistrationData = null + ) + + assertThat(flowState.backStack).isEqualTo(persisted.backStack) + assertThat(flowState.sessionMetadata).isEqualTo(session) + assertThat(flowState.sessionE164).isEqualTo("+15551234567") + assertThat(flowState.accountEntropyPool).isEqualTo(aep) + assertThat(flowState.temporaryMasterKey).isEqualTo(masterKey) + assertThat(flowState.preExistingRegistrationData).isNull() + assertThat(flowState.doNotAttemptRecoveryPassword).isEqualTo(true) + } +} diff --git a/feature/registration/src/test/java/org/signal/registration/RegistrationNavigationTest.kt b/feature/registration/src/test/java/org/signal/registration/RegistrationNavigationTest.kt index 098601ce8d..db0ab90521 100644 --- a/feature/registration/src/test/java/org/signal/registration/RegistrationNavigationTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/RegistrationNavigationTest.kt @@ -6,6 +6,7 @@ package org.signal.registration import android.app.Application +import android.os.Looper import androidx.compose.ui.test.assertIsDisplayed import androidx.compose.ui.test.junit4.createComposeRule import androidx.compose.ui.test.onNodeWithTag @@ -13,12 +14,14 @@ import androidx.compose.ui.test.performClick import androidx.lifecycle.SavedStateHandle import androidx.test.core.app.ApplicationProvider import com.google.accompanist.permissions.ExperimentalPermissionsApi +import io.mockk.coEvery import io.mockk.mockk import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.robolectric.RobolectricTestRunner +import org.robolectric.Shadows import org.robolectric.annotation.Config import org.signal.core.ui.CoreUiDependenciesRule import org.signal.core.ui.compose.theme.SignalTheme @@ -47,7 +50,11 @@ class RegistrationNavigationTest { @Before fun setup() { mockRepository = mockk(relaxed = true) + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + // Allow the init coroutine to complete so isRestoring becomes false. + Shadows.shadowOf(Looper.getMainLooper()).idle() } @Test @@ -55,6 +62,11 @@ class RegistrationNavigationTest { // Given val permissionsState = createMockPermissionsState() + // Verify the ViewModel state is correctly initialized + val state = viewModel.state.value + assert(!state.isRestoringNavigationState) { "isRestoring should be false after init, was: ${state.isRestoringNavigationState}" } + assert(state.backStack == listOf(RegistrationRoute.Welcome)) { "backStack should be [Welcome], was: ${state.backStack}" } + composeTestRule.setContent { SignalTheme(incognitoKeyboardEnabled = false) { RegistrationNavHost( @@ -86,6 +98,7 @@ class RegistrationNavigationTest { // When composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // Then - verify Permissions screen is displayed composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_SCREEN).assertIsDisplayed() @@ -108,9 +121,11 @@ class RegistrationNavigationTest { // Navigate to Permissions screen first composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // When composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NEXT_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // Then - verify PhoneNumber screen is displayed composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed() @@ -133,9 +148,11 @@ class RegistrationNavigationTest { // Navigate to Permissions screen first composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // When composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NOT_NOW_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // Then - verify PhoneNumber screen is displayed composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed() @@ -164,6 +181,7 @@ class RegistrationNavigationTest { // When composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_HAS_OLD_PHONE_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // Then - verify Permissions screen is displayed // (After permissions, user would go to RestoreViaQr screen) @@ -188,6 +206,7 @@ class RegistrationNavigationTest { // When composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_NO_OLD_PHONE_BUTTON).performClick() + Shadows.shadowOf(Looper.getMainLooper()).idle() // Then - verify Restore screen is displayed (or its expected content) // Note: Update this assertion based on actual Restore screen content when implemented diff --git a/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt b/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt new file mode 100644 index 0000000000..152b2a9826 --- /dev/null +++ b/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt @@ -0,0 +1,251 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.registration + +import androidx.lifecycle.SavedStateHandle +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isNull +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.resetMain +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.test.setMain +import org.junit.After +import org.junit.Before +import org.junit.Test + +@OptIn(ExperimentalCoroutinesApi::class) +class RegistrationViewModelRestoreTest { + + private val testDispatcher = StandardTestDispatcher() + private lateinit var mockRepository: RegistrationRepository + + @Before + fun setup() { + Dispatchers.setMain(testDispatcher) + mockRepository = mockk(relaxed = true) + } + + @After + fun tearDown() { + Dispatchers.resetMain() + } + + @Test + fun `no saved state starts fresh and loads preExistingRegistrationData`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + val state = viewModel.state.value + assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome)) + assertThat(state.sessionMetadata).isNull() + } + + @Test + fun `restore with valid session proceeds normally with updated session metadata`() = runTest(testDispatcher) { + val savedSession = createSessionMetadata("session-1") + val freshSession = createSessionMetadata("session-1").copy(nextSms = 9999L) + + val savedState = RegistrationFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry, + RegistrationRoute.VerificationCodeEntry + ), + sessionMetadata = savedSession, + sessionE164 = "+15551234567" + ) + + coEvery { mockRepository.restoreFlowState() } returns savedState + coEvery { mockRepository.validateSession("session-1") } returns freshSession + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + val state = viewModel.state.value + assertThat(state.backStack).isEqualTo(savedState.backStack) + assertThat(state.sessionMetadata).isEqualTo(freshSession) + assertThat(state.sessionE164).isEqualTo("+15551234567") + } + + @Test + fun `restore with expired session and not registered resets to PhoneNumberEntry with e164 preserved`() = runTest(testDispatcher) { + val savedSession = createSessionMetadata("session-expired") + + val savedState = RegistrationFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry, + RegistrationRoute.VerificationCodeEntry + ), + sessionMetadata = savedSession, + sessionE164 = "+15559876543" + ) + + coEvery { mockRepository.restoreFlowState() } returns savedState + coEvery { mockRepository.validateSession("session-expired") } returns null + coEvery { mockRepository.isRegistered() } returns false + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + val state = viewModel.state.value + assertThat(state.backStack).isEqualTo( + listOf( + RegistrationRoute.Welcome, + RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry), + RegistrationRoute.PhoneNumberEntry + ) + ) + assertThat(state.sessionMetadata).isNull() + assertThat(state.sessionE164).isEqualTo("+15559876543") + } + + @Test + fun `restore with expired session and already registered proceeds with null session`() = runTest(testDispatcher) { + val savedSession = createSessionMetadata("session-expired-2") + + val savedState = RegistrationFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.PinCreate + ), + sessionMetadata = savedSession, + sessionE164 = "+15551234567" + ) + + coEvery { mockRepository.restoreFlowState() } returns savedState + coEvery { mockRepository.validateSession("session-expired-2") } returns null + coEvery { mockRepository.isRegistered() } returns true + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + val state = viewModel.state.value + assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate)) + assertThat(state.sessionMetadata).isNull() + assertThat(state.sessionE164).isEqualTo("+15551234567") + } + + @Test + fun `restore with no session skips validation`() = runTest(testDispatcher) { + val savedState = RegistrationFlowState( + backStack = listOf( + RegistrationRoute.Welcome, + RegistrationRoute.PinCreate, + RegistrationRoute.ArchiveRestoreSelection + ), + sessionMetadata = null, + sessionE164 = "+15551234567", + doNotAttemptRecoveryPassword = true + ) + + coEvery { mockRepository.restoreFlowState() } returns savedState + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + val state = viewModel.state.value + assertThat(state.backStack).isEqualTo(savedState.backStack) + assertThat(state.sessionMetadata).isNull() + assertThat(state.doNotAttemptRecoveryPassword).isEqualTo(true) + + coVerify(exactly = 0) { mockRepository.validateSession(any()) } + } + + @Test + fun `onEvent ResetState clears flow state`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + viewModel.onEvent(RegistrationFlowEvent.ResetState) + advanceUntilIdle() + + coVerify { mockRepository.clearFlowState() } + } + + @Test + fun `onEvent NavigateToScreen FullyComplete clears flow state`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.FullyComplete)) + advanceUntilIdle() + + coVerify { mockRepository.clearFlowState() } + } + + @Test + fun `onEvent NavigateToScreen saves flow state`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.PhoneNumberEntry)) + advanceUntilIdle() + + coVerify { mockRepository.saveFlowState(any()) } + } + + @Test + fun `onEvent Registered does not save flow state`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + viewModel.onEvent(RegistrationFlowEvent.Registered(org.signal.core.models.AccountEntropyPool.generate())) + advanceUntilIdle() + + coVerify(exactly = 0) { mockRepository.saveFlowState(any()) } + } + + @Test + fun `onEvent MasterKeyRestoredFromSvr does not save flow state`() = runTest(testDispatcher) { + coEvery { mockRepository.restoreFlowState() } returns null + coEvery { mockRepository.getPreExistingRegistrationData() } returns null + + val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) + advanceUntilIdle() + + viewModel.onEvent(RegistrationFlowEvent.MasterKeyRestoredFromSvr(org.signal.core.models.MasterKey(ByteArray(32)))) + advanceUntilIdle() + + coVerify(exactly = 0) { mockRepository.saveFlowState(any()) } + } + + private fun createSessionMetadata(id: String = "test-session"): NetworkController.SessionMetadata { + return NetworkController.SessionMetadata( + id = id, + nextSms = 1000L, + nextCall = 2000L, + nextVerificationAttempt = 3000L, + allowedToRequestCode = true, + requestedInformation = emptyList(), + verified = false + ) + } +} diff --git a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt index c9cc1575fe..33414c1041 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt @@ -292,8 +292,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().sessionMetadata).isNotNull() - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -426,8 +428,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Should not create a new session, just request verification code - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -563,8 +567,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation to verification code entry - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -595,8 +601,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite no push challenge token - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -631,8 +639,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite push challenge submission failure - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -662,8 +672,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite network error - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -693,8 +705,10 @@ class PhoneNumberEntryViewModelTest { assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite application error - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf() @@ -744,8 +758,10 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.CaptchaCompleted("captcha-token"), stateEmitter, parentEventEmitter) - assertThat(emittedEvents).hasSize(1) - assertThat(emittedEvents.first()) + assertThat(emittedEvents).hasSize(3) + assertThat(emittedEvents[0]).isInstanceOf() + assertThat(emittedEvents[1]).isInstanceOf() + assertThat(emittedEvents[2]) .isInstanceOf() .prop(RegistrationFlowEvent.NavigateToScreen::route) .isInstanceOf()