diff --git a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt index 125f997e84..3428af33e9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt @@ -273,7 +273,7 @@ object LinkDeviceRepository { stopwatch.split("validate-backup") Log.d(TAG, "[createAndUploadArchive] Fetching an upload form...") - val uploadForm = when (val result = SignalNetwork.attachments.getAttachmentV4UploadForm()) { + val uploadForm = when (val result = NetworkResult.withRetry { SignalNetwork.attachments.getAttachmentV4UploadForm() }) { is NetworkResult.Success -> result.result.logD(TAG, "[createAndUploadArchive] Successfully retrieved upload form.") is NetworkResult.ApplicationError -> throw result.throwable is NetworkResult.NetworkError -> return LinkUploadArchiveResult.NetworkError(result.exception).logW(TAG, "[createAndUploadArchive] Network error when fetching form.", result.exception) @@ -289,12 +289,14 @@ object LinkDeviceRepository { stopwatch.split("upload-backup") Log.d(TAG, "[createAndUploadArchive] Setting the transfer archive...") - val transferSetResult = SignalNetwork.linkDevice.setTransferArchive( - destinationDeviceId = deviceId, - destinationDeviceCreated = deviceCreatedAt, - cdn = uploadForm.cdn, - cdnKey = uploadForm.key - ) + val transferSetResult = NetworkResult.withRetry { + SignalNetwork.linkDevice.setTransferArchive( + destinationDeviceId = deviceId, + destinationDeviceCreated = deviceCreatedAt, + cdn = uploadForm.cdn, + cdnKey = uploadForm.key + ) + } when (transferSetResult) { is NetworkResult.Success -> Log.i(TAG, "[createAndUploadArchive] Successfully set transfer archive.") @@ -317,39 +319,32 @@ object LinkDeviceRepository { * Handles uploading the archive for [createAndUploadArchive]. Handles resumable uploads and making multiple upload attempts. */ private fun uploadArchive(backupFile: File, uploadForm: AttachmentUploadForm): NetworkResult { - val resumableUploadUrl = when (val result = SignalNetwork.attachments.getResumableUploadUrl(uploadForm)) { + val resumableUploadUrl = when (val result = NetworkResult.withRetry { SignalNetwork.attachments.getResumableUploadUrl(uploadForm) }) { is NetworkResult.Success -> result.result is NetworkResult.NetworkError -> return result.map { Unit }.logW(TAG, "Network error when fetching upload URL.", result.exception) is NetworkResult.StatusCodeError -> return result.map { Unit }.logW(TAG, "Status code error when fetching upload URL.", result.exception) is NetworkResult.ApplicationError -> throw result.throwable } - val maxRetries = 5 - var attemptCount = 0 - - while (attemptCount < maxRetries) { - Log.i(TAG, "Starting upload attempt ${attemptCount + 1}/$maxRetries") - val uploadResult = FileInputStream(backupFile).use { + val uploadResult = NetworkResult.withRetry( + logAttempt = { attempt, maxAttempts -> Log.i(TAG, "Starting upload attempt ${attempt + 1}/$maxAttempts") } + ) { + FileInputStream(backupFile).use { SignalNetwork.attachments.uploadPreEncryptedFileToAttachmentV4( uploadForm = uploadForm, resumableUploadUrl = resumableUploadUrl, - inputStream = backupFile.inputStream(), + inputStream = it, inputStreamLength = backupFile.length() ) } - - when (uploadResult) { - is NetworkResult.Success -> return uploadResult - is NetworkResult.NetworkError -> Log.w(TAG, "Hit network error while uploading. May retry.", uploadResult.exception) - is NetworkResult.StatusCodeError -> return uploadResult.logW(TAG, "Status code error when uploading archive.", uploadResult.exception) - is NetworkResult.ApplicationError -> throw uploadResult.throwable - } - - attemptCount++ } - Log.w(TAG, "Hit the max retry count of $maxRetries. Failing.") - return NetworkResult.NetworkError(IOException("Hit max retries!")) + return when (uploadResult) { + is NetworkResult.Success -> uploadResult + is NetworkResult.NetworkError -> uploadResult.logW(TAG, "Network error while uploading.", uploadResult.exception) + is NetworkResult.StatusCodeError -> uploadResult.logW(TAG, "Status code error when uploading archive.", uploadResult.exception) + is NetworkResult.ApplicationError -> throw uploadResult.throwable + } } /** diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt index 76342e1ad0..8cd49a6820 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt @@ -96,6 +96,36 @@ sealed class NetworkResult( } catch (e: Throwable) { ApplicationError(e) } + + /** + * Runs [operation] to perform a network call. If [shouldRetry] returns false for the result, then returns it. Otherwise will call [operation] repeatedly + * until [shouldRetry] returns false or is called [maxAttempts] number of times. + * + * @param maxAttempts Max attempts to try the network operation, must be 1 or more, default is 5 + * @param shouldRetry Predicate to determine if network operation should be retried, default is any [NetworkError] result is retried + * @param logAttempt Log each attempt before [operation] is called, default is noop + * @param operation Network operation that can be called repeatedly for each attempt + */ + fun withRetry( + maxAttempts: Int = 5, + shouldRetry: (NetworkResult) -> Boolean = { it is NetworkError }, + logAttempt: (attempt: Int, maxAttempts: Int) -> Unit = { _, _ -> }, + operation: () -> NetworkResult + ): NetworkResult { + require(maxAttempts > 0) + + lateinit var result: NetworkResult + for (attempt in 0 until maxAttempts) { + logAttempt(attempt, maxAttempts) + result = operation() + + if (!shouldRetry(result)) { + return result + } + } + + return result + } } /** Indicates the request was successful */ @@ -160,6 +190,7 @@ sealed class NetworkResult( ApplicationError(e).runOnStatusCodeError(statusCodeErrorActions) } } + is NetworkError -> NetworkError(exception).runOnStatusCodeError(statusCodeErrorActions) is ApplicationError -> ApplicationError(throwable).runOnStatusCodeError(statusCodeErrorActions) is StatusCodeError -> StatusCodeError(code, stringBody, binaryBody, exception).runOnStatusCodeError(statusCodeErrorActions)