Improve regV5 restore flows.

This commit is contained in:
Greyson Parrelli
2026-04-15 14:12:28 -04:00
committed by jeffrey-signal
parent f1b61f8f7e
commit fcdbf93626
19 changed files with 331 additions and 164 deletions
@@ -57,6 +57,7 @@ import org.thoughtcrime.securesms.registration.fcm.PushChallengeRequest
import org.thoughtcrime.securesms.registration.viewmodel.SvrAuthCredentialSet
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.SvrNoDataException
import org.whispersystems.signalservice.api.archive.ArchiveServiceAccess
import org.whispersystems.signalservice.api.provisioning.ProvisioningSocket
import org.whispersystems.signalservice.api.svr.SecureValueRecovery.BackupResponse
import org.whispersystems.signalservice.internal.crypto.SecondaryProvisioningCipher
@@ -65,6 +66,7 @@ import org.whispersystems.signalservice.internal.push.PushServiceSocket
import java.io.IOException
import java.util.Locale
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import org.whispersystems.signalservice.api.account.AccountAttributes as ServiceAccountAttributes
import org.whispersystems.signalservice.api.account.PreKeyCollection as ServicePreKeyCollection
@@ -536,6 +538,57 @@ class AppRegistrationNetworkController(
}
}
override suspend fun getRemoteBackupInfo(): RequestResult<NetworkController.GetBackupInfoResponse, NetworkController.GetBackupInfoError> = withContext(Dispatchers.IO) {
val aci = SignalStore.account.aci ?: return@withContext RequestResult.ApplicationError(IllegalStateException("ACI not available"))
val currentTime = System.currentTimeMillis()
val messageCredential = SignalStore.backup.messageCredentials.byDay.getForCurrentTime(currentTime.milliseconds)
val access = if (messageCredential != null) {
ArchiveServiceAccess(messageCredential, SignalStore.backup.messageBackupKey)
} else {
when (val credResult = SignalNetwork.archive.getServiceCredentials(currentTime)) {
is NetworkResult.Success -> {
SignalStore.backup.messageCredentials.add(credResult.result.messageCredentials)
SignalStore.backup.messageCredentials.clearOlderThan(currentTime)
val credential = SignalStore.backup.messageCredentials.byDay.getForCurrentTime(currentTime.milliseconds)
?: return@withContext RequestResult.ApplicationError(IllegalStateException("Failed to obtain backup credentials after fetch"))
ArchiveServiceAccess(credential, SignalStore.backup.messageBackupKey)
}
is NetworkResult.StatusCodeError -> return@withContext RequestResult.ApplicationError(IllegalStateException("Failed to fetch backup credentials: ${credResult.code}"))
is NetworkResult.NetworkError -> return@withContext RequestResult.RetryableNetworkError(credResult.exception)
is NetworkResult.ApplicationError -> return@withContext RequestResult.ApplicationError(credResult.throwable)
}
}
when (val result = SignalNetwork.archive.getBackupInfo(aci, access)) {
is NetworkResult.Success -> {
val info = result.result
RequestResult.Success(
NetworkController.GetBackupInfoResponse(
cdn = info.cdn,
backupDir = info.backupDir,
mediaDir = info.mediaDir,
backupName = info.backupName,
usedSpace = info.usedSpace
)
)
}
is NetworkResult.StatusCodeError -> {
when (result.code) {
400 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadArguments(result.stringBody))
401 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadAuthCredential(result.stringBody))
403 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.Forbidden(result.stringBody))
404 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.NoBackup)
429 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.RateLimited(0.seconds))
else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${result.code}"))
}
}
is NetworkResult.NetworkError -> RequestResult.RetryableNetworkError(result.exception)
is NetworkResult.ApplicationError -> RequestResult.ApplicationError(result.throwable)
}
}
override fun startProvisioning(): Flow<ProvisioningEvent> = callbackFlow {
val socketHandles = mutableListOf<java.io.Closeable>()
val configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration()
@@ -22,7 +22,6 @@ import org.signal.registration.PreExistingRegistrationData
import org.signal.registration.StorageController
import org.signal.registration.proto.RegistrationData
import org.signal.registration.screens.localbackuprestore.LocalBackupInfo
import org.signal.registration.screens.restoreselection.ArchiveRestoreOption
import org.thoughtcrime.securesms.backup.FullBackupImporter
import org.thoughtcrime.securesms.backup.v2.BackupRepository
import org.thoughtcrime.securesms.backup.v2.local.LocalArchiver
@@ -178,16 +177,6 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo
Unit
}
override suspend fun getAvailableRestoreOptions(): Set<ArchiveRestoreOption> = withContext(Dispatchers.IO) {
// TODO [greyson] Real options
val options = mutableSetOf<ArchiveRestoreOption>()
options.add(ArchiveRestoreOption.LocalBackup)
options.add(ArchiveRestoreOption.DeviceTransfer)
options
}
override fun restoreLocalBackupV1(uri: Uri, passphrase: String): Flow<LocalBackupRestoreProgress> = flow {
// TODO [greyson] better progress
Log.d(TAG, "Starting V1 local backup restore from: $uri")
@@ -15,6 +15,8 @@ import org.signal.registration.NetworkController.BackupMasterKeyError
import org.signal.registration.NetworkController.CheckSvrCredentialsError
import org.signal.registration.NetworkController.CheckSvrCredentialsResponse
import org.signal.registration.NetworkController.CreateSessionError
import org.signal.registration.NetworkController.GetBackupInfoError
import org.signal.registration.NetworkController.GetBackupInfoResponse
import org.signal.registration.NetworkController.GetSessionStatusError
import org.signal.registration.NetworkController.GetSvrCredentialsError
import org.signal.registration.NetworkController.MasterKeyResponse
@@ -217,4 +219,12 @@ class DebugNetworkController(
return delegate.checkSvrCredentials(e164, credentials)
}
override suspend fun getRemoteBackupInfo(): RequestResult<GetBackupInfoResponse, GetBackupInfoError> {
NetworkDebugState.getOverride<RequestResult<GetBackupInfoResponse, GetBackupInfoError>>("getRemoteBackupInfo")?.let {
Log.d(TAG, "[getRemoteBackupInfo] Returning debug override")
return it
}
return delegate.getRemoteBackupInfo()
}
}
@@ -13,11 +13,13 @@ import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.Response
import org.signal.core.models.MasterKey
import org.signal.core.util.Base64
import org.signal.core.util.logging.Log
import org.signal.libsignal.net.Network
import org.signal.libsignal.net.RequestResult
@@ -25,6 +27,9 @@ import org.signal.libsignal.protocol.IdentityKey
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.ecc.ECPrivateKey
import org.signal.libsignal.protocol.util.Hex
import org.signal.libsignal.zkgroup.GenericServerPublicParams
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialResponse
import org.signal.registration.NetworkController
import org.signal.registration.NetworkController.AccountAttributes
import org.signal.registration.NetworkController.CheckSvrCredentialsRequest
@@ -62,8 +67,11 @@ import org.whispersystems.signalservice.internal.push.PushServiceSocket
import org.whispersystems.signalservice.internal.util.StaticCredentialsProvider
import org.whispersystems.signalservice.internal.websocket.LibSignalChatConnection
import java.io.IOException
import java.time.Instant
import java.util.Locale
import kotlin.time.Duration
import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import org.whispersystems.signalservice.api.account.AccountAttributes as ServiceAccountAttributes
import org.whispersystems.signalservice.api.account.PreKeyCollection as ServicePreKeyCollection
@@ -843,6 +851,126 @@ class DemoNetworkController(
}
}
override suspend fun getRemoteBackupInfo(): RequestResult<NetworkController.GetBackupInfoResponse, NetworkController.GetBackupInfoError> = withContext(Dispatchers.IO) {
val aci = RegistrationPreferences.aci
val password = RegistrationPreferences.servicePassword
val aep = RegistrationPreferences.aep
if (aci == null || password == null || aep == null) {
Log.w(TAG, "[getRemoteBackupInfo] Credentials not available")
return@withContext RequestResult.ApplicationError(IllegalStateException("Credentials not available"))
}
try {
val messageBackupKey = aep.deriveMessageBackupKey()
// Remember, this is a demo app
val credential = fetchArchiveServiceCredential(aci.toString(), password)
?: return@withContext RequestResult.RetryableNetworkError(IOException("Failed to fetch archive credentials"))
val headers = buildZkAuthHeaders(messageBackupKey, aci, credential)
val baseUrl = serviceConfiguration.signalServiceUrls[0].url
val request = okhttp3.Request.Builder()
.url("$baseUrl/v1/archives")
.get()
.apply { headers.forEach { (k, v) -> header(k, v) } }
.build()
okHttpClient.newCall(request).execute().use { response ->
when (response.code) {
200 -> {
val info = json.decodeFromString<NetworkController.GetBackupInfoResponse>(response.body.string())
RequestResult.Success(info)
}
400 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadArguments(response.body.string()))
401 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadAuthCredential(response.body.string()))
403 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.Forbidden(response.body.string()))
404 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.NoBackup)
429 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.RateLimited(response.retryAfter()))
else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${response.code}, body: ${response.body?.string()}"))
}
}
} catch (e: IOException) {
Log.w(TAG, "[getRemoteBackupInfo] IOException", e)
RequestResult.RetryableNetworkError(e)
} catch (e: Exception) {
Log.w(TAG, "[getRemoteBackupInfo] Exception", e)
RequestResult.ApplicationError(e)
}
}
/**
* Fetches an archive service credential for today by calling GET /v1/archives/auth on the authenticated channel.
*/
private fun fetchArchiveServiceCredential(aci: String, password: String): ArchiveCredential? {
val currentTime = System.currentTimeMillis()
val roundedToNearestDay = currentTime.milliseconds.inWholeDays.days
val endTime = roundedToNearestDay + 7.days
val startSeconds = roundedToNearestDay.inWholeSeconds
val endSeconds = endTime.inWholeSeconds
val credentials = okhttp3.Credentials.basic(aci, password)
val baseUrl = serviceConfiguration.signalServiceUrls[0].url
val request = okhttp3.Request.Builder()
.url("$baseUrl/v1/archives/auth?redemptionStartSeconds=$startSeconds&redemptionEndSeconds=$endSeconds")
.get()
.header("Authorization", credentials)
.build()
okHttpClient.newCall(request).execute().use { response ->
if (response.code != 200) {
Log.w(TAG, "[fetchArchiveServiceCredential] Unexpected response code: ${response.code}")
return null
}
val body = response.body.string()
val parsed = json.decodeFromString<ArchiveCredentialsResponse>(body)
val todaySeconds = roundedToNearestDay.inWholeSeconds
return parsed.credentials["messages"]?.firstOrNull { it.redemptionTime == todaySeconds }
}
}
/**
* Builds the ZK auth headers (X-Signal-ZK-Auth, X-Signal-ZK-Auth-Signature) needed for
* anonymous archive requests.
*/
private fun buildZkAuthHeaders(
messageBackupKey: org.signal.core.models.backup.MessageBackupKey,
aci: org.signal.core.models.ServiceId.ACI,
credential: ArchiveCredential
): Map<String, String> {
val backupServerPublicParams = GenericServerPublicParams(serviceConfiguration.backupServerPublicParams)
val backupRequestContext = BackupAuthCredentialRequestContext.create(messageBackupKey.value, aci.rawUuid)
val backupAuthCredentialResponse = BackupAuthCredentialResponse(Base64.decode(credential.credential))
val backupAuthCredential = backupRequestContext.receiveResponse(
backupAuthCredentialResponse,
Instant.ofEpochSecond(credential.redemptionTime),
backupServerPublicParams
)
val presentation = backupAuthCredential.present(backupServerPublicParams).serialize()
val privateKey = messageBackupKey.deriveAnonymousCredentialPrivateKey(aci)
val signedPresentation = privateKey.calculateSignature(presentation)
return mapOf(
"X-Signal-ZK-Auth" to Base64.encodeWithPadding(presentation),
"X-Signal-ZK-Auth-Signature" to Base64.encodeWithPadding(signedPresentation)
)
}
@Serializable
private data class ArchiveCredentialsResponse(
val credentials: Map<String, List<ArchiveCredential>>
)
@Serializable
private data class ArchiveCredential(
val credential: String,
val redemptionTime: Long
)
private fun AccountAttributes.toServiceAccountAttributes(): ServiceAccountAttributes {
return ServiceAccountAttributes(
signalingKey,
@@ -32,7 +32,6 @@ import org.signal.registration.proto.RegistrationData
import org.signal.registration.sample.storage.RegistrationDatabase
import org.signal.registration.sample.storage.RegistrationPreferences
import org.signal.registration.screens.localbackuprestore.LocalBackupInfo
import org.signal.registration.screens.restoreselection.ArchiveRestoreOption
import java.io.File
import java.time.LocalDateTime
@@ -181,14 +180,6 @@ class DemoStorageController(private val context: Context) : StorageController {
Unit
}
override suspend fun getAvailableRestoreOptions(): Set<ArchiveRestoreOption> {
return setOf(
ArchiveRestoreOption.SignalSecureBackup,
ArchiveRestoreOption.LocalBackup,
ArchiveRestoreOption.DeviceTransfer
)
}
override suspend fun scanLocalBackupFolder(folderUri: Uri): List<LocalBackupInfo> = withContext(Dispatchers.IO) {
val folder = DocumentFile.fromTreeUri(context, folderUri) ?: return@withContext emptyList()
val children = folder.listFiles()
@@ -182,6 +182,21 @@ interface NetworkController {
*/
suspend fun setAccountAttributes(attributes: AccountAttributes): RequestResult<Unit, SetAccountAttributesError>
/**
* Fetches metadata about your current backup. This will be different for different key/credential pairs. For example, message credentials will always
* return 0 for used space since that is stored under the media key/credential.
*
* GET /v1/archives
* - 200: Success
* - 400: Bad arguments. The request may have been made on an authenticated channel.
* - 401: The provided backup auth credential presentation could not be verified or the public key signature was invalid or there is no backup associated with
* the backup-id in the presentation or the credential was of the wrong type (messages/media)
* - 403: Forbidden
* - 404: No backup
* - 429: Rate limited
*/
suspend fun getRemoteBackupInfo(): RequestResult<GetBackupInfoResponse, GetBackupInfoError>
/**
* Starts a provisioning session for QR-based quick restore.
*
@@ -213,24 +228,24 @@ interface NetworkController {
// */
// suspend fun registerAsSecondaryDevice(verificationCode: String, attributes: AccountAttributes, aciPreKeys: PreKeyCollection, pniPreKeys: PreKeyCollection, fcmToken: String?)
sealed class CreateSessionError() : BadRequestError {
sealed class CreateSessionError : BadRequestError {
data class InvalidRequest(val message: String) : CreateSessionError()
data class RateLimited(val retryAfter: Duration) : CreateSessionError()
}
sealed class GetSessionStatusError() : BadRequestError {
sealed class GetSessionStatusError : BadRequestError {
data class InvalidSessionId(val message: String) : GetSessionStatusError()
data class SessionNotFound(val message: String) : GetSessionStatusError()
data class InvalidRequest(val message: String) : GetSessionStatusError()
}
sealed class UpdateSessionError() : BadRequestError {
sealed class UpdateSessionError : BadRequestError {
data class RejectedUpdate(val message: String) : UpdateSessionError()
data class InvalidRequest(val message: String) : UpdateSessionError()
data class RateLimited(val retryAfter: Duration, val session: SessionMetadata) : UpdateSessionError()
}
sealed class RequestVerificationCodeError() : BadRequestError {
sealed class RequestVerificationCodeError : BadRequestError {
data class InvalidSessionId(val message: String) : RequestVerificationCodeError()
data class SessionNotFound(val message: String) : RequestVerificationCodeError()
data class MissingRequestInformationOrAlreadyVerified(val session: SessionMetadata) : RequestVerificationCodeError()
@@ -240,14 +255,14 @@ interface NetworkController {
data class ThirdPartyServiceError(val data: ThirdPartyServiceErrorResponse) : RequestVerificationCodeError()
}
sealed class SubmitVerificationCodeError() : BadRequestError {
sealed class SubmitVerificationCodeError : BadRequestError {
data class InvalidSessionIdOrVerificationCode(val message: String) : SubmitVerificationCodeError()
data class SessionNotFound(val message: String) : SubmitVerificationCodeError()
data class SessionAlreadyVerifiedOrNoCodeRequested(val session: SessionMetadata) : SubmitVerificationCodeError()
data class RateLimited(val retryAfter: Duration, val session: SessionMetadata) : SubmitVerificationCodeError()
}
sealed class RegisterAccountError() : BadRequestError {
sealed class RegisterAccountError : BadRequestError {
data class SessionNotFoundOrNotVerified(val message: String) : RegisterAccountError()
data class RegistrationRecoveryPasswordIncorrect(val message: String) : RegisterAccountError()
data object DeviceTransferPossible : RegisterAccountError()
@@ -256,38 +271,46 @@ interface NetworkController {
data class RateLimited(val retryAfter: Duration) : RegisterAccountError()
}
sealed class RestoreMasterKeyError() : BadRequestError {
sealed class RestoreMasterKeyError : BadRequestError {
data class WrongPin(val triesRemaining: Int) : RestoreMasterKeyError()
data object NoDataFound : RestoreMasterKeyError()
}
sealed class BackupMasterKeyError() : BadRequestError {
sealed class BackupMasterKeyError : BadRequestError {
data object EnclaveNotFound : BackupMasterKeyError()
data object NotRegistered : BackupMasterKeyError()
}
sealed class SetRegistrationLockError() : BadRequestError {
sealed class SetRegistrationLockError : BadRequestError {
data class InvalidRequest(val message: String) : SetRegistrationLockError()
data object Unauthorized : SetRegistrationLockError()
data object NotRegistered : SetRegistrationLockError()
data object NoPinSet : SetRegistrationLockError()
}
sealed class SetAccountAttributesError() : BadRequestError {
sealed class SetAccountAttributesError : BadRequestError {
data class InvalidRequest(val message: String) : SetAccountAttributesError()
data object Unauthorized : SetAccountAttributesError()
}
sealed class GetSvrCredentialsError() : BadRequestError {
sealed class GetSvrCredentialsError : BadRequestError {
data object Unauthorized : GetSvrCredentialsError()
data object NoServiceCredentialsAvailable : GetSvrCredentialsError()
}
sealed class CheckSvrCredentialsError() : BadRequestError {
sealed class CheckSvrCredentialsError : BadRequestError {
data object Unauthorized : CheckSvrCredentialsError()
data class InvalidRequest(val message: String) : CheckSvrCredentialsError()
}
sealed class GetBackupInfoError : BadRequestError {
data class BadArguments(val body: String? = null) : GetBackupInfoError()
data class BadAuthCredential(val body: String? = null) : GetBackupInfoError()
data class Forbidden(val body: String? = null) : GetBackupInfoError()
data object NoBackup : GetBackupInfoError()
data class RateLimited(val retryAfter: Duration) : GetBackupInfoError()
}
data class MasterKeyResponse(
val masterKey: MasterKey
)
@@ -432,6 +455,15 @@ interface NetworkController {
SMS, VOICE
}
@Serializable
data class GetBackupInfoResponse(
val cdn: Int?,
val backupDir: String?,
val mediaDir: String?,
val backupName: String?,
val usedSpace: Long?
)
/**
* Data received from the old device during QR-based provisioning.
*/
@@ -60,6 +60,7 @@ import org.signal.registration.screens.pinentry.PinEntryForSvrRestoreViewModel
import org.signal.registration.screens.pinentry.PinEntryScreen
import org.signal.registration.screens.quickrestore.QuickRestoreQrScreen
import org.signal.registration.screens.quickrestore.QuickRestoreQrViewModel
import org.signal.registration.screens.restoreselection.ArchiveRestoreOption
import org.signal.registration.screens.restoreselection.ArchiveRestoreSelectionScreen
import org.signal.registration.screens.restoreselection.ArchiveRestoreSelectionViewModel
import org.signal.registration.screens.util.navigateBack
@@ -113,13 +114,42 @@ sealed interface RegistrationRoute : NavKey, Parcelable {
data object PinCreate : RegistrationRoute
@Serializable
data object ArchiveRestoreSelection : RegistrationRoute
data class ArchiveRestoreSelection(val restoreOptions: List<ArchiveRestoreOption>) : RegistrationRoute {
companion object {
fun forQuickRestore(hasRemoteBackup: Boolean): ArchiveRestoreSelection {
return ArchiveRestoreSelection(
buildList {
if (hasRemoteBackup) {
add(ArchiveRestoreOption.SignalSecureBackup)
}
add(ArchiveRestoreOption.LocalBackup)
add(ArchiveRestoreOption.DeviceTransfer)
}
)
}
@Serializable
data object ChooseRestoreOptionBeforeRegistration : RegistrationRoute
fun forManualRestore(): ArchiveRestoreSelection {
return ArchiveRestoreSelection(
buildList {
add(ArchiveRestoreOption.SignalSecureBackup)
add(ArchiveRestoreOption.LocalBackup)
add(ArchiveRestoreOption.DeviceTransfer)
}
)
}
@Serializable
data object ChooseRestoreOptionAfterRegistration : RegistrationRoute
fun forPostRegister(): ArchiveRestoreSelection {
return ArchiveRestoreSelection(
buildList {
add(ArchiveRestoreOption.SignalSecureBackup)
add(ArchiveRestoreOption.LocalBackup)
add(ArchiveRestoreOption.DeviceTransfer)
add(ArchiveRestoreOption.None)
}
)
}
}
}
@Serializable
data class LocalBackupRestore(val isPreRegistration: Boolean) : RegistrationRoute
@@ -248,9 +278,9 @@ private fun EntryProviderScope<NavKey>.navigationEntries(
onEvent = { event ->
when (event) {
WelcomeScreenEvents.Continue -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry))
WelcomeScreenEvents.LinkDevice -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.QuickRestoreQrScan)) // TODO - Replace this with the device-link QR code
WelcomeScreenEvents.LinkDevice -> throw NotImplementedError("Haven't implemented linked devices yet")
WelcomeScreenEvents.HasOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.QuickRestoreQrScan))
WelcomeScreenEvents.DoesNotHaveOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.ChooseRestoreOptionBeforeRegistration))
WelcomeScreenEvents.DoesNotHaveOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.ArchiveRestoreSelection.forManualRestore()))
}
}
)
@@ -447,14 +477,12 @@ private fun EntryProviderScope<NavKey>.navigationEntries(
)
}
// -- Archive Restore Selection Screen
entry<RegistrationRoute.ArchiveRestoreSelection> {
// -- Archive Restore Selection for Quick Restore Screen
entry<RegistrationRoute.ArchiveRestoreSelection> { key ->
val viewModel: ArchiveRestoreSelectionViewModel = viewModel(
factory = ArchiveRestoreSelectionViewModel.Factory(
repository = registrationRepository,
parentState = registrationViewModel.state,
parentEventEmitter = registrationViewModel::onEvent,
isPreRegistration = false
restoreOptions = key.restoreOptions,
parentEventEmitter = registrationViewModel::onEvent
)
)
val state by viewModel.state.collectAsStateWithLifecycle()
@@ -465,28 +493,6 @@ private fun EntryProviderScope<NavKey>.navigationEntries(
)
}
// -- Choose Restore Option Before Registration (saves selection, then navigates to phone number entry)
entry<RegistrationRoute.ChooseRestoreOptionBeforeRegistration> {
val viewModel: ArchiveRestoreSelectionViewModel = viewModel(
factory = ArchiveRestoreSelectionViewModel.Factory(
repository = registrationRepository,
parentState = registrationViewModel.state,
parentEventEmitter = registrationViewModel::onEvent,
isPreRegistration = true
)
)
val state by viewModel.state.collectAsStateWithLifecycle()
ArchiveRestoreSelectionScreen(
state = state,
onEvent = { viewModel.onEvent(it) }
)
}
entry<RegistrationRoute.ChooseRestoreOptionAfterRegistration> {
TODO("Implement RestoreScreen")
}
// -- Local Backup Restore Screen
entry<RegistrationRoute.LocalBackupRestore> { key ->
val viewModel: LocalBackupRestoreViewModel = viewModel(
@@ -526,6 +532,8 @@ private fun EntryProviderScope<NavKey>.navigationEntries(
)
}
// TODO I think we can re-use the screen but attach different viewmodels to progress forward rather than do for-result flows?
// -- Enter AEP
entry<RegistrationRoute.EnterAepScreen> {
val viewModel: EnterAepViewModel = viewModel(
@@ -43,7 +43,6 @@ import org.signal.registration.NetworkController.UpdateSessionError
import org.signal.registration.proto.ProvisioningData
import org.signal.registration.proto.SvrCredential
import org.signal.registration.screens.localbackuprestore.LocalBackupInfo
import org.signal.registration.screens.restoreselection.ArchiveRestoreOption
import org.signal.registration.util.SensitiveLog
import java.security.SecureRandom
import java.util.Locale
@@ -510,10 +509,6 @@ class RegistrationRepository(val context: Context, val networkController: Networ
data.aci.isNotEmpty() && data.pni.isNotEmpty()
}
suspend fun getAvailableRestoreOptions(): Set<ArchiveRestoreOption> = withContext(Dispatchers.IO) {
storageController.getAvailableRestoreOptions()
}
fun restoreV1Backup(uri: Uri, passphrase: String): Flow<LocalBackupRestoreProgress> {
return storageController.restoreLocalBackupV1(uri, passphrase)
}
@@ -19,7 +19,6 @@ import org.signal.libsignal.protocol.state.KyberPreKeyRecord
import org.signal.libsignal.protocol.state.SignedPreKeyRecord
import org.signal.registration.proto.RegistrationData
import org.signal.registration.screens.localbackuprestore.LocalBackupInfo
import org.signal.registration.screens.restoreselection.ArchiveRestoreOption
import org.signal.registration.util.ACIParceler
import org.signal.registration.util.AccountEntropyPoolParceler
import org.signal.registration.util.IdentityKeyPairParceler
@@ -83,12 +82,6 @@ interface StorageController {
*/
suspend fun commitRegistrationData()
/**
* Returns the set of restore options that are currently available to the user.
* For example, if a local backup file is present on the device, [ArchiveRestoreOption.LocalBackup] should be included.
*/
suspend fun getAvailableRestoreOptions(): Set<ArchiveRestoreOption>
/**
* Begins restoring from a V1 (.backup) file identified by the given [uri].
*
@@ -33,6 +33,7 @@ abstract class EventDrivenViewModel<E : DebugLoggable>(
}
fun onEvent(event: E) {
// Unlimited buffer means this will always succeed
eventChannel.trySend(event)
}
@@ -205,12 +205,14 @@ class PhoneNumberEntryViewModel(
// If the user selected a restore option before entering their phone number, navigate to the restore flow
if (state.pendingRestoreOption != null) {
parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164))
Log.i(TAG, "Pending restore option: ${state.pendingRestoreOption}. Navigating to appropriate screen.")
when (state.pendingRestoreOption) {
PendingRestoreOption.LocalBackup -> parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = true))
PendingRestoreOption.RemoteBackup -> {
Log.w(TAG, "[PendingRestore] Remote backup restore not yet implemented")
}
PendingRestoreOption.RemoteBackup -> parentEventEmitter.navigateTo(RegistrationRoute.EnterAepScreen)
}
return state
}
@@ -137,7 +137,7 @@ class PinEntryForRegistrationLockViewModel(
parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool))
// TODO storage service restore + profile screen
when {
response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration)
response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forPostRegister())
else -> parentEventEmitter.navigateTo(RegistrationRoute.FullyComplete)
}
state
@@ -59,8 +59,7 @@ class QuickRestoreQrViewModel(
state
}
is QuickRestoreQrEvents.UseProxy -> {
// TODO [registration] - Navigate to proxy settings
state
throw NotImplementedError("Proxy settings not implemented!")
}
is QuickRestoreQrEvents.DismissError -> {
startProvisioning()
@@ -99,7 +98,7 @@ class QuickRestoreQrViewModel(
private suspend fun handleProvisioningMessage(message: NetworkController.ProvisioningMessage) {
if (message.platform == NetworkController.ProvisioningMessage.Platform.IOS && message.tier == null) {
// iOS without a backup tier cannot do a quick restore — navigate to the choose-restore screen
parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionBeforeRegistration)
parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forManualRestore())
return
}
@@ -112,7 +111,7 @@ class QuickRestoreQrViewModel(
val (response, keyMaterial) = registerResult.result
Log.i(TAG, "[Register] Success! reregistration: ${response.reregistration}")
parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool))
parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration)
parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forQuickRestore(hasRemoteBackup = message.tier != null))
}
is RequestResult.NonSuccess -> {
when (val error = registerResult.error) {
@@ -150,7 +149,7 @@ class QuickRestoreQrViewModel(
)
}
is NetworkController.RegisterAccountError.DeviceTransferPossible -> {
Log.w(TAG, "[Register] Device transfer possible. Resetting.")
Log.w(TAG, "[Register] Device transfer possible. We never set this flag, so we should never see it. Resetting.")
parentEventEmitter(RegistrationFlowEvent.ResetState)
}
is NetworkController.RegisterAccountError.InvalidRequest -> {
@@ -44,7 +44,7 @@ fun ArchiveRestoreSelectionScreen(
onEvent: (ArchiveRestoreSelectionScreenEvents) -> Unit,
modifier: Modifier = Modifier
) {
if (state.showSkipRestoreWarning) {
if (state.showSkipWarningDialog) {
Dialogs.SimpleAlertDialog(
title = stringResource(R.string.ArchiveRestoreSelectionScreen__skip_restore_dialog_title),
body = stringResource(R.string.ArchiveRestoreSelectionScreen__skip_restore_dialog_warning),
@@ -98,16 +98,18 @@ fun ArchiveRestoreSelectionScreen(
Spacer(modifier = Modifier.weight(1f))
TextButton(
onClick = { onEvent(ArchiveRestoreSelectionScreenEvents.Skip) },
modifier = Modifier
.padding(bottom = 32.dp)
.testTag(TestTags.ARCHIVE_RESTORE_SELECTION_SKIP)
) {
Text(
text = stringResource(R.string.ArchiveRestoreSelectionScreen__skip),
color = MaterialTheme.colorScheme.primary
)
if (state.showSkipButton) {
TextButton(
onClick = { onEvent(ArchiveRestoreSelectionScreenEvents.Skip) },
modifier = Modifier
.padding(bottom = 32.dp)
.testTag(TestTags.ARCHIVE_RESTORE_SELECTION_SKIP)
) {
Text(
text = stringResource(R.string.ArchiveRestoreSelectionScreen__skip),
color = MaterialTheme.colorScheme.primary
)
}
}
}
}
@@ -9,6 +9,7 @@ import org.signal.registration.util.DebugLoggableModel
data class ArchiveRestoreSelectionState(
val restoreOptions: List<ArchiveRestoreOption> = emptyList(),
val showSkipButton: Boolean = false,
val showSkipRestoreWarning: Boolean = false
) : DebugLoggableModel()
val showSkipWarningDialog: Boolean = false
) : DebugLoggableModel() {
val showSkipButton: Boolean get() = ArchiveRestoreOption.None !in restoreOptions
}
@@ -8,45 +8,35 @@ package org.signal.registration.screens.restoreselection
import androidx.annotation.VisibleForTesting
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.launch
import org.signal.core.util.logging.Log
import org.signal.registration.PendingRestoreOption
import org.signal.registration.RegistrationFlowEvent
import org.signal.registration.RegistrationFlowState
import org.signal.registration.RegistrationRepository
import org.signal.registration.RegistrationRoute
import org.signal.registration.screens.EventDrivenViewModel
import org.signal.registration.screens.util.navigateTo
/**
* A view model to be used with [ArchiveRestoreSelectionScreen] after a quick restore.
* To avoid spinners, we'll have the quick restore screen determine if a remote backup
* is available and tell us.
*/
class ArchiveRestoreSelectionViewModel(
private val repository: RegistrationRepository,
private val parentState: StateFlow<RegistrationFlowState>,
private val parentEventEmitter: (RegistrationFlowEvent) -> Unit,
private val isPreRegistration: Boolean
private val restoreOptions: List<ArchiveRestoreOption>,
private val parentEventEmitter: (RegistrationFlowEvent) -> Unit
) : EventDrivenViewModel<ArchiveRestoreSelectionScreenEvents>(TAG) {
companion object {
private val TAG = Log.tag(ArchiveRestoreSelectionViewModel::class)
}
private val _localState = MutableStateFlow(ArchiveRestoreSelectionState())
val state = combine(_localState, parentState) { state, parentState -> applyParentState(state, parentState) }
.onEach { Log.d(TAG, "[State] $it") }
.stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), ArchiveRestoreSelectionState())
private val _localState = MutableStateFlow(
ArchiveRestoreSelectionState(
restoreOptions = restoreOptions
)
)
init {
viewModelScope.launch {
val options = repository.getAvailableRestoreOptions()
_localState.value = _localState.value.copy(restoreOptions = options.toList())
}
}
val state: StateFlow<ArchiveRestoreSelectionState> = _localState
override suspend fun processEvent(event: ArchiveRestoreSelectionScreenEvents) {
applyEvent(state.value, event) { _localState.value = it }
@@ -58,21 +48,11 @@ class ArchiveRestoreSelectionViewModel(
is ArchiveRestoreSelectionScreenEvents.RestoreOptionSelected -> {
when (event.option) {
ArchiveRestoreOption.LocalBackup -> {
if (isPreRegistration) {
parentEventEmitter(RegistrationFlowEvent.PendingRestoreOptionSelected(PendingRestoreOption.LocalBackup))
parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry)
} else {
parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = false))
}
parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = false))
state
}
ArchiveRestoreOption.SignalSecureBackup -> {
if (isPreRegistration) {
parentEventEmitter(RegistrationFlowEvent.PendingRestoreOptionSelected(PendingRestoreOption.RemoteBackup))
parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry)
} else {
Log.w(TAG, "Signal secure backup restore not yet implemented")
}
Log.w(TAG, "Signal secure backup restore not yet implemented")
state
}
ArchiveRestoreOption.DeviceTransfer -> {
@@ -80,47 +60,30 @@ class ArchiveRestoreSelectionViewModel(
state
}
ArchiveRestoreOption.None -> {
if (isPreRegistration) {
parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry)
state
} else {
state.copy(showSkipRestoreWarning = true)
}
state.copy(showSkipWarningDialog = true)
}
}
}
is ArchiveRestoreSelectionScreenEvents.Skip -> {
if (isPreRegistration) {
parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry)
state
} else {
state.copy(showSkipRestoreWarning = true)
}
state.copy(showSkipWarningDialog = true)
}
is ArchiveRestoreSelectionScreenEvents.ConfirmSkip -> {
parentEventEmitter.navigateTo(RegistrationRoute.PinCreate)
state.copy(showSkipRestoreWarning = false)
state.copy(showSkipWarningDialog = false)
}
is ArchiveRestoreSelectionScreenEvents.DismissSkipWarning -> {
state.copy(showSkipRestoreWarning = false)
state.copy(showSkipWarningDialog = false)
}
}
stateEmitter(result)
}
@VisibleForTesting
fun applyParentState(state: ArchiveRestoreSelectionState, parentState: RegistrationFlowState): ArchiveRestoreSelectionState {
return state
}
class Factory(
private val repository: RegistrationRepository,
private val parentState: StateFlow<RegistrationFlowState>,
private val parentEventEmitter: (RegistrationFlowEvent) -> Unit,
private val isPreRegistration: Boolean
private val restoreOptions: List<ArchiveRestoreOption>,
private val parentEventEmitter: (RegistrationFlowEvent) -> Unit
) : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: Class<T>): T {
return ArchiveRestoreSelectionViewModel(repository, parentState, parentEventEmitter, isPreRegistration) as T
return ArchiveRestoreSelectionViewModel(restoreOptions, parentEventEmitter) as T
}
}
}
@@ -176,7 +176,7 @@ class VerificationCodeViewModel(
parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool))
when {
// response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration)
response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forPostRegister())
response.storageCapable -> parentEventEmitter.navigateTo(RegistrationRoute.PinEntryForSvrRestore)
else -> parentEventEmitter.navigateTo(RegistrationRoute.PinCreate)
}
@@ -87,7 +87,7 @@ class PersistedFlowStateTest {
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
RegistrationRoute.ArchiveRestoreSelection.forManualRestore()
),
sessionMetadata = null,
sessionE164 = "+15551234567",
@@ -147,7 +147,7 @@ class RegistrationViewModelRestoreTest {
backStack = listOf(
RegistrationRoute.Welcome,
RegistrationRoute.PinCreate,
RegistrationRoute.ArchiveRestoreSelection
RegistrationRoute.ArchiveRestoreSelection.forManualRestore()
),
sessionMetadata = null,
sessionE164 = "+15551234567",