Reset backup auth credentials on verification failure.

This commit is contained in:
Cody Henthorne
2025-06-05 11:10:56 -04:00
committed by Greyson Parrelli
parent 297bca4c0f
commit df2e88eaac
4 changed files with 149 additions and 17 deletions

View File

@@ -31,6 +31,7 @@ import org.signal.core.util.requireNonNullBlob
import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction
import org.signal.libsignal.zkgroup.VerificationFailedException
import org.signal.libsignal.zkgroup.backups.BackupLevel
import org.signal.libsignal.zkgroup.profiles.ProfileKey
import org.thoughtcrime.securesms.attachments.Attachment
@@ -85,6 +86,7 @@ import org.thoughtcrime.securesms.recipients.RecipientId
import org.thoughtcrime.securesms.util.RemoteConfig
import org.thoughtcrime.securesms.util.toMillis
import org.whispersystems.signalservice.api.AccountEntropyPool
import org.whispersystems.signalservice.api.ApplicationErrorAction
import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.StatusCodeErrorAction
import org.whispersystems.signalservice.api.archive.ArchiveGetMediaItemsResponse
@@ -150,6 +152,14 @@ object BackupRepository {
}
}
private val clearAuthCredentials: ApplicationErrorAction = { error ->
if (error.getCause() is VerificationFailedException) {
Log.w(TAG, "Unable to verify/receive credentials, clearing cache to fetch new.", error.getCause())
SignalStore.backup.messageCredentials.clearAll()
SignalStore.backup.mediaCredentials.clearAll()
}
}
/**
* Triggers backup id reservation. As documented, this is safe to perform multiple times.
*/
@@ -1504,7 +1514,9 @@ object BackupRepository {
return if (!RemoteConfig.messageBackups) {
NetworkResult.StatusCodeError(555, null, null, emptyMap(), NonSuccessfulResponseCodeException(555, "Backups disabled!"))
} else if (SignalStore.backup.backupsInitialized) {
getArchiveServiceAccessPair().runOnStatusCodeError(resetInitializedStateErrorAction)
getArchiveServiceAccessPair()
.runOnStatusCodeError(resetInitializedStateErrorAction)
.runOnApplicationError(clearAuthCredentials)
} else if (isPreRestoreDuringRegistration()) {
Log.w(TAG, "Requesting/using auth credentials in pre-restore state", Throwable())
getArchiveServiceAccessPair()
@@ -1519,6 +1531,7 @@ object BackupRepository {
.then { credential -> SignalNetwork.archive.setPublicKey(SignalStore.account.requireAci(), credential.mediaBackupAccess).map { credential } }
.runIfSuccessful { SignalStore.backup.backupsInitialized = true }
.runOnStatusCodeError(resetInitializedStateErrorAction)
.runOnApplicationError(clearAuthCredentials)
}
}

View File

@@ -1,6 +1,7 @@
package org.thoughtcrime.securesms.jobs
import org.signal.core.util.logging.Log
import org.signal.libsignal.zkgroup.VerificationFailedException
import org.thoughtcrime.securesms.attachments.AttachmentId
import org.thoughtcrime.securesms.attachments.Cdn
import org.thoughtcrime.securesms.attachments.DatabaseAttachment
@@ -134,8 +135,13 @@ class CopyAttachmentToArchiveJob private constructor(private val attachmentId: A
}
is NetworkResult.ApplicationError -> {
Log.w(TAG, "[$attachmentId] Encountered a fatal error when trying to upload!")
Result.fatalFailure(RuntimeException(archiveResult.throwable))
if (archiveResult.throwable is VerificationFailedException) {
Log.w(TAG, "[$attachmentId] Encountered a verification failure when trying to upload! Retrying.")
Result.retry(defaultBackoff())
} else {
Log.w(TAG, "[$attachmentId] Encountered a fatal error when trying to upload!")
Result.fatalFailure(RuntimeException(archiveResult.throwable))
}
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.signalservice.api
import io.reactivex.rxjava3.core.Single
import org.signal.core.util.concurrent.safeBlockingGet
import org.whispersystems.signalservice.api.NetworkResult.ApplicationError
import org.whispersystems.signalservice.api.NetworkResult.StatusCodeError
import org.whispersystems.signalservice.api.push.exceptions.MalformedRequestException
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
@@ -24,6 +25,7 @@ import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
typealias StatusCodeErrorAction = (StatusCodeError<*>) -> Unit
typealias ApplicationErrorAction = (ApplicationError<*>) -> Unit
/**
* A helper class that wraps the result of a network request, turning common exceptions
@@ -40,7 +42,8 @@ typealias StatusCodeErrorAction = (StatusCodeError<*>) -> Unit
* the success case and the status code of the error, this can be quite convenient.
*/
sealed class NetworkResult<T>(
private val statusCodeErrorActions: MutableSet<StatusCodeErrorAction> = mutableSetOf()
private val statusCodeErrorActions: MutableSet<StatusCodeErrorAction> = mutableSetOf(),
private val applicationErrorActions: MutableSet<ApplicationErrorAction> = mutableSetOf()
) {
companion object {
/**
@@ -210,7 +213,7 @@ sealed class NetworkResult<T>(
} else {
null
}
} catch (e: MalformedRequestException) {
} catch (_: MalformedRequestException) {
null
}
}
@@ -268,19 +271,21 @@ sealed class NetworkResult<T>(
* ```
*/
fun <R> map(transform: (T) -> R): NetworkResult<R> {
return when (this) {
val map = when (this) {
is Success -> {
try {
Success(transform(this.result)).runOnStatusCodeError(statusCodeErrorActions)
Success(transform(this.result))
} catch (e: Throwable) {
ApplicationError<R>(e).runOnStatusCodeError(statusCodeErrorActions)
ApplicationError<R>(e)
}
}
is NetworkError -> NetworkError<R>(exception).runOnStatusCodeError(statusCodeErrorActions)
is ApplicationError -> ApplicationError<R>(throwable).runOnStatusCodeError(statusCodeErrorActions)
is StatusCodeError -> StatusCodeError<R>(code, stringBody, binaryBody, headers, exception).runOnStatusCodeError(statusCodeErrorActions)
is NetworkError -> NetworkError<R>(exception)
is ApplicationError -> ApplicationError<R>(throwable)
is StatusCodeError -> StatusCodeError<R>(code, stringBody, binaryBody, headers, exception)
}
return map.runOnStatusCodeError(statusCodeErrorActions).runOnApplicationError(applicationErrorActions)
}
/**
@@ -325,12 +330,14 @@ sealed class NetworkResult<T>(
* ```
*/
fun <R> then(result: (T) -> NetworkResult<R>): NetworkResult<R> {
return when (this) {
is Success -> result(this.result).runOnStatusCodeError(statusCodeErrorActions)
is NetworkError -> NetworkError<R>(exception).runOnStatusCodeError(statusCodeErrorActions)
is ApplicationError -> ApplicationError<R>(throwable).runOnStatusCodeError(statusCodeErrorActions)
is StatusCodeError -> StatusCodeError<R>(code, stringBody, binaryBody, headers, exception).runOnStatusCodeError(statusCodeErrorActions)
val then = when (this) {
is Success -> result(this.result)
is NetworkError -> NetworkError<R>(exception)
is ApplicationError -> ApplicationError<R>(throwable)
is StatusCodeError -> StatusCodeError<R>(code, stringBody, binaryBody, headers, exception)
}
return then.runOnStatusCodeError(statusCodeErrorActions).runOnApplicationError(applicationErrorActions)
}
/**
@@ -370,7 +377,7 @@ sealed class NetworkResult<T>(
return runOnStatusCodeError(setOf(action))
}
internal fun runOnStatusCodeError(actions: Collection<StatusCodeErrorAction>): NetworkResult<T> {
private fun runOnStatusCodeError(actions: Collection<StatusCodeErrorAction>): NetworkResult<T> {
if (actions.isEmpty()) {
return this
}
@@ -385,6 +392,41 @@ sealed class NetworkResult<T>(
return this
}
/**
* Specify an action to be run when a application error occurs. When a result is a [ApplicationErrorAction] or is transformed into one further down the chain via
* a future [map] or [then], this code will be run. There can only ever be a single application error in a chain, and therefore this lambda will only ever
* be run a single time.
*
* This is a low-visibility way of doing things, so use sparingly.
*
* ```kotlin
* val result = NetworkResult
* .fromFetch { getAuth() }
* .runOnApplicationError { error -> logError(error) }
* .then { credential ->
* NetworkResult.fromFetch { fetchUserDetails(credential) }
* }
* ```
*/
fun runOnApplicationError(action: ApplicationErrorAction): NetworkResult<T> {
return runOnApplicationError(setOf(action))
}
private fun runOnApplicationError(actions: Collection<ApplicationErrorAction>): NetworkResult<T> {
if (actions.isEmpty()) {
return this
}
applicationErrorActions += actions
if (this is ApplicationError) {
applicationErrorActions.forEach { it.invoke(this) }
applicationErrorActions.clear()
}
return this
}
fun interface Fetcher<T> {
@Throws(Exception::class)
fun fetch(): T

View File

@@ -218,4 +218,75 @@ class NetworkResultTest {
assertFalse(handled)
}
@Test
fun `runOnApplicationError - simple call`() {
var handled = false
NetworkResult
.fromFetch { throw RuntimeException() }
.runOnApplicationError { handled = true }
assertTrue(handled)
}
@Test
fun `runOnApplicationError - ensure only called once`() {
var handleCount = 0
NetworkResult
.fromFetch { throw RuntimeException() }
.runOnApplicationError { handleCount++ }
.map { 1 }
.then { NetworkResult.Success(2) }
.map { 3 }
assertEquals(1, handleCount)
}
@Test
fun `runOnApplicationError - called when placed before a failing then`() {
var handled = false
val result = NetworkResult
.fromFetch { }
.runOnApplicationError { handled = true }
.then { NetworkResult.fromFetch { throw RuntimeException() } }
assertTrue(handled)
assertTrue(result is NetworkResult.ApplicationError)
}
@Test
fun `runOnApplicationError - called when placed two spots before a failing then`() {
var handled = false
val result = NetworkResult
.fromFetch { }
.runOnApplicationError { handled = true }
.then { NetworkResult.Success(Unit) }
.then { NetworkResult.fromFetch { throw RuntimeException() } }
assertTrue(handled)
assertTrue(result is NetworkResult.ApplicationError)
}
@Test
fun `runOnApplicationError - should not be called for successful results`() {
var handled = false
NetworkResult
.fromFetch {}
.runOnApplicationError { handled = true }
NetworkResult
.fromFetch { throw NonSuccessfulResponseCodeException(404, "not found", "body") }
.runOnApplicationError { handled = true }
NetworkResult
.fromFetch { throw PushNetworkException("network error") }
.runOnApplicationError { handled = true }
assertFalse(handled)
}
}