Validate plaintext hashes for archived attachments.

This commit is contained in:
Greyson Parrelli
2025-06-20 15:26:23 -04:00
committed by Cody Henthorne
parent 38c8f852bf
commit 607b83d65b
21 changed files with 470 additions and 436 deletions

View File

@@ -182,37 +182,6 @@ class AttachmentTableTest {
assertThat(highInfo.file.exists()).isEqualTo(true)
}
@Test
fun finalizeAttachmentAfterDownload_fixDigestOnNonZeroPadding() {
// Insert attachment metadata for badly-padded attachment
val plaintext = byteArrayOf(1, 2, 3, 4)
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
val badlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully().also { it[it.size - 1] = 0x42 }
val badlyPaddedCiphertext = encryptPrePaddedBytes(badlyPaddedPlaintext, key, iv)
val badlyPaddedDigest = getDigest(badlyPaddedCiphertext)
val cipherFile = getTempFile()
cipherFile.writeBytes(badlyPaddedCiphertext)
val mmsId = -1L
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream)
// Verify the digest has been updated to the properly padded one
val properlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully()
val properlyPaddedCiphertext = encryptPrePaddedBytes(properlyPaddedPlaintext, key, iv)
val properlyPaddedDigest = getDigest(properlyPaddedCiphertext)
val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
assertArrayEquals(properlyPaddedDigest, newDigest)
}
@Test
fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() {
// Insert attachment metadata for properly-padded attachment
@@ -241,14 +210,14 @@ class AttachmentTableTest {
@Test
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() {
// Given an attachment with some digest
// Given an attachment with some plaintextHash+remoteKey
val blob = BlobProvider.getInstance().forData(byteArrayOf(1, 2, 3, 4, 5)).createForSingleSessionInMemory()
val attachment = createAttachment(1, blob, AttachmentTable.TransformProperties.empty())
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(-1L, listOf(attachment), emptyList()).values.first()
SignalDatabase.attachments.finalizeAttachmentAfterUpload(attachmentId, AttachmentTableTestUtil.createUploadResult(attachmentId))
SignalDatabase.attachments.setArchiveTransferState(attachmentId, AttachmentTable.ArchiveTransferState.FINISHED)
// Reset the transfer state by digest
// Reset the transfer state by plaintextHash+remoteKey
val plaintextHash = SignalDatabase.attachments.getAttachment(attachmentId)!!.dataHash!!.decodeBase64OrThrow()
val remoteKey = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteKey!!.decodeBase64OrThrow()
SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash, remoteKey)

View File

@@ -60,6 +60,10 @@ object DatabaseAttachmentArchiveUtil {
}
private fun hadIntegrityCheckPerformed(attachment: DatabaseAttachment): Boolean {
if (attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) {
return true
}
return when (attachment.transferState) {
AttachmentTable.TRANSFER_PROGRESS_DONE,
AttachmentTable.TRANSFER_NEEDS_RESTORE,
@@ -92,8 +96,8 @@ fun DatabaseAttachment.createArchiveAttachmentPointer(useArchiveCdn: Boolean): S
throw InvalidAttachmentException("empty encrypted key")
}
if (remoteDigest == null) {
throw InvalidAttachmentException("no digest")
if (remoteDigest == null && dataHash == null) {
throw InvalidAttachmentException("no integrity check available")
}
return try {

View File

@@ -55,10 +55,10 @@ fun InternalBackupStatsTab(stats: InternalBackupPlaygroundViewModel.StatsState,
Spacer(modifier = Modifier.size(16.dp))
Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}")
Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentDigestCount}/${stats.attachmentStats.finishedAttachmentDigestCount}")
Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")
Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}")
Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}")
Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentDigestCount}")
Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")
Spacer(modifier = Modifier.size(16.dp))

View File

@@ -34,7 +34,6 @@ import org.json.JSONArray
import org.json.JSONException
import org.signal.core.util.Base64
import org.signal.core.util.SqlUtil
import org.signal.core.util.StreamUtil
import org.signal.core.util.ThreadUtil
import org.signal.core.util.copyTo
import org.signal.core.util.count
@@ -60,7 +59,6 @@ import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireObject
import org.signal.core.util.requireString
import org.signal.core.util.select
import org.signal.core.util.stream.NullOutputStream
import org.signal.core.util.toInt
import org.signal.core.util.update
import org.signal.core.util.updateAll
@@ -98,7 +96,6 @@ import org.thoughtcrime.securesms.util.StorageUtil
import org.thoughtcrime.securesms.util.Util
import org.thoughtcrime.securesms.video.EncryptedMediaDataSource
import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil
import org.whispersystems.signalservice.api.util.UuidUtil
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
@@ -397,14 +394,14 @@ class AttachmentTable(
}
/**
* Returns a cursor (with just the digest+archive_cdn) for all attachments that are eligible for archive upload.
* In practice, this means that the attachments have a digest and have not hit a permanent archive upload failure.
* Returns a cursor (with just the plaintextHash+remoteKey+archive_cdn) for all attachments that are eligible for archive upload.
* In practice, this means that the attachments have a plaintextHash and have not hit a permanent archive upload failure.
*/
fun getAttachmentsEligibleForArchiveUpload(): Cursor {
return readableDatabase
.select(REMOTE_DIGEST, ARCHIVE_CDN)
.select(DATA_HASH_END, REMOTE_KEY, ARCHIVE_CDN)
.from(TABLE_NAME)
.where("$REMOTE_DIGEST IS NOT NULL AND $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value}")
.where("$DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value}")
.run()
}
@@ -530,7 +527,7 @@ class AttachmentTable(
fun getRestorableOptimizedAttachments(): List<RestorableAttachment> {
return readableDatabase
.select(ID, MESSAGE_ID, DATA_SIZE, REMOTE_DIGEST, REMOTE_KEY)
.select(ID, MESSAGE_ID, DATA_SIZE, DATA_HASH_END, REMOTE_KEY)
.from(TABLE_NAME)
.where("$TRANSFER_STATE = ? AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL", TRANSFER_RESTORE_OFFLOADED)
.orderBy("$ID DESC")
@@ -689,7 +686,7 @@ class AttachmentTable(
ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value,
ARCHIVE_CDN to null
)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey))
.run()
}
@@ -768,11 +765,12 @@ class AttachmentTable(
"""
SELECT SUM($DATA_SIZE)
FROM (
SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE
SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE
FROM $TABLE_NAME
WHERE
$DATA_FILE NOT NULL AND
$REMOTE_DIGEST NOT NULL AND
$DATA_HASH_END NOT NULL AND
$REMOTE_KEY NOT NULL AND
$ARCHIVE_TRANSFER_STATE NOT IN (${ArchiveTransferState.FINISHED.value}, ${ArchiveTransferState.PERMANENT_FAILURE.value})
)
""".trimIndent()
@@ -1257,7 +1255,7 @@ class AttachmentTable(
}
@Throws(IOException::class)
fun finalizeAttachmentThumbnailAfterDownload(attachmentId: AttachmentId, digest: ByteArray, inputStream: InputStream, transferFile: File) {
fun finalizeAttachmentThumbnailAfterDownload(attachmentId: AttachmentId, plaintextHash: String?, remoteKey: String?, inputStream: InputStream, transferFile: File) {
Log.i(TAG, "[finalizeAttachmentThumbnailAfterDownload] Finalizing downloaded data for $attachmentId.")
val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty())
@@ -1268,10 +1266,14 @@ class AttachmentTable(
THUMBNAIL_RESTORE_STATE to ThumbnailRestoreState.FINISHED.value
)
if (plaintextHash != null && remoteKey != null) {
db.update(TABLE_NAME)
.values(values)
.where("$REMOTE_DIGEST = ?", digest)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey)
.run()
} else {
Log.w(TAG, "[finalizeAttachmentThumbnailAfterDownload] No plaintext hash or remote key provided for $attachmentId. Cannot update other possible thumbnails.")
}
}
notifyConversationListListeners()
@@ -1287,7 +1289,8 @@ class AttachmentTable(
*/
fun finalizeAttachmentThumbnailAfterUpload(
attachmentId: AttachmentId,
attachmentDigest: ByteArray,
attachmentPlaintextHash: String?,
attachmentRemoteKey: String?,
data: ByteArray
) {
Log.i(TAG, "[finalizeAttachmentThumbnailAfterUpload] Finalizing archive data for $attachmentId thumbnail.")
@@ -1300,10 +1303,14 @@ class AttachmentTable(
THUMBNAIL_RESTORE_STATE to ThumbnailRestoreState.FINISHED.value
)
if (attachmentPlaintextHash != null && attachmentRemoteKey != null) {
db.update(TABLE_NAME)
.values(values)
.where("$ID = ? OR $REMOTE_DIGEST = ?", attachmentId, attachmentDigest)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", attachmentPlaintextHash, attachmentRemoteKey)
.run()
} else {
Log.w(TAG, "[finalizeAttachmentThumbnailAfterUpload] No plaintext hash or remote key provided for $attachmentId. Cannot update other possible thumbnails.")
}
}
}
@@ -1494,7 +1501,7 @@ class AttachmentTable(
}
/**
* As part of the digest backfill process, this updates the (key, IV, digest) tuple for all attachments that share a data file (and are done downloading).
* As part of the digest backfill process, this updates the (key, digest) tuple for all attachments that share a data file (and are done downloading).
*/
fun updateRemoteKeyAndDigestByDataFile(dataFile: String, key: ByteArray, digest: ByteArray) {
writableDatabase
@@ -1551,74 +1558,6 @@ class AttachmentTable(
return insertedAttachments
}
fun debugCopyAttachmentForArchiveRestore(
mmsId: Long,
attachment: DatabaseAttachment,
forThumbnail: Boolean
) {
val copy =
"""
INSERT INTO $TABLE_NAME
(
$MESSAGE_ID,
$CONTENT_TYPE,
$TRANSFER_STATE,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,
$REMOTE_INCREMENTAL_DIGEST,
$REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE,
$REMOTE_KEY,
$FILE_NAME,
$DATA_SIZE,
$VOICE_NOTE,
$BORDERLESS,
$VIDEO_GIF,
$WIDTH,
$HEIGHT,
$CAPTION,
$UPLOAD_TIMESTAMP,
$BLUR_HASH,
$DATA_SIZE,
$DATA_RANDOM,
$DATA_HASH_START,
$DATA_HASH_END,
$ARCHIVE_CDN,
$THUMBNAIL_RESTORE_STATE
)
SELECT
$mmsId,
$CONTENT_TYPE,
$TRANSFER_NEEDS_RESTORE,
$CDN_NUMBER,
$REMOTE_LOCATION,
$REMOTE_DIGEST,
$REMOTE_INCREMENTAL_DIGEST,
$REMOTE_INCREMENTAL_DIGEST_CHUNK_SIZE,
$REMOTE_KEY,
$FILE_NAME,
$DATA_SIZE,
$VOICE_NOTE,
$BORDERLESS,
$VIDEO_GIF,
$WIDTH,
$HEIGHT,
$CAPTION,
${System.currentTimeMillis()},
$BLUR_HASH,
$DATA_SIZE,
$DATA_RANDOM,
$DATA_HASH_START,
$DATA_HASH_END,
${attachment.archiveCdn},
${if (forThumbnail) ThumbnailRestoreState.NEEDS_RESTORE.value else ThumbnailRestoreState.NONE.value}
FROM $TABLE_NAME
WHERE $ID = ${attachment.attachmentId.id}
"""
writableDatabase.execSQL(copy)
}
/**
* Updates the data stored for an existing attachment. This happens after transformations, like transcoding.
*/
@@ -1921,56 +1860,47 @@ class AttachmentTable(
}
/**
* Sets the archive data for the specific attachment, as well as for any attachments that use the same underlying file.
* Sets the archive data for the specific attachment, as well as for any attachments that have the same mediaName (plaintextHash + remoteKey).
*/
fun setArchiveCdn(attachmentId: AttachmentId, archiveCdn: Int) {
writableDatabase.withinTransaction { db ->
val dataFile = db
.select(DATA_FILE)
val plaintextHashAndRemoteKey = db
.select(DATA_HASH_END, REMOTE_KEY)
.from(TABLE_NAME)
.where("$ID = ?", attachmentId.id)
.run()
.readToSingleObject { it.requireString(DATA_FILE) }
.readToSingleObject {
it.requireNonNullString(DATA_HASH_END) to it.requireNonNullString(REMOTE_KEY)
}
if (dataFile == null) {
if (plaintextHashAndRemoteKey == null) {
Log.w(TAG, "No data file found for attachment $attachmentId. Can't set archive data.")
return@withinTransaction
}
val (plaintextHash, remoteKey) = plaintextHashAndRemoteKey
db.update(TABLE_NAME)
.values(
ARCHIVE_CDN to archiveCdn,
ARCHIVE_TRANSFER_STATE to ArchiveTransferState.FINISHED.value
)
.where("$DATA_FILE = ?", dataFile)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey)
.run()
}
}
/**
* Updates all attachments that share the same digest with the given archive CDN.
* Updates all attachments that share the same mediaName (plaintextHash + remoteKey) with the given archive CDN.
*/
fun setArchiveCdnByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray, archiveCdn: Int) {
writableDatabase
.update(TABLE_NAME)
.values(ARCHIVE_CDN to archiveCdn)
.where("$DATA_HASH_END= ? AND $REMOTE_KEY = ?", plaintextHash, remoteKey)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey))
.run()
}
fun clearArchiveData(attachmentIds: List<AttachmentId>) {
SqlUtil.buildCollectionQuery(ID, attachmentIds.map { it.id })
.forEach { query ->
writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_CDN to null
)
.where(query.where, query.whereArgs)
.run()
}
}
fun clearAllArchiveData() {
writableDatabase
.updateAll(TABLE_NAME)
@@ -1981,18 +1911,6 @@ class AttachmentTable(
.run()
}
private fun calculateDigest(fileInfo: DataFileWriteResult, key: ByteArray, iv: ByteArray = Util.getSecretBytes(16)): ByteArray {
return calculateDigest(file = fileInfo.file, random = fileInfo.random, length = fileInfo.length, key = key, iv = iv)
}
private fun calculateDigest(file: File, random: ByteArray, length: Long, key: ByteArray, iv: ByteArray): ByteArray {
val stream = PaddingInputStream(getDataStream(file, random, 0), length)
val cipherOutputStream = AttachmentCipherOutputStream(key, iv, NullOutputStream)
StreamUtil.copy(stream, cipherOutputStream)
return cipherOutputStream.transmittedDigest
}
/**
* Deletes the data file if there's no strong references to other attachments.
* If deleted, it will also clear all weak references (i.e. quotes) of the attachment.
@@ -2476,15 +2394,20 @@ class AttachmentTable(
fun getEstimatedArchiveMediaSize(): Long {
val estimatedThumbnailCount = readableDatabase
.select("COUNT(DISTINCT $REMOTE_DIGEST)")
.from(TABLE_NAME)
.where(
.select("COUNT(*)")
.from(
"""
(
SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY
FROM $TABLE_NAME
WHERE
$DATA_FILE NOT NULL AND
$REMOTE_DIGEST NOT NULL AND
$DATA_HASH_END NOT NULL AND
$REMOTE_KEY NOT NULL AND
$TRANSFER_STATE = $TRANSFER_PROGRESS_DONE AND
$ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value} AND
($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%')
)
"""
)
.run()
@@ -2495,11 +2418,12 @@ class AttachmentTable(
"""
SELECT $DATA_SIZE
FROM (
SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE
SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE
FROM $TABLE_NAME
WHERE
$DATA_FILE NOT NULL AND
$REMOTE_DIGEST NOT NULL AND
$DATA_HASH_END NOT NULL AND
$REMOTE_KEY NOT NULL AND
$TRANSFER_STATE = $TRANSFER_PROGRESS_DONE AND
$ARCHIVE_TRANSFER_STATE != ${ArchiveTransferState.PERMANENT_FAILURE.value}
)
@@ -2657,23 +2581,23 @@ class AttachmentTable(
)
val transferStateCounts = transferStates
.map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $REMOTE_DIGEST NOT NULL").run().readToSingleLong(-1L) }
.map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) }
.toMap()
val validForArchiveTransferStateCounts = transferStates
.map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $REMOTE_DIGEST NOT NULL AND $DATA_FILE NOT NULL").run().readToSingleLong(-1L) }
.map { (state, name) -> name to readableDatabase.count().from(TABLE_NAME).where("$TRANSFER_STATE = $state AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $DATA_FILE NOT NULL").run().readToSingleLong(-1L) }
.toMap()
val archiveStateCounts = ArchiveTransferState
.entries
.associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $REMOTE_DIGEST NOT NULL").run().readToSingleLong(-1L) }
.associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) }
val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $REMOTE_DIGEST NOT NULL").readToSingleLong(-1L)
val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $REMOTE_DIGEST NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val attachmentDigestCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $REMOTE_DIGEST NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE)").readToSingleLong(-1L)
val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $REMOTE_DIGEST NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L)
val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L)
val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L)
val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L)
val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(DISTINCT $REMOTE_DIGEST) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $REMOTE_DIGEST NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%')").readToSingleLong(-1L)
val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L)
val pendingUploadBytes = getPendingArchiveUploadBytes()
val uploadedAttachmentBytes = readableDatabase
@@ -2681,11 +2605,12 @@ class AttachmentTable(
"""
SELECT $DATA_SIZE
FROM (
SELECT DISTINCT $REMOTE_DIGEST, $DATA_SIZE
SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY, $DATA_SIZE
FROM $TABLE_NAME
WHERE
$DATA_FILE NOT NULL AND
$REMOTE_DIGEST NOT NULL AND
$DATA_HASH_END NOT NULL AND
$REMOTE_KEY NOT NULL AND
$ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}
)
""".trimIndent()
@@ -2702,8 +2627,8 @@ class AttachmentTable(
archiveStateCounts = archiveStateCounts,
attachmentFileCount = attachmentFileCount,
finishedAttachmentFileCount = finishedAttachmentFileCount,
attachmentDigestCount = attachmentDigestCount,
finishedAttachmentDigestCount = finishedAttachmentDigestCount,
attachmentPlaintextHashAndKeyCount = attachmentPlaintextHashAndKeyCount,
finishedAttachmentPlaintextHashAndKeyCount = finishedAttachmentDigestCount,
thumbnailFileCount = thumbnailFileCount,
estimatedThumbnailCount = estimatedThumbnailCount,
pendingUploadBytes = pendingUploadBytes,
@@ -2953,8 +2878,8 @@ class AttachmentTable(
val archiveStateCounts: Map<ArchiveTransferState, Long> = emptyMap(),
val attachmentFileCount: Long = 0L,
val finishedAttachmentFileCount: Long = 0L,
val attachmentDigestCount: Long = 0L,
val finishedAttachmentDigestCount: Long,
val attachmentPlaintextHashAndKeyCount: Long = 0L,
val finishedAttachmentPlaintextHashAndKeyCount: Long,
val thumbnailFileCount: Long = 0L,
val pendingUploadBytes: Long = 0L,
val uploadedAttachmentBytes: Long = 0L,

View File

@@ -244,7 +244,7 @@ class BackupMediaSnapshotTable(context: Context, database: SignalDatabase) : Dat
return readableDatabase.rawQuery(
"""
WITH input_pairs($MEDIA_ID, $CDN) AS (VALUES $inputValues)
SELECT a.$PLAINTEXT_HASH, a.$REMOTE_KEY b.$CDN
SELECT a.$PLAINTEXT_HASH, a.$REMOTE_KEY, b.$CDN
FROM $TABLE_NAME a
JOIN input_pairs b ON a.$MEDIA_ID = b.$MEDIA_ID
WHERE a.$CDN != b.$CDN AND a.$IS_THUMBNAIL = 0 AND $SNAPSHOT_VERSION = $MAX_VERSION

View File

@@ -148,7 +148,8 @@ class ArchiveThumbnailUploadJob private constructor(
// save attachment thumbnail
SignalDatabase.attachments.finalizeAttachmentThumbnailAfterUpload(
attachmentId = attachmentId,
attachmentDigest = attachment.remoteDigest,
attachmentPlaintextHash = attachment.dataHash,
attachmentRemoteKey = attachment.remoteKey,
data = thumbnailResult.data
)

View File

@@ -94,7 +94,7 @@ class AttachmentDownloadJob private constructor(
AttachmentTable.TRANSFER_PROGRESS_PENDING,
AttachmentTable.TRANSFER_PROGRESS_FAILED -> {
if (SignalStore.backup.backsUpMedia && databaseAttachment.remoteLocation == null) {
if (SignalStore.backup.backsUpMedia && (databaseAttachment.remoteLocation == null || databaseAttachment.remoteDigest == null)) {
if (databaseAttachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) {
Log.i(TAG, "Trying to restore attachment from archive cdn")
RestoreAttachmentJob.restoreAttachment(databaseAttachment)

View File

@@ -131,8 +131,8 @@ class BackupMessagesJob private constructor(
val stopwatch = Stopwatch("BackupMessagesJob")
SignalDatabase.attachments.createRemoteKeyForAttachmentsThatNeedArchiveUpload().takeIf { it > 0 }?.let { count -> Log.w(TAG, "Needed to create $count key/iv/digests.") }
stopwatch.split("key-iv-digest")
SignalDatabase.attachments.createRemoteKeyForAttachmentsThatNeedArchiveUpload().takeIf { it > 0 }?.let { count -> Log.w(TAG, "Needed to create $count remote keys.") }
stopwatch.split("keygen")
if (isCanceled) {
return Result.failure()

View File

@@ -10,6 +10,7 @@ import android.content.Intent
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64.decodeBase64OrThrow
import org.signal.core.util.PendingIntentFlags
import org.signal.core.util.logging.Log
import org.signal.libsignal.protocol.InvalidMacException
@@ -182,9 +183,10 @@ class RestoreAttachmentJob private constructor(
if (attachment.transferState != AttachmentTable.TRANSFER_NEEDS_RESTORE &&
attachment.transferState != AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS &&
(attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED)
attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED &&
attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED
) {
Log.w(TAG, "Attachment does not need to be restored.")
Log.w(TAG, "Attachment does not need to be restored. Current state: ${attachment.transferState}")
return
}
@@ -263,6 +265,7 @@ class RestoreAttachmentJob private constructor(
messageReceiver
.retrieveArchivedAttachment(
SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireMediaName()),
attachment.dataHash!!.decodeBase64OrThrow(),
cdnCredentials,
archiveFile,
pointer,

View File

@@ -113,8 +113,8 @@ class RestoreAttachmentThumbnailJob private constructor(
return
}
if (attachment.remoteDigest == null) {
Log.w(TAG, "$attachmentId has no digest! Cannot proceed.")
if (attachment.dataHash == null) {
Log.w(TAG, "$attachmentId has no plaintext hash! Cannot proceed.")
return
}
@@ -142,7 +142,7 @@ class RestoreAttachmentThumbnailJob private constructor(
progressListener
)
SignalDatabase.attachments.finalizeAttachmentThumbnailAfterDownload(attachmentId, attachment.remoteDigest, decryptingStream, thumbnailTransferFile)
SignalDatabase.attachments.finalizeAttachmentThumbnailAfterDownload(attachmentId, attachment.dataHash, attachment.remoteKey, decryptingStream, thumbnailTransferFile)
if (!SignalDatabase.messages.isStory(messageId)) {
AppDependencies.messageNotifier.updateNotification(context)

View File

@@ -1,77 +0,0 @@
package org.thoughtcrime.securesms.mms;
import androidx.annotation.NonNull;
import com.bumptech.glide.Priority;
import com.bumptech.glide.load.DataSource;
import com.bumptech.glide.load.data.DataFetcher;
import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;
class AttachmentStreamLocalUriFetcher implements DataFetcher<InputStream> {
private static final String TAG = Log.tag(AttachmentStreamLocalUriFetcher.class);
private final File attachment;
private final byte[] key;
private final Optional<byte[]> digest;
private final long plaintextLength;
private InputStream is;
AttachmentStreamLocalUriFetcher(File attachment, long plaintextLength, byte[] key, Optional<byte[]> digest) {
this.attachment = attachment;
this.plaintextLength = plaintextLength;
this.digest = digest;
this.key = key;
}
@Override
public void loadData(@NonNull Priority priority, @NonNull DataCallback<? super InputStream> callback) {
try {
if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!");
is = AttachmentCipherInputStream.createForAttachment(attachment,
plaintextLength,
key,
digest.get(),
null,
0);
callback.onDataReady(is);
} catch (IOException | InvalidMessageException e) {
callback.onLoadFailed(e);
}
}
@Override
public void cleanup() {
try {
if (is != null) is.close();
is = null;
} catch (IOException ioe) {
Log.w(TAG, "ioe");
}
}
@Override
public void cancel() {}
@Override
public @NonNull Class<InputStream> getDataClass() {
return InputStream.class;
}
@Override
public @NonNull DataSource getDataSource() {
return DataSource.LOCAL;
}
}

View File

@@ -1,89 +0,0 @@
package org.thoughtcrime.securesms.mms;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import com.bumptech.glide.load.Key;
import com.bumptech.glide.load.Options;
import com.bumptech.glide.load.model.ModelLoader;
import com.bumptech.glide.load.model.ModelLoaderFactory;
import com.bumptech.glide.load.model.MultiModelLoaderFactory;
import org.thoughtcrime.securesms.mms.AttachmentStreamUriLoader.AttachmentModel;
import java.io.File;
import java.io.InputStream;
import java.security.MessageDigest;
import java.util.Optional;
public class AttachmentStreamUriLoader implements ModelLoader<AttachmentModel, InputStream> {
@Override
public @Nullable LoadData<InputStream> buildLoadData(@NonNull AttachmentModel attachmentModel, int width, int height, @NonNull Options options) {
return new LoadData<>(attachmentModel, new AttachmentStreamLocalUriFetcher(attachmentModel.attachment, attachmentModel.plaintextLength, attachmentModel.key, attachmentModel.digest));
}
@Override
public boolean handles(@NonNull AttachmentModel attachmentModel) {
return true;
}
static class Factory implements ModelLoaderFactory<AttachmentModel, InputStream> {
@Override
public @NonNull ModelLoader<AttachmentModel, InputStream> build(@NonNull MultiModelLoaderFactory multiFactory) {
return new AttachmentStreamUriLoader();
}
@Override
public void teardown() {
// Do nothing.
}
}
public static class AttachmentModel implements Key {
public @NonNull File attachment;
public @NonNull byte[] key;
public @NonNull Optional<byte[]> digest;
public @NonNull Optional<byte[]> incrementalDigest;
public int incrementalMacChunkSize;
public long plaintextLength;
public AttachmentModel(@NonNull File attachment,
@NonNull byte[] key,
long plaintextLength,
@NonNull Optional<byte[]> digest,
@NonNull Optional<byte[]> incrementalDigest,
int incrementalMacChunkSize)
{
this.attachment = attachment;
this.key = key;
this.digest = digest;
this.incrementalDigest = incrementalDigest;
this.incrementalMacChunkSize = incrementalMacChunkSize;
this.plaintextLength = plaintextLength;
}
@Override
public void updateDiskCacheKey(@NonNull MessageDigest messageDigest) {
messageDigest.update(attachment.toString().getBytes());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AttachmentModel that = (AttachmentModel)o;
return attachment.equals(that.attachment);
}
@Override
public int hashCode() {
return attachment.hashCode();
}
}
}

View File

@@ -41,7 +41,6 @@ import org.thoughtcrime.securesms.glide.cache.EncryptedCacheDecoder;
import org.thoughtcrime.securesms.glide.cache.EncryptedCacheEncoder;
import org.thoughtcrime.securesms.glide.cache.EncryptedGifDrawableResourceEncoder;
import org.thoughtcrime.securesms.glide.cache.WebpSanDecoder;
import org.thoughtcrime.securesms.mms.AttachmentStreamUriLoader.AttachmentModel;
import org.thoughtcrime.securesms.mms.DecryptableStreamUriLoader.DecryptableUri;
import org.thoughtcrime.securesms.stickers.StickerRemoteUri;
import org.thoughtcrime.securesms.stickers.StickerRemoteUriLoader;
@@ -96,7 +95,6 @@ public class SignalGlideComponents implements RegisterGlideComponents {
registry.append(ConversationShortcutPhoto.class, Bitmap.class, new ConversationShortcutPhoto.Loader.Factory(context));
registry.append(ContactPhoto.class, InputStream.class, new ContactPhotoLoader.Factory(context));
registry.append(DecryptableUri.class, InputStream.class, new DecryptableStreamUriLoader.Factory(context));
registry.append(AttachmentModel.class, InputStream.class, new AttachmentStreamUriLoader.Factory());
registry.append(ChunkedImageUrl.class, InputStream.class, new ChunkedImageUrlLoader.Factory());
registry.append(StickerRemoteUri.class, InputStream.class, new StickerRemoteUriLoader.Factory());
registry.append(BlurHash.class, BlurHash.class, new BlurHashModelLoader.Factory());

View File

@@ -86,7 +86,11 @@ class PartDataSource implements DataSource {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForArchivedMedia(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
if (attachment.dataHash == null || attachment.dataHash.isEmpty()) {
throw new InvalidMessageException("Missing plaintextHash!");
}
this.inputStream = AttachmentCipherInputStream.createForArchivedMedia(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, Base64.decodeOrThrow(attachment.dataHash), attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e);
}

View File

@@ -122,6 +122,7 @@ public class SignalServiceMessageReceiver {
* Retrieves an archived media attachment.
*
* @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media.
* @param plaintextHash The plaintext hash of the attachment, used to verify the integrity of the downloaded content.
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
@@ -131,6 +132,7 @@ public class SignalServiceMessageReceiver {
* @return An InputStream that streams the plaintext attachment contents.
*/
public InputStream retrieveArchivedAttachment(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial,
@Nonnull byte[] plaintextHash,
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@@ -139,10 +141,6 @@ public class SignalServiceMessageReceiver {
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
{
if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("No attachment digest!");
}
if (pointer.getKey() == null) {
throw new InvalidMessageException("No key!");
}
@@ -160,7 +158,7 @@ public class SignalServiceMessageReceiver {
originalCipherLength,
pointer.getSize().orElse(0),
pointer.getKey(),
pointer.getDigest().get(),
plaintextHash,
null,
0
);

View File

@@ -1,16 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.attachment
import org.signal.core.util.stream.LimitedInputStream
/**
* Holds the result of an attachment download.
*/
class AttachmentDownloadResult(
val dataStream: LimitedInputStream,
val iv: ByteArray
)

View File

@@ -54,13 +54,14 @@ object AttachmentCipherInputStream {
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
): InputStream {
return create(
streamSupplier = { FileInputStream(file) },
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
encryptedDigest = digest,
plaintextHash = null,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
@@ -83,13 +84,14 @@ object AttachmentCipherInputStream {
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
): InputStream {
return create(
streamSupplier = streamSupplier,
streamLength = streamLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
encryptedDigest = digest,
plaintextHash = null,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
@@ -112,10 +114,10 @@ object AttachmentCipherInputStream {
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
plaintextHash: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
@@ -128,10 +130,11 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = digest,
encryptedDigest = null,
plaintextHash = plaintextHash,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
ignoreDigest = true
)
}
@@ -151,7 +154,7 @@ object AttachmentCipherInputStream {
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray
): LimitedInputStream {
): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
@@ -164,7 +167,8 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
digest = null,
encryptedDigest = null,
plaintextHash = null,
incrementalDigest = null,
incrementalMacChunkSize = 0,
ignoreDigest = true
@@ -229,11 +233,12 @@ object AttachmentCipherInputStream {
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray?,
encryptedDigest: ByteArray?,
plaintextHash: ByteArray?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
): LimitedInputStream {
): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
@@ -241,7 +246,7 @@ object AttachmentCipherInputStream {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
}
if (!ignoreDigest && digest == null) {
if (!ignoreDigest && encryptedDigest == null) {
throw InvalidMessageException("Missing digest!")
}
@@ -250,20 +255,25 @@ object AttachmentCipherInputStream {
if (!hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream ->
verifyMac(macVerificationStream, streamLength, mac, digest)
verifyMac(macVerificationStream, streamLength, mac, encryptedDigest)
}
wrappedStream = streamSupplier.openStream()
} else {
if (digest == null) {
throw InvalidMessageException("Missing digest for incremental mac validation!")
if (encryptedDigest == null && plaintextHash == null) {
throw InvalidMessageException("Missing data (digest or plaintextHas) for incremental mac validation!")
}
val digestValidatingStream = if (encryptedDigest != null) {
DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), encryptedDigest)
} else {
streamSupplier.openStream()
}
wrappedStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
wrapped = streamSupplier.openStream(),
wrapped = digestValidatingStream,
fileLength = streamLength,
mac = mac,
theirDigest = digest
mac = mac
),
keyMaterial.macKey,
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
@@ -274,8 +284,17 @@ object AttachmentCipherInputStream {
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey)
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength)
return LimitedInputStream(decryptingStream, plaintextLength)
return if (plaintextHash != null) {
if (plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) {
throw InvalidMessageException("Invalid plaintext hash size: ${plaintextHash.size}")
}
DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), plaintextHash)
} else {
paddinglessDecryptingStream
}
}
private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher {
@@ -286,6 +305,14 @@ object AttachmentCipherInputStream {
}
}
private fun sha256Digest(): MessageDigest {
try {
return MessageDigest.getInstance("SHA-256")
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
}
}
private fun initMac(key: ByteArray): Mac {
try {
val mac = Mac.getInstance("HmacSHA256")

View File

@@ -0,0 +1,76 @@
/*
* Copyright (C) 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.crypto
import org.signal.libsignal.protocol.InvalidMessageException
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import java.security.MessageDigest
/**
* An InputStream that enforces hash validation by calculating a digest as data is read
* and verifying it against an expected hash when the stream is fully consumed.
*
* Important: The validation only occurs if you read the entire stream.
*
* @param inputStream The underlying InputStream to read from
* @param digest The MessageDigest instance to use for hash calculation
* @param expectedHash The expected hash value to validate against
*/
class DigestValidatingInputStream(
inputStream: InputStream,
private val digest: MessageDigest,
private val expectedHash: ByteArray
) : FilterInputStream(inputStream) {
var validationAttempted = false
private set
@Throws(IOException::class)
override fun read(): Int {
val byte = super.read()
if (byte != -1) {
digest.update(byte.toByte())
} else {
validateDigest()
}
return byte
}
@Throws(IOException::class)
override fun read(buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
@Throws(IOException::class)
override fun read(buffer: ByteArray, offset: Int, length: Int): Int {
val bytesRead = super.read(buffer, offset, length)
if (bytesRead > 0) {
digest.update(buffer, offset, bytesRead)
} else if (bytesRead == -1) {
validateDigest()
}
return bytesRead
}
/**
* Validates the calculated digest against the expected hash.
* Throws InvalidCiphertextException if they don't match.
*/
@Throws(InvalidMessageException::class)
private fun validateDigest() {
if (validationAttempted) {
return
}
validationAttempted = true
val calculatedHash = digest.digest()
if (!MessageDigest.isEqual(calculatedHash, expectedHash)) {
throw InvalidMessageException("Calculated digest does not match expected hash!")
}
}
}

View File

@@ -22,11 +22,9 @@ import kotlin.math.max
class IncrementalMacAdditionalValidationsInputStream(
wrapped: InputStream,
fileLength: Long,
private val mac: Mac,
private val theirDigest: ByteArray
private val mac: Mac
) : FilterInputStream(wrapped) {
private val digest: MessageDigest = MessageDigest.getInstance("SHA256")
private val macLength: Int = mac.macLength
private val macBuffer: ByteArray = ByteArray(macLength)
@@ -77,8 +75,6 @@ class IncrementalMacAdditionalValidationsInputStream(
mac.update(buffer, offset, bytesRead)
}
digest.update(buffer, offset, bytesRead)
if (bytesRemaining == 0) {
validate()
}
@@ -113,10 +109,5 @@ class IncrementalMacAdditionalValidationsInputStream(
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw InvalidMessageException("MAC doesn't match!")
}
val ourDigest = digest.digest()
if (!MessageDigest.isEqual(ourDigest, theirDigest)) {
throw InvalidMessageException("Digest doesn't match!")
}
}
}

View File

@@ -8,9 +8,7 @@ import org.conscrypt.Conscrypt
import org.junit.Assert
import org.junit.Test
import org.signal.core.util.StreamUtil
import org.signal.core.util.allMatch
import org.signal.core.util.readFully
import org.signal.core.util.stream.LimitedInputStream
import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException
@@ -27,6 +25,8 @@ import java.io.FileOutputStream
import java.io.InputStream
import java.io.OutputStream
import java.lang.AssertionError
import java.security.DigestInputStream
import java.security.MessageDigest
import java.security.Security
import java.util.Random
@@ -63,11 +63,10 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@@ -89,11 +88,10 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@@ -234,20 +232,19 @@ class AttachmentCipherTest {
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMedia(
val decryptedStream = AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
plaintextHash = innerEncryptResult.plaintextHash,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@@ -285,7 +282,7 @@ class AttachmentCipherTest {
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
plaintextHash = innerEncryptResult.digest,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
@@ -299,17 +296,17 @@ class AttachmentCipherTest {
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() {
fun archive_encryptDecrypt_nonIncremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental() {
fun archive_encryptDecrypt_incremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() {
fun archive_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
@@ -334,20 +331,19 @@ class AttachmentCipherTest {
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMedia(
val decryptedStream = AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
plaintextHash = innerEncryptResult.plaintextHash,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@@ -380,7 +376,7 @@ class AttachmentCipherTest {
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
digest = innerEncryptResult.digest,
plaintextHash = innerEncryptResult.digest,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
@@ -607,7 +603,14 @@ class AttachmentCipherTest {
cipherFile.delete()
}
private class EncryptResult(val ciphertext: ByteArray, val digest: ByteArray, val incrementalDigest: ByteArray, val chunkSizeChoice: Int)
private class EncryptResult(
val ciphertext: ByteArray,
val digest: ByteArray,
val incrementalDigest: ByteArray,
val chunkSizeChoice: Int,
val plaintextHash: ByteArray
)
companion object {
init {
// https://github.com/google/conscrypt/issues/1034
@@ -619,12 +622,16 @@ class AttachmentCipherTest {
private const val MEBIBYTE = 1024 * 1024
private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean, padded: Boolean = true): EncryptResult {
val digestingStream = DigestInputStream(ByteArrayInputStream(data), MessageDigest.getInstance("SHA-256"))
val actualData = if (padded) {
PaddingInputStream(ByteArrayInputStream(data), data.size.toLong()).readFully()
PaddingInputStream(digestingStream, data.size.toLong()).readFully()
} else {
data
digestingStream.readFully()
}
val plaintextHash = digestingStream.messageDigest.digest()
val outputStream = ByteArrayOutputStream()
val incrementalDigestOut = ByteArrayOutputStream()
val iv = Util.getSecretBytes(16)
@@ -643,7 +650,13 @@ class AttachmentCipherTest {
encryptStream.close()
incrementalDigestOut.close()
return EncryptResult(outputStream.toByteArray(), encryptStream.transmittedDigest, incrementalDigestOut.toByteArray(), sizeChoice.sizeInBytes)
return EncryptResult(
ciphertext = outputStream.toByteArray(),
digest = encryptStream.transmittedDigest,
incrementalDigest = incrementalDigestOut.toByteArray(),
chunkSizeChoice = sizeChoice.sizeInBytes,
plaintextHash = plaintextHash
)
}
private fun writeToFile(data: ByteArray): File {

View File

@@ -0,0 +1,207 @@
package org.whispersystems.signalservice.api.crypto
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isTrue
import assertk.fail
import org.junit.Test
import org.signal.core.util.readFully
import org.signal.libsignal.protocol.InvalidMessageException
import java.io.ByteArrayInputStream
import java.security.MessageDigest
class DigestValidatingInputStreamTest {
@Test
fun `success - read byte by byte`() {
val data = "Hello, World!".toByteArray()
val digest = MessageDigest.getInstance("SHA-256")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = expectedHash
)
val result = ByteArray(data.size)
var i = 0
var byteRead: Byte = 0
while (digestEnforcingStream.read().also { byteRead = it.toByte() } != -1) {
result[i] = byteRead
i++
}
assertThat(result).isEqualTo(data)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `success - read byte array`() {
val data = "Hello, World! This is a longer message to test buffer reading.".toByteArray()
val digest = MessageDigest.getInstance("SHA-256")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = expectedHash
)
val result = digestEnforcingStream.readFully()
assertThat(result.size).isEqualTo(data.size)
assertThat(result).isEqualTo(data)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `success - read byte array with offset and length`() {
val data = "This is test data for offset and length reading.".toByteArray()
val digest = MessageDigest.getInstance("SHA-256")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = expectedHash
)
val buffer = ByteArray(1024)
var totalBytesRead = 0
var bytesRead: Int
while (digestEnforcingStream.read(buffer, totalBytesRead, 10).also { bytesRead = it } > 0) {
totalBytesRead += bytesRead
}
val result = buffer.copyOf(totalBytesRead)
assertThat(result).isEqualTo(data)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `success - empty data`() {
val data = ByteArray(0)
val digest = MessageDigest.getInstance("SHA-256")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = expectedHash
)
// Should immediately return -1 and validate
val endByte = digestEnforcingStream.read()
assertThat(endByte).isEqualTo(-1)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `success - alternative digest, md5`() {
val data = "Testing MD5 hash validation".toByteArray()
val digest = MessageDigest.getInstance("MD5")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("MD5"),
expectedHash = expectedHash
)
val result = digestEnforcingStream.readFully()
assertThat(result).isEqualTo(data)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `success - multiple reads after close`() {
val data = "Test multiple validation calls".toByteArray()
val digest = MessageDigest.getInstance("SHA-256")
val expectedHash = digest.digest(data)
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = expectedHash
)
val result = digestEnforcingStream.readFully()
// Multiple calls to read() after EOF should not cause issues
assertThat(digestEnforcingStream.read()).isEqualTo(-1)
assertThat(digestEnforcingStream.read()).isEqualTo(-1)
assertThat(digestEnforcingStream.read()).isEqualTo(-1)
assertThat(result).isEqualTo(data)
assertThat(digestEnforcingStream.validationAttempted).isTrue()
digestEnforcingStream.close()
}
@Test
fun `failure - read byte by byte`() {
val data = "Hello, World!".toByteArray()
val wrongHash = ByteArray(32) // All zeros - wrong hash
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = wrongHash
)
try {
while (digestEnforcingStream.read() != -1) {
// Reading byte by byte
}
fail("Expected InvalidCiphertextException to be thrown")
} catch (e: InvalidMessageException) {
// Expected exception
} finally {
digestEnforcingStream.close()
}
}
@Test
fun `failure - read byte array`() {
val data = "Hello, World! This is a test message.".toByteArray()
val wrongHash = ByteArray(32) // All zeros - wrong hash
val inputStream = ByteArrayInputStream(data)
val digestEnforcingStream = DigestValidatingInputStream(
inputStream = inputStream,
digest = MessageDigest.getInstance("SHA-256"),
expectedHash = wrongHash
)
try {
digestEnforcingStream.readFully()
fail("Expected InvalidCiphertextException to be thrown")
} catch (e: InvalidMessageException) {
// Expected exception
} finally {
digestEnforcingStream.close()
}
}
}