Make regV5 resumable if the app closes.

This commit is contained in:
Greyson Parrelli
2026-03-18 16:53:50 -04:00
committed by Michelle Tang
parent c7ec3ab837
commit f09bf5b14c
15 changed files with 895 additions and 41 deletions

View File

@@ -0,0 +1,60 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import kotlinx.serialization.Serializable
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
/**
* A serializable snapshot of [RegistrationFlowState] fields that need to survive app kills.
*
* Fields like [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey]
* are reconstructed from dedicated proto fields, not from this JSON snapshot.
* [RegistrationFlowState.preExistingRegistrationData] is loaded from permanent storage.
*/
@Serializable
data class PersistedFlowState(
val backStack: List<RegistrationRoute>,
val sessionMetadata: NetworkController.SessionMetadata?,
val sessionE164: String?,
val doNotAttemptRecoveryPassword: Boolean
)
/**
* Extracts the persistable fields from a [RegistrationFlowState].
*/
fun RegistrationFlowState.toPersistedFlowState(): PersistedFlowState {
return PersistedFlowState(
backStack = backStack,
sessionMetadata = sessionMetadata,
sessionE164 = sessionE164,
doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword
)
}
/**
* Reconstructs a full [RegistrationFlowState] from persisted data and separately-stored fields.
*
* @param accountEntropyPool Restored from the proto's dedicated `accountEntropyPool` field.
* @param temporaryMasterKey Restored from the proto's dedicated `temporaryMasterKey` field.
* @param preExistingRegistrationData Loaded from permanent storage via [StorageController.getPreExistingRegistrationData].
*/
fun PersistedFlowState.toRegistrationFlowState(
accountEntropyPool: AccountEntropyPool?,
temporaryMasterKey: MasterKey?,
preExistingRegistrationData: PreExistingRegistrationData?
): RegistrationFlowState {
return RegistrationFlowState(
backStack = backStack,
sessionMetadata = sessionMetadata,
sessionE164 = sessionE164,
accountEntropyPool = accountEntropyPool,
temporaryMasterKey = temporaryMasterKey,
preExistingRegistrationData = preExistingRegistrationData,
doNotAttemptRecoveryPassword = doNotAttemptRecoveryPassword
)
}

View File

@@ -37,5 +37,8 @@ data class RegistrationFlowState(
val preExistingRegistrationData: PreExistingRegistrationData? = null,
/** If true, do not attempt any flows where we generate RRP's. Create a session instead. */
val doNotAttemptRecoveryPassword: Boolean = false
val doNotAttemptRecoveryPassword: Boolean = false,
/** If true, the ViewModel is still deciding whether to restore a previous flow or start fresh. */
val isRestoringNavigationState: Boolean = true
) : Parcelable, DebugLoggable

View File

