diff --git a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationNetworkController.kt b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationNetworkController.kt index 0af3746999..5135241220 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationNetworkController.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationNetworkController.kt @@ -57,6 +57,7 @@ import org.thoughtcrime.securesms.registration.fcm.PushChallengeRequest import org.thoughtcrime.securesms.registration.viewmodel.SvrAuthCredentialSet import org.whispersystems.signalservice.api.NetworkResult import org.whispersystems.signalservice.api.SvrNoDataException +import org.whispersystems.signalservice.api.archive.ArchiveServiceAccess import org.whispersystems.signalservice.api.provisioning.ProvisioningSocket import org.whispersystems.signalservice.api.svr.SecureValueRecovery.BackupResponse import org.whispersystems.signalservice.internal.crypto.SecondaryProvisioningCipher @@ -65,6 +66,7 @@ import org.whispersystems.signalservice.internal.push.PushServiceSocket import java.io.IOException import java.util.Locale import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds import org.whispersystems.signalservice.api.account.AccountAttributes as ServiceAccountAttributes import org.whispersystems.signalservice.api.account.PreKeyCollection as ServicePreKeyCollection @@ -536,6 +538,57 @@ class AppRegistrationNetworkController( } } + override suspend fun getRemoteBackupInfo(): RequestResult = withContext(Dispatchers.IO) { + val aci = SignalStore.account.aci ?: return@withContext RequestResult.ApplicationError(IllegalStateException("ACI not available")) + + val currentTime = System.currentTimeMillis() + val messageCredential = SignalStore.backup.messageCredentials.byDay.getForCurrentTime(currentTime.milliseconds) + + val access = if (messageCredential != null) { + ArchiveServiceAccess(messageCredential, SignalStore.backup.messageBackupKey) + } else { + when (val credResult = SignalNetwork.archive.getServiceCredentials(currentTime)) { + is NetworkResult.Success -> { + SignalStore.backup.messageCredentials.add(credResult.result.messageCredentials) + SignalStore.backup.messageCredentials.clearOlderThan(currentTime) + val credential = SignalStore.backup.messageCredentials.byDay.getForCurrentTime(currentTime.milliseconds) + ?: return@withContext RequestResult.ApplicationError(IllegalStateException("Failed to obtain backup credentials after fetch")) + ArchiveServiceAccess(credential, SignalStore.backup.messageBackupKey) + } + is NetworkResult.StatusCodeError -> return@withContext RequestResult.ApplicationError(IllegalStateException("Failed to fetch backup credentials: ${credResult.code}")) + is NetworkResult.NetworkError -> return@withContext RequestResult.RetryableNetworkError(credResult.exception) + is NetworkResult.ApplicationError -> return@withContext RequestResult.ApplicationError(credResult.throwable) + } + } + + when (val result = SignalNetwork.archive.getBackupInfo(aci, access)) { + is NetworkResult.Success -> { + val info = result.result + RequestResult.Success( + NetworkController.GetBackupInfoResponse( + cdn = info.cdn, + backupDir = info.backupDir, + mediaDir = info.mediaDir, + backupName = info.backupName, + usedSpace = info.usedSpace + ) + ) + } + is NetworkResult.StatusCodeError -> { + when (result.code) { + 400 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadArguments(result.stringBody)) + 401 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadAuthCredential(result.stringBody)) + 403 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.Forbidden(result.stringBody)) + 404 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.NoBackup) + 429 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.RateLimited(0.seconds)) + else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${result.code}")) + } + } + is NetworkResult.NetworkError -> RequestResult.RetryableNetworkError(result.exception) + is NetworkResult.ApplicationError -> RequestResult.ApplicationError(result.throwable) + } + } + override fun startProvisioning(): Flow = callbackFlow { val socketHandles = mutableListOf() val configuration = AppDependencies.signalServiceNetworkAccess.getConfiguration() diff --git a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt index 1195070486..bd05a6f843 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registration/v2/AppRegistrationStorageController.kt @@ -22,7 +22,6 @@ import org.signal.registration.PreExistingRegistrationData import org.signal.registration.StorageController import org.signal.registration.proto.RegistrationData import org.signal.registration.screens.localbackuprestore.LocalBackupInfo -import org.signal.registration.screens.restoreselection.ArchiveRestoreOption import org.thoughtcrime.securesms.backup.FullBackupImporter import org.thoughtcrime.securesms.backup.v2.BackupRepository import org.thoughtcrime.securesms.backup.v2.local.LocalArchiver @@ -178,16 +177,6 @@ class AppRegistrationStorageController(private val context: Context) : StorageCo Unit } - override suspend fun getAvailableRestoreOptions(): Set = withContext(Dispatchers.IO) { - // TODO [greyson] Real options - val options = mutableSetOf() - - options.add(ArchiveRestoreOption.LocalBackup) - options.add(ArchiveRestoreOption.DeviceTransfer) - - options - } - override fun restoreLocalBackupV1(uri: Uri, passphrase: String): Flow = flow { // TODO [greyson] better progress Log.d(TAG, "Starting V1 local backup restore from: $uri") diff --git a/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt b/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt index c2bbb13e63..8e5e113f72 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/debug/DebugNetworkController.kt @@ -15,6 +15,8 @@ import org.signal.registration.NetworkController.BackupMasterKeyError import org.signal.registration.NetworkController.CheckSvrCredentialsError import org.signal.registration.NetworkController.CheckSvrCredentialsResponse import org.signal.registration.NetworkController.CreateSessionError +import org.signal.registration.NetworkController.GetBackupInfoError +import org.signal.registration.NetworkController.GetBackupInfoResponse import org.signal.registration.NetworkController.GetSessionStatusError import org.signal.registration.NetworkController.GetSvrCredentialsError import org.signal.registration.NetworkController.MasterKeyResponse @@ -217,4 +219,12 @@ class DebugNetworkController( return delegate.checkSvrCredentials(e164, credentials) } + + override suspend fun getRemoteBackupInfo(): RequestResult { + NetworkDebugState.getOverride>("getRemoteBackupInfo")?.let { + Log.d(TAG, "[getRemoteBackupInfo] Returning debug override") + return it + } + return delegate.getRemoteBackupInfo() + } } diff --git a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt index 81702ae67b..6468f92a00 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoNetworkController.kt @@ -13,11 +13,13 @@ import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.isActive import kotlinx.coroutines.launch import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import okhttp3.MediaType.Companion.toMediaType import okhttp3.RequestBody.Companion.toRequestBody import okhttp3.Response import org.signal.core.models.MasterKey +import org.signal.core.util.Base64 import org.signal.core.util.logging.Log import org.signal.libsignal.net.Network import org.signal.libsignal.net.RequestResult @@ -25,6 +27,9 @@ import org.signal.libsignal.protocol.IdentityKey import org.signal.libsignal.protocol.IdentityKeyPair import org.signal.libsignal.protocol.ecc.ECPrivateKey import org.signal.libsignal.protocol.util.Hex +import org.signal.libsignal.zkgroup.GenericServerPublicParams +import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext +import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialResponse import org.signal.registration.NetworkController import org.signal.registration.NetworkController.AccountAttributes import org.signal.registration.NetworkController.CheckSvrCredentialsRequest @@ -62,8 +67,11 @@ import org.whispersystems.signalservice.internal.push.PushServiceSocket import org.whispersystems.signalservice.internal.util.StaticCredentialsProvider import org.whispersystems.signalservice.internal.websocket.LibSignalChatConnection import java.io.IOException +import java.time.Instant import java.util.Locale import kotlin.time.Duration +import kotlin.time.Duration.Companion.days +import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds import org.whispersystems.signalservice.api.account.AccountAttributes as ServiceAccountAttributes import org.whispersystems.signalservice.api.account.PreKeyCollection as ServicePreKeyCollection @@ -843,6 +851,126 @@ class DemoNetworkController( } } + override suspend fun getRemoteBackupInfo(): RequestResult = withContext(Dispatchers.IO) { + val aci = RegistrationPreferences.aci + val password = RegistrationPreferences.servicePassword + val aep = RegistrationPreferences.aep + + if (aci == null || password == null || aep == null) { + Log.w(TAG, "[getRemoteBackupInfo] Credentials not available") + return@withContext RequestResult.ApplicationError(IllegalStateException("Credentials not available")) + } + + try { + val messageBackupKey = aep.deriveMessageBackupKey() + + // Remember, this is a demo app + val credential = fetchArchiveServiceCredential(aci.toString(), password) + ?: return@withContext RequestResult.RetryableNetworkError(IOException("Failed to fetch archive credentials")) + + val headers = buildZkAuthHeaders(messageBackupKey, aci, credential) + + val baseUrl = serviceConfiguration.signalServiceUrls[0].url + val request = okhttp3.Request.Builder() + .url("$baseUrl/v1/archives") + .get() + .apply { headers.forEach { (k, v) -> header(k, v) } } + .build() + + okHttpClient.newCall(request).execute().use { response -> + when (response.code) { + 200 -> { + val info = json.decodeFromString(response.body.string()) + RequestResult.Success(info) + } + 400 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadArguments(response.body.string())) + 401 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.BadAuthCredential(response.body.string())) + 403 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.Forbidden(response.body.string())) + 404 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.NoBackup) + 429 -> RequestResult.NonSuccess(NetworkController.GetBackupInfoError.RateLimited(response.retryAfter())) + else -> RequestResult.ApplicationError(IllegalStateException("Unexpected response code: ${response.code}, body: ${response.body?.string()}")) + } + } + } catch (e: IOException) { + Log.w(TAG, "[getRemoteBackupInfo] IOException", e) + RequestResult.RetryableNetworkError(e) + } catch (e: Exception) { + Log.w(TAG, "[getRemoteBackupInfo] Exception", e) + RequestResult.ApplicationError(e) + } + } + + /** + * Fetches an archive service credential for today by calling GET /v1/archives/auth on the authenticated channel. + */ + private fun fetchArchiveServiceCredential(aci: String, password: String): ArchiveCredential? { + val currentTime = System.currentTimeMillis() + val roundedToNearestDay = currentTime.milliseconds.inWholeDays.days + val endTime = roundedToNearestDay + 7.days + val startSeconds = roundedToNearestDay.inWholeSeconds + val endSeconds = endTime.inWholeSeconds + + val credentials = okhttp3.Credentials.basic(aci, password) + val baseUrl = serviceConfiguration.signalServiceUrls[0].url + val request = okhttp3.Request.Builder() + .url("$baseUrl/v1/archives/auth?redemptionStartSeconds=$startSeconds&redemptionEndSeconds=$endSeconds") + .get() + .header("Authorization", credentials) + .build() + + okHttpClient.newCall(request).execute().use { response -> + if (response.code != 200) { + Log.w(TAG, "[fetchArchiveServiceCredential] Unexpected response code: ${response.code}") + return null + } + + val body = response.body.string() + val parsed = json.decodeFromString(body) + val todaySeconds = roundedToNearestDay.inWholeSeconds + + return parsed.credentials["messages"]?.firstOrNull { it.redemptionTime == todaySeconds } + } + } + + /** + * Builds the ZK auth headers (X-Signal-ZK-Auth, X-Signal-ZK-Auth-Signature) needed for + * anonymous archive requests. + */ + private fun buildZkAuthHeaders( + messageBackupKey: org.signal.core.models.backup.MessageBackupKey, + aci: org.signal.core.models.ServiceId.ACI, + credential: ArchiveCredential + ): Map { + val backupServerPublicParams = GenericServerPublicParams(serviceConfiguration.backupServerPublicParams) + val backupRequestContext = BackupAuthCredentialRequestContext.create(messageBackupKey.value, aci.rawUuid) + val backupAuthCredentialResponse = BackupAuthCredentialResponse(Base64.decode(credential.credential)) + val backupAuthCredential = backupRequestContext.receiveResponse( + backupAuthCredentialResponse, + Instant.ofEpochSecond(credential.redemptionTime), + backupServerPublicParams + ) + + val presentation = backupAuthCredential.present(backupServerPublicParams).serialize() + val privateKey = messageBackupKey.deriveAnonymousCredentialPrivateKey(aci) + val signedPresentation = privateKey.calculateSignature(presentation) + + return mapOf( + "X-Signal-ZK-Auth" to Base64.encodeWithPadding(presentation), + "X-Signal-ZK-Auth-Signature" to Base64.encodeWithPadding(signedPresentation) + ) + } + + @Serializable + private data class ArchiveCredentialsResponse( + val credentials: Map> + ) + + @Serializable + private data class ArchiveCredential( + val credential: String, + val redemptionTime: Long + ) + private fun AccountAttributes.toServiceAccountAttributes(): ServiceAccountAttributes { return ServiceAccountAttributes( signalingKey, diff --git a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt index ce3f63fc53..ed9e9c2623 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/dependencies/DemoStorageController.kt @@ -32,7 +32,6 @@ import org.signal.registration.proto.RegistrationData import org.signal.registration.sample.storage.RegistrationDatabase import org.signal.registration.sample.storage.RegistrationPreferences import org.signal.registration.screens.localbackuprestore.LocalBackupInfo -import org.signal.registration.screens.restoreselection.ArchiveRestoreOption import java.io.File import java.time.LocalDateTime @@ -181,14 +180,6 @@ class DemoStorageController(private val context: Context) : StorageController { Unit } - override suspend fun getAvailableRestoreOptions(): Set { - return setOf( - ArchiveRestoreOption.SignalSecureBackup, - ArchiveRestoreOption.LocalBackup, - ArchiveRestoreOption.DeviceTransfer - ) - } - override suspend fun scanLocalBackupFolder(folderUri: Uri): List = withContext(Dispatchers.IO) { val folder = DocumentFile.fromTreeUri(context, folderUri) ?: return@withContext emptyList() val children = folder.listFiles() diff --git a/feature/registration/src/main/java/org/signal/registration/NetworkController.kt b/feature/registration/src/main/java/org/signal/registration/NetworkController.kt index 9fb888eb2e..1c83ad369a 100644 --- a/feature/registration/src/main/java/org/signal/registration/NetworkController.kt +++ b/feature/registration/src/main/java/org/signal/registration/NetworkController.kt @@ -182,6 +182,21 @@ interface NetworkController { */ suspend fun setAccountAttributes(attributes: AccountAttributes): RequestResult + /** + * Fetches metadata about your current backup. This will be different for different key/credential pairs. For example, message credentials will always + * return 0 for used space since that is stored under the media key/credential. + * + * GET /v1/archives + * - 200: Success + * - 400: Bad arguments. The request may have been made on an authenticated channel. + * - 401: The provided backup auth credential presentation could not be verified or the public key signature was invalid or there is no backup associated with + * the backup-id in the presentation or the credential was of the wrong type (messages/media) + * - 403: Forbidden + * - 404: No backup + * - 429: Rate limited + */ + suspend fun getRemoteBackupInfo(): RequestResult + /** * Starts a provisioning session for QR-based quick restore. * @@ -213,24 +228,24 @@ interface NetworkController { // */ // suspend fun registerAsSecondaryDevice(verificationCode: String, attributes: AccountAttributes, aciPreKeys: PreKeyCollection, pniPreKeys: PreKeyCollection, fcmToken: String?) - sealed class CreateSessionError() : BadRequestError { + sealed class CreateSessionError : BadRequestError { data class InvalidRequest(val message: String) : CreateSessionError() data class RateLimited(val retryAfter: Duration) : CreateSessionError() } - sealed class GetSessionStatusError() : BadRequestError { + sealed class GetSessionStatusError : BadRequestError { data class InvalidSessionId(val message: String) : GetSessionStatusError() data class SessionNotFound(val message: String) : GetSessionStatusError() data class InvalidRequest(val message: String) : GetSessionStatusError() } - sealed class UpdateSessionError() : BadRequestError { + sealed class UpdateSessionError : BadRequestError { data class RejectedUpdate(val message: String) : UpdateSessionError() data class InvalidRequest(val message: String) : UpdateSessionError() data class RateLimited(val retryAfter: Duration, val session: SessionMetadata) : UpdateSessionError() } - sealed class RequestVerificationCodeError() : BadRequestError { + sealed class RequestVerificationCodeError : BadRequestError { data class InvalidSessionId(val message: String) : RequestVerificationCodeError() data class SessionNotFound(val message: String) : RequestVerificationCodeError() data class MissingRequestInformationOrAlreadyVerified(val session: SessionMetadata) : RequestVerificationCodeError() @@ -240,14 +255,14 @@ interface NetworkController { data class ThirdPartyServiceError(val data: ThirdPartyServiceErrorResponse) : RequestVerificationCodeError() } - sealed class SubmitVerificationCodeError() : BadRequestError { + sealed class SubmitVerificationCodeError : BadRequestError { data class InvalidSessionIdOrVerificationCode(val message: String) : SubmitVerificationCodeError() data class SessionNotFound(val message: String) : SubmitVerificationCodeError() data class SessionAlreadyVerifiedOrNoCodeRequested(val session: SessionMetadata) : SubmitVerificationCodeError() data class RateLimited(val retryAfter: Duration, val session: SessionMetadata) : SubmitVerificationCodeError() } - sealed class RegisterAccountError() : BadRequestError { + sealed class RegisterAccountError : BadRequestError { data class SessionNotFoundOrNotVerified(val message: String) : RegisterAccountError() data class RegistrationRecoveryPasswordIncorrect(val message: String) : RegisterAccountError() data object DeviceTransferPossible : RegisterAccountError() @@ -256,38 +271,46 @@ interface NetworkController { data class RateLimited(val retryAfter: Duration) : RegisterAccountError() } - sealed class RestoreMasterKeyError() : BadRequestError { + sealed class RestoreMasterKeyError : BadRequestError { data class WrongPin(val triesRemaining: Int) : RestoreMasterKeyError() data object NoDataFound : RestoreMasterKeyError() } - sealed class BackupMasterKeyError() : BadRequestError { + sealed class BackupMasterKeyError : BadRequestError { data object EnclaveNotFound : BackupMasterKeyError() data object NotRegistered : BackupMasterKeyError() } - sealed class SetRegistrationLockError() : BadRequestError { + sealed class SetRegistrationLockError : BadRequestError { data class InvalidRequest(val message: String) : SetRegistrationLockError() data object Unauthorized : SetRegistrationLockError() data object NotRegistered : SetRegistrationLockError() data object NoPinSet : SetRegistrationLockError() } - sealed class SetAccountAttributesError() : BadRequestError { + sealed class SetAccountAttributesError : BadRequestError { data class InvalidRequest(val message: String) : SetAccountAttributesError() data object Unauthorized : SetAccountAttributesError() } - sealed class GetSvrCredentialsError() : BadRequestError { + sealed class GetSvrCredentialsError : BadRequestError { data object Unauthorized : GetSvrCredentialsError() data object NoServiceCredentialsAvailable : GetSvrCredentialsError() } - sealed class CheckSvrCredentialsError() : BadRequestError { + sealed class CheckSvrCredentialsError : BadRequestError { data object Unauthorized : CheckSvrCredentialsError() data class InvalidRequest(val message: String) : CheckSvrCredentialsError() } + sealed class GetBackupInfoError : BadRequestError { + data class BadArguments(val body: String? = null) : GetBackupInfoError() + data class BadAuthCredential(val body: String? = null) : GetBackupInfoError() + data class Forbidden(val body: String? = null) : GetBackupInfoError() + data object NoBackup : GetBackupInfoError() + data class RateLimited(val retryAfter: Duration) : GetBackupInfoError() + } + data class MasterKeyResponse( val masterKey: MasterKey ) @@ -432,6 +455,15 @@ interface NetworkController { SMS, VOICE } + @Serializable + data class GetBackupInfoResponse( + val cdn: Int?, + val backupDir: String?, + val mediaDir: String?, + val backupName: String?, + val usedSpace: Long? + ) + /** * Data received from the old device during QR-based provisioning. */ diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt index 6eec14c2ef..38d4df0f89 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationNavigation.kt @@ -60,6 +60,7 @@ import org.signal.registration.screens.pinentry.PinEntryForSvrRestoreViewModel import org.signal.registration.screens.pinentry.PinEntryScreen import org.signal.registration.screens.quickrestore.QuickRestoreQrScreen import org.signal.registration.screens.quickrestore.QuickRestoreQrViewModel +import org.signal.registration.screens.restoreselection.ArchiveRestoreOption import org.signal.registration.screens.restoreselection.ArchiveRestoreSelectionScreen import org.signal.registration.screens.restoreselection.ArchiveRestoreSelectionViewModel import org.signal.registration.screens.util.navigateBack @@ -113,13 +114,42 @@ sealed interface RegistrationRoute : NavKey, Parcelable { data object PinCreate : RegistrationRoute @Serializable - data object ArchiveRestoreSelection : RegistrationRoute + data class ArchiveRestoreSelection(val restoreOptions: List) : RegistrationRoute { + companion object { + fun forQuickRestore(hasRemoteBackup: Boolean): ArchiveRestoreSelection { + return ArchiveRestoreSelection( + buildList { + if (hasRemoteBackup) { + add(ArchiveRestoreOption.SignalSecureBackup) + } + add(ArchiveRestoreOption.LocalBackup) + add(ArchiveRestoreOption.DeviceTransfer) + } + ) + } - @Serializable - data object ChooseRestoreOptionBeforeRegistration : RegistrationRoute + fun forManualRestore(): ArchiveRestoreSelection { + return ArchiveRestoreSelection( + buildList { + add(ArchiveRestoreOption.SignalSecureBackup) + add(ArchiveRestoreOption.LocalBackup) + add(ArchiveRestoreOption.DeviceTransfer) + } + ) + } - @Serializable - data object ChooseRestoreOptionAfterRegistration : RegistrationRoute + fun forPostRegister(): ArchiveRestoreSelection { + return ArchiveRestoreSelection( + buildList { + add(ArchiveRestoreOption.SignalSecureBackup) + add(ArchiveRestoreOption.LocalBackup) + add(ArchiveRestoreOption.DeviceTransfer) + add(ArchiveRestoreOption.None) + } + ) + } + } + } @Serializable data class LocalBackupRestore(val isPreRegistration: Boolean) : RegistrationRoute @@ -248,9 +278,9 @@ private fun EntryProviderScope.navigationEntries( onEvent = { event -> when (event) { WelcomeScreenEvents.Continue -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.PhoneNumberEntry)) - WelcomeScreenEvents.LinkDevice -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.QuickRestoreQrScan)) // TODO - Replace this with the device-link QR code + WelcomeScreenEvents.LinkDevice -> throw NotImplementedError("Haven't implemented linked devices yet") WelcomeScreenEvents.HasOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.QuickRestoreQrScan)) - WelcomeScreenEvents.DoesNotHaveOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.ChooseRestoreOptionBeforeRegistration)) + WelcomeScreenEvents.DoesNotHaveOldPhone -> parentEventEmitter.navigateTo(RegistrationRoute.Permissions(nextRoute = RegistrationRoute.ArchiveRestoreSelection.forManualRestore())) } } ) @@ -447,14 +477,12 @@ private fun EntryProviderScope.navigationEntries( ) } - // -- Archive Restore Selection Screen - entry { + // -- Archive Restore Selection for Quick Restore Screen + entry { key -> val viewModel: ArchiveRestoreSelectionViewModel = viewModel( factory = ArchiveRestoreSelectionViewModel.Factory( - repository = registrationRepository, - parentState = registrationViewModel.state, - parentEventEmitter = registrationViewModel::onEvent, - isPreRegistration = false + restoreOptions = key.restoreOptions, + parentEventEmitter = registrationViewModel::onEvent ) ) val state by viewModel.state.collectAsStateWithLifecycle() @@ -465,28 +493,6 @@ private fun EntryProviderScope.navigationEntries( ) } - // -- Choose Restore Option Before Registration (saves selection, then navigates to phone number entry) - entry { - val viewModel: ArchiveRestoreSelectionViewModel = viewModel( - factory = ArchiveRestoreSelectionViewModel.Factory( - repository = registrationRepository, - parentState = registrationViewModel.state, - parentEventEmitter = registrationViewModel::onEvent, - isPreRegistration = true - ) - ) - val state by viewModel.state.collectAsStateWithLifecycle() - - ArchiveRestoreSelectionScreen( - state = state, - onEvent = { viewModel.onEvent(it) } - ) - } - - entry { - TODO("Implement RestoreScreen") - } - // -- Local Backup Restore Screen entry { key -> val viewModel: LocalBackupRestoreViewModel = viewModel( @@ -526,6 +532,8 @@ private fun EntryProviderScope.navigationEntries( ) } + // TODO I think we can re-use the screen but attach different viewmodels to progress forward rather than do for-result flows? + // -- Enter AEP entry { val viewModel: EnterAepViewModel = viewModel( diff --git a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt index 6890ac8116..325630941e 100644 --- a/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt +++ b/feature/registration/src/main/java/org/signal/registration/RegistrationRepository.kt @@ -43,7 +43,6 @@ import org.signal.registration.NetworkController.UpdateSessionError import org.signal.registration.proto.ProvisioningData import org.signal.registration.proto.SvrCredential import org.signal.registration.screens.localbackuprestore.LocalBackupInfo -import org.signal.registration.screens.restoreselection.ArchiveRestoreOption import org.signal.registration.util.SensitiveLog import java.security.SecureRandom import java.util.Locale @@ -510,10 +509,6 @@ class RegistrationRepository(val context: Context, val networkController: Networ data.aci.isNotEmpty() && data.pni.isNotEmpty() } - suspend fun getAvailableRestoreOptions(): Set = withContext(Dispatchers.IO) { - storageController.getAvailableRestoreOptions() - } - fun restoreV1Backup(uri: Uri, passphrase: String): Flow { return storageController.restoreLocalBackupV1(uri, passphrase) } diff --git a/feature/registration/src/main/java/org/signal/registration/StorageController.kt b/feature/registration/src/main/java/org/signal/registration/StorageController.kt index 450b3d40ca..e17e4ec188 100644 --- a/feature/registration/src/main/java/org/signal/registration/StorageController.kt +++ b/feature/registration/src/main/java/org/signal/registration/StorageController.kt @@ -19,7 +19,6 @@ import org.signal.libsignal.protocol.state.KyberPreKeyRecord import org.signal.libsignal.protocol.state.SignedPreKeyRecord import org.signal.registration.proto.RegistrationData import org.signal.registration.screens.localbackuprestore.LocalBackupInfo -import org.signal.registration.screens.restoreselection.ArchiveRestoreOption import org.signal.registration.util.ACIParceler import org.signal.registration.util.AccountEntropyPoolParceler import org.signal.registration.util.IdentityKeyPairParceler @@ -83,12 +82,6 @@ interface StorageController { */ suspend fun commitRegistrationData() - /** - * Returns the set of restore options that are currently available to the user. - * For example, if a local backup file is present on the device, [ArchiveRestoreOption.LocalBackup] should be included. - */ - suspend fun getAvailableRestoreOptions(): Set - /** * Begins restoring from a V1 (.backup) file identified by the given [uri]. * diff --git a/feature/registration/src/main/java/org/signal/registration/screens/EventDrivenViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/EventDrivenViewModel.kt index 670238a08a..7dea861736 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/EventDrivenViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/EventDrivenViewModel.kt @@ -33,6 +33,7 @@ abstract class EventDrivenViewModel( } fun onEvent(event: E) { + // Unlimited buffer means this will always succeed eventChannel.trySend(event) } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt index 41aa90140f..884381234f 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt @@ -205,12 +205,14 @@ class PhoneNumberEntryViewModel( // If the user selected a restore option before entering their phone number, navigate to the restore flow if (state.pendingRestoreOption != null) { parentEventEmitter(RegistrationFlowEvent.E164Chosen(e164)) + + Log.i(TAG, "Pending restore option: ${state.pendingRestoreOption}. Navigating to appropriate screen.") + when (state.pendingRestoreOption) { PendingRestoreOption.LocalBackup -> parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = true)) - PendingRestoreOption.RemoteBackup -> { - Log.w(TAG, "[PendingRestore] Remote backup restore not yet implemented") - } + PendingRestoreOption.RemoteBackup -> parentEventEmitter.navigateTo(RegistrationRoute.EnterAepScreen) } + return state } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt index dc27af9c7c..0f53a585e0 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt @@ -137,7 +137,7 @@ class PinEntryForRegistrationLockViewModel( parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool)) // TODO storage service restore + profile screen when { - response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration) + response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forPostRegister()) else -> parentEventEmitter.navigateTo(RegistrationRoute.FullyComplete) } state diff --git a/feature/registration/src/main/java/org/signal/registration/screens/quickrestore/QuickRestoreQrViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/quickrestore/QuickRestoreQrViewModel.kt index 0799f58212..159fc54704 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/quickrestore/QuickRestoreQrViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/quickrestore/QuickRestoreQrViewModel.kt @@ -59,8 +59,7 @@ class QuickRestoreQrViewModel( state } is QuickRestoreQrEvents.UseProxy -> { - // TODO [registration] - Navigate to proxy settings - state + throw NotImplementedError("Proxy settings not implemented!") } is QuickRestoreQrEvents.DismissError -> { startProvisioning() @@ -99,7 +98,7 @@ class QuickRestoreQrViewModel( private suspend fun handleProvisioningMessage(message: NetworkController.ProvisioningMessage) { if (message.platform == NetworkController.ProvisioningMessage.Platform.IOS && message.tier == null) { // iOS without a backup tier cannot do a quick restore — navigate to the choose-restore screen - parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionBeforeRegistration) + parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forManualRestore()) return } @@ -112,7 +111,7 @@ class QuickRestoreQrViewModel( val (response, keyMaterial) = registerResult.result Log.i(TAG, "[Register] Success! reregistration: ${response.reregistration}") parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool)) - parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration) + parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forQuickRestore(hasRemoteBackup = message.tier != null)) } is RequestResult.NonSuccess -> { when (val error = registerResult.error) { @@ -150,7 +149,7 @@ class QuickRestoreQrViewModel( ) } is NetworkController.RegisterAccountError.DeviceTransferPossible -> { - Log.w(TAG, "[Register] Device transfer possible. Resetting.") + Log.w(TAG, "[Register] Device transfer possible. We never set this flag, so we should never see it. Resetting.") parentEventEmitter(RegistrationFlowEvent.ResetState) } is NetworkController.RegisterAccountError.InvalidRequest -> { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionScreen.kt b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionScreen.kt index 4ae0459859..aeaf7872bc 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionScreen.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionScreen.kt @@ -44,7 +44,7 @@ fun ArchiveRestoreSelectionScreen( onEvent: (ArchiveRestoreSelectionScreenEvents) -> Unit, modifier: Modifier = Modifier ) { - if (state.showSkipRestoreWarning) { + if (state.showSkipWarningDialog) { Dialogs.SimpleAlertDialog( title = stringResource(R.string.ArchiveRestoreSelectionScreen__skip_restore_dialog_title), body = stringResource(R.string.ArchiveRestoreSelectionScreen__skip_restore_dialog_warning), @@ -98,16 +98,18 @@ fun ArchiveRestoreSelectionScreen( Spacer(modifier = Modifier.weight(1f)) - TextButton( - onClick = { onEvent(ArchiveRestoreSelectionScreenEvents.Skip) }, - modifier = Modifier - .padding(bottom = 32.dp) - .testTag(TestTags.ARCHIVE_RESTORE_SELECTION_SKIP) - ) { - Text( - text = stringResource(R.string.ArchiveRestoreSelectionScreen__skip), - color = MaterialTheme.colorScheme.primary - ) + if (state.showSkipButton) { + TextButton( + onClick = { onEvent(ArchiveRestoreSelectionScreenEvents.Skip) }, + modifier = Modifier + .padding(bottom = 32.dp) + .testTag(TestTags.ARCHIVE_RESTORE_SELECTION_SKIP) + ) { + Text( + text = stringResource(R.string.ArchiveRestoreSelectionScreen__skip), + color = MaterialTheme.colorScheme.primary + ) + } } } } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionState.kt b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionState.kt index 61249dbf88..e5174d6f64 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionState.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionState.kt @@ -9,6 +9,7 @@ import org.signal.registration.util.DebugLoggableModel data class ArchiveRestoreSelectionState( val restoreOptions: List = emptyList(), - val showSkipButton: Boolean = false, - val showSkipRestoreWarning: Boolean = false -) : DebugLoggableModel() + val showSkipWarningDialog: Boolean = false +) : DebugLoggableModel() { + val showSkipButton: Boolean get() = ArchiveRestoreOption.None !in restoreOptions +} diff --git a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt index 1e09c6b4ec..846ae9998a 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/restoreselection/ArchiveRestoreSelectionViewModel.kt @@ -8,45 +8,35 @@ package org.signal.registration.screens.restoreselection import androidx.annotation.VisibleForTesting import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModelProvider -import androidx.lifecycle.viewModelScope import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow -import kotlinx.coroutines.flow.combine -import kotlinx.coroutines.flow.onEach -import kotlinx.coroutines.flow.stateIn -import kotlinx.coroutines.launch import org.signal.core.util.logging.Log -import org.signal.registration.PendingRestoreOption import org.signal.registration.RegistrationFlowEvent -import org.signal.registration.RegistrationFlowState -import org.signal.registration.RegistrationRepository import org.signal.registration.RegistrationRoute import org.signal.registration.screens.EventDrivenViewModel import org.signal.registration.screens.util.navigateTo +/** + * A view model to be used with [ArchiveRestoreSelectionScreen] after a quick restore. + * To avoid spinners, we'll have the quick restore screen determine if a remote backup + * is available and tell us. + */ class ArchiveRestoreSelectionViewModel( - private val repository: RegistrationRepository, - private val parentState: StateFlow, - private val parentEventEmitter: (RegistrationFlowEvent) -> Unit, - private val isPreRegistration: Boolean + private val restoreOptions: List, + private val parentEventEmitter: (RegistrationFlowEvent) -> Unit ) : EventDrivenViewModel(TAG) { companion object { private val TAG = Log.tag(ArchiveRestoreSelectionViewModel::class) } - private val _localState = MutableStateFlow(ArchiveRestoreSelectionState()) - val state = combine(_localState, parentState) { state, parentState -> applyParentState(state, parentState) } - .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), ArchiveRestoreSelectionState()) + private val _localState = MutableStateFlow( + ArchiveRestoreSelectionState( + restoreOptions = restoreOptions + ) + ) - init { - viewModelScope.launch { - val options = repository.getAvailableRestoreOptions() - _localState.value = _localState.value.copy(restoreOptions = options.toList()) - } - } + val state: StateFlow = _localState override suspend fun processEvent(event: ArchiveRestoreSelectionScreenEvents) { applyEvent(state.value, event) { _localState.value = it } @@ -58,21 +48,11 @@ class ArchiveRestoreSelectionViewModel( is ArchiveRestoreSelectionScreenEvents.RestoreOptionSelected -> { when (event.option) { ArchiveRestoreOption.LocalBackup -> { - if (isPreRegistration) { - parentEventEmitter(RegistrationFlowEvent.PendingRestoreOptionSelected(PendingRestoreOption.LocalBackup)) - parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry) - } else { - parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = false)) - } + parentEventEmitter.navigateTo(RegistrationRoute.LocalBackupRestore(isPreRegistration = false)) state } ArchiveRestoreOption.SignalSecureBackup -> { - if (isPreRegistration) { - parentEventEmitter(RegistrationFlowEvent.PendingRestoreOptionSelected(PendingRestoreOption.RemoteBackup)) - parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry) - } else { - Log.w(TAG, "Signal secure backup restore not yet implemented") - } + Log.w(TAG, "Signal secure backup restore not yet implemented") state } ArchiveRestoreOption.DeviceTransfer -> { @@ -80,47 +60,30 @@ class ArchiveRestoreSelectionViewModel( state } ArchiveRestoreOption.None -> { - if (isPreRegistration) { - parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry) - state - } else { - state.copy(showSkipRestoreWarning = true) - } + state.copy(showSkipWarningDialog = true) } } } is ArchiveRestoreSelectionScreenEvents.Skip -> { - if (isPreRegistration) { - parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry) - state - } else { - state.copy(showSkipRestoreWarning = true) - } + state.copy(showSkipWarningDialog = true) } is ArchiveRestoreSelectionScreenEvents.ConfirmSkip -> { parentEventEmitter.navigateTo(RegistrationRoute.PinCreate) - state.copy(showSkipRestoreWarning = false) + state.copy(showSkipWarningDialog = false) } is ArchiveRestoreSelectionScreenEvents.DismissSkipWarning -> { - state.copy(showSkipRestoreWarning = false) + state.copy(showSkipWarningDialog = false) } } stateEmitter(result) } - @VisibleForTesting - fun applyParentState(state: ArchiveRestoreSelectionState, parentState: RegistrationFlowState): ArchiveRestoreSelectionState { - return state - } - class Factory( - private val repository: RegistrationRepository, - private val parentState: StateFlow, - private val parentEventEmitter: (RegistrationFlowEvent) -> Unit, - private val isPreRegistration: Boolean + private val restoreOptions: List, + private val parentEventEmitter: (RegistrationFlowEvent) -> Unit ) : ViewModelProvider.Factory { override fun create(modelClass: Class): T { - return ArchiveRestoreSelectionViewModel(repository, parentState, parentEventEmitter, isPreRegistration) as T + return ArchiveRestoreSelectionViewModel(restoreOptions, parentEventEmitter) as T } } } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt index 3561d2fa27..d052066c03 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt @@ -176,7 +176,7 @@ class VerificationCodeViewModel( parentEventEmitter(RegistrationFlowEvent.Registered(keyMaterial.accountEntropyPool)) when { -// response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ChooseRestoreOptionAfterRegistration) + response.reregistration -> parentEventEmitter.navigateTo(RegistrationRoute.ArchiveRestoreSelection.forPostRegister()) response.storageCapable -> parentEventEmitter.navigateTo(RegistrationRoute.PinEntryForSvrRestore) else -> parentEventEmitter.navigateTo(RegistrationRoute.PinCreate) } diff --git a/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt b/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt index c9ab0270f7..b6ffee6706 100644 --- a/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/PersistedFlowStateTest.kt @@ -87,7 +87,7 @@ class PersistedFlowStateTest { backStack = listOf( RegistrationRoute.Welcome, RegistrationRoute.PinCreate, - RegistrationRoute.ArchiveRestoreSelection + RegistrationRoute.ArchiveRestoreSelection.forManualRestore() ), sessionMetadata = null, sessionE164 = "+15551234567", diff --git a/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt b/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt index 152b2a9826..4d1de1b345 100644 --- a/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/RegistrationViewModelRestoreTest.kt @@ -147,7 +147,7 @@ class RegistrationViewModelRestoreTest { backStack = listOf( RegistrationRoute.Welcome, RegistrationRoute.PinCreate, - RegistrationRoute.ArchiveRestoreSelection + RegistrationRoute.ArchiveRestoreSelection.forManualRestore() ), sessionMetadata = null, sessionE164 = "+15551234567",