Fix use of media credentials for fetching/restoring media related data.

This commit is contained in:
Cody Henthorne
2024-11-04 16:42:15 -05:00
committed by Greyson Parrelli
parent d7c08690ee
commit f848a78365
11 changed files with 65 additions and 54 deletions

View File

@@ -23,6 +23,7 @@ import org.signal.core.util.getAllTriggerDefinitions
import org.signal.core.util.getForeignKeyViolations
import org.signal.core.util.logging.Log
import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction
import org.signal.libsignal.messagebackup.MessageBackup
import org.signal.libsignal.messagebackup.MessageBackup.ValidationResult
@@ -112,6 +113,7 @@ object BackupRepository {
SignalStore.backup.backupsInitialized = false
SignalStore.backup.messageCredentials.clearAll()
SignalStore.backup.mediaCredentials.clearAll()
SignalStore.backup.cachedMediaCdnPath = null
}
403 -> {
@@ -716,7 +718,7 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential)
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.map { it.usedSpace }
}
}
@@ -744,7 +746,7 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.map { credential ->
val zkCredential = SignalNetwork.archive.getZkCredential(backupKey, aci, credential.mediaCredential)
val zkCredential = SignalNetwork.archive.getZkCredential(backupKey, aci, credential.messageCredential)
if (zkCredential.backupLevel == BackupLevel.PAID) {
MessageBackupTier.PAID
} else {
@@ -762,16 +764,16 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential)
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.map { it to credential }
}
.then { pair ->
val (info, credential) = pair
val (mediaBackupInfo, credential) = pair
SignalNetwork.archive.debugGetUploadedMediaItemMetadata(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential)
.also { Log.i(TAG, "MediaItemMetadataResult: $it") }
.map { mediaObjects ->
BackupMetadata(
usedSpace = info.usedSpace ?: 0,
usedSpace = mediaBackupInfo.usedSpace ?: 0,
mediaCount = mediaObjects.size.toLong()
)
}
@@ -1119,21 +1121,15 @@ object BackupRepository {
}
/**
* Retrieves backupDir and mediaDir, preferring cached value if available.
* Retrieves media-specific cdn path, preferring cached value if available.
*
* These will only ever change if the backup expires.
* This will change if the backup expires, a new backup-id is set, or the delete all endpoint is called.
*/
fun getCdnBackupDirectories(): NetworkResult<BackupDirectories> {
val cachedBackupDirectory = SignalStore.backup.cachedBackupDirectory
val cachedBackupMediaDirectory = SignalStore.backup.cachedBackupMediaDirectory
fun getArchivedMediaCdnPath(): NetworkResult<String> {
val cachedMediaPath = SignalStore.backup.cachedMediaCdnPath
if (cachedBackupDirectory != null && cachedBackupMediaDirectory != null) {
return NetworkResult.Success(
BackupDirectories(
backupDir = cachedBackupDirectory,
mediaDir = cachedBackupMediaDirectory
)
)
if (cachedMediaPath != null) {
return NetworkResult.Success(cachedMediaPath)
}
val backupKey = SignalStore.backup.messageBackupKey
@@ -1141,15 +1137,14 @@ object BackupRepository {
return initBackupAndFetchAuth(backupKey, mediaRootBackupKey)
.then { credential ->
SignalNetwork.archive.getBackupInfo(backupKey, SignalStore.account.requireAci(), credential.messageCredential).map {
SignalNetwork.archive.getBackupInfo(mediaRootBackupKey, SignalStore.account.requireAci(), credential.mediaCredential).map {
SignalStore.backup.usedBackupMediaSpace = it.usedSpace ?: 0L
BackupDirectories(it.backupDir!!, it.mediaDir!!)
"${it.backupDir!!.urlEncode()}/${it.mediaDir!!.urlEncode()}"
}
}
.also {
if (it is NetworkResult.Success) {
SignalStore.backup.cachedBackupDirectory = it.result.backupDir
SignalStore.backup.cachedBackupMediaDirectory = it.result.mediaDir
SignalStore.backup.cachedMediaCdnPath = it.result
}
}
}
@@ -1303,8 +1298,6 @@ object BackupRepository {
data class ArchivedMediaObject(val mediaId: String, val cdn: Int)
data class BackupDirectories(val backupDir: String, val mediaDir: String)
class ExportState(val backupTime: Long, val mediaBackupEnabled: Boolean) {
val recipientIds: MutableSet<Long> = hashSetOf()
val threadIds: MutableSet<Long> = hashSetOf()

View File

@@ -35,11 +35,10 @@ fun DatabaseAttachment.createArchiveAttachmentPointer(useArchiveCdn: Boolean): S
return try {
val (remoteId, cdnNumber) = if (useArchiveCdn) {
val mediaRootBackupKey = SignalStore.backup.mediaRootBackupKey
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
val mediaCdnPath = BackupRepository.getArchivedMediaCdnPath().successOrThrow()
val id = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaCdnPath = mediaCdnPath,
mediaId = mediaRootBackupKey.deriveMediaId(MediaName(archiveMediaName!!)).encode()
)
@@ -92,15 +91,14 @@ fun DatabaseAttachment.createArchiveThumbnailPointer(): SignalServiceAttachmentP
}
val mediaRootBackupKey = SignalStore.backup.mediaRootBackupKey
val backupDirectories = BackupRepository.getCdnBackupDirectories().successOrThrow()
val mediaCdnPath = BackupRepository.getArchivedMediaCdnPath().successOrThrow()
return try {
val key = mediaRootBackupKey.deriveThumbnailTransitKey(getThumbnailMediaName())
val mediaId = mediaRootBackupKey.deriveMediaId(getThumbnailMediaName()).encode()
SignalServiceAttachmentPointer(
cdnNumber = archiveCdn,
remoteId = SignalServiceAttachmentRemoteId.Backup(
backupDir = backupDirectories.backupDir,
mediaDir = backupDirectories.mediaDir,
mediaCdnPath = mediaCdnPath,
mediaId = mediaId
),
contentType = null,

View File

@@ -223,6 +223,7 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
}
else -> {
Log.w(TAG, "Error checking remote backup state", result.getCause())
_state.value = _state.value.copy(remoteBackupState = RemoteBackupState.GeneralError)
}
}

View File

@@ -1583,7 +1583,7 @@ class AttachmentTable(
SELECT
$mmsId,
$CONTENT_TYPE,
$TRANSFER_PROGRESS_PENDING,
$TRANSFER_NEEDS_RESTORE,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,

View File

@@ -98,8 +98,14 @@ class BackupMessagesJob private constructor(parameters: Parameters) : Job(parame
Log.i(TAG, "Successfully uploaded backup file.")
SignalStore.backup.hasBackupBeenUploaded = true
}
is NetworkResult.NetworkError -> return Result.retry(defaultBackoff())
is NetworkResult.StatusCodeError -> return Result.retry(defaultBackoff())
is NetworkResult.NetworkError -> {
Log.i(TAG, "Network failure", result.getCause())
return Result.retry(defaultBackoff())
}
is NetworkResult.StatusCodeError -> {
Log.i(TAG, "Status code failure", result.getCause())
return Result.retry(defaultBackoff())
}
is NetworkResult.ApplicationError -> throw result.throwable
}
}

View File

@@ -221,7 +221,7 @@ class RestoreAttachmentJob private constructor(
val downloadResult = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MESSAGE, attachment.archiveCdn).successOrThrow().headers
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn).successOrThrow().headers
messageReceiver
.retrieveArchivedAttachment(
@@ -265,6 +265,7 @@ class RestoreAttachmentJob private constructor(
return
} else if (e.code == 401 && useArchiveCdn) {
SignalStore.backup.mediaCredentials.cdnReadCredentials = null
SignalStore.backup.cachedMediaCdnPath = null
throw RetryLaterException(e)
}
}

View File

@@ -42,8 +42,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
private const val KEY_TOTAL_RESTORABLE_ATTACHMENT_SIZE = "backup.totalRestorableAttachmentSize"
private const val KEY_BACKUP_FREQUENCY = "backup.backupFrequency"
private const val KEY_CDN_BACKUP_DIRECTORY = "backup.cdn.directory"
private const val KEY_CDN_BACKUP_MEDIA_DIRECTORY = "backup.cdn.mediaDirectory"
private const val KEY_CDN_MEDIA_PATH = "backup.cdn.mediaPath"
private const val KEY_BACKUP_OVER_CELLULAR = "backup.useCellular"
private const val KEY_OPTIMIZE_STORAGE = "backup.optimizeStorage"
@@ -69,8 +68,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
override fun onFirstEverAppLaunch() = Unit
override fun getKeysToIncludeInBackup(): List<String> = emptyList()
var cachedBackupDirectory: String? by stringValue(KEY_CDN_BACKUP_DIRECTORY, null)
var cachedBackupMediaDirectory: String? by stringValue(KEY_CDN_BACKUP_MEDIA_DIRECTORY, null)
var cachedMediaCdnPath: String? by stringValue(KEY_CDN_MEDIA_PATH, null)
var usedBackupMediaSpace: Long by longValue(KEY_BACKUP_USED_MEDIA_SPACE, 0L)
var lastBackupProtoSize: Long by longValue(KEY_BACKUP_LAST_PROTO_SIZE, 0L)
@@ -116,6 +114,8 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
lock.withLock {
Log.i(TAG, "Setting MediaRootBackupKey", Throwable())
putBlob(KEY_MEDIA_ROOT_BACKUP_KEY, value.value)
mediaCredentials.clearAll()
cachedMediaCdnPath = null
}
}
@@ -240,6 +240,7 @@ class BackupValues(store: KeyValueStore) : SignalStoreValues(store) {
/** Clears all credentials. */
fun clearAll() {
putString(authKey, null)
cdnReadCredentials = null
}
/** Credentials to read from the CDN. */