From 64496d1d92cfa1d7318763ba67fffe2b6192cc71 Mon Sep 17 00:00:00 2001 From: Greyson Parrelli Date: Thu, 18 Jun 2026 16:30:02 -0400 Subject: [PATCH] Fix regV5 integration points. --- .../securesms/ApplicationContext.java | 11 ++-- .../securesms/PassphraseRequiredActivity.java | 6 +- .../megaphone/MegaphoneRepository.kt | 3 +- .../v2/AppRegistrationStorageController.kt | 55 ++++++++++++++++--- .../registration/RegistrationActivity.kt | 21 ++++++- .../registration/RegistrationDependencies.kt | 1 + .../registration/RegistrationRepository.kt | 14 +++++ .../DeviceTransferCompleteViewModel.kt | 2 + .../LocalBackupRestoreViewModel.kt | 2 + .../pincreation/PinCreationViewModel.kt | 3 + .../PinEntryForSvrRestoreViewModel.kt | 3 + .../RemoteBackupRestoreViewModel.kt | 2 + .../ArchiveRestoreSelectionViewModel.kt | 2 + .../src/main/protowire/Registration.proto | 14 +++++ .../DeviceTransferCompleteViewModelTest.kt | 2 + .../LocalBackupRestoreViewModelTest.kt | 41 ++++++++++++++ .../pincreation/PinCreationViewModelTest.kt | 3 + .../PinEntryForSvrRestoreViewModelTest.kt | 3 + .../RemoteBackupRestoreViewModelTest.kt | 2 + .../ArchiveRestoreSelectionViewModelTest.kt | 7 ++- 20 files changed, 174 insertions(+), 23 deletions(-) diff --git a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java index 6ac02397c1..89c64ae92d 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java +++ b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java @@ -46,6 +46,7 @@ import org.signal.core.util.tracing.Tracer; import org.signal.glide.SignalGlideCodecs; import org.signal.libsignal.net.ChatServiceException; import org.signal.libsignal.protocol.logging.SignalProtocolLoggerProvider; +import org.signal.registration.RegistrationDependencies; import org.signal.ringrtc.CallManager; import org.thoughtcrime.securesms.apkupdate.ApkUpdateRefreshListener; import org.thoughtcrime.securesms.avatar.AvatarPickerStorage; @@ -102,6 +103,8 @@ import org.thoughtcrime.securesms.providers.BlobProvider; import org.thoughtcrime.securesms.ratelimit.RateLimitUtil; import org.thoughtcrime.securesms.recipients.Recipient; import org.thoughtcrime.securesms.registration.util.RegistrationUtil; +import org.thoughtcrime.securesms.registration.v2.AppRegistrationNetworkController; +import org.thoughtcrime.securesms.registration.v2.AppRegistrationStorageController; import org.thoughtcrime.securesms.ringrtc.RingRtcLogger; import org.thoughtcrime.securesms.service.AnalyzeDatabaseAlarmListener; import org.thoughtcrime.securesms.service.DirectoryRefreshListener; @@ -421,10 +424,10 @@ public class ApplicationContext extends Application implements AppForegroundObse } private void initializeRegistrationDependencies() { - org.signal.registration.RegistrationDependencies.Companion.provide( - new org.signal.registration.RegistrationDependencies( - new org.thoughtcrime.securesms.registration.v2.AppRegistrationNetworkController(this, AppDependencies.getPushServiceSocket()), - new org.thoughtcrime.securesms.registration.v2.AppRegistrationStorageController(this), + RegistrationDependencies.provide( + new RegistrationDependencies( + new AppRegistrationNetworkController(this, AppDependencies.getPushServiceSocket()), + new AppRegistrationStorageController(this), Environment.IS_LINK_AND_SYNC_AVAILABLE, null, context -> { diff --git a/app/src/main/java/org/thoughtcrime/securesms/PassphraseRequiredActivity.java b/app/src/main/java/org/thoughtcrime/securesms/PassphraseRequiredActivity.java index ada95e41ec..c90e4c7357 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/PassphraseRequiredActivity.java +++ b/app/src/main/java/org/thoughtcrime/securesms/PassphraseRequiredActivity.java @@ -135,12 +135,8 @@ public abstract class PassphraseRequiredActivity extends BaseActivity implements Intent intent = getIntentForState(applicationState); if (intent != null) { Log.d(TAG, "routeApplicationState(), intent: " + intent.getComponent()); - if (applicationState == STATE_WELCOME_PUSH_SCREEN && Environment.USE_NEW_REGISTRATION) { - startActivity(intent); - } else { startActivity(intent); finish(); - } } } @@ -227,7 +223,7 @@ public abstract class PassphraseRequiredActivity extends BaseActivity implements private Intent getPushRegistrationIntent() { if (Environment.USE_NEW_REGISTRATION) { - return org.signal.registration.RegistrationActivity.createIntent(this); + return org.signal.registration.RegistrationActivity.createIntent(this, MainActivity.clearTop(this)); } else { return RegistrationActivity.newIntentForNewRegistration(this, getIntent()); } diff --git a/app/src/main/java/org/thoughtcrime/securesms/megaphone/MegaphoneRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/megaphone/MegaphoneRepository.kt index 5a0800d500..eeff237b89 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/megaphone/MegaphoneRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/megaphone/MegaphoneRepository.kt @@ -7,6 +7,7 @@ import org.signal.core.util.concurrent.SignalExecutors import org.signal.core.util.logging.Log import org.thoughtcrime.securesms.database.MegaphoneDatabase import org.thoughtcrime.securesms.database.model.MegaphoneRecord +import org.thoughtcrime.securesms.keyvalue.SignalStore import java.util.concurrent.Executor import kotlin.time.Duration.Companion.days @@ -59,7 +60,7 @@ class MegaphoneRepository(private val context: Application) { @AnyThread fun getNextMegaphone(callback: Callback) { executor.execute { - if (enabled) { + if (enabled && SignalStore.account.isRegistered) { init() val currentTime = System.currentTimeMillis() diff --git a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt index 9d572df105..cfc21f7c77 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt @@ -29,6 +29,7 @@ import org.signal.registration.PreExistingRegistrationData import org.signal.registration.StorageController import org.signal.registration.StoredProfileData import org.signal.registration.proto.RegistrationData +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.localbackuprestore.LocalBackupInfo import org.signal.registration.screens.remotebackuprestore.RemoteBackupRestoreProgress import org.thoughtcrime.securesms.backup.FullBackupImporter @@ -41,7 +42,12 @@ import org.thoughtcrime.securesms.crypto.AttachmentSecretProvider import org.thoughtcrime.securesms.crypto.ProfileKeyUtil import org.thoughtcrime.securesms.database.SignalDatabase import org.thoughtcrime.securesms.database.model.databaseprotos.LocalRegistrationMetadata +import org.thoughtcrime.securesms.database.model.databaseprotos.RestoreDecisionState +import org.thoughtcrime.securesms.keyvalue.Completed +import org.thoughtcrime.securesms.keyvalue.NewAccount import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.keyvalue.Skipped +import org.thoughtcrime.securesms.keyvalue.isDecisionPending import org.thoughtcrime.securesms.pin.SvrRepository import org.thoughtcrime.securesms.profiles.AvatarHelper import org.thoughtcrime.securesms.recipients.Recipient @@ -158,6 +164,16 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo override suspend fun commitRegistrationData() = withContext(Dispatchers.IO) { val data = readInProgressRegistrationData() + // The account's master key is always the one derived from the AEP, which we expect to have by the time we commit. + // Restore it up-front so any master-key-derived state we touch below resolves against the correct value rather + // than lazily generating a new AEP. + val accountEntropyPool: AccountEntropyPool? = data.accountEntropyPool.takeIf { it.isNotEmpty() }?.let { AccountEntropyPool(it) } + if (accountEntropyPool != null) { + SignalStore.account.restoreAccountEntropyPool(accountEntropyPool) + } + + val masterKey: MasterKey? = accountEntropyPool?.deriveMasterKey() + // Build LocalRegistrationMetadata if we have enough data for account setup if (data.e164.isNotEmpty() && data.aci.isNotEmpty() && data.pni.isNotEmpty() && data.servicePassword.isNotEmpty()) { val profileKey = RegistrationRepository.getProfileKey(data.e164) @@ -190,9 +206,7 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo hasPin = data.pin.isNotEmpty() if (data.pin.isNotEmpty()) { pin = data.pin - } - if (data.temporaryMasterKey.size > 0) { - masterKey = data.temporaryMasterKey + masterKey?.let { this.masterKey = it.serialize().toByteString() } } fcmEnabled = SignalStore.account.fcmEnabled fcmToken = SignalStore.account.fcmToken ?: "" @@ -202,15 +216,10 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo // TODO [greyson] Should probably move this stuff into this file as we get closer to being done RegistrationRepository.registerAccountLocally(context, metadata) SignalStore.registration.localRegistrationMetadata = metadata - - if (data.accountEntropyPool.isNotEmpty()) { - SignalStore.account.restoreAccountEntropyPool(AccountEntropyPool(data.accountEntropyPool)) - } } // Handle PIN/master key - if (data.pin.isNotEmpty() && data.temporaryMasterKey.size > 0) { - val masterKey = MasterKey(data.temporaryMasterKey.toByteArray()) + if (data.pin.isNotEmpty() && masterKey != null) { SvrRepository.onRegistrationComplete( masterKey, data.pin, @@ -223,9 +232,37 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo SvrRepository.optOutOfPin(rotateAep = false) } + // The temporaryMasterKey is the one-time key restored from SVR during re-registration. The account's own master key + // is always the AEP-derived one above, so this is retained separately as the initial-restore key (used for the + // first storage service sync + recovery password). It must be set last, as onRegistrationComplete will have cleared + // the initial-restore key after recognizing the AEP-derived master key as our own. + if (data.temporaryMasterKey.size > 0) { + SignalStore.svr.masterKeyForInitialDataRestore = MasterKey(data.temporaryMasterKey.toByteArray()) + } + + applyRestoreDecision(data.restoreDecision) + Unit } + /** + * Translates the registration module's [RestoreDecision] into the app's [RestoreDecisionState] so the rest of the app + * knows whether we're a fresh account, skipped a restore, or successfully restored data. Only applied while the + * decision is still pending, since the state machine is otherwise terminal. + */ + private fun applyRestoreDecision(decision: RestoreDecision) { + if (!SignalStore.registration.restoreDecisionState.isDecisionPending) { + return + } + + SignalStore.registration.restoreDecisionState = when (decision) { + RestoreDecision.NEW_ACCOUNT -> RestoreDecisionState.NewAccount + RestoreDecision.SKIPPED -> RestoreDecisionState.Skipped + RestoreDecision.COMPLETED -> RestoreDecisionState.Completed + RestoreDecision.UNSET -> return + } + } + override fun restoreLocalBackupV1(uri: Uri, passphrase: String): Flow = flow { // TODO [greyson] better progress Log.d(TAG, "Starting V1 local backup restore from: $uri") diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt index cc1047166e..209ff84172 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationActivity.kt @@ -10,6 +10,7 @@ import androidx.activity.result.contract.ActivityResultContract import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.Surface import androidx.compose.ui.Modifier +import androidx.core.content.IntentCompat import com.google.accompanist.permissions.ExperimentalPermissionsApi import org.signal.core.ui.compose.theme.SignalTheme @@ -17,14 +18,27 @@ import org.signal.core.ui.compose.theme.SignalTheme * Activity entry point for the registration flow. * * This activity can be launched from the main app to start the registration process. - * Upon successful completion, it will return RESULT_OK. + * Upon successful completion, it will return RESULT_OK and, if provided via [createIntent], launch the next intent to + * route the user back into the main app. */ class RegistrationActivity : ComponentActivity() { companion object { + private const val NEXT_INTENT_EXTRA = "next_intent" + + /** + * @param nextIntent An optional intent to launch once registration completes successfully. This is how the caller + * (which lives outside this module) routes the user back into the main app, since the launching activity will + * typically have finished itself. + */ @JvmStatic - fun createIntent(context: Context): Intent { - return Intent(context, RegistrationActivity::class.java) + @JvmOverloads + fun createIntent(context: Context, nextIntent: Intent? = null): Intent { + return Intent(context, RegistrationActivity::class.java).apply { + if (nextIntent != null) { + putExtra(NEXT_INTENT_EXTRA, nextIntent) + } + } } } @@ -50,6 +64,7 @@ class RegistrationActivity : ComponentActivity() { modifier = Modifier.fillMaxSize(), onRegistrationComplete = { setResult(RESULT_OK) + IntentCompat.getParcelableExtra(intent, NEXT_INTENT_EXTRA, Intent::class.java)?.let { startActivity(it) } finish() } ) diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt index 1da0b8e64e..006ced4c15 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationDependencies.kt @@ -26,6 +26,7 @@ class RegistrationDependencies( companion object { lateinit var dependencies: RegistrationDependencies + @JvmStatic fun provide(registrationDependencies: RegistrationDependencies) { dependencies = registrationDependencies SensitiveLog.init(dependencies.sensitiveLogger) diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt index 2f5123cf97..997cd1bb98 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt @@ -41,6 +41,7 @@ import org.signal.registration.NetworkController.SessionMetadata import org.signal.registration.NetworkController.SvrCredentials import org.signal.registration.NetworkController.UpdateSessionError import org.signal.registration.proto.ProvisioningData +import org.signal.registration.proto.RestoreDecision import org.signal.registration.proto.SvrCredential import org.signal.registration.screens.localbackuprestore.LocalBackupInfo import org.signal.registration.screens.remotebackuprestore.RemoteBackupRestoreProgress @@ -533,6 +534,19 @@ class RegistrationRepository(val context: Context, val networkController: Networ storageController.commitRegistrationData() } + /** + * Records the terminal restore decision the user reached (new account, skipped a restore, or successfully restored) + * and commits it. The app translates this into its own restore-decision state so the rest of the app knows what + * happened during registration. + */ + suspend fun setRestoreDecision(decision: RestoreDecision): Unit = withContext(Dispatchers.IO) { + Log.i(TAG, "[setRestoreDecision] Recording restore decision: $decision") + storageController.updateInProgressRegistrationData { + this.restoreDecision = decision + } + storageController.commitRegistrationData() + } + suspend fun getPreExistingRegistrationData(): PreExistingRegistrationData? { return storageController.getPreExistingRegistrationData() } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModel.kt index 426662309d..0224ddf1f4 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModel.kt @@ -13,6 +13,7 @@ import kotlinx.coroutines.flow.StateFlow import org.signal.core.util.logging.Log import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel class DeviceTransferCompleteViewModel( @@ -41,6 +42,7 @@ class DeviceTransferCompleteViewModel( ) { when (event) { DeviceTransferCompleteScreenEvents.ContinueClicked -> { + repository.setRestoreDecision(RestoreDecision.COMPLETED) repository.finishRegistrationOrCreateProfile(parentEventEmitter) } DeviceTransferCompleteScreenEvents.ConsumeOneTimeEvent -> { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModel.kt index c4dfff7d5f..f6bcba65bd 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModel.kt @@ -23,6 +23,7 @@ import org.signal.core.util.logging.Log import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel import org.signal.registration.screens.util.navigateBack import org.signal.registration.screens.util.navigateTo @@ -118,6 +119,7 @@ class LocalBackupRestoreViewModel( resultBus.sendResult(resultKey, LocalBackupRestoreResult.Success(state.aep)) parentEventEmitter.navigateBack() } else { + repository.setRestoreDecision(RestoreDecision.COMPLETED) repository.finishRegistrationOrCreateProfile(parentEventEmitter) } } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt index 3a6f60c46a..9c760178af 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt @@ -21,6 +21,7 @@ import org.signal.registration.NetworkController import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel /** @@ -77,6 +78,7 @@ class PinCreationViewModel( private suspend fun applyOptOut() { Log.i(TAG, "[OptOut] User opted out of creating a PIN. Recording choice and completing registration.") repository.setPinOptedOut() + repository.setRestoreDecision(RestoreDecision.NEW_ACCOUNT) parentEventEmitter(RegistrationFlowEvent.RegistrationComplete) } @@ -99,6 +101,7 @@ class PinCreationViewModel( return when (val result = repository.setNewlyCreatedPin(pin, state.isAlphanumericKeyboard, masterKey)) { is RequestResult.Success -> { Log.i(TAG, "[PinSubmitted] Successfully backed up master key to SVR.") + repository.setRestoreDecision(RestoreDecision.NEW_ACCOUNT) repository.finishRegistrationOrCreateProfile(parentEventEmitter) state } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt index 275aae8eb0..c795d1cb4b 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt @@ -21,6 +21,7 @@ import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel import org.signal.registration.screens.util.navigateTo @@ -115,6 +116,7 @@ class PinEntryForSvrRestoreViewModel( is RequestResult.Success -> { Log.i(TAG, "[PinEntered] Successfully restored master key from SVR.") repository.enqueueSvrResetGuessCountJob() + repository.setRestoreDecision(RestoreDecision.COMPLETED) parentEventEmitter(RegistrationFlowEvent.MasterKeyRestoredFromSvr(result.result.masterKey)) repository.finishRegistrationOrCreateProfile(parentEventEmitter) state @@ -146,6 +148,7 @@ class PinEntryForSvrRestoreViewModel( private suspend fun handleSkip() { Log.i(TAG, "[Skip] User opted out of restoring data and creating a PIN. Recording choice and completing registration.") repository.setPinOptedOut() + repository.setRestoreDecision(RestoreDecision.SKIPPED) parentEventEmitter(RegistrationFlowEvent.RegistrationComplete) } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModel.kt index 7ce286ce88..2d3f5e0e7c 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModel.kt @@ -23,6 +23,7 @@ import org.signal.libsignal.net.RequestResult import org.signal.registration.NetworkController import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel import org.signal.registration.screens.util.navigateBack import kotlin.coroutines.CoroutineContext @@ -116,6 +117,7 @@ class RemoteBackupRestoreViewModel( restoreProgress = null ) parentEventEmitter(RegistrationFlowEvent.UserSuppliedAepVerified(aep)) + repository.setRestoreDecision(RestoreDecision.COMPLETED) repository.finishRegistrationOrCreateProfile(parentEventEmitter) } is RemoteBackupRestoreProgress.NetworkError -> { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt index 6915340e29..fa609c08ea 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt @@ -23,6 +23,7 @@ import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision import org.signal.registration.screens.EventDrivenViewModel import org.signal.registration.screens.util.navigateTo @@ -106,6 +107,7 @@ class ArchiveRestoreSelectionViewModel( } is ArchiveRestoreSelectionScreenEvents.ConfirmSkip -> { notifyOldDevice(state.restoreMethodToken, NetworkController.RestoreMethod.DECLINE) + repository.setRestoreDecision(RestoreDecision.SKIPPED) parentEventEmitter.navigateTo(RegistrationRoute.PinCreate) state.copy(showSkipWarningDialog = false) } diff --git a/feature/registration/src/main/protowire/Registration.proto b/feature/registration/src/main/protowire/Registration.proto index c3db3ef149..a1d5702db1 100644 --- a/feature/registration/src/main/protowire/Registration.proto +++ b/feature/registration/src/main/protowire/Registration.proto @@ -52,6 +52,20 @@ message RegistrationData { // JSON-serialized flow state snapshot (from saveFlowState/restoreFlowState) string flowStateJson = 21; + + // The terminal restore decision the user reached during this flow. The app translates this into its own + // RestoreDecisionState when committing, so the rest of the app knows whether we're a fresh account, skipped a + // restore, or successfully restored data. + RestoreDecision restoreDecision = 25; +} + +// Mirrors the terminal states of the app's RestoreDecisionState. We intentionally do not model the transient +// pending states (START / INTEND_TO_RESTORE) here, as the new flow performs any restore inline before completing. +enum RestoreDecision { + UNSET = 0; + NEW_ACCOUNT = 1; + SKIPPED = 2; + COMPLETED = 3; } message SvrCredential { diff --git a/feature/registration/src/test/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModelTest.kt index f63d5fbead..91ca8036d6 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/devicetransfer/complete/DeviceTransferCompleteViewModelTest.kt @@ -18,6 +18,7 @@ import org.junit.Before import org.junit.Test import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision @OptIn(ExperimentalCoroutinesApi::class) class DeviceTransferCompleteViewModelTest { @@ -57,6 +58,7 @@ class DeviceTransferCompleteViewModelTest { stateEmitter ) + coVerify { mockRepository.setRestoreDecision(RestoreDecision.COMPLETED) } coVerify { mockRepository.finishRegistrationOrCreateProfile(parentEventEmitter, any()) } } } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModelTest.kt index 0baff0dad4..8f1b0ebab8 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/localbackuprestore/LocalBackupRestoreViewModelTest.kt @@ -15,18 +15,32 @@ import assertk.assertions.isNotNull import assertk.assertions.isNull import assertk.assertions.isTrue import assertk.assertions.prop +import io.mockk.coVerify +import io.mockk.every import io.mockk.mockk +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.resetMain import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.test.setMain +import org.junit.After import org.junit.Before import org.junit.Test +import org.signal.archive.LocalBackupRestoreProgress import org.signal.core.ui.navigation.ResultEventBus import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision import java.time.LocalDateTime +@OptIn(ExperimentalCoroutinesApi::class) class LocalBackupRestoreViewModelTest { + private val testDispatcher = UnconfinedTestDispatcher() + private lateinit var mockRepository: RegistrationRepository private lateinit var resultBus: ResultEventBus private lateinit var emittedParentEvents: MutableList @@ -38,6 +52,7 @@ class LocalBackupRestoreViewModelTest { @Before fun setup() { + Dispatchers.setMain(testDispatcher) mockRepository = mockk(relaxed = true) resultBus = ResultEventBus() emittedParentEvents = mutableListOf() @@ -46,6 +61,11 @@ class LocalBackupRestoreViewModelTest { stateEmitter = { state -> emittedStates.add(state) } } + @After + fun tearDown() { + Dispatchers.resetMain() + } + private fun createViewModel(isPreRegistration: Boolean): LocalBackupRestoreViewModel { return LocalBackupRestoreViewModel( repository = mockRepository, @@ -227,4 +247,25 @@ class LocalBackupRestoreViewModelTest { assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents.first()).isEqualTo(RegistrationFlowEvent.NavigateBack) } + + // ==================== Restore Completion Tests ==================== + + @Test + fun `successful V1 restore records COMPLETED restore decision and finishes registration`() = runTest(testDispatcher) { + val viewModel = createViewModel(isPreRegistration = false) + val backupInfo = LocalBackupInfo( + type = LocalBackupInfo.BackupType.V1, + date = LocalDateTime.now(), + name = "backup.backup", + uri = mockk() + ) + val initialState = LocalBackupRestoreState(backupInfo = backupInfo) + + every { mockRepository.restoreV1Backup(any(), any()) } returns flowOf(LocalBackupRestoreProgress.Complete) + + viewModel.applyEvent(initialState, LocalBackupRestoreEvents.PassphraseSubmitted("passphrase"), stateEmitter) + + coVerify { mockRepository.setRestoreDecision(RestoreDecision.COMPLETED) } + coVerify { mockRepository.finishRegistrationOrCreateProfile(parentEventEmitter, any()) } + } } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/pincreation/PinCreationViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/pincreation/PinCreationViewModelTest.kt index 48fff9a9ae..4a2e3b2d76 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/pincreation/PinCreationViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/pincreation/PinCreationViewModelTest.kt @@ -29,6 +29,7 @@ import org.signal.registration.NetworkController import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision @OptIn(ExperimentalCoroutinesApi::class) class PinCreationViewModelTest { @@ -72,6 +73,7 @@ class PinCreationViewModelTest { viewModel.applyEvent(initialState, PinCreationScreenEvents.PinSubmitted("123456")) + coVerify { mockRepository.setRestoreDecision(RestoreDecision.NEW_ACCOUNT) } coVerify { mockRepository.finishRegistrationOrCreateProfile(parentEventEmitter, any()) } } @@ -112,6 +114,7 @@ class PinCreationViewModelTest { viewModel.applyEvent(initialState, PinCreationScreenEvents.OptOut) coVerify { mockRepository.setPinOptedOut() } + coVerify { mockRepository.setRestoreDecision(RestoreDecision.NEW_ACCOUNT) } assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents.first()).isEqualTo(RegistrationFlowEvent.RegistrationComplete) } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt index 54c6abe65a..5547676554 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModelTest.kt @@ -24,6 +24,7 @@ import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision class PinEntryForSvrRestoreViewModelTest { @@ -75,6 +76,7 @@ class PinEntryForSvrRestoreViewModelTest { assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents[0]).isInstanceOf() + coVerify { mockRepository.setRestoreDecision(RestoreDecision.COMPLETED) } coVerify { mockRepository.finishRegistrationOrCreateProfile(parentEventEmitter, any()) } } @@ -231,6 +233,7 @@ class PinEntryForSvrRestoreViewModelTest { viewModel.applyEvent(initialState, PinEntryScreenEvents.Skip, parentEventEmitter, stateEmitter) coVerify { mockRepository.setPinOptedOut() } + coVerify { mockRepository.setRestoreDecision(RestoreDecision.SKIPPED) } assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents.first()).isEqualTo(RegistrationFlowEvent.RegistrationComplete) } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModelTest.kt index 9df23a9d9e..ee9bec012a 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/remotebackuprestore/RemoteBackupRestoreViewModelTest.kt @@ -30,6 +30,7 @@ import org.signal.libsignal.net.RequestResult import org.signal.registration.NetworkController import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationRepository +import org.signal.registration.proto.RestoreDecision @OptIn(ExperimentalCoroutinesApi::class) class RemoteBackupRestoreViewModelTest { @@ -173,6 +174,7 @@ class RemoteBackupRestoreViewModelTest { assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents[0]).isInstanceOf() + coVerify { mockRepository.setRestoreDecision(RestoreDecision.COMPLETED) } coVerify { mockRepository.finishRegistrationOrCreateProfile(parentEventEmitter, any()) } } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModelTest.kt index f1943176e3..d6d9693035 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModelTest.kt @@ -12,6 +12,7 @@ import assertk.assertions.isFalse import assertk.assertions.isInstanceOf import assertk.assertions.isTrue import assertk.assertions.prop +import io.mockk.coVerify import io.mockk.mockk import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.test.runTest @@ -22,9 +23,11 @@ import org.signal.registration.RegistrationFlowEvent import org.signal.registration.RegistrationFlowState import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute +import org.signal.registration.proto.RestoreDecision class ArchiveRestoreSelectionViewModelTest { + private lateinit var mockRepository: RegistrationRepository private lateinit var emittedParentEvents: MutableList private lateinit var parentEventEmitter: (RegistrationFlowEvent) -> Unit private lateinit var emittedStates: MutableList @@ -32,6 +35,7 @@ class ArchiveRestoreSelectionViewModelTest { @Before fun setup() { + mockRepository = mockk(relaxed = true) emittedParentEvents = mutableListOf() parentEventEmitter = { event -> emittedParentEvents.add(event) } emittedStates = mutableListOf() @@ -49,7 +53,7 @@ class ArchiveRestoreSelectionViewModelTest { return ArchiveRestoreSelectionViewModel( restoreOptions = restoreOptions, isPreRegistration = isPreRegistration, - repository = mockk(relaxed = true), + repository = mockRepository, parentState = MutableStateFlow(RegistrationFlowState()), parentEventEmitter = parentEventEmitter ) @@ -192,6 +196,7 @@ class ArchiveRestoreSelectionViewModelTest { viewModel.applyEvent(initialState, ArchiveRestoreSelectionScreenEvents.ConfirmSkip, stateEmitter) + coVerify { mockRepository.setRestoreDecision(RestoreDecision.SKIPPED) } assertThat(emittedParentEvents).hasSize(1) assertThat(emittedParentEvents.first()) .isInstanceOf()