diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt index 9844a98a81..037b8b5de3 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt @@ -8,7 +8,6 @@ import androidx.test.platform.app.InstrumentationRegistry import assertk.assertThat import assertk.assertions.isEqualTo import assertk.assertions.isNotEqualTo -import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertNotEquals import org.junit.Before @@ -17,7 +16,6 @@ import org.junit.Test import org.junit.runner.RunWith import org.signal.core.util.Base64.decodeBase64OrThrow import org.signal.core.util.copyTo -import org.signal.core.util.readFully import org.signal.core.util.stream.NullOutputStream import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.AttachmentId @@ -27,13 +25,10 @@ import org.thoughtcrime.securesms.mms.MediaStream import org.thoughtcrime.securesms.mms.SentMediaQuality import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.util.MediaUtil -import org.thoughtcrime.securesms.util.Util -import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId -import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import java.io.ByteArrayOutputStream import java.io.File import java.util.Optional @@ -182,32 +177,6 @@ class AttachmentTableTest { assertThat(highInfo.file.exists()).isEqualTo(true) } - @Test - fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() { - // Insert attachment metadata for properly-padded attachment - val plaintext = byteArrayOf(1, 2, 3, 4) - val key = Util.getSecretBytes(64) - val iv = Util.getSecretBytes(16) - - val paddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully() - val ciphertext = encryptPrePaddedBytes(paddedPlaintext, key, iv) - val digest = getDigest(ciphertext) - - val cipherFile = getTempFile() - cipherFile.writeBytes(ciphertext) - - val mmsId = -1L - val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first() - - // Give data to attachment table - val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4) - SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream) - - // Verify the digest hasn't changed - val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!! - assertArrayEquals(digest, newDigest) - } - @Test fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() { // Given an attachment with some plaintextHash+remoteKey diff --git a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java index 3b21fad208..0e743ae48a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java +++ b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java @@ -110,6 +110,7 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences; import org.thoughtcrime.securesms.util.Util; import org.thoughtcrime.securesms.util.VersionTracker; import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper; +import org.whispersystems.signalservice.api.backup.MediaName; import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import java.io.InterruptedIOException; diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt index 30215f8d53..483629e397 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/BackupRepository.kt @@ -19,6 +19,7 @@ import okio.ByteString import okio.ByteString.Companion.toByteString import org.greenrobot.eventbus.EventBus import org.signal.core.util.Base64 +import org.signal.core.util.Base64.decodeBase64OrThrow import org.signal.core.util.ByteSize import org.signal.core.util.CursorUtil import org.signal.core.util.EventTimer @@ -37,7 +38,7 @@ import org.signal.core.util.getForeignKeyViolations import org.signal.core.util.logging.Log import org.signal.core.util.money.FiatMoney import org.signal.core.util.requireIntOrNull -import org.signal.core.util.requireNonNullBlob +import org.signal.core.util.requireNonNullString import org.signal.core.util.stream.NonClosingOutputStream import org.signal.core.util.urlEncode import org.signal.core.util.withinTransaction @@ -1290,11 +1291,11 @@ object BackupRepository { return initBackupAndFetchAuth() .then { credential -> SignalNetwork.archive.getMessageBackupUploadForm(SignalStore.account.requireAci(), credential.messageBackupAccess) - .also { Log.i(TAG, "UploadFormResult: $it") } + .also { Log.i(TAG, "UploadFormResult: ${it::class.simpleName}") } } .then { form -> SignalNetwork.archive.getBackupResumableUploadUrl(form) - .also { Log.i(TAG, "ResumableUploadUrlResult: $it") } + .also { Log.i(TAG, "ResumableUploadUrlResult: ${it::class.simpleName}") } .map { ResumableMessagesBackupUploadSpec(attachmentUploadForm = form, resumableUri = it) } } } @@ -1307,7 +1308,7 @@ object BackupRepository { ): NetworkResult { val (form, resumableUploadUrl) = resumableSpec return SignalNetwork.archive.uploadBackupFile(form, resumableUploadUrl, backupStream, backupStreamLength, progressListener) - .also { Log.i(TAG, "UploadBackupFileResult: $it") } + .also { Log.i(TAG, "UploadBackupFileResult: ${it::class.simpleName}") } } fun downloadBackupFile(destination: File, listener: ProgressListener? = null): NetworkResult { @@ -1395,7 +1396,7 @@ object BackupRepository { .map { response -> SignalDatabase.attachments.setArchiveCdn(attachmentId = attachment.attachmentId, archiveCdn = response.cdn) } - .also { Log.i(TAG, "archiveMediaResult: $it") } + .also { Log.i(TAG, "archiveMediaResult: ${it::class.simpleName}") } } fun deleteAbandonedMediaObjects(mediaObjects: Collection): NetworkResult { @@ -1421,7 +1422,7 @@ object BackupRepository { mediaToDelete = mediaToDelete ) } - .also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: $it") } + .also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: ${it::class.simpleName}") } } fun deleteBackup(): NetworkResult { @@ -1477,7 +1478,7 @@ object BackupRepository { .map { SignalDatabase.attachments.clearAllArchiveData() } - .also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: $it") } + .also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: ${it::class.simpleName}") } } /** @@ -1512,7 +1513,7 @@ object BackupRepository { credentialStore.cdnReadCredentials = it.result } } - .also { Log.i(TAG, "getCdnReadCredentialsResult: $it") } + .also { Log.i(TAG, "getCdnReadCredentialsResult: ${it::class.simpleName}") } } fun restoreBackupTier(aci: ACI): MessageBackupTier? { @@ -1965,8 +1966,8 @@ class ArchiveMediaItemIterator(private val cursor: Cursor) : Iterator viewModel.onBackupTierSelected(tier) }, onCheckRemoteBackupStateClicked = { viewModel.checkRemoteBackupState() }, onEnqueueRemoteBackupClicked = { viewModel.triggerBackupJob() }, + onEnqueueReconciliationClicked = { AppDependencies.jobManager.add(ArchiveAttachmentReconciliationJob(forced = true)) }, onHaltAllBackupJobsClicked = { viewModel.haltAllJobs() }, onValidateBackupClicked = { viewModel.validateBackup() }, onSaveEncryptedBackupToDiskClicked = { @@ -222,7 +223,7 @@ class InternalBackupPlaygroundFragment : ComposeFragment() { onDeleteRemoteBackup = { MaterialAlertDialogBuilder(context) .setTitle("Are you sure?") - .setMessage("This will delete all of your remote backup data?") + .setMessage("This will delete all of your remote backup data!") .setPositiveButton("Delete remote data") { _, _ -> lifecycleScope.launch { val success = viewModel.deleteRemoteBackupData() @@ -234,6 +235,21 @@ class InternalBackupPlaygroundFragment : ComposeFragment() { .setNegativeButton("Cancel", null) .show() }, + onClearLocalMediaBackupState = { + MaterialAlertDialogBuilder(context) + .setTitle("Are you sure?") + .setMessage("This will cause you to have to re-upload all of your media!") + .setPositiveButton("Clear local media state") { _, _ -> + lifecycleScope.launch { + viewModel.clearLocalMediaBackupState() + withContext(Dispatchers.Main) { + Toast.makeText(requireContext(), "Done!", Toast.LENGTH_SHORT).show() + } + } + } + .setNegativeButton("Cancel", null) + .show() + }, onDisplayInitialBackupFailureSheet = { BackupRepository.displayInitialBackupFailureNotification() BackupAlertBottomSheet @@ -312,8 +328,8 @@ fun Screen( onImportNewStyleLocalBackupClicked: () -> Unit = {}, onCheckRemoteBackupStateClicked: () -> Unit = {}, onEnqueueRemoteBackupClicked: () -> Unit = {}, + onEnqueueReconciliationClicked: () -> Unit = {}, onWipeDataAndRestoreFromRemoteClicked: () -> Unit = {}, - onBackupTierSelected: (MessageBackupTier?) -> Unit = {}, onHaltAllBackupJobsClicked: () -> Unit = {}, onSavePlaintextCopyOfRemoteBackupClicked: () -> Unit = {}, onValidateBackupClicked: () -> Unit = {}, @@ -322,6 +338,7 @@ fun Screen( onImportEncryptedBackupFromDiskClicked: () -> Unit = {}, onImportEncryptedBackupFromDiskDismissed: () -> Unit = {}, onImportEncryptedBackupFromDiskConfirmed: (aci: String, backupKey: String) -> Unit = { _, _ -> }, + onClearLocalMediaBackupState: () -> Unit = {}, onDeleteRemoteBackup: () -> Unit = {}, onDisplayInitialBackupFailureSheet: () -> Unit = {} ) { @@ -353,21 +370,6 @@ fun Screen( .fillMaxSize() .verticalScroll(scrollState) ) { - Row(verticalAlignment = Alignment.CenterVertically) { - Text("Tier", fontWeight = FontWeight.Bold) - options.forEach { option -> - Row(verticalAlignment = Alignment.CenterVertically) { - RadioButton( - selected = option.value == state.backupTier, - onClick = { onBackupTierSelected(option.value) } - ) - Text(option.key) - } - } - } - - Dividers.Default() - Rows.TextRow( text = { Text( @@ -392,6 +394,12 @@ fun Screen( onClick = onEnqueueRemoteBackupClicked ) + Rows.TextRow( + text = "Enqueue reconciliation job", + label = "Schedules a job that will ensure local and remote media state are in sync.", + onClick = onEnqueueReconciliationClicked + ) + Rows.TextRow( text = "Halt all backup jobs", label = "Stops all backup-related jobs to the best of our ability.", @@ -513,6 +521,12 @@ fun Screen( onClick = onDeleteRemoteBackup ) + Rows.TextRow( + text = "Clear local media backup state", + label = "Resets local state tracking so you think you haven't uploaded any media. The media still exists on the server.", + onClick = onClearLocalMediaBackupState + ) + Dividers.Default() Rows.TextRow( diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupPlaygroundViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupPlaygroundViewModel.kt index 1c62c3d1e7..de69a87067 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupPlaygroundViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupPlaygroundViewModel.kt @@ -292,11 +292,6 @@ class InternalBackupPlaygroundViewModel : ViewModel() { } } - fun onBackupTierSelected(backupTier: MessageBackupTier?) { - SignalStore.backup.backupTier = backupTier - _state.value = _state.value.copy(backupTier = backupTier) - } - fun onImportSelected() { _state.value = _state.value.copy(dialog = DialogState.ImportCredentials) } @@ -398,6 +393,10 @@ class InternalBackupPlaygroundViewModel : ViewModel() { return@withContext false } + suspend fun clearLocalMediaBackupState() = withContext(Dispatchers.IO) { + SignalDatabase.attachments.clearAllArchiveData() + } + override fun onCleared() { disposables.clear() } diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt index 2ffc4b787e..cd00a213d5 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt @@ -55,7 +55,7 @@ fun InternalBackupStatsTab(stats: InternalBackupPlaygroundViewModel.StatsState, Spacer(modifier = Modifier.size(16.dp)) Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}") - Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") + Text(text = "Unique/archived verified plaintextHash count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}") Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}") Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt index c8f98f107a..d9ee392b94 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt @@ -679,15 +679,15 @@ class AttachmentTable( /** * Sets the archive transfer state for the given attachment by digest. */ - fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray) { - writableDatabase + fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray): Boolean { + return writableDatabase .update(TABLE_NAME) .values( ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value, ARCHIVE_CDN to null ) .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey)) - .run() + .run() > 0 } /** @@ -2593,11 +2593,11 @@ class AttachmentTable( .associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) } val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L) - val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) + val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L) - val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L) + val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L) val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L) - val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L) + val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L) val pendingUploadBytes = getPendingArchiveUploadBytes() val uploadedAttachmentBytes = readableDatabase diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentReconciliationJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentReconciliationJob.kt index 8470a53ca3..2ce2c12c0e 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentReconciliationJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveAttachmentReconciliationJob.kt @@ -127,7 +127,13 @@ class ArchiveAttachmentReconciliationJob private constructor( val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it) // TODO [backup] Re-enqueue thumbnail uploads if necessary if (!entry.isThumbnail) { - SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey) + val success = SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey) + if (!success) { + Log.e(TAG, "Failed to reset archive transfer state by remote hash/key!") + if (RemoteConfig.internalUser) { + throw RuntimeException("Failed to reset archive transfer state by remote hash/key!") + } + } } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt index 64205cd55a..f88a3d2b71 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt @@ -85,8 +85,8 @@ class ArchiveThumbnailUploadJob private constructor( return Result.success() } - if (attachment.remoteDigest == null) { - Log.w(TAG, "$attachmentId was never uploaded! Cannot proceed.") + if (attachment.remoteDigest == null && attachment.dataHash == null) { + Log.w(TAG, "$attachmentId has no integrity check! Cannot proceed.") return Result.success() } @@ -153,7 +153,7 @@ class ArchiveThumbnailUploadJob private constructor( data = thumbnailResult.data ) - Log.d(TAG, "Successfully archived thumbnail for $attachmentId mediaName=${attachment.requireThumbnailMediaName()}") + Log.d(TAG, "Successfully archived thumbnail for $attachmentId") Result.success() } diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt index 331ff6ce15..93b5130ed0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt @@ -35,6 +35,7 @@ import org.thoughtcrime.securesms.transport.RetryLaterException import org.thoughtcrime.securesms.util.AttachmentUtil import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.Util +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer @@ -277,12 +278,18 @@ class AttachmentDownloadJob private constructor( } } + if (attachment.remoteDigest == null && attachment.dataHash == null) { + Log.w(TAG, "Attachment has no integrity check!") + throw InvalidAttachmentException("Attachment has no integrity check!") + } + val decryptingStream = AppDependencies .signalServiceMessageReceiver .retrieveAttachment( pointer, attachmentFile, maxReceiveSize, + IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash), progressListener ) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/AvatarGroupsV1DownloadJob.java b/app/src/main/java/org/thoughtcrime/securesms/jobs/AvatarGroupsV1DownloadJob.java index 3fa961e527..b459d33982 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/AvatarGroupsV1DownloadJob.java +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/AvatarGroupsV1DownloadJob.java @@ -16,6 +16,8 @@ import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint; import org.thoughtcrime.securesms.profiles.AvatarHelper; import org.signal.core.util.Hex; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId; import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException; @@ -84,9 +86,16 @@ public final class AvatarGroupsV1DownloadJob extends BaseJob { attachment = File.createTempFile("avatar", "tmp", context.getCacheDir()); attachment.deleteOnExit(); - SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver(); - SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null); - InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE); + + SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver(); + SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null); + + if (pointer.getDigest().isEmpty()) { + throw new InvalidMessageException("Missing digest!"); + } + + IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigest(pointer.getDigest().get()); + InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE, integrityCheck); AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream); SignalDatabase.groups().onAvatarUpdated(groupId, true); diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/MultiDeviceContactSyncJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/MultiDeviceContactSyncJob.kt index 1e4bda053f..72d5b4a8c2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/MultiDeviceContactSyncJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/MultiDeviceContactSyncJob.kt @@ -12,6 +12,7 @@ import org.thoughtcrime.securesms.net.NotPushRegisteredException import org.thoughtcrime.securesms.profiles.AvatarHelper import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.recipients.Recipient +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.multidevice.DeviceContact import org.whispersystems.signalservice.api.messages.multidevice.DeviceContactsInputStream @@ -59,7 +60,7 @@ class MultiDeviceContactSyncJob(parameters: Parameters, private val attachmentPo try { val contactsFile: File = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(context) AppDependencies.signalServiceMessageReceiver - .retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE) + .retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE, IntegrityCheck.forEncryptedDigest(contactAttachment.digest.get())) .use(this::processContactFile) } catch (e: MissingConfigurationException) { throw IOException(e) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt index 2286241ee1..76fe5f09ca 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt @@ -40,6 +40,7 @@ import org.thoughtcrime.securesms.notifications.NotificationChannels import org.thoughtcrime.securesms.notifications.NotificationIds import org.thoughtcrime.securesms.transport.RetryLaterException import org.thoughtcrime.securesms.util.RemoteConfig +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException @@ -49,6 +50,7 @@ import org.whispersystems.signalservice.api.push.exceptions.RangeException import java.io.File import java.io.IOException import java.util.concurrent.TimeUnit +import kotlin.jvm.optionals.getOrNull import kotlin.math.max import kotlin.math.pow import kotlin.time.Duration.Companion.days @@ -172,12 +174,12 @@ class RestoreAttachmentJob private constructor( val attachment = SignalDatabase.attachments.getAttachment(attachmentId) if (attachment == null) { - Log.w(TAG, "attachment no longer exists.") + Log.w(TAG, "[$attachmentId] Attachment no longer exists.") return } if (attachment.isPermanentlyFailed) { - Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.") + Log.w(TAG, "[$attachmentId] Attachment was marked as a permanent failure. Refusing to download.") return } @@ -186,7 +188,7 @@ class RestoreAttachmentJob private constructor( attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED && attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED ) { - Log.w(TAG, "Attachment does not need to be restored. Current state: ${attachment.transferState}") + Log.w(TAG, "[$attachmentId] Attachment does not need to be restored. Current state: ${attachment.transferState}") return } @@ -231,14 +233,20 @@ class RestoreAttachmentJob private constructor( var archiveFile: File? = null var useArchiveCdn = false + if (attachment.remoteDigest == null && attachment.dataHash == null) { + Log.w(TAG, "[$attachmentId] Attachment has no integrity check! Cannot proceed.") + markPermanentlyFailed(attachmentId) + return + } + try { if (attachment.size > maxReceiveSize) { - throw MmsException("Attachment too large, failing download") + throw MmsException("[$attachmentId] Attachment too large, failing download") } useArchiveCdn = if (SignalStore.backup.backsUpMedia && !forceTransitTier) { if (attachment.archiveTransferState != AttachmentTable.ArchiveTransferState.FINISHED) { - throw InvalidAttachmentException("Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}") + throw InvalidAttachmentException("[$attachmentId] Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}") } true } else { @@ -259,7 +267,9 @@ class RestoreAttachmentJob private constructor( } val decryptingStream = if (useArchiveCdn) { - archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId) + // TODO next PR: remove archive transfer file and just use the regular attachment file + archiveFile = attachmentFile +// archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId) val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn ?: RemoteConfig.backupFallbackArchiveCdn).successOrThrow().headers messageReceiver @@ -269,7 +279,6 @@ class RestoreAttachmentJob private constructor( cdnCredentials, archiveFile, pointer, - attachmentFile, maxReceiveSize, progressListener ) @@ -279,6 +288,7 @@ class RestoreAttachmentJob private constructor( pointer, attachmentFile, maxReceiveSize, + IntegrityCheck.forEncryptedDigestAndPlaintextHash(pointer.digest.getOrNull(), attachment.dataHash), progressListener ) } @@ -286,7 +296,7 @@ class RestoreAttachmentJob private constructor( SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, decryptingStream, if (manual) System.currentTimeMillis().milliseconds else null) } catch (e: RangeException) { val transferFile = archiveFile ?: attachmentFile - Log.w(TAG, "Range exception, file size " + transferFile.length(), e) + Log.w(TAG, "[$attachmentId] Range exception, file size " + transferFile.length(), e) if (transferFile.delete()) { Log.i(TAG, "Deleted temp download file to recover") throw RetryLaterException(e) @@ -299,7 +309,7 @@ class RestoreAttachmentJob private constructor( } catch (e: NonSuccessfulResponseCodeException) { if (SignalStore.backup.backsUpMedia) { if (e.code == 404 && !forceTransitTier && attachment.remoteLocation?.isNotBlank() == true) { - Log.i(TAG, "Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.") + Log.i(TAG, "[$attachmentId] Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.") if (RemoteConfig.internalUser) { postFailedToDownloadFromArchiveNotification() } @@ -316,18 +326,18 @@ class RestoreAttachmentJob private constructor( } } - Log.w(TAG, "Experienced exception while trying to download an attachment.", e) + Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e) markFailed(attachmentId) } catch (e: MmsException) { - Log.w(TAG, "Experienced exception while trying to download an attachment.", e) + Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e) markFailed(attachmentId) } catch (e: MissingConfigurationException) { - Log.w(TAG, "Experienced exception while trying to download an attachment.", e) + Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e) markFailed(attachmentId) } catch (e: InvalidMessageException) { - Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e) + Log.w(TAG, "[$attachmentId] Experienced an InvalidMessageException while trying to download an attachment.", e) if (e.cause is InvalidMacException) { - Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.") + Log.w(TAG, "[$attachmentId] Detected an invalid mac. Treating as a permanent failure.") markPermanentlyFailed(attachmentId) } else { markFailed(attachmentId) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt index 42e377aada..eddd09ecdc 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt @@ -120,7 +120,6 @@ class RestoreAttachmentThumbnailJob private constructor( val maxThumbnailSize: Long = RemoteConfig.maxAttachmentReceiveSizeBytes val thumbnailTransferFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile() - val thumbnailFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile() val progressListener = object : SignalServiceAttachment.ProgressListener { override fun onAttachmentProgress(progress: AttachmentTransferProgress) = Unit @@ -137,7 +136,6 @@ class RestoreAttachmentThumbnailJob private constructor( cdnCredentials, thumbnailTransferFile, pointer, - thumbnailFile, maxThumbnailSize, progressListener ) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt index badd01a104..de231a7a0a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt @@ -23,6 +23,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.mms.MmsException import org.whispersystems.signalservice.api.backup.MediaName import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.StreamSupplier import java.io.IOException @@ -154,7 +155,10 @@ class RestoreLocalAttachmentJob private constructor( streamLength = size, plaintextLength = attachment.size, combinedKeyMaterial = combinedKey, - digest = attachment.remoteDigest, + integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash( + encryptedDigest = attachment.remoteDigest, + plaintextHash = attachment.dataHash + ), incrementalDigest = null, incrementalMacChunkSize = 0 ).use { input -> diff --git a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java index 2c4d83b191..200a12a081 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java +++ b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java @@ -23,6 +23,7 @@ import org.signal.core.util.Base64; import org.whispersystems.signalservice.api.backup.MediaName; import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.signal.core.util.stream.TailerInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; @@ -100,11 +101,13 @@ class PartDataSource implements DataSource { long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size)); AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength); - if (attachment.remoteDigest == null) { + if (attachment.remoteDigest == null && attachment.dataHash == null) { throw new InvalidMessageException("Missing digest!"); } - this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); + IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash); + + this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, integrityCheck, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); } catch (InvalidMessageException e) { throw new IOException("Error decrypting attachment stream!", e); } diff --git a/app/src/test/java/org/thoughtcrime/securesms/backup/v2/ArchivedMediaObjectIteratorTest.kt b/app/src/test/java/org/thoughtcrime/securesms/backup/v2/ArchivedMediaObjectIteratorTest.kt index 17c0646553..35d0b8ed57 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/backup/v2/ArchivedMediaObjectIteratorTest.kt +++ b/app/src/test/java/org/thoughtcrime/securesms/backup/v2/ArchivedMediaObjectIteratorTest.kt @@ -7,14 +7,16 @@ import io.mockk.mockk import io.mockk.mockkObject import org.junit.Before import org.junit.Test +import org.signal.core.util.Base64 import org.thoughtcrime.securesms.MockCursor import org.thoughtcrime.securesms.keyvalue.BackupValues import org.thoughtcrime.securesms.keyvalue.SignalStore +import org.thoughtcrime.securesms.util.Util import org.whispersystems.signalservice.api.backup.MediaRootBackupKey class ArchivedMediaObjectIteratorTest { private val cursor = mockk(relaxed = true) { - every { getString(any()) } returns "A" + every { getString(any()) } returns Base64.encodeWithPadding(Util.getSecretBytes(32)) every { moveToPosition(any()) } answers { callOriginal() } every { moveToNext() } answers { callOriginal() } every { position } answers { callOriginal() } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 4bdcdd59c2..785c494211 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -7,10 +7,12 @@ package org.whispersystems.signalservice.api; import org.signal.core.util.StreamUtil; +import org.signal.core.util.logging.Log; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream; import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener; @@ -63,9 +65,9 @@ public class SignalServiceMessageReceiver { * @throws IOException * @throws InvalidMessageException */ - public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes) + public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck) throws IOException, InvalidMessageException, MissingConfigurationException { - return retrieveAttachment(pointer, destination, maxSizeBytes, null); + return retrieveAttachment(pointer, destination, maxSizeBytes, integrityCheck, null); } public InputStream retrieveProfileAvatar(String path, File destination, ProfileKey profileKey, long maxSizeBytes) @@ -96,9 +98,9 @@ public class SignalServiceMessageReceiver { * @throws IOException * @throws InvalidMessageException */ - public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener) + public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck, ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException { - if (pointer.getDigest().isEmpty()) throw new InvalidMessageException("No attachment digest!"); + if (integrityCheck == null) throw new InvalidMessageException("No integrity check!"); if (pointer.getKey() == null) throw new InvalidMessageException("No key!"); socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener); @@ -112,7 +114,7 @@ public class SignalServiceMessageReceiver { destination, pointer.getSize().orElse(0), pointer.getKey(), - pointer.getDigest().get(), + integrityCheck, null, 0 ); @@ -126,7 +128,6 @@ public class SignalServiceMessageReceiver { * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. - * @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed. * @param listener An optional listener (may be null) to receive callbacks on download progress. * * @return An InputStream that streams the plaintext attachment contents. @@ -136,7 +137,6 @@ public class SignalServiceMessageReceiver { @Nonnull Map readCredentialHeaders, @Nonnull File archiveDestination, @Nonnull SignalServiceAttachmentPointer pointer, - @Nonnull File attachmentDestination, long maxSizeBytes, @Nullable ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException @@ -154,7 +154,7 @@ public class SignalServiceMessageReceiver { return AttachmentCipherInputStream.createForArchivedMedia( archivedMediaKeyMaterial, - attachmentDestination, + archiveDestination, originalCipherLength, pointer.getSize().orElse(0), pointer.getKey(), @@ -171,7 +171,6 @@ public class SignalServiceMessageReceiver { * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. - * @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed. * @param listener An optional listener (may be null) to receive callbacks on download progress. * * @return An InputStream that streams the plaintext attachment contents. @@ -180,7 +179,6 @@ public class SignalServiceMessageReceiver { @Nonnull Map readCredentialHeaders, @Nonnull File archiveDestination, @Nonnull SignalServiceAttachmentPointer pointer, - @Nonnull File attachmentDestination, long maxSizeBytes, @Nullable ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException @@ -198,7 +196,7 @@ public class SignalServiceMessageReceiver { return AttachmentCipherInputStream.createForArchivedThumbnail( archivedMediaKeyMaterial, - attachmentDestination, + archiveDestination, originalCipherLength, pointer.getSize().orElse(0), pointer.getKey() diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/ArchiveGetMediaItemsResponse.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/ArchiveGetMediaItemsResponse.kt index a456fad647..564e9702a6 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/ArchiveGetMediaItemsResponse.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/ArchiveGetMediaItemsResponse.kt @@ -16,7 +16,7 @@ class ArchiveGetMediaItemsResponse( @JsonProperty val mediaDir: String?, @JsonProperty val cursor: String? ) { - class StoredMediaObject( + data class StoredMediaObject( @JsonProperty val cdn: Int, @JsonProperty val mediaId: String, @JsonProperty val objectLength: Long diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/DeleteArchivedMediaRequest.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/DeleteArchivedMediaRequest.kt index 4d9a5d73b3..a46c9a8ab8 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/DeleteArchivedMediaRequest.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/archive/DeleteArchivedMediaRequest.kt @@ -13,7 +13,7 @@ import com.fasterxml.jackson.annotation.JsonProperty class DeleteArchivedMediaRequest( @JsonProperty val mediaToDelete: List ) { - class ArchivedMediaObject( + data class ArchivedMediaObject( @JsonProperty val cdn: Int, @JsonProperty val mediaId: String ) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/backup/MediaName.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/backup/MediaName.kt index cd4b5efd7b..65aed968a9 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/backup/MediaName.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/backup/MediaName.kt @@ -31,10 +31,6 @@ value class MediaName(val name: String) { return mediaRootBackupKey.deriveMediaId(this) } - fun toByteArray(): ByteArray { - return name.toByteArray() - } - override fun toString(): String { return name } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt index 70ad95ea71..4fe85d127d 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt @@ -5,6 +5,7 @@ */ package org.whispersystems.signalservice.api.crypto +import org.signal.core.util.Base64 import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.stream.LimitedInputStream import org.signal.libsignal.protocol.InvalidMessageException @@ -51,7 +52,7 @@ object AttachmentCipherInputStream { file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray, - digest: ByteArray, + integrityCheck: IntegrityCheck, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int ): InputStream { @@ -60,11 +61,9 @@ object AttachmentCipherInputStream { streamLength = file.length(), plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - encryptedDigest = digest, - plaintextHash = null, + integrityCheck = integrityCheck, incrementalDigest = incrementalDigest, - incrementalMacChunkSize = incrementalMacChunkSize, - ignoreDigest = false + incrementalMacChunkSize = incrementalMacChunkSize ) } @@ -81,7 +80,7 @@ object AttachmentCipherInputStream { streamLength: Long, plaintextLength: Long, combinedKeyMaterial: ByteArray, - digest: ByteArray, + integrityCheck: IntegrityCheck, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int ): InputStream { @@ -90,11 +89,9 @@ object AttachmentCipherInputStream { streamLength = streamLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - encryptedDigest = digest, - plaintextHash = null, + integrityCheck = integrityCheck, incrementalDigest = incrementalDigest, - incrementalMacChunkSize = incrementalMacChunkSize, - ignoreDigest = false + incrementalMacChunkSize = incrementalMacChunkSize ) } @@ -130,11 +127,9 @@ object AttachmentCipherInputStream { streamLength = originalCipherTextLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - encryptedDigest = null, - plaintextHash = plaintextHash, + integrityCheck = IntegrityCheck(plaintextHash = plaintextHash, encryptedDigest = null), incrementalDigest = incrementalDigest, - incrementalMacChunkSize = incrementalMacChunkSize, - ignoreDigest = true + incrementalMacChunkSize = incrementalMacChunkSize ) } @@ -159,7 +154,7 @@ object AttachmentCipherInputStream { val mac = initMac(keyMaterial.macKey) if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead!") + throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength") } return create( @@ -167,11 +162,9 @@ object AttachmentCipherInputStream { streamLength = originalCipherTextLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - encryptedDigest = null, - plaintextHash = null, + integrityCheck = null, incrementalDigest = null, - incrementalMacChunkSize = 0, - ignoreDigest = true + incrementalMacChunkSize = 0 ) } @@ -189,7 +182,7 @@ object AttachmentCipherInputStream { } ByteArrayInputStream(data).use { inputStream -> - verifyMac(inputStream, data.size.toLong(), mac, null) + verifyMacAndMaybeEncryptedDigest(inputStream, data.size.toLong(), mac, null) } val encryptedStream = ByteArrayInputStream(data) @@ -211,11 +204,11 @@ object AttachmentCipherInputStream { val mac = initMac(archivedMediaKeyMaterial.macKey) if (file.length() <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead!") + throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got ${file.length()}") } FileInputStream(file).use { macVerificationStream -> - verifyMac(macVerificationStream, file.length(), mac, null) + verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), mac, null) } val encryptedStream = FileInputStream(file) @@ -226,6 +219,10 @@ object AttachmentCipherInputStream { return LimitedInputStream(inputStream, originalCipherTextLength) } + /** + * @param integrityCheck If null, no integrity check is performed! This is a private method, so it's assumed that care has been taken to ensure that this is + * the correct course of action. Public methods should properly enforce when integrity checks are required. + */ @JvmStatic @Throws(InvalidMessageException::class, IOException::class) private fun create( @@ -233,11 +230,9 @@ object AttachmentCipherInputStream { streamLength: Long, plaintextLength: Long, combinedKeyMaterial: ByteArray, - encryptedDigest: ByteArray?, - plaintextHash: ByteArray?, + integrityCheck: IntegrityCheck?, incrementalDigest: ByteArray?, - incrementalMacChunkSize: Int, - ignoreDigest: Boolean + incrementalMacChunkSize: Int ): InputStream { val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) val mac = initMac(keyMaterial.macKey) @@ -246,25 +241,16 @@ object AttachmentCipherInputStream { throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") } - if (!ignoreDigest && encryptedDigest == null) { - throw InvalidMessageException("Missing digest!") - } - val wrappedStream: InputStream val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0 - if (!hasIncrementalMac) { - streamSupplier.openStream().use { macVerificationStream -> - verifyMac(macVerificationStream, streamLength, mac, encryptedDigest) - } - wrappedStream = streamSupplier.openStream() - } else { - if (encryptedDigest == null && plaintextHash == null) { - throw InvalidMessageException("Missing data (digest or plaintextHas) for incremental mac validation!") + if (hasIncrementalMac) { + if (integrityCheck == null) { + throw InvalidMessageException("Missing integrityCheck for incremental mac validation!") } - val digestValidatingStream = if (encryptedDigest != null) { - DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), encryptedDigest) + val digestValidatingStream = if (integrityCheck.encryptedDigest != null) { + DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), integrityCheck.encryptedDigest) } else { streamSupplier.openStream() } @@ -279,6 +265,11 @@ object AttachmentCipherInputStream { ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), incrementalDigest ) + } else { + streamSupplier.openStream().use { macVerificationStream -> + verifyMacAndMaybeEncryptedDigest(macVerificationStream, streamLength, mac, integrityCheck?.encryptedDigest) + } + wrappedStream = streamSupplier.openStream() } val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength) @@ -286,12 +277,12 @@ object AttachmentCipherInputStream { val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher) val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength) - return if (plaintextHash != null) { - if (plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) { - throw InvalidMessageException("Invalid plaintext hash size: ${plaintextHash.size}") + return if (integrityCheck?.plaintextHash != null) { + if (integrityCheck.plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) { + throw InvalidMessageException("Invalid plaintext hash size: ${integrityCheck.plaintextHash.size}") } - DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), plaintextHash) + DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), integrityCheck.plaintextHash) } else { paddinglessDecryptingStream } @@ -326,7 +317,7 @@ object AttachmentCipherInputStream { } @Throws(InvalidMessageException::class) - private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { + private fun verifyMacAndMaybeEncryptedDigest(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { try { val digest = MessageDigest.getInstance("SHA256") var remainingData = Util.toIntExact(length) - mac.macLength @@ -375,4 +366,33 @@ object AttachmentCipherInputStream { @Throws(IOException::class) fun openStream(): InputStream } + + class IntegrityCheck( + val encryptedDigest: ByteArray?, + val plaintextHash: ByteArray? + ) { + init { + if (encryptedDigest == null && plaintextHash == null) { + throw IllegalArgumentException("At least one of encryptedDigest or plaintextHash must be provided") + } + } + + companion object { + @JvmStatic + fun forEncryptedDigest(encryptedDigest: ByteArray): IntegrityCheck { + return IntegrityCheck(encryptedDigest, null) + } + + @JvmStatic + fun forPlaintextHash(plaintextHash: ByteArray): IntegrityCheck { + return IntegrityCheck(null, plaintextHash) + } + + @JvmStatic + fun forEncryptedDigestAndPlaintextHash(encryptedDigest: ByteArray?, plaintextHash: String?): IntegrityCheck { + val plaintextHashBytes = plaintextHash?.let { Base64.decode(it) } + return IntegrityCheck(encryptedDigest, plaintextHashBytes) + } + } + } } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt index 2a890c95ba..eda1409164 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt @@ -13,6 +13,7 @@ import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.InvalidMacException import org.signal.libsignal.protocol.kdf.HKDF +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck import org.whispersystems.signalservice.api.crypto.AttachmentCipherTestHelper.createMediaKeyMaterial import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory @@ -32,19 +33,39 @@ import java.util.Random class AttachmentCipherTest { @Test - fun attachment_encryptDecrypt_nonIncremental() { - attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) + fun attachment_encryptDecrypt_nonIncremental_encryptedDigest() { + attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST) } @Test - fun attachment_encryptDecrypt_incremental() { - attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) + fun attachment_encryptDecrypt_nonIncremental_plaintextHash() { + attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH) + } + + @Test + fun attachment_encryptDecrypt_nonIncremental_bothIntegrityChecks() { + attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH) + } + + @Test + fun attachment_encryptDecrypt_incremental_encryptedDigest() { + attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST) + } + + @Test + fun attachment_encryptDecrypt_incremental_plaintextHash() { + attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH) + } + + @Test + fun attachment_encryptDecrypt_incremental_bothIntegrityChecks() { + attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH) } @Test fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() { for (i in 0..99) { - attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), integrityCheckMode = IntegrityCheckMode.BOTH) } } @@ -52,18 +73,44 @@ class AttachmentCipherTest { fun attachment_encryptDecrypt_incremental_manyFileSizes() { // Designed to stress the various boundary conditions of reading the final mac for (i in 0..99) { - attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), IntegrityCheckMode.BOTH) } } - private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int) { + private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int, integrityCheckMode: IntegrityCheckMode) { val key = Util.getSecretBytes(64) val plaintextInput = Util.getSecretBytes(fileSize) + val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput) val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val integrityCheck = when (integrityCheckMode) { + IntegrityCheckMode.ENCRYPTED_DIGEST -> IntegrityCheck.forEncryptedDigest(encryptResult.digest) + IntegrityCheckMode.PLAINTEXT_HASH -> IntegrityCheck.forPlaintextHash(plaintextHash) + IntegrityCheckMode.BOTH -> IntegrityCheck( + encryptedDigest = encryptResult.digest, + plaintextHash = plaintextHash + ) + } + val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val plaintextOutput = inputStream.readFully(autoClose = false) + + assertThat(plaintextOutput).isEqualTo(plaintextInput) + + cipherFile.delete() + } + + private fun attachment_encryptDecrypt_plaintextHash(incremental: Boolean, fileSize: Int) { + val key = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(fileSize) + val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput) + + val encryptResult = encryptData(plaintextInput, key, incremental) + val cipherFile = writeToFile(encryptResult.ciphertext) + + val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash) + val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val plaintextOutput = inputStream.readFully(autoClose = false) assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -88,7 +135,8 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest) + val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val plaintextOutput = inputStream.readFully(autoClose = false) Assert.assertArrayEquals(plaintextInput, plaintextOutput) @@ -117,7 +165,8 @@ class AttachmentCipherTest { cipherFile = writeToFile(encryptResult.ciphertext) val badKey = ByteArray(64) - AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, encryptResult.digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest) + AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, integrityCheck, null, 0) } finally { cipherFile?.delete() } @@ -147,7 +196,8 @@ class AttachmentCipherTest { cipherFile = writeToFile(badMacCiphertext) - val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest) + val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) // In incremental mode, we'll only check the digest after reading the whole thing if (incremental) { @@ -159,16 +209,26 @@ class AttachmentCipherTest { } @Test(expected = InvalidMessageException::class) - fun attachment_decryptFailOnBadDigest_nonIncremental() { - attachment_decryptFailOnBadDigest(incremental = false) + fun attachment_decryptFailOnBadEncryptedDigest_nonIncremental() { + attachment_decryptFailOnBadEncryptedDigest(incremental = false) } @Test(expected = InvalidMessageException::class) - fun attachment_decryptFailOnBadDigest_incremental() { - attachment_decryptFailOnBadDigest(incremental = true) + fun attachment_decryptFailOnBadEncryptedDigest_incremental() { + attachment_decryptFailOnBadEncryptedDigest(incremental = true) } - private fun attachment_decryptFailOnBadDigest(incremental: Boolean) { + @Test(expected = InvalidMessageException::class) + fun attachment_decryptFailOnBadPlaintextHash_nonIncremental() { + attachment_decryptFailOnBadPlaintextHash(incremental = false) + } + + @Test(expected = InvalidMessageException::class) + fun attachment_decryptFailOnBadPlaintextHash_incremental() { + attachment_decryptFailOnBadPlaintextHash(incremental = true) + } + + private fun attachment_decryptFailOnBadEncryptedDigest(incremental: Boolean) { var cipherFile: File? = null try { @@ -180,7 +240,8 @@ class AttachmentCipherTest { cipherFile = writeToFile(encryptResult.ciphertext) - val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, badDigest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val integrityCheck = IntegrityCheck.forEncryptedDigest(badDigest) + val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) // In incremental mode, we'll only check the digest after reading the whole thing if (incremental) { @@ -191,6 +252,29 @@ class AttachmentCipherTest { } } + private fun attachment_decryptFailOnBadPlaintextHash(incremental: Boolean) { + var cipherFile: File? = null + + try { + val key = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(MEBIBYTE) + val badPlaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput).apply { + this[0] = (this[0] + 1).toByte() + } + + val encryptResult = encryptData(plaintextInput, key, incremental) + + cipherFile = writeToFile(encryptResult.ciphertext) + + val integrityCheck = IntegrityCheck.forPlaintextHash(badPlaintextHash) + val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + + StreamUtil.readFully(stream) + } finally { + cipherFile?.delete() + } + } + @Test fun attachment_decryptFailOnBadIncrementalDigest() { var cipherFile: File? = null @@ -205,7 +289,8 @@ class AttachmentCipherTest { cipherFile = writeToFile(encryptResult.ciphertext) - val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, badDigest, encryptResult.chunkSizeChoice) + val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest) + val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, badDigest, encryptResult.chunkSizeChoice) val plaintextOutput = readInputStreamFully(decryptedStream) fail(AssertionError("Expected to fail before hitting this line")) @@ -480,7 +565,8 @@ class AttachmentCipherTest { val combinedData = plaintextInput1 + plaintextInput2 val cipherFile = writeToFile(encryptedData) - val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(digest) + val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0) val plaintextOutput = readInputStreamFully(decryptedStream) assertThat(plaintextOutput).isEqualTo(combinedData) @@ -511,7 +597,8 @@ class AttachmentCipherTest { val combinedData = plaintextInput1 + plaintextInput2 val cipherFile = writeToFile(encryptedData) - val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(digest) + val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0) val plaintextOutput = readInputStreamFully(decryptedStream) assertThat(plaintextOutput).isEqualTo(combinedData) @@ -536,7 +623,8 @@ class AttachmentCipherTest { val digest = encryptingOutputStream.transmittedDigest val cipherFile = writeToFile(encryptedData) - val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(digest) + val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0) val plaintextOutput = readInputStreamFully(decryptedStream) assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -567,7 +655,8 @@ class AttachmentCipherTest { val digest = encryptingOutputStream.transmittedDigest val cipherFile = writeToFile(encryptedData) - val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(digest) + val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, integrityCheck, null, 0) val plaintextOutput = readInputStreamFully(decryptedStream) assertThat(plaintextOutput).isEqualTo(expectedData) @@ -596,7 +685,8 @@ class AttachmentCipherTest { val digest = encryptingOutputStream.transmittedDigest val cipherFile = writeToFile(encryptedData) - val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0) + val integrityCheck = IntegrityCheck.forEncryptedDigest(digest) + val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0) val plaintextOutput = readInputStreamFully(decryptedStream) assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -677,4 +767,10 @@ class AttachmentCipherTest { return HKDF.deriveSecrets(shortKey, "Sticker Pack".toByteArray(), 64) } } + + enum class IntegrityCheckMode { + ENCRYPTED_DIGEST, + PLAINTEXT_HASH, + BOTH + } }