@@ -8,9 +8,13 @@
package org.signal.registration
import android.os.Parcelable
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewmodel.compose.viewModel
@@ -62,6 +66,7 @@ import org.signal.registration.screens.welcome.WelcomeScreenEvents
* Navigation routes for the registration flow.
* Using @Serializable and NavKey for type-safe navigation with Navigation 3.
*/
@Serializable
@Parcelize
sealed interface RegistrationRoute : NavKey, Parcelable {
@Serializable
@@ -77,7 +82,7 @@ sealed interface RegistrationRoute : NavKey, Parcelable {
data object CountryCodePicker : RegistrationRoute
@Serializable
data class VerificationCodeEntry(val session: NetworkController.SessionMetadata, val e164: String) : RegistrationRoute
data object VerificationCodeEntry : RegistrationRoute
@Serializable
data class Captcha(val session: NetworkController.SessionMetadata) : RegistrationRoute
@@ -150,6 +155,13 @@ fun RegistrationNavHost(
val registrationState by viewModel.state.collectAsStateWithLifecycle()
val permissions: MultiplePermissionsState = permissionsState ?: rememberMultiplePermissionsState(viewModel.getRequiredPermissions())
if (registrationState.isRestoringNavigationState) {
Box(modifier = modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
CircularProgressIndicator()
}
return
}
val entryProvider = entryProvider {
navigationEntries(
registrationRepository = registrationRepository,

View File

@@ -10,6 +10,7 @@ import android.content.Context
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import okio.ByteString.Companion.toByteString
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
@@ -410,6 +411,81 @@ class RegistrationRepository(val context: Context, val networkController: Networ
return storageController.getPreExistingRegistrationData()
}
/**
* Persists the current flow state as JSON in the in-progress registration data proto.
*/
suspend fun saveFlowState(state: RegistrationFlowState) = withContext(Dispatchers.IO) {
try {
val json = flowStateJson.encodeToString(PersistedFlowState.serializer(), state.toPersistedFlowState())
storageController.updateInProgressRegistrationData {
flowStateJson = json
}
} catch (e: Exception) {
Log.w(TAG, "Failed to save flow state", e)
}
}
/**
* Restores the flow state from disk. Returns null if no state is saved or deserialization fails.
* Reconstructs [RegistrationFlowState.accountEntropyPool] and [RegistrationFlowState.temporaryMasterKey]
* from their dedicated proto fields, and loads [RegistrationFlowState.preExistingRegistrationData]
* from permanent storage.
*/
suspend fun restoreFlowState(): RegistrationFlowState? = withContext(Dispatchers.IO) {
try {
val data = storageController.readInProgressRegistrationData()
if (data.flowStateJson.isEmpty()) return@withContext null
val persisted = flowStateJson.decodeFromString(PersistedFlowState.serializer(), data.flowStateJson)
val aep = data.accountEntropyPool.takeIf { it.isNotEmpty() }?.let { AccountEntropyPool(it) }
val masterKey = data.temporaryMasterKey.takeIf { it.size > 0 }?.let { MasterKey(it.toByteArray()) }
val preExisting = storageController.getPreExistingRegistrationData()
persisted.toRegistrationFlowState(
accountEntropyPool = aep,
temporaryMasterKey = masterKey,
preExistingRegistrationData = preExisting
)
} catch (e: Exception) {
Log.w(TAG, "Failed to restore flow state", e)
null
}
}
/**
* Clears any persisted flow state JSON from the in-progress registration data.
*/
suspend fun clearFlowState() = withContext(Dispatchers.IO) {
try {
storageController.updateInProgressRegistrationData {
flowStateJson = ""
}
} catch (e: Exception) {
Log.w(TAG, "Failed to clear flow state", e)
}
}
/**
* Validates a registration session by fetching its current status from the server.
* Returns fresh [SessionMetadata] on success, or null if the session is expired/invalid.
*/
suspend fun validateSession(sessionId: String): SessionMetadata? = withContext(Dispatchers.IO) {
when (val result = networkController.getSession(sessionId)) {
is RegistrationNetworkResult.Success -> result.data
else -> null
}
}
/**
* Checks whether the in-progress registration data indicates a completed registration
* (i.e. both ACI and PNI have been saved).
*/
suspend fun isRegistered(): Boolean = withContext(Dispatchers.IO) {
val data = storageController.readInProgressRegistrationData()
data.aci.isNotEmpty() && data.pni.isNotEmpty()
}
private fun generateKeyMaterial(
existingAccountEntropyPool: AccountEntropyPool? = null,
existingAciIdentityKeyPair: IdentityKeyPair? = null,
@@ -488,5 +564,6 @@ class RegistrationRepository(val context: Context, val networkController: Networ
companion object {
private val TAG = Log.tag(RegistrationRepository::class)
private val flowStateJson = Json { ignoreUnknownKeys = true }
}
}

View File

@@ -13,6 +13,7 @@ import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.createSavedStateHandle
import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.CreationExtras
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
@@ -37,9 +38,17 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
val resultBus = ResultEventBus()
init {
_state.value = _state.value.copy(isRestoringNavigationState = true)
viewModelScope.launch {
repository.getPreExistingRegistrationData()?.let {
_state.value = _state.value.copy(preExistingRegistrationData = it)
val restored = repository.restoreFlowState()
if (restored != null) {
Log.i(TAG, "[init] Restored flow state from disk. Backstack size: ${restored.backStack.size}, hasSession: ${restored.sessionMetadata != null}")
_state.value = validateRestoredState(restored).copy(isRestoringNavigationState = false)
} else {
_state.value = _state.value.copy(
preExistingRegistrationData = repository.getPreExistingRegistrationData(),
isRestoringNavigationState = false
)
}
}
}
@@ -47,11 +56,15 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
fun onEvent(event: RegistrationFlowEvent) {
Log.d(TAG, "[Event] $event")
_state.value = applyEvent(_state.value, event)
viewModelScope.launch(Dispatchers.IO) {
persistFlowState(event)
}
}
fun applyEvent(state: RegistrationFlowState, event: RegistrationFlowEvent): RegistrationFlowState {
return when (event) {
is RegistrationFlowEvent.ResetState -> RegistrationFlowState()
is RegistrationFlowEvent.ResetState -> RegistrationFlowState(isRestoringNavigationState = false)
is RegistrationFlowEvent.SessionUpdated -> state.copy(sessionMetadata = event.session)
is RegistrationFlowEvent.E164Chosen -> state.copy(sessionE164 = event.e164)
is RegistrationFlowEvent.Registered -> state.copy(accountEntropyPool = event.accountEntropyPool)
@@ -63,14 +76,43 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
}
private fun applyNavigationToScreenEvent(inputState: RegistrationFlowState, event: RegistrationFlowEvent.NavigateToScreen): RegistrationFlowState {
val state = inputState.copy(backStack = inputState.backStack + event.route)
return inputState.copy(backStack = inputState.backStack + event.route)
}
return when (event.route) {
is RegistrationRoute.VerificationCodeEntry -> {
state.copy(sessionMetadata = event.route.session, sessionE164 = event.route.e164)
}
else -> state
/**
* Validates a restored flow state by checking if the session is still valid.
*
* - If the session is still valid, updates session metadata with fresh data.
* - If the session is expired and the user is already registered, nulls out the session
* (post-registration screens like PinCreate don't need a session).
* - If the session is expired and the user is NOT registered, resets the backstack to
* PhoneNumberEntry with the phone number pre-filled so the user can re-submit.
*/
private suspend fun validateRestoredState(state: RegistrationFlowState): RegistrationFlowState {
val sessionMetadata = state.sessionMetadata ?: return state
val freshSession = repository.validateSession(sessionMetadata.id)
if (freshSession != null) {
Log.i(TAG, "[validateRestoredState] Session still valid.")
return state.copy(sessionMetadata = freshSession)
}
Log.i(TAG, "[validateRestoredState] Session expired/invalid.")
if (repository.isRegistered()) {
Log.i(TAG, "[validateRestoredState] User is registered, proceeding without session.")
return state.copy(sessionMetadata = null)
}
Log.i(TAG, "[validateRestoredState] User is NOT registered, resetting to PhoneNumberEntry.")
return state.copy(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
),
sessionMetadata = null
)
}
/**
@@ -101,6 +143,27 @@ class RegistrationViewModel(private val repository: RegistrationRepository, save
}
}
private suspend fun persistFlowState(event: RegistrationFlowEvent) {
when (event) {
is RegistrationFlowEvent.ResetState -> repository.clearFlowState()
is RegistrationFlowEvent.NavigateToScreen -> {
if (event.route is RegistrationRoute.FullyComplete) {
repository.clearFlowState()
} else {
repository.saveFlowState(_state.value)
}
}
is RegistrationFlowEvent.NavigateBack,
is RegistrationFlowEvent.SessionUpdated,
is RegistrationFlowEvent.E164Chosen,
is RegistrationFlowEvent.RecoveryPasswordInvalid -> repository.saveFlowState(_state.value)
// No need to persist anything new, fields accounted for in proto already
is RegistrationFlowEvent.Registered,
is RegistrationFlowEvent.MasterKeyRestoredFromSvr -> { }
}
}
class Factory(private val repository: RegistrationRepository) : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: KClass<T>, extras: CreationExtras): T {
return RegistrationViewModel(repository, extras.createSavedStateHandle()) as T

View File

@@ -276,14 +276,17 @@ class PhoneNumberEntryViewModel(
is NetworkController.RegistrationNetworkResult.Failure<NetworkController.CreateSessionError> -> {
return when (response.error) {
is NetworkController.CreateSessionError.InvalidRequest -> {
Log.w(TAG, "[CreateSession] Invalid request when creating session. Message: ${response.error.message}")
state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
}
is NetworkController.CreateSessionError.RateLimited -> {
Log.w(TAG, "[CreateSession] Rate limited (retryAfter: ${response.error.retryAfter}).")
state.copy(oneTimeEvent = OneTimeEvent.RateLimited(response.error.retryAfter))
}
}
}
is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "[CreateSession] Network error.", response.exception)
return state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
}
is NetworkController.RegistrationNetworkResult.ApplicationError -> {
@@ -304,23 +307,26 @@ class PhoneNumberEntryViewModel(
Log.d(TAG, "Received push challenge token, submitting...")
val updateResult = repository.submitPushChallengeToken(sessionMetadata.id, pushChallengeToken)
sessionMetadata = when (updateResult) {
is NetworkController.RegistrationNetworkResult.Success -> updateResult.data
is NetworkController.RegistrationNetworkResult.Success -> {
Log.d(TAG, "[SubmitPushChallengeToken] Successfully submitted push challenge token.")
updateResult.data
}
is NetworkController.RegistrationNetworkResult.Failure -> {
Log.w(TAG, "Failed to submit push challenge token: ${updateResult.error}")
Log.w(TAG, "[SubmitPushChallengeToken] Failed to submit push challenge token: ${updateResult.error}")
sessionMetadata
}
is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "Network error submitting push challenge token", updateResult.exception)
Log.w(TAG, "[SubmitPushChallengeToken] Network error submitting push challenge token", updateResult.exception)
sessionMetadata
}
is NetworkController.RegistrationNetworkResult.ApplicationError -> {
Log.w(TAG, "Application error submitting push challenge token", updateResult.exception)
Log.w(TAG, "[SubmitPushChallengeToken] Application error submitting push challenge token", updateResult.exception)
sessionMetadata
}
}
state = state.copy(sessionMetadata = sessionMetadata)
} else {
Log.d(TAG, "Push challenge token not received within timeout")
Log.d(TAG, "[SubmitPushChallengeToken] Push challenge token not received within timeout")
}
}
@@ -337,41 +343,49 @@ class PhoneNumberEntryViewModel(
sessionMetadata = when (verificationCodeResponse) {
is NetworkController.RegistrationNetworkResult.Success<NetworkController.SessionMetadata> -> {
Log.d(TAG, "[RequestVerificationCode] Successfully requested verification code.")
verificationCodeResponse.data
}
is NetworkController.RegistrationNetworkResult.Failure<NetworkController.RequestVerificationCodeError> -> {
return when (verificationCodeResponse.error) {
is NetworkController.RequestVerificationCodeError.InvalidRequest -> {
Log.w(TAG, "[RequestVerificationCode] Invalid request when requesting verification code. Message: ${verificationCodeResponse.error.message}")
state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
}
is NetworkController.RequestVerificationCodeError.RateLimited -> {
Log.w(TAG, "[RequestVerificationCode] Rate limited (retryAfter: ${verificationCodeResponse.error.retryAfter}).")
state.copy(oneTimeEvent = OneTimeEvent.RateLimited(verificationCodeResponse.error.retryAfter))
}
is NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport -> {
Log.w(TAG, "[RequestVerificationCode] Could not fulfill with requested transport.")
state.copy(oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport)
}
is NetworkController.RequestVerificationCodeError.InvalidSessionId -> {
Log.w(TAG, "[RequestVerificationCode] Invalid session ID when requesting verification code.")
parentEventEmitter(RegistrationFlowEvent.ResetState)
state
}
is NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified -> {
Log.w(TAG, "When requesting verification code, missing request information or already verified.")
Log.w(TAG, "[RequestVerificationCode] Missing request information or already verified.")
state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
}
is NetworkController.RequestVerificationCodeError.SessionNotFound -> {
Log.w(TAG, "[RequestVerificationCode] Session not found when requesting verification code.")
parentEventEmitter(RegistrationFlowEvent.ResetState)
state
}
is NetworkController.RequestVerificationCodeError.ThirdPartyServiceError -> {
Log.w(TAG, "[RequestVerificationCode] Third party service error.")
state.copy(oneTimeEvent = OneTimeEvent.ThirdPartyError)
}
}
}
is NetworkController.RegistrationNetworkResult.NetworkError -> {
Log.w(TAG, "[RequestVerificationCode] Network error.", verificationCodeResponse.exception)
return state.copy(oneTimeEvent = OneTimeEvent.NetworkError)
}
is NetworkController.RegistrationNetworkResult.ApplicationError -> {
Log.w(TAG, "Unknown error when creating session.", verificationCodeResponse.exception)
Log.w(TAG, "[RequestVerificationCode] Unknown error when creating session.", verificationCodeResponse.exception)
return state.copy(oneTimeEvent = OneTimeEvent.UnknownError)
}
}
@@ -383,7 +397,9 @@ class PhoneNumberEntryViewModel(
return state
}
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164))
parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata))
parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164))
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry)
return state
}
@@ -470,9 +486,9 @@ class PhoneNumberEntryViewModel(
}
}
val e164 = "+${inputState.countryCode}${inputState.nationalNumber}"
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry(sessionMetadata, e164))
parentEventEmitter(RegistrationFlowEvent.SessionUpdated(sessionMetadata))
parentEventEmitter(RegistrationFlowEvent.E164Chosen("+${inputState.countryCode}${inputState.nationalNumber}"))
parentEventEmitter.navigateTo(RegistrationRoute.VerificationCodeEntry)
return state
}

View File

@@ -39,6 +39,9 @@ message RegistrationData {
// Provisioning data (from saveProvisioningData)
ProvisioningData provisioningData = 20;
// JSON-serialized flow state snapshot (from saveFlowState/restoreFlowState)
string flowStateJson = 21;
}
message SvrCredential {

View File

@@ -0,0 +1,225 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import kotlinx.serialization.json.Json
import org.junit.Test
import org.signal.core.models.AccountEntropyPool
import org.signal.core.models.MasterKey
class PersistedFlowStateTest {
private val json = Json { ignoreUnknownKeys = true }
@Test
fun `round-trip serialization with simple backstack`() {
val state = PersistedFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PhoneNumberEntry),
sessionMetadata = null,
sessionE164 = null,
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with nested Permissions route`() {
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with VerificationCodeEntry`() {
val session = NetworkController.SessionMetadata(
id = "session-123",
nextSms = 1000L,
nextCall = 2000L,
nextVerificationAttempt = 3000L,
allowedToRequestCode = true,
requestedInformation = listOf("pushChallenge"),
verified = false
)
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with post-registration routes`() {
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with PinEntryForRegistrationLock`() {
val creds = NetworkController.SvrCredentials(username = "user", password = "pass")
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinEntryForRegistrationLock(timeRemaining = 86400000L, svrCredentials = creds)
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `round-trip serialization with Captcha route`() {
val session = NetworkController.SessionMetadata(
id = "session-456",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = false,
requestedInformation = listOf("captcha"),
verified = false
)
val state = PersistedFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.Captcha(session = session)
),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = false
)
val encoded = json.encodeToString(PersistedFlowState.serializer(), state)
val decoded = json.decodeFromString(PersistedFlowState.serializer(), encoded)
assertThat(decoded).isEqualTo(state)
}
@Test
fun `deserialization ignores unknown keys for forward compatibility`() {
val validJson = """{"backStack":[{"type":"org.signal.registration.RegistrationRoute.Welcome"}],"sessionMetadata":null,"sessionE164":null,"doNotAttemptRecoveryPassword":false,"unknownField":"value"}"""
val decoded = json.decodeFromString(PersistedFlowState.serializer(), validJson)
assertThat(decoded.backStack).isEqualTo(listOf(RegistrationRoute.Welcome))
assertThat(decoded.sessionMetadata).isNull()
}
@Test
fun `toPersistedFlowState captures correct fields`() {
val session = NetworkController.SessionMetadata(
id = "session-789",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = true
)
val flowState = RegistrationFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate),
sessionMetadata = session,
sessionE164 = "+15551234567",
accountEntropyPool = AccountEntropyPool.generate(),
temporaryMasterKey = MasterKey(ByteArray(32)),
doNotAttemptRecoveryPassword = true
)
val persisted = flowState.toPersistedFlowState()
assertThat(persisted.backStack).isEqualTo(flowState.backStack)
assertThat(persisted.sessionMetadata).isEqualTo(session)
assertThat(persisted.sessionE164).isEqualTo("+15551234567")
assertThat(persisted.doNotAttemptRecoveryPassword).isEqualTo(true)
}
@Test
fun `toRegistrationFlowState reconstructs all fields`() {
val session = NetworkController.SessionMetadata(
id = "session-101",
nextSms = null,
nextCall = null,
nextVerificationAttempt = null,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = true
)
val persisted = PersistedFlowState(
backStack = listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate),
sessionMetadata = session,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
val aep = AccountEntropyPool.generate()
val masterKey = MasterKey(ByteArray(32))
val flowState = persisted.toRegistrationFlowState(
accountEntropyPool = aep,
temporaryMasterKey = masterKey,
preExistingRegistrationData = null
)
assertThat(flowState.backStack).isEqualTo(persisted.backStack)
assertThat(flowState.sessionMetadata).isEqualTo(session)
assertThat(flowState.sessionE164).isEqualTo("+15551234567")
assertThat(flowState.accountEntropyPool).isEqualTo(aep)
assertThat(flowState.temporaryMasterKey).isEqualTo(masterKey)
assertThat(flowState.preExistingRegistrationData).isNull()
assertThat(flowState.doNotAttemptRecoveryPassword).isEqualTo(true)
}
}

View File

@@ -6,6 +6,7 @@
package org.signal.registration
import android.app.Application
import android.os.Looper
import androidx.compose.ui.test.assertIsDisplayed
import androidx.compose.ui.test.junit4.createComposeRule
import androidx.compose.ui.test.onNodeWithTag
@@ -13,12 +14,14 @@ import androidx.compose.ui.test.performClick
import androidx.lifecycle.SavedStateHandle
import androidx.test.core.app.ApplicationProvider
import com.google.accompanist.permissions.ExperimentalPermissionsApi
import io.mockk.coEvery
import io.mockk.mockk
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.Shadows
import org.robolectric.annotation.Config
import org.signal.core.ui.CoreUiDependenciesRule
import org.signal.core.ui.compose.theme.SignalTheme
@@ -47,7 +50,11 @@ class RegistrationNavigationTest {
@Before
fun setup() {
mockRepository = mockk<RegistrationRepository>(relaxed = true)
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
// Allow the init coroutine to complete so isRestoring becomes false.
Shadows.shadowOf(Looper.getMainLooper()).idle()
}
@Test
@@ -55,6 +62,11 @@ class RegistrationNavigationTest {
// Given
val permissionsState = createMockPermissionsState()
// Verify the ViewModel state is correctly initialized
val state = viewModel.state.value
assert(!state.isRestoringNavigationState) { "isRestoring should be false after init, was: ${state.isRestoringNavigationState}" }
assert(state.backStack == listOf(RegistrationRoute.Welcome)) { "backStack should be [Welcome], was: ${state.backStack}" }
composeTestRule.setContent {
SignalTheme(incognitoKeyboardEnabled = false) {
RegistrationNavHost(
@@ -86,6 +98,7 @@ class RegistrationNavigationTest {
// When
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Permissions screen is displayed
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_SCREEN).assertIsDisplayed()
@@ -108,9 +121,11 @@ class RegistrationNavigationTest {
// Navigate to Permissions screen first
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// When
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NEXT_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify PhoneNumber screen is displayed
composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed()
@@ -133,9 +148,11 @@ class RegistrationNavigationTest {
// Navigate to Permissions screen first
composeTestRule.onNodeWithTag(TestTags.WELCOME_GET_STARTED_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// When
composeTestRule.onNodeWithTag(TestTags.PERMISSIONS_NOT_NOW_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify PhoneNumber screen is displayed
composeTestRule.onNodeWithTag(TestTags.PHONE_NUMBER_SCREEN).assertIsDisplayed()
@@ -164,6 +181,7 @@ class RegistrationNavigationTest {
// When
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick()
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_HAS_OLD_PHONE_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Permissions screen is displayed
// (After permissions, user would go to RestoreViaQr screen)
@@ -188,6 +206,7 @@ class RegistrationNavigationTest {
// When
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_OR_TRANSFER_BUTTON).performClick()
composeTestRule.onNodeWithTag(TestTags.WELCOME_RESTORE_NO_OLD_PHONE_BUTTON).performClick()
Shadows.shadowOf(Looper.getMainLooper()).idle()
// Then - verify Restore screen is displayed (or its expected content)
// Note: Update this assertion based on actual Restore screen content when implemented

View File

@@ -0,0 +1,251 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.registration
import androidx.lifecycle.SavedStateHandle
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.mockk
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.test.setMain
import org.junit.After
import org.junit.Before
import org.junit.Test
@OptIn(ExperimentalCoroutinesApi::class)
class RegistrationViewModelRestoreTest {
private val testDispatcher = StandardTestDispatcher()
private lateinit var mockRepository: RegistrationRepository
@Before
fun setup() {
Dispatchers.setMain(testDispatcher)
mockRepository = mockk(relaxed = true)
}
@After
fun tearDown() {
Dispatchers.resetMain()
}
@Test
fun `no saved state starts fresh and loads preExistingRegistrationData`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome))
assertThat(state.sessionMetadata).isNull()
}
@Test
fun `restore with valid session proceeds normally with updated session metadata`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-1")
val freshSession = createSessionMetadata("session-1").copy(nextSms = 9999L)
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = savedSession,
sessionE164 = "+15551234567"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-1") } returns freshSession
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(savedState.backStack)
assertThat(state.sessionMetadata).isEqualTo(freshSession)
assertThat(state.sessionE164).isEqualTo("+15551234567")
}
@Test
fun `restore with expired session and not registered resets to PhoneNumberEntry with e164 preserved`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-expired")
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry,
RegistrationRoute.VerificationCodeEntry
),
sessionMetadata = savedSession,
sessionE164 = "+15559876543"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-expired") } returns null
coEvery { mockRepository.isRegistered() } returns false
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(
listOf(
RegistrationRoute.Welcome,
RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry),
RegistrationRoute.PhoneNumberEntry
)
)
assertThat(state.sessionMetadata).isNull()
assertThat(state.sessionE164).isEqualTo("+15559876543")
}
@Test
fun `restore with expired session and already registered proceeds with null session`() = runTest(testDispatcher) {
val savedSession = createSessionMetadata("session-expired-2")
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate
),
sessionMetadata = savedSession,
sessionE164 = "+15551234567"
)
coEvery { mockRepository.restoreFlowState() } returns savedState
coEvery { mockRepository.validateSession("session-expired-2") } returns null
coEvery { mockRepository.isRegistered() } returns true
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(listOf(RegistrationRoute.Welcome, RegistrationRoute.PinCreate))
assertThat(state.sessionMetadata).isNull()
assertThat(state.sessionE164).isEqualTo("+15551234567")
}
@Test
fun `restore with no session skips validation`() = runTest(testDispatcher) {
val savedState = RegistrationFlowState(
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
),
sessionMetadata = null,
sessionE164 = "+15551234567",
doNotAttemptRecoveryPassword = true
)
coEvery { mockRepository.restoreFlowState() } returns savedState
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
val state = viewModel.state.value
assertThat(state.backStack).isEqualTo(savedState.backStack)
assertThat(state.sessionMetadata).isNull()
assertThat(state.doNotAttemptRecoveryPassword).isEqualTo(true)
coVerify(exactly = 0) { mockRepository.validateSession(any()) }
}
@Test
fun `onEvent ResetState clears flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.ResetState)
advanceUntilIdle()
coVerify { mockRepository.clearFlowState() }
}
@Test
fun `onEvent NavigateToScreen FullyComplete clears flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.FullyComplete))
advanceUntilIdle()
coVerify { mockRepository.clearFlowState() }
}
@Test
fun `onEvent NavigateToScreen saves flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.NavigateToScreen(RegistrationRoute.PhoneNumberEntry))
advanceUntilIdle()
coVerify { mockRepository.saveFlowState(any()) }
}
@Test
fun `onEvent Registered does not save flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.Registered(org.signal.core.models.AccountEntropyPool.generate()))
advanceUntilIdle()
coVerify(exactly = 0) { mockRepository.saveFlowState(any()) }
}
@Test
fun `onEvent MasterKeyRestoredFromSvr does not save flow state`() = runTest(testDispatcher) {
coEvery { mockRepository.restoreFlowState() } returns null
coEvery { mockRepository.getPreExistingRegistrationData() } returns null
val viewModel = RegistrationViewModel(mockRepository, SavedStateHandle())
advanceUntilIdle()
viewModel.onEvent(RegistrationFlowEvent.MasterKeyRestoredFromSvr(org.signal.core.models.MasterKey(ByteArray(32))))
advanceUntilIdle()
coVerify(exactly = 0) { mockRepository.saveFlowState(any()) }
}
private fun createSessionMetadata(id: String = "test-session"): NetworkController.SessionMetadata {
return NetworkController.SessionMetadata(
id = id,
nextSms = 1000L,
nextCall = 2000L,
nextVerificationAttempt = 3000L,
allowedToRequestCode = true,
requestedInformation = emptyList(),
verified = false
)
}
}

View File

@@ -292,8 +292,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
assertThat(emittedStates.last().sessionMetadata).isNotNull()
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -426,8 +428,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Should not create a new session, just request verification code
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -563,8 +567,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation to verification code entry
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -595,8 +601,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite no push challenge token
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -631,8 +639,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite push challenge submission failure
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -662,8 +672,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite network error
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -693,8 +705,10 @@ class PhoneNumberEntryViewModelTest {
assertThat(emittedStates.last().showSpinner).isFalse()
// Verify navigation continues despite application error
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()
@@ -744,8 +758,10 @@ class PhoneNumberEntryViewModelTest {
viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.CaptchaCompleted("captcha-token"), stateEmitter, parentEventEmitter)
assertThat(emittedEvents).hasSize(1)
assertThat(emittedEvents.first())
assertThat(emittedEvents).hasSize(3)
assertThat(emittedEvents[0]).isInstanceOf<RegistrationFlowEvent.SessionUpdated>()
assertThat(emittedEvents[1]).isInstanceOf<RegistrationFlowEvent.E164Chosen>()
assertThat(emittedEvents[2])
.isInstanceOf<RegistrationFlowEvent.NavigateToScreen>()
.prop(RegistrationFlowEvent.NavigateToScreen::route)
.isInstanceOf<RegistrationRoute.VerificationCodeEntry>()