diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt index cf87832881..9844a98a81 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt @@ -182,37 +182,6 @@ class AttachmentTableTest { assertThat(highInfo.file.exists()).isEqualTo(true) } - @Test - fun finalizeAttachmentAfterDownload_fixDigestOnNonZeroPadding() { - // Insert attachment metadata for badly-padded attachment - val plaintext = byteArrayOf(1, 2, 3, 4) - val key = Util.getSecretBytes(64) - val iv = Util.getSecretBytes(16) - - val badlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully().also { it[it.size - 1] = 0x42 } - val badlyPaddedCiphertext = encryptPrePaddedBytes(badlyPaddedPlaintext, key, iv) - val badlyPaddedDigest = getDigest(badlyPaddedCiphertext) - - val cipherFile = getTempFile() - cipherFile.writeBytes(badlyPaddedCiphertext) - - val mmsId = -1L - val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first() - - // Give data to attachment table - val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4) - SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream) - - // Verify the digest has been updated to the properly padded one - val properlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully() - val properlyPaddedCiphertext = encryptPrePaddedBytes(properlyPaddedPlaintext, key, iv) - val properlyPaddedDigest = getDigest(properlyPaddedCiphertext) - - val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!! - - assertArrayEquals(properlyPaddedDigest, newDigest) - } - @Test fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() { // Insert attachment metadata for properly-padded attachment @@ -241,14 +210,14 @@ class AttachmentTableTest { @Test fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() { - // Given an attachment with some digest + // Given an attachment with some plaintextHash+remoteKey val blob = BlobProvider.getInstance().forData(byteArrayOf(1, 2, 3, 4, 5)).createForSingleSessionInMemory() val attachment = createAttachment(1, blob, AttachmentTable.TransformProperties.empty()) val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(-1L, listOf(attachment), emptyList()).values.first() SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId)) SignalDatabase.attachments.setArchiveTransferState(attachmentId, AttachmentTable.ArchiveTransferState.FINISHED) - // Reset the transfer state by digest + // Reset the transfer state by plaintextHash+remoteKey val plaintextHash = SignalDatabase.attachments.getAttachment(attachmentId)!!.dataHash!!.decodeBase64OrThrow() val remoteKey = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteKey!!.decodeBase64OrThrow() SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash, remoteKey) diff --git a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/DatabaseAttachmentArchiveUtil.kt b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/DatabaseAttachmentArchiveUtil.kt index 487b0ee4d4..0a3c628dcb 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/backup/v2/DatabaseAttachmentArchiveUtil.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/backup/v2/DatabaseAttachmentArchiveUtil.kt @@ -60,6 +60,10 @@ object DatabaseAttachmentArchiveUtil { } private fun hadIntegrityCheckPerformed(attachment: DatabaseAttachment): Boolean { + if (attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) { + return true + } + return when (attachment.transferState) { AttachmentTable.TRANSFER_PROGRESS_DONE, AttachmentTable.TRANSFER_NEEDS_RESTORE, @@ -92,8 +96,8 @@ fun DatabaseAttachment.createArchiveAttachmentPointer(useArchiveCdn: Boolean): S throw InvalidAttachmentException("empty encrypted key") } - if (remoteDigest == null) { - throw InvalidAttachmentException("no digest") + if (remoteDigest == null && dataHash == null) { + throw InvalidAttachmentException("no integrity check available") } return try { diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt index 5b1ab06187..2ffc4b787e 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/internal/backup/InternalBackupStatsTab.kt @@ -55,10 +55,10 @@ fun InternalBackupStatsTab(stats: InternalBackupPlaygroundViewModel.StatsState, Spacer(modifier = Modifier.size(16.dp)) Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}") - Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentDigestCount}/${stats.attachmentStats.finishedAttachmentDigestCount}") + Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}") Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}") - Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentDigestCount}") + Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") Spacer(modifier = Modifier.size(16.dp)) diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt index 0f3d237e2c..c8f98f107a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt @@ -34,7 +34,6 @@ import org.json.JSONArray import org.json.JSONException import org.signal.core.util.Base64 import org.signal.core.util.SqlUtil -import org.signal.core.util.StreamUtil import org.signal.core.util.ThreadUtil import org.signal.core.util.copyTo import org.signal.core.util.count @@ -60,7 +59,6 @@ import org.signal.core.util.requireNonNullString import org.signal.core.util.requireObject import org.signal.core.util.requireString import org.signal.core.util.select -import org.signal.core.util.stream.NullOutputStream import org.signal.core.util.toInt import org.signal.core.util.update import org.signal.core.util.updateAll @@ -98,7 +96,6 @@ import org.thoughtcrime.securesms.util.StorageUtil import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.video.EncryptedMediaDataSource import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult -import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil import org.whispersystems.signalservice.api.util.UuidUtil import org.whispersystems.signalservice.internal.crypto.PaddingInputStream @@ -397,14 +394,14 @@ class AttachmentTable( } /** - * Returns a cursor (with just the digest+archive_cdn) for all attachments that are eligible for archive upload. - * In practice, this means that the attachments have a digest and have not hit a permanent archive upload failure. + * Returns a cursor (with just the plaintextHash+remoteKey+archive_cdn) for all attachments that are eligible for archive upload. + * In practice, this means that the attachments have a plaintextHash and have not hit a permanent archive upload failure. */ fun getAttachmentsEligibleForArchiveUpload(): Cursor { return readableDatabase - .select(REMOTE_DIGEST, ARCHIVE_CDN) + .select(DATA_HASH_END, REMOTE_KEY, ARCHIVE_CDN) .from(TABLE_NAME) - .where("$REMOTE_DIGEST IS NOT NULL AND $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value}") + .where("$DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value}") .run() } @@ -530,7 +527,7 @@ class AttachmentTable( fun getRestorableOptimizedAttachments(): List { return readableDatabase - .select(ID, MESSAGE_ID, DATA_SIZE, REMOTE_DIGEST, REMOTE_KEY) + .select(ID, MESSAGE_ID, DATA_SIZE, DATA_HASH_END, REMOTE_KEY) .from(TABLE_NAME) .where("$TRANSFER_STATE = ? AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL", TRANSFER_RESTORE_OFFLOADED) .orderBy("$ID DESC") @@ -689,7 +686,7 @@ class AttachmentTable( ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value, ARCHIVE_CDN to null ) - .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey) + .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey)) .run() } @@ -768,11 +765,12 @@ class AttachmentTable( """ SELECT SUM($DATA_SIZE) FROM ( - SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE + SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND - $REMOTE_DIGEST NOT NULL AND + $DATA_HASH_END NOT NULL AND + $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE NOT IN (${ArchiveTransferState.FINISHED.value}, ${ArchiveTransferState.PERMANENT_FAILURE.value}) ) """.trimIndent() @@ -1257,7 +1255,7 @@ class AttachmentTable( } @Throws(IOException::class) - fun finalizeAttachmentThumbnailAfterDownload(attachmentId: AttachmentId, digest: ByteArray, inputStream: InputStream, transferFile: File) { + fun finalizeAttachmentThumbnailAfterDownload(attachmentId: AttachmentId, plaintextHash: String?, remoteKey: String?, inputStream: InputStream, transferFile: File) { Log.i(TAG, "[finalizeAttachmentThumbnailAfterDownload] Finalizing downloaded data for $attachmentId.") val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty()) @@ -1268,10 +1266,14 @@ class AttachmentTable( THUMBNAIL_RESTORE_STATE to ThumbnailRestoreState.FINISHED.value ) - db.update(TABLE_NAME) - .values(values) - .where("$REMOTE_DIGEST = ?", digest) - .run() + if (plaintextHash != null && remoteKey != null) { + db.update(TABLE_NAME) + .values(values) + .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey) + .run() + } else { + Log.w(TAG, "[finalizeAttachmentThumbnailAfterDownload] No plaintext hash or remote key provided for $attachmentId. Cannot update other possible thumbnails.") + } } notifyConversationListListeners() @@ -1287,7 +1289,8 @@ class AttachmentTable( */ fun finalizeAttachmentThumbnailAfterUpload( attachmentId: AttachmentId, - attachmentDigest: ByteArray, + attachmentPlaintextHash: String?, + attachmentRemoteKey: String?, data: ByteArray ) { Log.i(TAG, "[finalizeAttachmentThumbnailAfterUpload] Finalizing archive data for $attachmentId thumbnail.") @@ -1300,10 +1303,14 @@ class AttachmentTable( THUMBNAIL_RESTORE_STATE to ThumbnailRestoreState.FINISHED.value ) - db.update(TABLE_NAME) - .values(values) - .where("$ID = ? OR $REMOTE_DIGEST = ?", attachmentId, attachmentDigest) - .run() + if (attachmentPlaintextHash != null && attachmentRemoteKey != null) { + db.update(TABLE_NAME) + .values(values) + .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", attachmentPlaintextHash, attachmentRemoteKey) + .run() + } else { + Log.w(TAG, "[finalizeAttachmentThumbnailAfterUpload] No plaintext hash or remote key provided for $attachmentId. Cannot update other possible thumbnails.") + } } } @@ -1494,7 +1501,7 @@ class AttachmentTable( } /** - * As part of the digest backfill process, this updates the (key, IV, digest) tuple for all attachments that share a data file (and are done downloading). + * As part of the digest backfill process, this updates the (key, digest) tuple for all attachments that share a data file (and are done downloading). */ fun updateRemoteKeyAndDigestByDataFile(dataFile: String, key: ByteArray, digest: ByteArray) { writableDatabase @@ -1551,74 +1558,6 @@ class AttachmentTable( return insertedAttachments } - fun debugCopyAttachmentForArchiveRestore( - mmsId: Long, - attachment: DatabaseAttachment, - forThumbnail: Boolean - ) { - val copy = - """ - INSERT INTO $TABLE_NAME - ( - $MESSAGE_ID, - $CONTENT_TYPE, - $TRANSFER_STATE, - $CDN_NUMBER, - $REMOTE_LOCATION, - $REMOTE_DIGEST, - $REMOTE_INCREMENTAL_DIGEST, - $REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE, - $REMOTE_KEY, - $FILE_NAME, - $DATA_SIZE, - $VOICE_NOTE, - $BORDERLESS, - $VIDEO_GIF, - $WIDTH, - $HEIGHT, - $CAPTION, - $UPLOAD_TIMESTAMP, - $BLUR_HASH, - $DATA_SIZE, - $DATA_RANDOM, - $DATA_HASH_START, - $DATA_HASH_END, - $ARCHIVE_CDN, - $THUMBNAIL_RESTORE_STATE - ) - SELECT - $mmsId, - $CONTENT_TYPE, - $TRANSFER_NEEDS_RESTORE, - $CDN_NUMBER, - $REMOTE_LOCATION, - $REMOTE_DIGEST, - $REMOTE_INCREMENTAL_DIGEST, - $REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE, - $REMOTE_KEY, - $FILE_NAME, - $DATA_SIZE, - $VOICE_NOTE, - $BORDERLESS, - $VIDEO_GIF, - $WIDTH, - $HEIGHT, - $CAPTION, - ${System.currentTimeMillis()}, - $BLUR_HASH, - $DATA_SIZE, - $DATA_RANDOM, - $DATA_HASH_START, - $DATA_HASH_END, - ${attachment.archiveCdn}, - ${if (forThumbnail) ThumbnailRestoreState.NEEDS_RESTORE.value else ThumbnailRestoreState.NONE.value} - FROM $TABLE_NAME - WHERE $ID = ${attachment.attachmentId.id} - """ - - writableDatabase.execSQL(copy) - } - /** * Updates the data stored for an existing attachment. This happens after transformations, like transcoding. */ @@ -1921,56 +1860,47 @@ class AttachmentTable( } /** - * Sets the archive data for the specific attachment, as well as for any attachments that use the same underlying file. + * Sets the archive data for the specific attachment, as well as for any attachments that have the same mediaName (plaintextHash + remoteKey). */ fun setArchiveCdn(attachmentId: AttachmentId, archiveCdn: Int) { writableDatabase.withinTransaction { db -> - val dataFile = db - .select(DATA_FILE) + val plaintextHashAndRemoteKey = db + .select(DATA_HASH_END, REMOTE_KEY) .from(TABLE_NAME) .where("$ID = ?", attachmentId.id) .run() - .readToSingleObject { it.requireString(DATA_FILE) } + .readToSingleObject { + it.requireNonNullString(DATA_HASH_END) to it.requireNonNullString(REMOTE_KEY) + } - if (dataFile == null) { + if (plaintextHashAndRemoteKey == null) { Log.w(TAG, "No data file found for attachment $attachmentId. Can't set archive data.") return@withinTransaction } + val (plaintextHash, remoteKey) = plaintextHashAndRemoteKey + db.update(TABLE_NAME) .values( ARCHIVE_CDN to archiveCdn, ARCHIVE_TRANSFER_STATE to ArchiveTransferState.FINISHED.value ) - .where("$DATA_FILE = ?", dataFile) + .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey) .run() } } /** - * Updates all attachments that share the same digest with the given archive CDN. + * Updates all attachments that share the same mediaName (plaintextHash + remoteKey) with the given archive CDN. */ fun setArchiveCdnByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray, archiveCdn: Int) { writableDatabase .update(TABLE_NAME) .values(ARCHIVE_CDN to archiveCdn) - .where("$DATA_HASH_END= ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey) + .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey)) .run() } - fun clearArchiveData(attachmentIds: List) { - SqlUtil.buildCollectionQuery(ID, attachmentIds.map { it.id }) - .forEach { query -> - writableDatabase - .update(TABLE_NAME) - .values( - ARCHIVE_CDN to null - ) - .where(query.where, query.whereArgs) - .run() - } - } - fun clearAllArchiveData() { writableDatabase .updateAll(TABLE_NAME) @@ -1981,18 +1911,6 @@ class AttachmentTable( .run() } - private fun calculateDigest(fileInfo: DataFileWriteResult, key: ByteArray, iv: ByteArray = Util.getSecretBytes(16)): ByteArray { - return calculateDigest(file = fileInfo.file, random = fileInfo.random, length = fileInfo.length, key = key, iv = iv) - } - - private fun calculateDigest(file: File, random: ByteArray, length: Long, key: ByteArray, iv: ByteArray): ByteArray { - val stream = PaddingInputStream(getDataStream(file, random, 0), length) - val cipherOutputStream = AttachmentCipherOutputStream(key, iv, NullOutputStream) - - StreamUtil.copy(stream, cipherOutputStream) - return cipherOutputStream.transmittedDigest - } - /** * Deletes the data file if there's no strong references to other attachments. * If deleted, it will also clear all weak references (i.e. quotes) of the attachment. @@ -2476,16 +2394,21 @@ class AttachmentTable( fun getEstimatedArchiveMediaSize(): Long { val estimatedThumbnailCount = readableDatabase - .select("COUNT(DISTINCT $REMOTE_DIGEST)") - .from(TABLE_NAME) - .where( + .select("COUNT(*)") + .from( + """ + ( + SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY + FROM $TABLE_NAME + WHERE + $DATA_FILE NOT NULL AND + $DATA_HASH_END NOT NULL AND + $REMOTE_KEY NOT NULL AND + $TRANSFER_STATE = $TRANSFER_PROGRESS_DONE AND + $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value} AND + ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%') + ) """ - $DATA_FILE NOT NULL AND - $REMOTE_DIGEST NOT NULL AND - $TRANSFER_STATE = $TRANSFER_PROGRESS_DONE AND - $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value} AND - ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%') - """ ) .run() .readToSingleLong(0L) @@ -2495,11 +2418,12 @@ class AttachmentTable( """ SELECT $DATA_SIZE FROM ( - SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE + SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND - $REMOTE_DIGEST NOT NULL AND + $DATA_HASH_END NOT NULL AND + $REMOTE_KEY NOT NULL AND $TRANSFER_STATE = $TRANSFER_PROGRESS_DONE AND $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value} ) @@ -2657,23 +2581,23 @@ class AttachmentTable( ) val transferStateCounts = transferStates - .map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $REMOTE_DIGEST NOT NULL").run().readToSingleLong(-1L) } + .map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) } .toMap() val validForArchiveTransferStateCounts = transferStates - .map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $REMOTE_DIGEST NOT NULL AND $DATA_FILE NOT NULL").run().readToSingleLong(-1L) } + .map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $DATA_FILE NOT NULL").run().readToSingleLong(-1L) } .toMap() val archiveStateCounts = ArchiveTransferState .entries - .associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $REMOTE_DIGEST NOT NULL").run().readToSingleLong(-1L) } + .associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) } - val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $REMOTE_DIGEST NOT NULL").readToSingleLong(-1L) - val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $REMOTE_DIGEST NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) - val attachmentDigestCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $REMOTE_DIGEST NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE)").readToSingleLong(-1L) - val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $REMOTE_DIGEST NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) + val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L) + val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) + val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L) + val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L) val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L) - val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $REMOTE_DIGEST NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%')").readToSingleLong(-1L) + val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L) val pendingUploadBytes = getPendingArchiveUploadBytes() val uploadedAttachmentBytes = readableDatabase @@ -2681,11 +2605,12 @@ class AttachmentTable( """ SELECT $DATA_SIZE FROM ( - SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE + SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND - $REMOTE_DIGEST NOT NULL AND + $DATA_HASH_END NOT NULL AND + $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} ) """.trimIndent() @@ -2702,8 +2627,8 @@ class AttachmentTable( archiveStateCounts = archiveStateCounts, attachmentFileCount = attachmentFileCount, finishedAttachmentFileCount = finishedAttachmentFileCount, - attachmentDigestCount = attachmentDigestCount, - finishedAttachmentDigestCount = finishedAttachmentDigestCount, + attachmentPlaintextHashAndKeyCount = attachmentPlaintextHashAndKeyCount, + finishedAttachmentPlaintextHashAndKeyCount = finishedAttachmentDigestCount, thumbnailFileCount = thumbnailFileCount, estimatedThumbnailCount = estimatedThumbnailCount, pendingUploadBytes = pendingUploadBytes, @@ -2953,8 +2878,8 @@ class AttachmentTable( val archiveStateCounts: Map = emptyMap(), val attachmentFileCount: Long = 0L, val finishedAttachmentFileCount: Long = 0L, - val attachmentDigestCount: Long = 0L, - val finishedAttachmentDigestCount: Long, + val attachmentPlaintextHashAndKeyCount: Long = 0L, + val finishedAttachmentPlaintextHashAndKeyCount: Long, val thumbnailFileCount: Long = 0L, val pendingUploadBytes: Long = 0L, val uploadedAttachmentBytes: Long = 0L, diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt index 398e26e057..0a67cf538b 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/BackupMediaSnapshotTable.kt @@ -244,7 +244,7 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat return readableDatabase.rawQuery( """ WITH input_pairs($MEDIA_ID, $CDN) AS (VALUES $inputValues) - SELECT a.$PLAINTEXT_HASH, a.$REMOTE_KEY b.$CDN + SELECT a.$PLAINTEXT_HASH, a.$REMOTE_KEY, b.$CDN FROM $TABLE_NAME a JOIN input_pairs b ON a.$MEDIA_ID = b.$MEDIA_ID WHERE a.$CDN != b.$CDN AND a.$IS_THUMBNAIL = 0 AND $SNAPSHOT_VERSION = $MAX_VERSION diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt index 2aadf03c9d..64205cd55a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/ArchiveThumbnailUploadJob.kt @@ -148,7 +148,8 @@ class ArchiveThumbnailUploadJob private constructor( // save attachment thumbnail SignalDatabase.attachments.finalizeAttachmentThumbnailAfterUpload( attachmentId = attachmentId, - attachmentDigest = attachment.remoteDigest, + attachmentPlaintextHash = attachment.dataHash, + attachmentRemoteKey = attachment.remoteKey, data = thumbnailResult.data ) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt index de4a84a565..331ff6ce15 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt @@ -94,7 +94,7 @@ class AttachmentDownloadJob private constructor( AttachmentTable.TRANSFER_PROGRESS_PENDING, AttachmentTable.TRANSFER_PROGRESS_FAILED -> { - if (SignalStore.backup.backsUpMedia && databaseAttachment.remoteLocation == null) { + if (SignalStore.backup.backsUpMedia && (databaseAttachment.remoteLocation == null || databaseAttachment.remoteDigest == null)) { if (databaseAttachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) { Log.i(TAG, "Trying to restore attachment from archive cdn") RestoreAttachmentJob.restoreAttachment(databaseAttachment) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMessagesJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMessagesJob.kt index f20da6c89d..4675ecf315 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMessagesJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/BackupMessagesJob.kt @@ -131,8 +131,8 @@ class BackupMessagesJob private constructor( val stopwatch = Stopwatch("BackupMessagesJob") - SignalDatabase.attachments.createRemoteKeyForAttachmentsThatNeedArchiveUpload().takeIf { it > 0 }?.let { count -> Log.w(TAG, "Needed to create $count key/iv/digests.") } - stopwatch.split("key-iv-digest") + SignalDatabase.attachments.createRemoteKeyForAttachmentsThatNeedArchiveUpload().takeIf { it > 0 }?.let { count -> Log.w(TAG, "Needed to create $count remote keys.") } + stopwatch.split("keygen") if (isCanceled) { return Result.failure() diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt index 770450d711..2286241ee1 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt @@ -10,6 +10,7 @@ import android.content.Intent import androidx.core.app.NotificationCompat import androidx.core.app.NotificationManagerCompat import org.greenrobot.eventbus.EventBus +import org.signal.core.util.Base64.decodeBase64OrThrow import org.signal.core.util.PendingIntentFlags import org.signal.core.util.logging.Log import org.signal.libsignal.protocol.InvalidMacException @@ -182,9 +183,10 @@ class RestoreAttachmentJob private constructor( if (attachment.transferState != AttachmentTable.TRANSFER_NEEDS_RESTORE && attachment.transferState != AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && - (attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED) + attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED && + attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED ) { - Log.w(TAG, "Attachment does not need to be restored.") + Log.w(TAG, "Attachment does not need to be restored. Current state: ${attachment.transferState}") return } @@ -263,6 +265,7 @@ class RestoreAttachmentJob private constructor( messageReceiver .retrieveArchivedAttachment( SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireMediaName()), + attachment.dataHash!!.decodeBase64OrThrow(), cdnCredentials, archiveFile, pointer, diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt index c8185ef4cd..42e377aada 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt @@ -113,8 +113,8 @@ class RestoreAttachmentThumbnailJob private constructor( return } - if (attachment.remoteDigest == null) { - Log.w(TAG, "$attachmentId has no digest! Cannot proceed.") + if (attachment.dataHash == null) { + Log.w(TAG, "$attachmentId has no plaintext hash! Cannot proceed.") return } @@ -142,7 +142,7 @@ class RestoreAttachmentThumbnailJob private constructor( progressListener ) - SignalDatabase.attachments.finalizeAttachmentThumbnailAfterDownload(attachmentId, attachment.remoteDigest, decryptingStream, thumbnailTransferFile) + SignalDatabase.attachments.finalizeAttachmentThumbnailAfterDownload(attachmentId, attachment.dataHash, attachment.remoteKey, decryptingStream, thumbnailTransferFile) if (!SignalDatabase.messages.isStory(messageId)) { AppDependencies.messageNotifier.updateNotification(context) diff --git a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java b/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java deleted file mode 100644 index 39719f9e14..0000000000 --- a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java +++ /dev/null @@ -1,77 +0,0 @@ -package org.thoughtcrime.securesms.mms; - -import androidx.annotation.NonNull; - -import com.bumptech.glide.Priority; -import com.bumptech.glide.load.DataSource; -import com.bumptech.glide.load.data.DataFetcher; - -import org.signal.core.util.logging.Log; -import org.signal.libsignal.protocol.InvalidMessageException; -import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.Optional; - -class AttachmentStreamLocalUriFetcher implements DataFetcher { - - private static final String TAG = Log.tag(AttachmentStreamLocalUriFetcher.class); - - private final File attachment; - private final byte[] key; - private final Optional digest; - private final long plaintextLength; - - private InputStream is; - - AttachmentStreamLocalUriFetcher(File attachment, long plaintextLength, byte[] key, Optional digest) { - this.attachment = attachment; - this.plaintextLength = plaintextLength; - this.digest = digest; - this.key = key; - } - - @Override - public void loadData(@NonNull Priority priority, @NonNull DataCallback callback) { - try { - if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!"); - is = AttachmentCipherInputStream.createForAttachment(attachment, - plaintextLength, - key, - digest.get(), - null, - 0); - callback.onDataReady(is); - } catch (IOException | InvalidMessageException e) { - callback.onLoadFailed(e); - } - } - - @Override - public void cleanup() { - try { - if (is != null) is.close(); - is = null; - } catch (IOException ioe) { - Log.w(TAG, "ioe"); - } - } - - @Override - public void cancel() {} - - @Override - public @NonNull Class getDataClass() { - return InputStream.class; - } - - @Override - public @NonNull DataSource getDataSource() { - return DataSource.LOCAL; - } - - -} diff --git a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamUriLoader.java b/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamUriLoader.java deleted file mode 100644 index cdb902f509..0000000000 --- a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamUriLoader.java +++ /dev/null @@ -1,89 +0,0 @@ -package org.thoughtcrime.securesms.mms; - -import androidx.annotation.NonNull; -import androidx.annotation.Nullable; - -import com.bumptech.glide.load.Key; -import com.bumptech.glide.load.Options; -import com.bumptech.glide.load.model.ModelLoader; -import com.bumptech.glide.load.model.ModelLoaderFactory; -import com.bumptech.glide.load.model.MultiModelLoaderFactory; - -import org.thoughtcrime.securesms.mms.AttachmentStreamUriLoader.AttachmentModel; - -import java.io.File; -import java.io.InputStream; -import java.security.MessageDigest; -import java.util.Optional; - -public class AttachmentStreamUriLoader implements ModelLoader { - - @Override - public @Nullable LoadData buildLoadData(@NonNull AttachmentModel attachmentModel, int width, int height, @NonNull Options options) { - return new LoadData<>(attachmentModel, new AttachmentStreamLocalUriFetcher(attachmentModel.attachment, attachmentModel.plaintextLength, attachmentModel.key, attachmentModel.digest)); - } - - @Override - public boolean handles(@NonNull AttachmentModel attachmentModel) { - return true; - } - - static class Factory implements ModelLoaderFactory { - - @Override - public @NonNull ModelLoader build(@NonNull MultiModelLoaderFactory multiFactory) { - return new AttachmentStreamUriLoader(); - } - - @Override - public void teardown() { - // Do nothing. - } - } - - public static class AttachmentModel implements Key { - public @NonNull File attachment; - public @NonNull byte[] key; - public @NonNull Optional digest; - public @NonNull Optional incrementalDigest; - public int incrementalMacChunkSize; - public long plaintextLength; - - public AttachmentModel(@NonNull File attachment, - @NonNull byte[] key, - long plaintextLength, - @NonNull Optional digest, - @NonNull Optional incrementalDigest, - int incrementalMacChunkSize) - { - this.attachment = attachment; - this.key = key; - this.digest = digest; - this.incrementalDigest = incrementalDigest; - this.incrementalMacChunkSize = incrementalMacChunkSize; - this.plaintextLength = plaintextLength; - } - - @Override - public void updateDiskCacheKey(@NonNull MessageDigest messageDigest) { - messageDigest.update(attachment.toString().getBytes()); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - AttachmentModel that = (AttachmentModel)o; - - return attachment.equals(that.attachment); - - } - - @Override - public int hashCode() { - return attachment.hashCode(); - } - } -} - diff --git a/app/src/main/java/org/thoughtcrime/securesms/mms/SignalGlideComponents.java b/app/src/main/java/org/thoughtcrime/securesms/mms/SignalGlideComponents.java index 715acba78a..c8383df567 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/mms/SignalGlideComponents.java +++ b/app/src/main/java/org/thoughtcrime/securesms/mms/SignalGlideComponents.java @@ -41,7 +41,6 @@ import org.thoughtcrime.securesms.glide.cache.EncryptedCacheDecoder; import org.thoughtcrime.securesms.glide.cache.EncryptedCacheEncoder; import org.thoughtcrime.securesms.glide.cache.EncryptedGifDrawableResourceEncoder; import org.thoughtcrime.securesms.glide.cache.WebpSanDecoder; -import org.thoughtcrime.securesms.mms.AttachmentStreamUriLoader.AttachmentModel; import org.thoughtcrime.securesms.mms.DecryptableStreamUriLoader.DecryptableUri; import org.thoughtcrime.securesms.stickers.StickerRemoteUri; import org.thoughtcrime.securesms.stickers.StickerRemoteUriLoader; @@ -96,7 +95,6 @@ public class SignalGlideComponents implements RegisterGlideComponents { registry.append(ConversationShortcutPhoto.class, Bitmap.class, new ConversationShortcutPhoto.Loader.Factory(context)); registry.append(ContactPhoto.class, InputStream.class, new ContactPhotoLoader.Factory(context)); registry.append(DecryptableUri.class, InputStream.class, new DecryptableStreamUriLoader.Factory(context)); - registry.append(AttachmentModel.class, InputStream.class, new AttachmentStreamUriLoader.Factory()); registry.append(ChunkedImageUrl.class, InputStream.class, new ChunkedImageUrlLoader.Factory()); registry.append(StickerRemoteUri.class, InputStream.class, new StickerRemoteUriLoader.Factory()); registry.append(BlurHash.class, BlurHash.class, new BlurHashModelLoader.Factory()); diff --git a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java index 4011e50dc7..2c4d83b191 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java +++ b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java @@ -86,7 +86,11 @@ class PartDataSource implements DataSource { throw new InvalidMessageException("Missing digest!"); } - this.inputStream = AttachmentCipherInputStream.createForArchivedMedia(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); + if (attachment.dataHash == null || attachment.dataHash.isEmpty()) { + throw new InvalidMessageException("Missing plaintextHash!"); + } + + this.inputStream = AttachmentCipherInputStream.createForArchivedMedia(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, Base64.decodeOrThrow(attachment.dataHash), attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); } catch (InvalidMessageException e) { throw new IOException("Error decrypting attachment stream!", e); } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 7fa3a2d66f..4bdcdd59c2 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -122,6 +122,7 @@ public class SignalServiceMessageReceiver { * Retrieves an archived media attachment. * * @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media. + * @param plaintextHash The plaintext hash of the attachment, used to verify the integrity of the downloaded content. * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. @@ -131,6 +132,7 @@ public class SignalServiceMessageReceiver { * @return An InputStream that streams the plaintext attachment contents. */ public InputStream retrieveArchivedAttachment(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, + @Nonnull byte[] plaintextHash, @Nonnull Map readCredentialHeaders, @Nonnull File archiveDestination, @Nonnull SignalServiceAttachmentPointer pointer, @@ -139,10 +141,6 @@ public class SignalServiceMessageReceiver { @Nullable ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException { - if (pointer.getDigest().isEmpty()) { - throw new InvalidMessageException("No attachment digest!"); - } - if (pointer.getKey() == null) { throw new InvalidMessageException("No key!"); } @@ -160,7 +158,7 @@ public class SignalServiceMessageReceiver { originalCipherLength, pointer.getSize().orElse(0), pointer.getKey(), - pointer.getDigest().get(), + plaintextHash, null, 0 ); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt deleted file mode 100644 index 850ea7c7f1..0000000000 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.signalservice.api.attachment - -import org.signal.core.util.stream.LimitedInputStream - -/** - * Holds the result of an attachment download. - */ -class AttachmentDownloadResult( - val dataStream: LimitedInputStream, - val iv: ByteArray -) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt index e97ea7c1ec..70ad95ea71 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt @@ -54,13 +54,14 @@ object AttachmentCipherInputStream { digest: ByteArray, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int - ): LimitedInputStream { + ): InputStream { return create( streamSupplier = { FileInputStream(file) }, streamLength = file.length(), plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - digest = digest, + encryptedDigest = digest, + plaintextHash = null, incrementalDigest = incrementalDigest, incrementalMacChunkSize = incrementalMacChunkSize, ignoreDigest = false @@ -83,13 +84,14 @@ object AttachmentCipherInputStream { digest: ByteArray, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int - ): LimitedInputStream { + ): InputStream { return create( streamSupplier = streamSupplier, streamLength = streamLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - digest = digest, + encryptedDigest = digest, + plaintextHash = null, incrementalDigest = incrementalDigest, incrementalMacChunkSize = incrementalMacChunkSize, ignoreDigest = false @@ -112,10 +114,10 @@ object AttachmentCipherInputStream { originalCipherTextLength: Long, plaintextLength: Long, combinedKeyMaterial: ByteArray, - digest: ByteArray, + plaintextHash: ByteArray, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int - ): LimitedInputStream { + ): InputStream { val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) val mac = initMac(keyMaterial.macKey) @@ -128,10 +130,11 @@ object AttachmentCipherInputStream { streamLength = originalCipherTextLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - digest = digest, + encryptedDigest = null, + plaintextHash = plaintextHash, incrementalDigest = incrementalDigest, incrementalMacChunkSize = incrementalMacChunkSize, - ignoreDigest = false + ignoreDigest = true ) } @@ -151,7 +154,7 @@ object AttachmentCipherInputStream { originalCipherTextLength: Long, plaintextLength: Long, combinedKeyMaterial: ByteArray - ): LimitedInputStream { + ): InputStream { val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) val mac = initMac(keyMaterial.macKey) @@ -164,7 +167,8 @@ object AttachmentCipherInputStream { streamLength = originalCipherTextLength, plaintextLength = plaintextLength, combinedKeyMaterial = combinedKeyMaterial, - digest = null, + encryptedDigest = null, + plaintextHash = null, incrementalDigest = null, incrementalMacChunkSize = 0, ignoreDigest = true @@ -229,11 +233,12 @@ object AttachmentCipherInputStream { streamLength: Long, plaintextLength: Long, combinedKeyMaterial: ByteArray, - digest: ByteArray?, + encryptedDigest: ByteArray?, + plaintextHash: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean - ): LimitedInputStream { + ): InputStream { val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) val mac = initMac(keyMaterial.macKey) @@ -241,7 +246,7 @@ object AttachmentCipherInputStream { throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") } - if (!ignoreDigest && digest == null) { + if (!ignoreDigest && encryptedDigest == null) { throw InvalidMessageException("Missing digest!") } @@ -250,20 +255,25 @@ object AttachmentCipherInputStream { if (!hasIncrementalMac) { streamSupplier.openStream().use { macVerificationStream -> - verifyMac(macVerificationStream, streamLength, mac, digest) + verifyMac(macVerificationStream, streamLength, mac, encryptedDigest) } wrappedStream = streamSupplier.openStream() } else { - if (digest == null) { - throw InvalidMessageException("Missing digest for incremental mac validation!") + if (encryptedDigest == null && plaintextHash == null) { + throw InvalidMessageException("Missing data (digest or plaintextHas) for incremental mac validation!") + } + + val digestValidatingStream = if (encryptedDigest != null) { + DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), encryptedDigest) + } else { + streamSupplier.openStream() } wrappedStream = IncrementalMacInputStream( IncrementalMacAdditionalValidationsInputStream( - wrapped = streamSupplier.openStream(), + wrapped = digestValidatingStream, fileLength = streamLength, - mac = mac, - theirDigest = digest + mac = mac ), keyMaterial.macKey, ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), @@ -274,8 +284,17 @@ object AttachmentCipherInputStream { val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength) val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey) val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher) + val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength) - return LimitedInputStream(decryptingStream, plaintextLength) + return if (plaintextHash != null) { + if (plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) { + throw InvalidMessageException("Invalid plaintext hash size: ${plaintextHash.size}") + } + + DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), plaintextHash) + } else { + paddinglessDecryptingStream + } } private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher { @@ -286,6 +305,14 @@ object AttachmentCipherInputStream { } } + private fun sha256Digest(): MessageDigest { + try { + return MessageDigest.getInstance("SHA-256") + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) + } + } + private fun initMac(key: ByteArray): Mac { try { val mac = Mac.getInstance("HmacSHA256") diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStream.kt new file mode 100644 index 0000000000..9c85deaa0d --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStream.kt @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.crypto + +import org.signal.libsignal.protocol.InvalidMessageException +import java.io.FilterInputStream +import java.io.IOException +import java.io.InputStream +import java.security.MessageDigest + +/** + * An InputStream that enforces hash validation by calculating a digest as data is read + * and verifying it against an expected hash when the stream is fully consumed. + * + * Important: The validation only occurs if you read the entire stream. + * + * @param inputStream The underlying InputStream to read from + * @param digest The MessageDigest instance to use for hash calculation + * @param expectedHash The expected hash value to validate against + */ +class DigestValidatingInputStream( + inputStream: InputStream, + private val digest: MessageDigest, + private val expectedHash: ByteArray +) : FilterInputStream(inputStream) { + + var validationAttempted = false + private set + + @Throws(IOException::class) + override fun read(): Int { + val byte = super.read() + if (byte != -1) { + digest.update(byte.toByte()) + } else { + validateDigest() + } + return byte + } + + @Throws(IOException::class) + override fun read(buffer: ByteArray): Int { + return read(buffer, 0, buffer.size) + } + + @Throws(IOException::class) + override fun read(buffer: ByteArray, offset: Int, length: Int): Int { + val bytesRead = super.read(buffer, offset, length) + if (bytesRead > 0) { + digest.update(buffer, offset, bytesRead) + } else if (bytesRead == -1) { + validateDigest() + } + return bytesRead + } + + /** + * Validates the calculated digest against the expected hash. + * Throws InvalidCiphertextException if they don't match. + */ + @Throws(InvalidMessageException::class) + private fun validateDigest() { + if (validationAttempted) { + return + } + validationAttempted = true + + val calculatedHash = digest.digest() + if (!MessageDigest.isEqual(calculatedHash, expectedHash)) { + throw InvalidMessageException("Calculated digest does not match expected hash!") + } + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt index f24dd6cfe6..709e45071e 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt @@ -22,11 +22,9 @@ import kotlin.math.max class IncrementalMacAdditionalValidationsInputStream( wrapped: InputStream, fileLength: Long, - private val mac: Mac, - private val theirDigest: ByteArray + private val mac: Mac ) : FilterInputStream(wrapped) { - private val digest: MessageDigest = MessageDigest.getInstance("SHA256") private val macLength: Int = mac.macLength private val macBuffer: ByteArray = ByteArray(macLength) @@ -77,8 +75,6 @@ class IncrementalMacAdditionalValidationsInputStream( mac.update(buffer, offset, bytesRead) } - digest.update(buffer, offset, bytesRead) - if (bytesRemaining == 0) { validate() } @@ -113,10 +109,5 @@ class IncrementalMacAdditionalValidationsInputStream( if (!MessageDigest.isEqual(ourMac, theirMac)) { throw InvalidMessageException("MAC doesn't match!") } - - val ourDigest = digest.digest() - if (!MessageDigest.isEqual(ourDigest, theirDigest)) { - throw InvalidMessageException("Digest doesn't match!") - } } } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt index cc47e7bed6..2a890c95ba 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt @@ -8,9 +8,7 @@ import org.conscrypt.Conscrypt import org.junit.Assert import org.junit.Test import org.signal.core.util.StreamUtil -import org.signal.core.util.allMatch import org.signal.core.util.readFully -import org.signal.core.util.stream.LimitedInputStream import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.InvalidMacException @@ -27,6 +25,8 @@ import java.io.FileOutputStream import java.io.InputStream import java.io.OutputStream import java.lang.AssertionError +import java.security.DigestInputStream +import java.security.MessageDigest import java.security.Security import java.util.Random @@ -63,11 +63,10 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val plaintextOutput = inputStream.readFully(autoClose = false) assertThat(plaintextOutput).isEqualTo(plaintextInput) - assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @@ -89,11 +88,10 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val plaintextOutput = inputStream.readFully(autoClose = false) Assert.assertArrayEquals(plaintextInput, plaintextOutput) - assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @@ -234,20 +232,19 @@ class AttachmentCipherTest { val cipherFile = writeToFile(outerEncryptResult.ciphertext) val keyMaterial = createMediaKeyMaterial(outerKey) - val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMedia( + val decryptedStream = AttachmentCipherInputStream.createForArchivedMedia( archivedMediaKeyMaterial = keyMaterial, file = cipherFile, originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), plaintextLength = plaintextInput.size.toLong(), combinedKeyMaterial = innerKey, - digest = innerEncryptResult.digest, + plaintextHash = innerEncryptResult.plaintextHash, incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) val plaintextOutput = decryptedStream.readFully(autoClose = false) assertThat(plaintextOutput).isEqualTo(plaintextInput) - assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @@ -285,7 +282,7 @@ class AttachmentCipherTest { originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), plaintextLength = plaintextInput.size.toLong(), combinedKeyMaterial = innerKey, - digest = innerEncryptResult.digest, + plaintextHash = innerEncryptResult.digest, incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) @@ -299,17 +296,17 @@ class AttachmentCipherTest { } @Test - fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() { + fun archive_encryptDecrypt_nonIncremental() { archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) } @Test - fun archiveInnerAndOuter_encryptDecrypt_incremental() { + fun archive_encryptDecrypt_incremental() { archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) } @Test - fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() { + fun archive_encryptDecrypt_nonIncremental_manyFileSizes() { for (i in 0..99) { archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) } @@ -334,20 +331,19 @@ class AttachmentCipherTest { val cipherFile = writeToFile(outerEncryptResult.ciphertext) val keyMaterial = createMediaKeyMaterial(outerKey) - val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMedia( + val decryptedStream = AttachmentCipherInputStream.createForArchivedMedia( archivedMediaKeyMaterial = keyMaterial, file = cipherFile, originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), plaintextLength = plaintextInput.size.toLong(), combinedKeyMaterial = innerKey, - digest = innerEncryptResult.digest, + plaintextHash = innerEncryptResult.plaintextHash, incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) val plaintextOutput = decryptedStream.readFully(autoClose = false) assertThat(plaintextOutput).isEqualTo(plaintextInput) - assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @@ -380,7 +376,7 @@ class AttachmentCipherTest { originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), plaintextLength = plaintextInput.size.toLong(), combinedKeyMaterial = innerKey, - digest = innerEncryptResult.digest, + plaintextHash = innerEncryptResult.digest, incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) @@ -607,7 +603,14 @@ class AttachmentCipherTest { cipherFile.delete() } - private class EncryptResult(val ciphertext: ByteArray, val digest: ByteArray, val incrementalDigest: ByteArray, val chunkSizeChoice: Int) + private class EncryptResult( + val ciphertext: ByteArray, + val digest: ByteArray, + val incrementalDigest: ByteArray, + val chunkSizeChoice: Int, + val plaintextHash: ByteArray + ) + companion object { init { // https://github.com/google/conscrypt/issues/1034 @@ -619,12 +622,16 @@ class AttachmentCipherTest { private const val MEBIBYTE = 1024 * 1024 private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean, padded: Boolean = true): EncryptResult { + val digestingStream = DigestInputStream(ByteArrayInputStream(data), MessageDigest.getInstance("SHA-256")) + val actualData = if (padded) { - PaddingInputStream(ByteArrayInputStream(data), data.size.toLong()).readFully() + PaddingInputStream(digestingStream, data.size.toLong()).readFully() } else { - data + digestingStream.readFully() } + val plaintextHash = digestingStream.messageDigest.digest() + val outputStream = ByteArrayOutputStream() val incrementalDigestOut = ByteArrayOutputStream() val iv = Util.getSecretBytes(16) @@ -643,7 +650,13 @@ class AttachmentCipherTest { encryptStream.close() incrementalDigestOut.close() - return EncryptResult(outputStream.toByteArray(), encryptStream.transmittedDigest, incrementalDigestOut.toByteArray(), sizeChoice.sizeInBytes) + return EncryptResult( + ciphertext = outputStream.toByteArray(), + digest = encryptStream.transmittedDigest, + incrementalDigest = incrementalDigestOut.toByteArray(), + chunkSizeChoice = sizeChoice.sizeInBytes, + plaintextHash = plaintextHash + ) } private fun writeToFile(data: ByteArray): File { diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStreamTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStreamTest.kt new file mode 100644 index 0000000000..754ea23d37 --- /dev/null +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/DigestValidatingInputStreamTest.kt @@ -0,0 +1,207 @@ +package org.whispersystems.signalservice.api.crypto + +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isTrue +import assertk.fail +import org.junit.Test +import org.signal.core.util.readFully +import org.signal.libsignal.protocol.InvalidMessageException +import java.io.ByteArrayInputStream +import java.security.MessageDigest + +class DigestValidatingInputStreamTest { + + @Test + fun `success - read byte by byte`() { + val data = "Hello, World!".toByteArray() + val digest = MessageDigest.getInstance("SHA-256") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = expectedHash + ) + + val result = ByteArray(data.size) + var i = 0 + var byteRead: Byte = 0 + while (digestEnforcingStream.read().also { byteRead = it.toByte() } != -1) { + result[i] = byteRead + i++ + } + + assertThat(result).isEqualTo(data) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `success - read byte array`() { + val data = "Hello, World! This is a longer message to test buffer reading.".toByteArray() + val digest = MessageDigest.getInstance("SHA-256") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = expectedHash + ) + + val result = digestEnforcingStream.readFully() + + assertThat(result.size).isEqualTo(data.size) + assertThat(result).isEqualTo(data) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `success - read byte array with offset and length`() { + val data = "This is test data for offset and length reading.".toByteArray() + val digest = MessageDigest.getInstance("SHA-256") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = expectedHash + ) + + val buffer = ByteArray(1024) + var totalBytesRead = 0 + var bytesRead: Int + + while (digestEnforcingStream.read(buffer, totalBytesRead, 10).also { bytesRead = it } > 0) { + totalBytesRead += bytesRead + } + + val result = buffer.copyOf(totalBytesRead) + assertThat(result).isEqualTo(data) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `success - empty data`() { + val data = ByteArray(0) + val digest = MessageDigest.getInstance("SHA-256") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = expectedHash + ) + + // Should immediately return -1 and validate + val endByte = digestEnforcingStream.read() + assertThat(endByte).isEqualTo(-1) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `success - alternative digest, md5`() { + val data = "Testing MD5 hash validation".toByteArray() + val digest = MessageDigest.getInstance("MD5") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("MD5"), + expectedHash = expectedHash + ) + + val result = digestEnforcingStream.readFully() + + assertThat(result).isEqualTo(data) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `success - multiple reads after close`() { + val data = "Test multiple validation calls".toByteArray() + val digest = MessageDigest.getInstance("SHA-256") + val expectedHash = digest.digest(data) + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = expectedHash + ) + + val result = digestEnforcingStream.readFully() + + // Multiple calls to read() after EOF should not cause issues + assertThat(digestEnforcingStream.read()).isEqualTo(-1) + assertThat(digestEnforcingStream.read()).isEqualTo(-1) + assertThat(digestEnforcingStream.read()).isEqualTo(-1) + + assertThat(result).isEqualTo(data) + assertThat(digestEnforcingStream.validationAttempted).isTrue() + + digestEnforcingStream.close() + } + + @Test + fun `failure - read byte by byte`() { + val data = "Hello, World!".toByteArray() + val wrongHash = ByteArray(32) // All zeros - wrong hash + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = wrongHash + ) + + try { + while (digestEnforcingStream.read() != -1) { + // Reading byte by byte + } + + fail("Expected InvalidCiphertextException to be thrown") + } catch (e: InvalidMessageException) { + // Expected exception + } finally { + digestEnforcingStream.close() + } + } + + @Test + fun `failure - read byte array`() { + val data = "Hello, World! This is a test message.".toByteArray() + val wrongHash = ByteArray(32) // All zeros - wrong hash + + val inputStream = ByteArrayInputStream(data) + val digestEnforcingStream = DigestValidatingInputStream( + inputStream = inputStream, + digest = MessageDigest.getInstance("SHA-256"), + expectedHash = wrongHash + ) + + try { + digestEnforcingStream.readFully() + + fail("Expected InvalidCiphertextException to be thrown") + } catch (e: InvalidMessageException) { + // Expected exception + } finally { + digestEnforcingStream.close() + } + } +}