Make regV5 resumable if the app closes.

This commit is contained in:
Greyson Parrelli
2026-03-18 16:53:50 -04:00
committed by Michelle Tang
parent c7ec3ab837
commit f09bf5b14c
15 changed files with 895 additions and 41 deletions

View File

@@ -69,7 +69,7 @@ fun NetworkDebugOverlay(
onClick = { showDialog = true }, onClick = { showDialog = true },
dragOffset = dragOffset, dragOffset = dragOffset,
onDrag = { delta -> dragOffset += delta }, onDrag = { delta -> dragOffset += delta },
modifier = Modifier.align(Alignment.CenterEnd) modifier = Modifier.align(Alignment.TopEnd)
) )
if (showDialog) { if (showDialog) {

View File

@@ -14,7 +14,9 @@ import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.AlertDialog import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.ButtonDefaults
@@ -74,6 +76,7 @@ fun MainScreen(
Column( Column(
modifier = modifier modifier = modifier
.fillMaxSize() .fillMaxSize()
.verticalScroll(rememberScrollState())
.padding(24.dp), .padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally, horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center verticalArrangement = Arrangement.Center
@@ -95,6 +98,30 @@ fun MainScreen(
Spacer(modifier = Modifier.height(32.dp)) Spacer(modifier = Modifier.height(32.dp))
} }
if (state.pendingFlowState != null) {
PendingFlowStateCard(state.pendingFlowState)
Spacer(modifier = Modifier.height(16.dp))
Button(
onClick = { onEvent(MainScreenEvents.LaunchRegistration) },
modifier = Modifier.fillMaxWidth()
) {
Text("Resume Registration")
}
TextButton(
onClick = { showClearDataDialog = true },
modifier = Modifier.fillMaxWidth(),
colors = ButtonDefaults.textButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Text("Clear Pending Data")
}
Spacer(modifier = Modifier.height(16.dp))
}
if (state.existingRegistrationState != null) { if (state.existingRegistrationState != null) {
if (state.registrationExpired) { if (state.registrationExpired) {
Row( Row(
@@ -150,7 +177,7 @@ fun MainScreen(
) { ) {
Text("Clear All Data") Text("Clear All Data")
} }
} else { } else if (state.pendingFlowState == null) {
Button( Button(
onClick = { onEvent(MainScreenEvents.LaunchRegistration) }, onClick = { onEvent(MainScreenEvents.LaunchRegistration) },
modifier = Modifier.fillMaxWidth() modifier = Modifier.fillMaxWidth()
@@ -213,6 +240,55 @@ private fun RegistrationField(label: String, value: String) {
} }
} }
@Composable
private fun PendingFlowStateCard(pending: MainScreenState.PendingFlowState) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer
)
) {
Column(
modifier = Modifier.padding(16.dp)
) {
Text(
text = "In-Progress Registration",
style = MaterialTheme.typography.titleMedium,
color = MaterialTheme.colorScheme.onTertiaryContainer
)
HorizontalDivider(modifier = Modifier.padding(vertical = 8.dp))
RegistrationField(label = "Current Screen", value = pending.currentScreen)
RegistrationField(label = "Backstack Depth", value = pending.backstackSize.toString())
if (pending.e164 != null) {
RegistrationField(label = "Phone Number", value = pending.e164)
}
RegistrationField(label = "Has Session", value = if (pending.hasSession) "Yes" else "No")
RegistrationField(label = "Has AEP", value = if (pending.hasAccountEntropyPool) "Yes" else "No")
}
}
}
@Preview(showBackground = true)
@Composable
private fun MainScreenWithPendingFlowStatePreview() {
Previews.Preview {
MainScreen(
state = MainScreenState(
pendingFlowState = MainScreenState.PendingFlowState(
e164 = "+15551234567",
backstackSize = 4,
currentScreen = "VerificationCodeEntry",
hasSession = true,
hasAccountEntropyPool = false
)
),
onEvent = {}
)
}
}
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
private fun MainScreenPreview() { private fun MainScreenPreview() {

View File

@@ -7,8 +7,17 @@ package org.signal.registration.sample.screens.main
data class MainScreenState( data class MainScreenState(
val existingRegistrationState: ExistingRegistrationState? = null, val existingRegistrationState: ExistingRegistrationState? = null,
val registrationExpired: Boolean = false val registrationExpired: Boolean = false,
val pendingFlowState: PendingFlowState? = null
) { ) {
data class PendingFlowState(
val e164: String?,
val backstackSize: Int,
val currentScreen: String,
val hasSession: Boolean,
val hasAccountEntropyPool: Boolean
)
data class ExistingRegistrationState( data class ExistingRegistrationState(
val phoneNumber: String, val phoneNumber: String,
val aci: String, val aci: String,

View File

@@ -12,9 +12,11 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json
import org.signal.core.util.Base64 import org.signal.core.util.Base64
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.registration.NetworkController import org.signal.registration.NetworkController
import org.signal.registration.PersistedFlowState
import org.signal.registration.StorageController import org.signal.registration.StorageController
import org.signal.registration.sample.storage.RegistrationPreferences import org.signal.registration.sample.storage.RegistrationPreferences
@@ -75,6 +77,7 @@ class MainScreenViewModel(
} else { } else {
null null
}, },
pendingFlowState = loadPendingFlowState(),
registrationExpired = false registrationExpired = false
) )
@@ -84,6 +87,27 @@ class MainScreenViewModel(
} }
} }
private suspend fun loadPendingFlowState(): MainScreenState.PendingFlowState? {
return try {
val data = storageController.readInProgressRegistrationData()
if (data.flowStateJson.isEmpty()) return null
val json = Json { ignoreUnknownKeys = true }
val persisted = json.decodeFromString(PersistedFlowState.serializer(), data.flowStateJson)
MainScreenState.PendingFlowState(
e164 = persisted.sessionE164,
backstackSize = persisted.backStack.size,
currentScreen = persisted.backStack.lastOrNull()?.let { it::class.simpleName } ?: "Unknown",
hasSession = persisted.sessionMetadata != null,
hasAccountEntropyPool = data.accountEntropyPool.isNotEmpty()
)
} catch (e: Exception) {
Log.w(TAG, "Failed to load pending flow state", e)
null
}
}
private suspend fun checkRegistrationStatus() { private suspend fun checkRegistrationStatus() {
when (val result = networkController.getSvrCredentials()) { when (val result = networkController.getSvrCredentials()) {
is NetworkController.RegistrationNetworkResult.Success -> { is NetworkController.RegistrationNetworkResult.Success -> {

View File

@@ -0,0 +1,60 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import kotlinx.serialization.Serializable
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
/**
* A serializable snapshot of [RegistrationFlowState] fields that need to survive app kills.
*
* Fields like [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey]
* are reconstructed from dedicated proto fields, not from this JSON snapshot.
* [RegistrationFlowState.preExistingRegistrationData] is loaded from permanent storage.
*/
@Serializable
data class PersistedFlowState(
val backStack: List<RegistrationRoute>,
val sessionMetadata: NetworkController.SessionMetadata?,
val sessionE164: String?,
val doNotAttemptRecoveryPassword: Boolean
)
/**
* Extracts the persistable fields from a [RegistrationFlowState].
*/
fun RegistrationFlowState.toPersistedFlowState(): PersistedFlowState {
return PersistedFlowState(
backStack = backStack,
sessionMetadata = sessionMetadata,
sessionE164 = sessionE164,
doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword
)
}
/**
* Reconstructs a full [RegistrationFlowState] from persisted data and separately-stored fields.
*
* @param accountEntropyPool Restored from the proto's dedicated `accountEntropyPool` field.
* @param temporaryMasterKey Restored from the proto's dedicated `temporaryMasterKey` field.
* @param preExistingRegistrationData Loaded from permanent storage via [StorageController.getPreExistingRegistrationData].
*/
fun PersistedFlowState.toRegistrationFlowState(
accountEntropyPool: AccountEntropyPool?,
temporaryMasterKey: MasterKey?,
preExistingRegistrationData: PreExistingRegistrationData?
): RegistrationFlowState {
return RegistrationFlowState(
backStack = backStack,
sessionMetadata = sessionMetadata,
sessionE164 = sessionE164,
accountEntropyPool = accountEntropyPool,
temporaryMasterKey = temporaryMasterKey,
preExistingRegistrationData = preExistingRegistrationData,
doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword
)
}

View File

@@ -37,5 +37,8 @@ data class RegistrationFlowState(
val preExistingRegistrationData: PreExistingRegistrationData? = null, val preExistingRegistrationData: PreExistingRegistrationData? = null,
/** If true, do not attempt any flows where we generate RRP's. Create a session instead. */ /** If true, do not attempt any flows where we generate RRP's. Create a session instead. */
val doNotAttemptRecoveryPassword: Boolean = false val doNotAttemptRecoveryPassword: Boolean = false,
/** If true, the ViewModel is still deciding whether to restore a previous flow or start fresh. */
val isRestoringNavigationState: Boolean = true
) : Parcelable, DebugLoggable ) : Parcelable, DebugLoggable

View File

@@ -8,9 +8,13 @@
package org.signal.registration package org.signal.registration
import android.os.Parcelable import android.os.Parcelable
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
@@ -62,6 +66,7 @@ import org.signal.registration.screens.welcome.WelcomeScreenEvents
* Navigation routes for the registration flow. * Navigation routes for the registration flow.
* Using @Serializable and NavKey for type-safe navigation with Navigation 3. * Using @Serializable and NavKey for type-safe navigation with Navigation 3.
*/ */
@Serializable
@Parcelize @Parcelize
sealed interface RegistrationRoute : NavKey, Parcelable { sealed interface RegistrationRoute : NavKey, Parcelable {
@Serializable @Serializable
@@ -77,7 +82,7 @@ sealed interface RegistrationRoute : NavKey, Parcelable {
data object CountryCodePicker : RegistrationRoute data object CountryCodePicker : RegistrationRoute
@Serializable @Serializable
data class VerificationCodeEntry(val session: NetworkController.SessionMetadata, val e164: String) : RegistrationRoute data object VerificationCodeEntry : RegistrationRoute
@Serializable @Serializable
data class Captcha(val session: NetworkController.SessionMetadata) : RegistrationRoute data class Captcha(val session: NetworkController.SessionMetadata) : RegistrationRoute
@@ -150,6 +155,13 @@ fun RegistrationNavHost(
val registrationState by viewModel.state.collectAsStateWithLifecycle() val registrationState by viewModel.state.collectAsStateWithLifecycle()
val permissions: MultiplePermissionsState = permissionsState ?: rememberMultiplePermissionsState(viewModel.getRequiredPermissions()) val permissions: MultiplePermissionsState = permissionsState ?: rememberMultiplePermissionsState(viewModel.getRequiredPermissions())
if (registrationState.isRestoringNavigationState) {
Box(modifier = modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
CircularProgressIndicator()
}
return
}
val entryProvider = entryProvider { val entryProvider = entryProvider {
navigationEntries( navigationEntries(
registrationRepository = registrationRepository, registrationRepository = registrationRepository,

View File

@@ -10,6 +10,7 @@ import android.content.Context
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import okio.ByteString.Companion.toByteString import okio.ByteString.Companion.toByteString
import org.signal.core.models.AccountEntropyPool import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey import org.signal.core.models.MasterKey
@@ -410,6 +411,81 @@ class RegistrationRepository(val context: Context, val networkController: Networ
return storageController.getPreExistingRegistrationData() return storageController.getPreExistingRegistrationData()
} }
/**
* Persists the current flow state as JSON in the in-progress registration data proto.
*/
suspend fun saveFlowState(state: RegistrationFlowState) = withContext(Dispatchers.IO) {
try {
val json = flowStateJson.encodeToString(PersistedFlowState.serializer(), state.toPersistedFlowState())
storageController.updateInProgressRegistrationData {
flowStateJson = json
}
} catch (e: Exception) {
Log.w(TAG, "Failed to save flow state", e)
}
}
/**
* Restores the flow state from disk. Returns null if no state is saved or deserialization fails.
* Reconstructs [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey]
* from their dedicated proto fields, and loads [RegistrationFlowState.preExistingRegistrationData]
* from permanent storage.
*/
suspend fun restoreFlowState(): RegistrationFlowState? = withContext(Dispatchers.IO) {
try {
val data = storageController.readInProgressRegistrationData()
if (data.flowStateJson.isEmpty()) return@withContext null
val persisted = flowStateJson.decodeFromString(PersistedFlowState.serializer(), data.flowStateJson)
val aep = data.accountEntropyPool.takeIf { it.isNotEmpty() }?.let { AccountEntropyPool(it) }
val masterKey = data.temporaryMasterKey.takeIf { it.size > 0 }?.let { MasterKey(it.toByteArray()) }
val preExisting = storageController.getPreExistingRegistrationData()
persisted.toRegistrationFlowState(
accountEntropyPool = aep,
temporaryMasterKey = masterKey,
preExistingRegistrationData = preExisting
)
} catch (e: Exception) {
Log.w(TAG, "Failed to restore flow state", e)
null
}
}
/**
* Clears any persisted flow state JSON from the in-progress registration data.
*/
suspend fun clearFlowState() = withContext(Dispatchers.IO) {
try {
storageController.updateInProgressRegistrationData {
flowStateJson = ""
}
} catch (e: Exception) {
Log.w(TAG, "Failed to clear flow state", e)
}
}
/**
* Validates a registration session by fetching its current status from the server.
* Returns fresh [SessionMetadata] on success, or null if the session is expired/invalid.
*/
suspend fun validateSession(sessionId: String): SessionMetadata? = withContext(Dispatchers.IO) {
when (val result = networkController.getSession(sessionId)) {
is RegistrationNetworkResult.Success -> result.data
else -> null
}
}
/**
* Checks whether the in-progress registration data indicates a completed registration
* (i.e. both ACI and PNI have been saved).
*/
suspend fun isRegistered(): Boolean = withContext(Dispatchers.IO) {
val data = storageController.readInProgressRegistrationData()
data.aci.isNotEmpty() && data.pni.isNotEmpty()
}
private fun generateKeyMaterial( private fun generateKeyMaterial(
existingAccountEntropyPool: AccountEntropyPool? = null, existingAccountEntropyPool: AccountEntropyPool? = null,
existingAciIdentityKeyPair: IdentityKeyPair? = null, existingAciIdentityKeyPair: IdentityKeyPair? = null,
@@ -488,5 +564,6 @@ class RegistrationRepository(val context: Context, val networkController: Networ
companion object { companion object {
private val TAG = Log.tag(RegistrationRepository::class) private val TAG = Log.tag(RegistrationRepository::class)
private val flowStateJson = Json { ignoreUnknownKeys = true }
} }
} }

View File

@@ -13,6 +13,7 @@ import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.createSavedStateHandle import androidx.lifecycle.createSavedStateHandle
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.CreationExtras
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
@@ -37,9 +38,17 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
val resultBus = ResultEventBus() val resultBus = ResultEventBus()
init { init {
_state.value = _state.value.copy(isRestoringNavigationState = true)
viewModelScope.launch { viewModelScope.launch {
repository.getPreExistingRegistrationData()?.let { val restored = repository.restoreFlowState()
_state.value = _state.value.copy(preExistingRegistrationData = it) if (restored != null) {
Log.i(TAG, "[init] Restored flow state from disk. Backstack size: ${restored.backStack.size}, hasSession: ${restored.sessionMetadata != null}")
_state.value = validateRestoredState(restored).copy(isRestoringNavigationState = false)
} else {
_state.value = _state.value.copy(
preExistingRegistrationData = repository.getPreExistingRegistrationData(),
isRestoringNavigationState = false
)
} }
} }
} }
@@ -47,11 +56,15 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
fun onEvent(event: RegistrationFlowEvent) { fun onEvent(event: RegistrationFlowEvent) {
Log.d(TAG, "[Event] $event") Log.d(TAG, "[Event] $event")
_state.value = applyEvent(_state.value, event) _state.value = applyEvent(_state.value, event)
viewModelScope.launch(Dispatchers.IO) {
persistFlowState(event)
}
} }
fun applyEvent(state: RegistrationFlowState, event: RegistrationFlowEvent): RegistrationFlowState { fun applyEvent(state: RegistrationFlowState, event: RegistrationFlowEvent): RegistrationFlowState {
return when (event) { return when (event) {
is RegistrationFlowEvent.ResetState -> RegistrationFlowState() is RegistrationFlowEvent.ResetState -> RegistrationFlowState(isRestoringNavigationState = false)
is RegistrationFlowEvent.SessionUpdated -> state.copy(sessionMetadata = event.session) is RegistrationFlowEvent.SessionUpdated -> state.copy(sessionMetadata = event.session)
is RegistrationFlowEvent.E164Chosen -> state.copy(sessionE164 = event.e164) is RegistrationFlowEvent.E164Chosen -> state.copy(sessionE164 = event.e164)
is RegistrationFlowEvent.Registered -> state.copy(accountEntropyPool = event.accountEntropyPool) is RegistrationFlowEvent.Registered -> state.copy(accountEntropyPool = event.accountEntropyPool)
@@ -63,14 +76,43 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
} }
private fun applyNavigationToScreenEvent(inputState: RegistrationFlowState, event: RegistrationFlowEvent.NavigateToScreen): RegistrationFlowState { private fun applyNavigationToScreenEvent(inputState: RegistrationFlowState, event: RegistrationFlowEvent.NavigateToScreen): RegistrationFlowState {
val state = inputState.copy(backStack = inputState.backStack + event.route) return inputState.copy(backStack = inputState.backStack + event.route)
}
return when (event.route) { /**
is RegistrationRoute.VerificationCodeEntry -> { * Validates a restored flow state by checking if the session is still valid.
state.copy(sessionMetadata = event.route.session, sessionE164 = event.route.e164) *
} * - If the session is still valid, updates session metadata with fresh data.
else -> state * - If the session is expired and the user is already registered, nulls out the session
* (post-registration screens like PinCreate don't need a session).
* - If the session is expired and the user is NOT registered, resets the backstack to
* PhoneNumberEntry with the phone number pre-filled so the user can re-submit.
*/
private suspend fun validateRestoredState(state: RegistrationFlowState): RegistrationFlowState {
val sessionMetadata = state.sessionMetadata ?: return state
val freshSession = repository.validateSession(sessionMetadata.id)
if (freshSession != null) {
Log.i(TAG, "[validateRestoredState] Session still valid.")
return state.copy(sessionMetadata = freshSession)
} }
Log.i(TAG, "[validateRestoredState] Session expired/invalid.")
if (repository.isRegistered()) {
Log.i(TAG, "[validateRestoredState] User is registered, proceeding without session.")
return state.copy(sessionMetadata = null)
}
Log.i(TAG, "[validateRestoredState] User is NOT registered, resetting to PhoneNumberEntry.")
return state.copy(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
),
sessionMetadata = null
)
} }
/** /**
@@ -101,6 +143,27 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
} }
} }
private suspend fun persistFlowState(event: RegistrationFlowEvent) {
when (event) {
is RegistrationFlowEvent.ResetState -> repository.clearFlowState()
is RegistrationFlowEvent.NavigateToScreen -> {
if (event.route is RegistrationRoute.FullyComplete) {
repository.clearFlowState()
} else {
repository.saveFlowState(_state.value)
}
}
is RegistrationFlowEvent.NavigateBack,
is RegistrationFlowEvent.SessionUpdated,
is RegistrationFlowEvent.E164Chosen,
is RegistrationFlowEvent.RecoveryPasswordInvalid -> repository.saveFlowState(_state.value)
// No need to persist anything new, fields accounted for in proto already
is RegistrationFlowEvent.Registered,
is RegistrationFlowEvent.MasterKeyRestoredFromSvr -> { }
}
}
class Factory(private val repository: RegistrationRepository) : ViewModelProvider.Factory { class Factory(private val repository: RegistrationRepository) : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: KClass<T>, extras: CreationExtras): T { override fun <T : ViewModel> create(modelClass: KClass<T>, extras: CreationExtras): T {
return RegistrationViewModel(repository, extras.createSavedStateHandle()) as T return RegistrationViewModel(repository, extras.createSavedStateHandle()) as T

View File

@@ -276,14 +276,17 @@ class PhoneNumberEntryViewModel(
is NetworkController.RegistrationNetworkResult.Failure<NetworkController.CreateSessionError> -> { is NetworkController.RegistrationNetworkResult.Failure<NetworkController.CreateSessionError> -> {
return when (response.error) { return when (response.error) {
is NetworkController.CreateSessionError.InvalidRequest -> { is NetworkController.CreateSessionError.InvalidRequest -> {
Log.w(TAG, "[CreateSession] Invalid request when creating session. Message: ${response.error.message}")
state.copy(oneTimeEvent = OneTimeEvent.UnknownError) state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
} }
is NetworkController.CreateSessionError.RateLimited -> { is NetworkController.CreateSessionError.RateLimited -> {
Log.w(TAG, "[CreateSession] Rate limited (retryAfter: ${response.error.retryAfter}).")
state.copy(oneTimeEvent = OneTimeEvent.RateLimited(response.error.retryAfter)) state.copy(oneTimeEvent = OneTimeEvent.RateLimited(response.error.retryAfter))
} }
} }
} }
is NetworkController.RegistrationNetworkResult.NetworkError -> { is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "[CreateSession] Network error.", response.exception)
return state.copy(oneTimeEvent = OneTimeEvent.NetworkError) return state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
} }
is NetworkController.RegistrationNetworkResult.ApplicationError -> { is NetworkController.RegistrationNetworkResult.ApplicationError -> {
@@ -304,23 +307,26 @@ class PhoneNumberEntryViewModel(
Log.d(TAG, "Received push challenge token, submitting...") Log.d(TAG, "Received push challenge token, submitting...")
val updateResult = repository.submitPushChallengeToken(sessionMetadata.id, pushChallengeToken) val updateResult = repository.submitPushChallengeToken(sessionMetadata.id, pushChallengeToken)
sessionMetadata = when (updateResult) { sessionMetadata = when (updateResult) {
is NetworkController.RegistrationNetworkResult.Success -> updateResult.data is NetworkController.RegistrationNetworkResult.Success -> {
Log.d(TAG, "[SubmitPushChallengeToken] Successfully submitted push challenge token.")
updateResult.data
}
is NetworkController.RegistrationNetworkResult.Failure -> { is NetworkController.RegistrationNetworkResult.Failure -> {
Log.w(TAG, "Failed to submit push challenge token: ${updateResult.error}") Log.w(TAG, "[SubmitPushChallengeToken] Failed to submit push challenge token: ${updateResult.error}")
sessionMetadata sessionMetadata
} }
is NetworkController.RegistrationNetworkResult.NetworkError -> { is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "Network error submitting push challenge token", updateResult.exception) Log.w(TAG, "[SubmitPushChallengeToken] Network error submitting push challenge token", updateResult.exception)
sessionMetadata sessionMetadata
} }
is NetworkController.RegistrationNetworkResult.ApplicationError -> { is NetworkController.RegistrationNetworkResult.ApplicationError -> {
Log.w(TAG, "Application error submitting push challenge token", updateResult.exception) Log.w(TAG, "[SubmitPushChallengeToken] Application error submitting push challenge token", updateResult.exception)
sessionMetadata sessionMetadata
} }
} }
state = state.copy(sessionMetadata = sessionMetadata) state = state.copy(sessionMetadata = sessionMetadata)
} else { } else {
Log.d(TAG, "Push challenge token not received within timeout") Log.d(TAG, "[SubmitPushChallengeToken] Push challenge token not received within timeout")
} }
} }
@@ -337,41 +343,49 @@ class PhoneNumberEntryViewModel(
sessionMetadata = when (verificationCodeResponse) { sessionMetadata = when (verificationCodeResponse) {
is NetworkController.RegistrationNetworkResult.Success<NetworkController.SessionMetadata> -> { is NetworkController.RegistrationNetworkResult.Success<NetworkController.SessionMetadata> -> {
Log.d(TAG, "[RequestVerificationCode] Successfully requested verification code.")
verificationCodeResponse.data verificationCodeResponse.data
} }
is NetworkController.RegistrationNetworkResult.Failure<NetworkController.RequestVerificationCodeError> -> { is NetworkController.RegistrationNetworkResult.Failure<NetworkController.RequestVerificationCodeError> -> {
return when (verificationCodeResponse.error) { return when (verificationCodeResponse.error) {
is NetworkController.RequestVerificationCodeError.InvalidRequest -> { is NetworkController.RequestVerificationCodeError.InvalidRequest -> {
Log.w(TAG, "[RequestVerificationCode] Invalid request when requesting verification code. Message: ${verificationCodeResponse.error.message}")
state.copy(oneTimeEvent = OneTimeEvent.UnknownError) state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
} }
is NetworkController.RequestVerificationCodeError.RateLimited -> { is NetworkController.RequestVerificationCodeError.RateLimited -> {
Log.w(TAG, "[RequestVerificationCode] Rate limited (retryAfter: ${verificationCodeResponse.error.retryAfter}).")
state.copy(oneTimeEvent = OneTimeEvent.RateLimited(verificationCodeResponse.error.retryAfter)) state.copy(oneTimeEvent = OneTimeEvent.RateLimited(verificationCodeResponse.error.retryAfter))
} }
is NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport -> { is NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport -> {
Log.w(TAG, "[RequestVerificationCode] Could not fulfill with requested transport.")
state.copy(oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) state.copy(oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport)
} }
is NetworkController.RequestVerificationCodeError.InvalidSessionId -> { is NetworkController.RequestVerificationCodeError.InvalidSessionId -> {
Log.w(TAG, "[RequestVerificationCode] Invalid session ID when requesting verification code.")
parentEventEmitter(RegistrationFlowEvent.ResetState) parentEventEmitter(RegistrationFlowEvent.ResetState)
state state
} }
is NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified -> { is NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified -> {
Log.w(TAG, "When requesting verification code, missing request information or already verified.") Log.w(TAG, "[RequestVerificationCode] Missing request information or already verified.")
state.copy(oneTimeEvent = OneTimeEvent.NetworkError) state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
} }
is NetworkController.RequestVerificationCodeError.SessionNotFound -> { is NetworkController.RequestVerificationCodeError.SessionNotFound -> {
Log.w(TAG, "[RequestVerificationCode] Session not found when requesting verification code.")
parentEventEmitter(RegistrationFlowEvent.ResetState) parentEventEmitter(RegistrationFlowEvent.ResetState)
state state
} }
is NetworkController.RequestVerificationCodeError.ThirdPartyServiceError -> { is NetworkController.RequestVerificationCodeError.ThirdPartyServiceError -> {
Log.w(TAG, "[RequestVerificationCode] Third party service error.")
state.copy(oneTimeEvent = OneTimeEvent.ThirdPartyError) state.copy(oneTimeEvent = OneTimeEvent.ThirdPartyError)
} }
} }
} }
is NetworkController.RegistrationNetworkResult.NetworkError -> { is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "[RequestVerificationCode] Network error.", verificationCodeResponse.exception)
return state.copy(oneTimeEvent = OneTimeEvent.NetworkError) return state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
} }
is NetworkController.RegistrationNetworkResult.ApplicationError -> { is NetworkController.RegistrationNetworkResult.ApplicationError -> {
Log.w(TAG, "Unknown error when creating session.", verificationCodeResponse.exception) Log.w(TAG, "[RequestVerificationCode] Unknown error when creating session.", verificationCodeResponse.exception)
return state.copy(oneTimeEvent = OneTimeEvent.UnknownError) return state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
} }
} }
@@ -383,7 +397,9 @@ class PhoneNumberEntryViewModel(
return state return state
} }
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164)) parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata))
parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164))
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry)
return state return state
} }
@@ -470,9 +486,9 @@ class PhoneNumberEntryViewModel(
} }
} }
val e164 = "+${inputState.countryCode}${inputState.nationalNumber}" parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata))
parentEventEmitter(RegistrationFlowEvent.E164Chosen("+${inputState.countryCode}${inputState.nationalNumber}"))
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164)) parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry)
return state return state
} }

View File

@@ -39,6 +39,9 @@ message RegistrationData {
// Provisioning data (from saveProvisioningData) // Provisioning data (from saveProvisioningData)
ProvisioningData provisioningData = 20; ProvisioningData provisioningData = 20;
// JSON-serialized flow state snapshot (from saveFlowState/restoreFlowState)
string flowStateJson = 21;
} }
message SvrCredential { message SvrCredential {

View File

@@ -0,0 +1,225 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import kotlinx.serialization.json.Json
import org.junit.Test
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
class PersistedFlowStateTest {
private val json = Json { ignoreUnknownKeys = true }
@Test
fun `round-trip serialization with simple backstack`() {
val state = PersistedFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PhoneNumberEntry),
sessionMetadata = null,
sessionE164 = null,
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with nested Permissions route`() {
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with VerificationCodeEntry`() {
val session = NetworkController.SessionMetadata(
id = "session-123",
nextSms = 1000L,
nextCall = 2000L,
nextVerificationAttempt = 3000L,
allowedToRequestCode = true,
requestedInformation = listOf("pushChallenge"),
verified = false
)
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with post-registration routes`() {
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with PinEntryForRegistrationLock`() {
val creds = NetworkController.SvrCredentials(username = "user", password = "pass")
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinEntryForRegistrationLock(timeRemaining = 86400000L, svrCredentials = creds)
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with Captcha route`() {
val session = NetworkController.SessionMetadata(
id = "session-456",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = false,
requestedInformation = listOf("captcha"),
verified = false
)
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.Captcha(session = session)
),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `deserialization ignores unknown keys for forward compatibility`() {
val validJson = """{"backStack":[{"type":"org.signal.registration.RegistrationRoute.Welcome"}],"sessionMetadata":null,"sessionE164":null,"doNotAttemptRecoveryPassword":false,"unknownField":"value"}"""
val decoded = json.decodeFromString(PersistedFlowState.serializer(), validJson)
assertThat(decoded.backStack).isEqualTo(listOf(RegistrationRoute.Welcome))
assertThat(decoded.sessionMetadata).isNull()
}
@Test
fun `toPersistedFlowState captures correct fields`() {
val session = NetworkController.SessionMetadata(
id = "session-789",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = true
)
val flowState = RegistrationFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate),
sessionMetadata = session,
sessionE164 = "+15551234567",
accountEntropyPool = AccountEntropyPool.generate(),
temporaryMasterKey = MasterKey(ByteArray(32)),
doNotAttemptRecoveryPassword = true
)
val persisted = flowState.toPersistedFlowState()
assertThat(persisted.backStack).isEqualTo(flowState.backStack)
assertThat(persisted.sessionMetadata).isEqualTo(session)
assertThat(persisted.sessionE164).isEqualTo("+15551234567")
assertThat(persisted.doNotAttemptRecoveryPassword).isEqualTo(true)
}
@Test
fun `toRegistrationFlowState reconstructs all fields`() {
val session = NetworkController.SessionMetadata(
id = "session-101",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = true
)
val persisted = PersistedFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
val aep = AccountEntropyPool.generate()
val masterKey = MasterKey(ByteArray(32))
val flowState = persisted.toRegistrationFlowState(
accountEntropyPool = aep,
temporaryMasterKey = masterKey,
preExistingRegistrationData = null
)
assertThat(flowState.backStack).isEqualTo(persisted.backStack)
assertThat(flowState.sessionMetadata).isEqualTo(session)
assertThat(flowState.sessionE164).isEqualTo("+15551234567")
assertThat(flowState.accountEntropyPool).isEqualTo(aep)
assertThat(flowState.temporaryMasterKey).isEqualTo(masterKey)
assertThat(flowState.preExistingRegistrationData).isNull()
assertThat(flowState.doNotAttemptRecoveryPassword).isEqualTo(true)
}
}

View File

@@ -6,6 +6,7 @@
package org.signal.registration package org.signal.registration
import android.app.Application import android.app.Application
import android.os.Looper
import androidx.compose.ui.test.assertIsDisplayed import androidx.compose.ui.test.assertIsDisplayed
import androidx.compose.ui.test.junit4.createComposeRule import androidx.compose.ui.test.junit4.createComposeRule
import androidx.compose.ui.test.onNodeWithTag import androidx.compose.ui.test.onNodeWithTag
@@ -13,12 +14,14 @@ import androidx.compose.ui.test.performClick
import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.SavedStateHandle
import androidx.test.core.app.ApplicationProvider import androidx.test.core.app.ApplicationProvider
import com.google.accompanist.permissions.ExperimentalPermissionsApi import com.google.accompanist.permissions.ExperimentalPermissionsApi
import io.mockk.coEvery
import io.mockk.mockk import io.mockk.mockk
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner import org.robolectric.RobolectricTestRunner
import org.robolectric.Shadows
import org.robolectric.annotation.Config import org.robolectric.annotation.Config
import org.signal.core.ui.CoreUiDependenciesRule import org.signal.core.ui.CoreUiDependenciesRule
import org.signal.core.ui.compose.theme.SignalTheme import org.signal.core.ui.compose.theme.SignalTheme
@@ -47,7 +50,11 @@ class RegistrationNavigationTest {
@Before @Before
fun setup() { fun setup() {
mockRepository = mockk<RegistrationRepository>(relaxed = true) mockRepository = mockk<RegistrationRepository>(relaxed = true)
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
viewModel = RegistrationViewModel(mockRepository, SavedStateHandle()) viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
// Allow the init coroutine to complete so isRestoring becomes false.
Shadows.shadowOf(Looper.getMainLooper()).idle()
} }
@Test @Test
@@ -55,6 +62,11 @@ class RegistrationNavigationTest {
// Given // Given
val permissionsState = createMockPermissionsState() val permissionsState = createMockPermissionsState()
// Verify the ViewModel state is correctly initialized
val state = viewModel.state.value
assert(!state.isRestoringNavigationState) { "isRestoring should be false after init, was: ${state.isRestoringNavigationState}" }
assert(state.backStack == listOf(RegistrationRoute.Welcome)) { "backStack should be [Welcome], was: ${state.backStack}" }
composeTestRule.setContent { composeTestRule.setContent {
SignalTheme(incognitoKeyboardEnabled = false) { SignalTheme(incognitoKeyboardEnabled = false) {
RegistrationNavHost( RegistrationNavHost(
@@ -86,6 +98,7 @@ class RegistrationNavigationTest {
// When // When
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Permissions screen is displayed // Then - verify Permissions screen is displayed
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_SCREEN).assertIsDisplayed() composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_SCREEN).assertIsDisplayed()
@@ -108,9 +121,11 @@ class RegistrationNavigationTest {
// Navigate to Permissions screen first // Navigate to Permissions screen first
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// When // When
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NEXT_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NEXT_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify PhoneNumber screen is displayed // Then - verify PhoneNumber screen is displayed
composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed() composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed()
@@ -133,9 +148,11 @@ class RegistrationNavigationTest {
// Navigate to Permissions screen first // Navigate to Permissions screen first
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// When // When
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NOT_NOW_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NOT_NOW_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify PhoneNumber screen is displayed // Then - verify PhoneNumber screen is displayed
composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed() composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed()
@@ -164,6 +181,7 @@ class RegistrationNavigationTest {
// When // When
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick()
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_HAS_OLD_PHONE_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_HAS_OLD_PHONE_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Permissions screen is displayed // Then - verify Permissions screen is displayed
// (After permissions, user would go to RestoreViaQr screen) // (After permissions, user would go to RestoreViaQr screen)
@@ -188,6 +206,7 @@ class RegistrationNavigationTest {
// When // When
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick()
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_NO_OLD_PHONE_BUTTON).performClick() composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_NO_OLD_PHONE_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Restore screen is displayed (or its expected content) // Then - verify Restore screen is displayed (or its expected content)
// Note: Update this assertion based on actual Restore screen content when implemented // Note: Update this assertion based on actual Restore screen content when implemented

View File

@@ -0,0 +1,251 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import androidx.lifecycle.SavedStateHandle
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.mockk
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
import org.junit.After
import org.junit.Before
import org.junit.Test
@OptIn(ExperimentalCoroutinesApi::class)
class RegistrationViewModelRestoreTest {
private val testDispatcher = StandardTestDispatcher()
private lateinit var mockRepository: RegistrationRepository
@Before
fun setup() {
Dispatchers.setMain(testDispatcher)
mockRepository = mockk(relaxed = true)
}
@After
fun tearDown() {
Dispatchers.resetMain()
}
@Test
fun `no saved state starts fresh and loads preExistingRegistrationData`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome))
assertThat(state.sessionMetadata).isNull()
}
@Test
fun `restore with valid session proceeds normally with updated session metadata`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-1")
val freshSession = createSessionMetadata("session-1").copy(nextSms = 9999L)
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = savedSession,
sessionE164 = "+15551234567"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-1") } returns freshSession
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(savedState.backStack)
assertThat(state.sessionMetadata).isEqualTo(freshSession)
assertThat(state.sessionE164).isEqualTo("+15551234567")
}
@Test
fun `restore with expired session and not registered resets to PhoneNumberEntry with e164 preserved`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-expired")
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = savedSession,
sessionE164 = "+15559876543"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-expired") } returns null
coEvery { mockRepository.isRegistered() } returns false
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(
listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
)
)
assertThat(state.sessionMetadata).isNull()
assertThat(state.sessionE164).isEqualTo("+15559876543")
}
@Test
fun `restore with expired session and already registered proceeds with null session`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-expired-2")
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate
),
sessionMetadata = savedSession,
sessionE164 = "+15551234567"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-expired-2") } returns null
coEvery { mockRepository.isRegistered() } returns true
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate))
assertThat(state.sessionMetadata).isNull()
assertThat(state.sessionE164).isEqualTo("+15551234567")
}
@Test
fun `restore with no session skips validation`() = runTest(testDispatcher) {
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
coEvery { mockRepository.restoreFlowState() } returns savedState
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(savedState.backStack)
assertThat(state.sessionMetadata).isNull()
assertThat(state.doNotAttemptRecoveryPassword).isEqualTo(true)
coVerify(exactly = 0) { mockRepository.validateSession(any()) }
}
@Test
fun `onEvent ResetState clears flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.ResetState)
advanceUntilIdle()
coVerify { mockRepository.clearFlowState() }
}
@Test
fun `onEvent NavigateToScreen FullyComplete clears flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.FullyComplete))
advanceUntilIdle()
coVerify { mockRepository.clearFlowState() }
}
@Test
fun `onEvent NavigateToScreen saves flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.PhoneNumberEntry))
advanceUntilIdle()
coVerify { mockRepository.saveFlowState(any()) }
}
@Test
fun `onEvent Registered does not save flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.Registered(org.signal.core.models.AccountEntropyPool.generate()))
advanceUntilIdle()
coVerify(exactly = 0) { mockRepository.saveFlowState(any()) }
}
@Test
fun `onEvent MasterKeyRestoredFromSvr does not save flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.MasterKeyRestoredFromSvr(org.signal.core.models.MasterKey(ByteArray(32))))
advanceUntilIdle()
coVerify(exactly = 0) { mockRepository.saveFlowState(any()) }
}
private fun createSessionMetadata(id: String = "test-session"): NetworkController.SessionMetadata {
return NetworkController.SessionMetadata(
id = id,
nextSms = 1000L,
nextCall = 2000L,
nextVerificationAttempt = 3000L,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = false
)
}
}

View File

@@ -292,8 +292,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
assertThat(emittedStates.last().sessionMetadata).isNotNull() assertThat(emittedStates.last().sessionMetadata).isNotNull()
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -426,8 +428,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Should not create a new session, just request verification code // Should not create a new session, just request verification code
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -563,8 +567,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation to verification code entry // Verify navigation to verification code entry
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -595,8 +601,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite no push challenge token // Verify navigation continues despite no push challenge token
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -631,8 +639,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite push challenge submission failure // Verify navigation continues despite push challenge submission failure
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -662,8 +672,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite network error // Verify navigation continues despite network error
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -693,8 +705,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite application error // Verify navigation continues despite application error
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -744,8 +758,10 @@ class PhoneNumberEntryViewModelTest {
viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.CaptchaCompleted("captcha-token"), stateEmitter, parentEventEmitter) viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.CaptchaCompleted("captcha-token"), stateEmitter, parentEventEmitter)
assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents.first()) assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>() .isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route) .prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>() .isInstanceOf<RegistrationRoute.VerificationCodeEntry>()