Allow normal attachments to be validated with plaintextHashes.

This commit is contained in:
Greyson Parrelli
2025-06-23 12:13:30 -04:00
committed by Cody Henthorne
parent 607b83d65b
commit ec5452744d
23 changed files with 319 additions and 185 deletions

View File

@@ -8,7 +8,6 @@ import androidx.test.platform.app.InstrumentationRegistry
import assertk.assertThat import assertk.assertThat
import assertk.assertions.isEqualTo import assertk.assertions.isEqualTo
import assertk.assertions.isNotEqualTo import assertk.assertions.isNotEqualTo
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotEquals import org.junit.Assert.assertNotEquals
import org.junit.Before import org.junit.Before
@@ -17,7 +16,6 @@ import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.signal.core.util.Base64.decodeBase64OrThrow import org.signal.core.util.Base64.decodeBase64OrThrow
import org.signal.core.util.copyTo import org.signal.core.util.copyTo
import org.signal.core.util.readFully
import org.signal.core.util.stream.NullOutputStream import org.signal.core.util.stream.NullOutputStream
import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId import org.thoughtcrime.securesms.attachments.AttachmentId
@@ -27,13 +25,10 @@ import org.thoughtcrime.securesms.mms.MediaStream
import org.thoughtcrime.securesms.mms.SentMediaQuality import org.thoughtcrime.securesms.mms.SentMediaQuality
import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.util.MediaUtil import org.thoughtcrime.securesms.util.MediaUtil
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.File import java.io.File
import java.util.Optional import java.util.Optional
@@ -182,32 +177,6 @@ class AttachmentTableTest {
assertThat(highInfo.file.exists()).isEqualTo(true) assertThat(highInfo.file.exists()).isEqualTo(true)
} }
@Test
fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() {
// Insert attachment metadata for properly-padded attachment
val plaintext = byteArrayOf(1, 2, 3, 4)
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
val paddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully()
val ciphertext = encryptPrePaddedBytes(paddedPlaintext, key, iv)
val digest = getDigest(ciphertext)
val cipherFile = getTempFile()
cipherFile.writeBytes(ciphertext)
val mmsId = -1L
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream)
// Verify the digest hasn't changed
val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
assertArrayEquals(digest, newDigest)
}
@Test @Test
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() { fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() {
// Given an attachment with some plaintextHash+remoteKey // Given an attachment with some plaintextHash+remoteKey

View File

@@ -110,6 +110,7 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.thoughtcrime.securesms.util.Util; import org.thoughtcrime.securesms.util.Util;
import org.thoughtcrime.securesms.util.VersionTracker; import org.thoughtcrime.securesms.util.VersionTracker;
import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper; import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper;
import org.whispersystems.signalservice.api.backup.MediaName;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import java.io.InterruptedIOException; import java.io.InterruptedIOException;

View File

@@ -19,6 +19,7 @@ import okio.ByteString
import okio.ByteString.Companion.toByteString import okio.ByteString.Companion.toByteString
import org.greenrobot.eventbus.EventBus import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64 import org.signal.core.util.Base64
import org.signal.core.util.Base64.decodeBase64OrThrow
import org.signal.core.util.ByteSize import org.signal.core.util.ByteSize
import org.signal.core.util.CursorUtil import org.signal.core.util.CursorUtil
import org.signal.core.util.EventTimer import org.signal.core.util.EventTimer
@@ -37,7 +38,7 @@ import org.signal.core.util.getForeignKeyViolations
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.signal.core.util.money.FiatMoney import org.signal.core.util.money.FiatMoney
import org.signal.core.util.requireIntOrNull import org.signal.core.util.requireIntOrNull
import org.signal.core.util.requireNonNullBlob import org.signal.core.util.requireNonNullString
import org.signal.core.util.stream.NonClosingOutputStream import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction import org.signal.core.util.withinTransaction
@@ -1290,11 +1291,11 @@ object BackupRepository {
return initBackupAndFetchAuth() return initBackupAndFetchAuth()
.then { credential -> .then { credential ->
SignalNetwork.archive.getMessageBackupUploadForm(SignalStore.account.requireAci(), credential.messageBackupAccess) SignalNetwork.archive.getMessageBackupUploadForm(SignalStore.account.requireAci(), credential.messageBackupAccess)
.also { Log.i(TAG, "UploadFormResult: $it") } .also { Log.i(TAG, "UploadFormResult: ${it::class.simpleName}") }
} }
.then { form -> .then { form ->
SignalNetwork.archive.getBackupResumableUploadUrl(form) SignalNetwork.archive.getBackupResumableUploadUrl(form)
.also { Log.i(TAG, "ResumableUploadUrlResult: $it") } .also { Log.i(TAG, "ResumableUploadUrlResult: ${it::class.simpleName}") }
.map { ResumableMessagesBackupUploadSpec(attachmentUploadForm = form, resumableUri = it) } .map { ResumableMessagesBackupUploadSpec(attachmentUploadForm = form, resumableUri = it) }
} }
} }
@@ -1307,7 +1308,7 @@ object BackupRepository {
): NetworkResult<Unit> { ): NetworkResult<Unit> {
val (form, resumableUploadUrl) = resumableSpec val (form, resumableUploadUrl) = resumableSpec
return SignalNetwork.archive.uploadBackupFile(form, resumableUploadUrl, backupStream, backupStreamLength, progressListener) return SignalNetwork.archive.uploadBackupFile(form, resumableUploadUrl, backupStream, backupStreamLength, progressListener)
.also { Log.i(TAG, "UploadBackupFileResult: $it") } .also { Log.i(TAG, "UploadBackupFileResult: ${it::class.simpleName}") }
} }
fun downloadBackupFile(destination: File, listener: ProgressListener? = null): NetworkResult<Unit> { fun downloadBackupFile(destination: File, listener: ProgressListener? = null): NetworkResult<Unit> {
@@ -1395,7 +1396,7 @@ object BackupRepository {
.map { response -> .map { response ->
SignalDatabase.attachments.setArchiveCdn(attachmentId = attachment.attachmentId, archiveCdn = response.cdn) SignalDatabase.attachments.setArchiveCdn(attachmentId = attachment.attachmentId, archiveCdn = response.cdn)
} }
.also { Log.i(TAG, "archiveMediaResult: $it") } .also { Log.i(TAG, "archiveMediaResult: ${it::class.simpleName}") }
} }
fun deleteAbandonedMediaObjects(mediaObjects: Collection<ArchivedMediaObject>): NetworkResult<Unit> { fun deleteAbandonedMediaObjects(mediaObjects: Collection<ArchivedMediaObject>): NetworkResult<Unit> {
@@ -1421,7 +1422,7 @@ object BackupRepository {
mediaToDelete = mediaToDelete mediaToDelete = mediaToDelete
) )
} }
.also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: $it") } .also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: ${it::class.simpleName}") }
} }
fun deleteBackup(): NetworkResult<Unit> { fun deleteBackup(): NetworkResult<Unit> {
@@ -1477,7 +1478,7 @@ object BackupRepository {
.map { .map {
SignalDatabase.attachments.clearAllArchiveData() SignalDatabase.attachments.clearAllArchiveData()
} }
.also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: $it") } .also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: ${it::class.simpleName}") }
} }
/** /**
@@ -1512,7 +1513,7 @@ object BackupRepository {
credentialStore.cdnReadCredentials = it.result credentialStore.cdnReadCredentials = it.result
} }
} }
.also { Log.i(TAG, "getCdnReadCredentialsResult: $it") } .also { Log.i(TAG, "getCdnReadCredentialsResult: ${it::class.simpleName}") }
} }
fun restoreBackupTier(aci: ACI): MessageBackupTier? { fun restoreBackupTier(aci: ACI): MessageBackupTier? {
@@ -1965,8 +1966,8 @@ class ArchiveMediaItemIterator(private val cursor: Cursor) : Iterator<ArchiveMed
override fun hasNext(): Boolean = !cursor.isAfterLast override fun hasNext(): Boolean = !cursor.isAfterLast
override fun next(): ArchiveMediaItem { override fun next(): ArchiveMediaItem {
val plaintextHash = cursor.requireNonNullBlob(AttachmentTable.DATA_HASH_END) val plaintextHash = cursor.requireNonNullString(AttachmentTable.DATA_HASH_END).decodeBase64OrThrow()
val remoteKey = cursor.requireNonNullBlob(AttachmentTable.REMOTE_KEY) val remoteKey = cursor.requireNonNullString(AttachmentTable.REMOTE_KEY).decodeBase64OrThrow()
val cdn = cursor.requireIntOrNull(AttachmentTable.ARCHIVE_CDN) val cdn = cursor.requireIntOrNull(AttachmentTable.ARCHIVE_CDN)
val mediaId = MediaName.fromPlaintextHashAndRemoteKey(plaintextHash, remoteKey).toMediaId(SignalStore.backup.mediaRootBackupKey).encode() val mediaId = MediaName.fromPlaintextHashAndRemoteKey(plaintextHash, remoteKey).toMediaId(SignalStore.backup.mediaRootBackupKey).encode()

View File

@@ -25,7 +25,6 @@ import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.RadioButton
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHostState import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Surface import androidx.compose.material3.Surface
@@ -77,6 +76,8 @@ import org.thoughtcrime.securesms.backup.v2.ui.BackupAlertBottomSheet
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.DialogState import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.DialogState
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.ScreenState import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.ScreenState
import org.thoughtcrime.securesms.compose.ComposeFragment import org.thoughtcrime.securesms.compose.ComposeFragment
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobs.ArchiveAttachmentReconciliationJob
import org.thoughtcrime.securesms.jobs.LocalBackupJob import org.thoughtcrime.securesms.jobs.LocalBackupJob
import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.util.Util
@@ -149,9 +150,9 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
mainContent = { mainContent = {
Screen( Screen(
state = state, state = state,
onBackupTierSelected = { tier -> viewModel.onBackupTierSelected(tier) },
onCheckRemoteBackupStateClicked = { viewModel.checkRemoteBackupState() }, onCheckRemoteBackupStateClicked = { viewModel.checkRemoteBackupState() },
onEnqueueRemoteBackupClicked = { viewModel.triggerBackupJob() }, onEnqueueRemoteBackupClicked = { viewModel.triggerBackupJob() },
onEnqueueReconciliationClicked = { AppDependencies.jobManager.add(ArchiveAttachmentReconciliationJob(forced = true)) },
onHaltAllBackupJobsClicked = { viewModel.haltAllJobs() }, onHaltAllBackupJobsClicked = { viewModel.haltAllJobs() },
onValidateBackupClicked = { viewModel.validateBackup() }, onValidateBackupClicked = { viewModel.validateBackup() },
onSaveEncryptedBackupToDiskClicked = { onSaveEncryptedBackupToDiskClicked = {
@@ -222,7 +223,7 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
onDeleteRemoteBackup = { onDeleteRemoteBackup = {
MaterialAlertDialogBuilder(context) MaterialAlertDialogBuilder(context)
.setTitle("Are you sure?") .setTitle("Are you sure?")
.setMessage("This will delete all of your remote backup data?") .setMessage("This will delete all of your remote backup data!")
.setPositiveButton("Delete remote data") { _, _ -> .setPositiveButton("Delete remote data") { _, _ ->
lifecycleScope.launch { lifecycleScope.launch {
val success = viewModel.deleteRemoteBackupData() val success = viewModel.deleteRemoteBackupData()
@@ -234,6 +235,21 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
.setNegativeButton("Cancel", null) .setNegativeButton("Cancel", null)
.show() .show()
}, },
onClearLocalMediaBackupState = {
MaterialAlertDialogBuilder(context)
.setTitle("Are you sure?")
.setMessage("This will cause you to have to re-upload all of your media!")
.setPositiveButton("Clear local media state") { _, _ ->
lifecycleScope.launch {
viewModel.clearLocalMediaBackupState()
withContext(Dispatchers.Main) {
Toast.makeText(requireContext(), "Done!", Toast.LENGTH_SHORT).show()
}
}
}
.setNegativeButton("Cancel", null)
.show()
},
onDisplayInitialBackupFailureSheet = { onDisplayInitialBackupFailureSheet = {
BackupRepository.displayInitialBackupFailureNotification() BackupRepository.displayInitialBackupFailureNotification()
BackupAlertBottomSheet BackupAlertBottomSheet
@@ -312,8 +328,8 @@ fun Screen(
onImportNewStyleLocalBackupClicked: () -> Unit = {}, onImportNewStyleLocalBackupClicked: () -> Unit = {},
onCheckRemoteBackupStateClicked: () -> Unit = {}, onCheckRemoteBackupStateClicked: () -> Unit = {},
onEnqueueRemoteBackupClicked: () -> Unit = {}, onEnqueueRemoteBackupClicked: () -> Unit = {},
onEnqueueReconciliationClicked: () -> Unit = {},
onWipeDataAndRestoreFromRemoteClicked: () -> Unit = {}, onWipeDataAndRestoreFromRemoteClicked: () -> Unit = {},
onBackupTierSelected: (MessageBackupTier?) -> Unit = {},
onHaltAllBackupJobsClicked: () -> Unit = {}, onHaltAllBackupJobsClicked: () -> Unit = {},
onSavePlaintextCopyOfRemoteBackupClicked: () -> Unit = {}, onSavePlaintextCopyOfRemoteBackupClicked: () -> Unit = {},
onValidateBackupClicked: () -> Unit = {}, onValidateBackupClicked: () -> Unit = {},
@@ -322,6 +338,7 @@ fun Screen(
onImportEncryptedBackupFromDiskClicked: () -> Unit = {}, onImportEncryptedBackupFromDiskClicked: () -> Unit = {},
onImportEncryptedBackupFromDiskDismissed: () -> Unit = {}, onImportEncryptedBackupFromDiskDismissed: () -> Unit = {},
onImportEncryptedBackupFromDiskConfirmed: (aci: String, backupKey: String) -> Unit = { _, _ -> }, onImportEncryptedBackupFromDiskConfirmed: (aci: String, backupKey: String) -> Unit = { _, _ -> },
onClearLocalMediaBackupState: () -> Unit = {},
onDeleteRemoteBackup: () -> Unit = {}, onDeleteRemoteBackup: () -> Unit = {},
onDisplayInitialBackupFailureSheet: () -> Unit = {} onDisplayInitialBackupFailureSheet: () -> Unit = {}
) { ) {
@@ -353,21 +370,6 @@ fun Screen(
.fillMaxSize() .fillMaxSize()
.verticalScroll(scrollState) .verticalScroll(scrollState)
) { ) {
Row(verticalAlignment = Alignment.CenterVertically) {
Text("Tier", fontWeight = FontWeight.Bold)
options.forEach { option ->
Row(verticalAlignment = Alignment.CenterVertically) {
RadioButton(
selected = option.value == state.backupTier,
onClick = { onBackupTierSelected(option.value) }
)
Text(option.key)
}
}
}
Dividers.Default()
Rows.TextRow( Rows.TextRow(
text = { text = {
Text( Text(
@@ -392,6 +394,12 @@ fun Screen(
onClick = onEnqueueRemoteBackupClicked onClick = onEnqueueRemoteBackupClicked
) )
Rows.TextRow(
text = "Enqueue reconciliation job",
label = "Schedules a job that will ensure local and remote media state are in sync.",
onClick = onEnqueueReconciliationClicked
)
Rows.TextRow( Rows.TextRow(
text = "Halt all backup jobs", text = "Halt all backup jobs",
label = "Stops all backup-related jobs to the best of our ability.", label = "Stops all backup-related jobs to the best of our ability.",
@@ -513,6 +521,12 @@ fun Screen(
onClick = onDeleteRemoteBackup onClick = onDeleteRemoteBackup
) )
Rows.TextRow(
text = "Clear local media backup state",
label = "Resets local state tracking so you think you haven't uploaded any media. The media still exists on the server.",
onClick = onClearLocalMediaBackupState
)
Dividers.Default() Dividers.Default()
Rows.TextRow( Rows.TextRow(

View File

@@ -292,11 +292,6 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
} }
} }
fun onBackupTierSelected(backupTier: MessageBackupTier?) {
SignalStore.backup.backupTier = backupTier
_state.value = _state.value.copy(backupTier = backupTier)
}
fun onImportSelected() { fun onImportSelected() {
_state.value = _state.value.copy(dialog = DialogState.ImportCredentials) _state.value = _state.value.copy(dialog = DialogState.ImportCredentials)
} }
@@ -398,6 +393,10 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
return@withContext false return@withContext false
} }
suspend fun clearLocalMediaBackupState() = withContext(Dispatchers.IO) {
SignalDatabase.attachments.clearAllArchiveData()
}
override fun onCleared() { override fun onCleared() {
disposables.clear() disposables.clear()
} }

View File

@@ -55,7 +55,7 @@ fun InternalBackupStatsTab(stats: InternalBackupPlaygroundViewModel.StatsState,
Spacer(modifier = Modifier.size(16.dp)) Spacer(modifier = Modifier.size(16.dp))
Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}") Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}")
Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") Text(text = "Unique/archived verified plaintextHash count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")
Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}") Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}")
Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}") Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}")
Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}") Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")

View File

@@ -679,15 +679,15 @@ class AttachmentTable(
/** /**
* Sets the archive transfer state for the given attachment by digest. * Sets the archive transfer state for the given attachment by digest.
*/ */
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray) { fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray): Boolean {
writableDatabase return writableDatabase
.update(TABLE_NAME) .update(TABLE_NAME)
.values( .values(
ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value, ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value,
ARCHIVE_CDN to null ARCHIVE_CDN to null
) )
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey)) .where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey))
.run() .run() > 0
} }
/** /**
@@ -2593,11 +2593,11 @@ class AttachmentTable(
.associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) } .associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) }
val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L) val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L)
val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L) val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L) val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L)
val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L) val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L)
val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L) val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L)
val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L) val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L)
val pendingUploadBytes = getPendingArchiveUploadBytes() val pendingUploadBytes = getPendingArchiveUploadBytes()
val uploadedAttachmentBytes = readableDatabase val uploadedAttachmentBytes = readableDatabase

View File

@@ -127,7 +127,13 @@ class ArchiveAttachmentReconciliationJob private constructor(
val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it) val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it)
// TODO [backup] Re-enqueue thumbnail uploads if necessary // TODO [backup] Re-enqueue thumbnail uploads if necessary
if (!entry.isThumbnail) { if (!entry.isThumbnail) {
SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey) val success = SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey)
if (!success) {
Log.e(TAG, "Failed to reset archive transfer state by remote hash/key!")
if (RemoteConfig.internalUser) {
throw RuntimeException("Failed to reset archive transfer state by remote hash/key!")
}
}
} }
} }

View File

@@ -85,8 +85,8 @@ class ArchiveThumbnailUploadJob private constructor(
return Result.success() return Result.success()
} }
if (attachment.remoteDigest == null) { if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "$attachmentId was never uploaded! Cannot proceed.") Log.w(TAG, "$attachmentId has no integrity check! Cannot proceed.")
return Result.success() return Result.success()
} }
@@ -153,7 +153,7 @@ class ArchiveThumbnailUploadJob private constructor(
data = thumbnailResult.data data = thumbnailResult.data
) )
Log.d(TAG, "Successfully archived thumbnail for $attachmentId mediaName=${attachment.requireThumbnailMediaName()}") Log.d(TAG, "Successfully archived thumbnail for $attachmentId")
Result.success() Result.success()
} }

View File

@@ -35,6 +35,7 @@ import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.AttachmentUtil import org.thoughtcrime.securesms.util.AttachmentUtil
import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.RemoteConfig
import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
@@ -277,12 +278,18 @@ class AttachmentDownloadJob private constructor(
} }
} }
if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "Attachment has no integrity check!")
throw InvalidAttachmentException("Attachment has no integrity check!")
}
val decryptingStream = AppDependencies val decryptingStream = AppDependencies
.signalServiceMessageReceiver .signalServiceMessageReceiver
.retrieveAttachment( .retrieveAttachment(
pointer, pointer,
attachmentFile, attachmentFile,
maxReceiveSize, maxReceiveSize,
IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash),
progressListener progressListener
) )

View File

@@ -16,6 +16,8 @@ import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.profiles.AvatarHelper; import org.thoughtcrime.securesms.profiles.AvatarHelper;
import org.signal.core.util.Hex; import org.signal.core.util.Hex;
import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException; import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
@@ -84,9 +86,16 @@ public final class AvatarGroupsV1DownloadJob extends BaseJob {
attachment = File.createTempFile("avatar", "tmp", context.getCacheDir()); attachment = File.createTempFile("avatar", "tmp", context.getCacheDir());
attachment.deleteOnExit(); attachment.deleteOnExit();
SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver(); SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null); SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null);
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE);
if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("Missing digest!");
}
IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigest(pointer.getDigest().get());
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE, integrityCheck);
AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream); AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream);
SignalDatabase.groups().onAvatarUpdated(groupId, true); SignalDatabase.groups().onAvatarUpdated(groupId, true);

View File

@@ -12,6 +12,7 @@ import org.thoughtcrime.securesms.net.NotPushRegisteredException
import org.thoughtcrime.securesms.profiles.AvatarHelper import org.thoughtcrime.securesms.profiles.AvatarHelper
import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.recipients.Recipient import org.thoughtcrime.securesms.recipients.Recipient
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.multidevice.DeviceContact import org.whispersystems.signalservice.api.messages.multidevice.DeviceContact
import org.whispersystems.signalservice.api.messages.multidevice.DeviceContactsInputStream import org.whispersystems.signalservice.api.messages.multidevice.DeviceContactsInputStream
@@ -59,7 +60,7 @@ class MultiDeviceContactSyncJob(parameters: Parameters, private val attachmentPo
try { try {
val contactsFile: File = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(context) val contactsFile: File = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(context)
AppDependencies.signalServiceMessageReceiver AppDependencies.signalServiceMessageReceiver
.retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE) .retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE, IntegrityCheck.forEncryptedDigest(contactAttachment.digest.get()))
.use(this::processContactFile) .use(this::processContactFile)
} catch (e: MissingConfigurationException) { } catch (e: MissingConfigurationException) {
throw IOException(e) throw IOException(e)

View File

@@ -40,6 +40,7 @@ import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.transport.RetryLaterException import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.RemoteConfig
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException
@@ -49,6 +50,7 @@ import org.whispersystems.signalservice.api.push.exceptions.RangeException
import java.io.File import java.io.File
import java.io.IOException import java.io.IOException
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.jvm.optionals.getOrNull
import kotlin.math.max import kotlin.math.max
import kotlin.math.pow import kotlin.math.pow
import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.days
@@ -172,12 +174,12 @@ class RestoreAttachmentJob private constructor(
val attachment = SignalDatabase.attachments.getAttachment(attachmentId) val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
if (attachment == null) { if (attachment == null) {
Log.w(TAG, "attachment no longer exists.") Log.w(TAG, "[$attachmentId] Attachment no longer exists.")
return return
} }
if (attachment.isPermanentlyFailed) { if (attachment.isPermanentlyFailed) {
Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.") Log.w(TAG, "[$attachmentId] Attachment was marked as a permanent failure. Refusing to download.")
return return
} }
@@ -186,7 +188,7 @@ class RestoreAttachmentJob private constructor(
attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED && attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED &&
attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED
) { ) {
Log.w(TAG, "Attachment does not need to be restored. Current state: ${attachment.transferState}") Log.w(TAG, "[$attachmentId] Attachment does not need to be restored. Current state: ${attachment.transferState}")
return return
} }
@@ -231,14 +233,20 @@ class RestoreAttachmentJob private constructor(
var archiveFile: File? = null var archiveFile: File? = null
var useArchiveCdn = false var useArchiveCdn = false
if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "[$attachmentId] Attachment has no integrity check! Cannot proceed.")
markPermanentlyFailed(attachmentId)
return
}
try { try {
if (attachment.size > maxReceiveSize) { if (attachment.size > maxReceiveSize) {
throw MmsException("Attachment too large, failing download") throw MmsException("[$attachmentId] Attachment too large, failing download")
} }
useArchiveCdn = if (SignalStore.backup.backsUpMedia && !forceTransitTier) { useArchiveCdn = if (SignalStore.backup.backsUpMedia && !forceTransitTier) {
if (attachment.archiveTransferState != AttachmentTable.ArchiveTransferState.FINISHED) { if (attachment.archiveTransferState != AttachmentTable.ArchiveTransferState.FINISHED) {
throw InvalidAttachmentException("Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}") throw InvalidAttachmentException("[$attachmentId] Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}")
} }
true true
} else { } else {
@@ -259,7 +267,9 @@ class RestoreAttachmentJob private constructor(
} }
val decryptingStream = if (useArchiveCdn) { val decryptingStream = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId) // TODO next PR: remove archive transfer file and just use the regular attachment file
archiveFile = attachmentFile
// archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn ?: RemoteConfig.backupFallbackArchiveCdn).successOrThrow().headers val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn ?: RemoteConfig.backupFallbackArchiveCdn).successOrThrow().headers
messageReceiver messageReceiver
@@ -269,7 +279,6 @@ class RestoreAttachmentJob private constructor(
cdnCredentials, cdnCredentials,
archiveFile, archiveFile,
pointer, pointer,
attachmentFile,
maxReceiveSize, maxReceiveSize,
progressListener progressListener
) )
@@ -279,6 +288,7 @@ class RestoreAttachmentJob private constructor(
pointer, pointer,
attachmentFile, attachmentFile,
maxReceiveSize, maxReceiveSize,
IntegrityCheck.forEncryptedDigestAndPlaintextHash(pointer.digest.getOrNull(), attachment.dataHash),
progressListener progressListener
) )
} }
@@ -286,7 +296,7 @@ class RestoreAttachmentJob private constructor(
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, decryptingStream, if (manual) System.currentTimeMillis().milliseconds else null) SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, decryptingStream, if (manual) System.currentTimeMillis().milliseconds else null)
} catch (e: RangeException) { } catch (e: RangeException) {
val transferFile = archiveFile ?: attachmentFile val transferFile = archiveFile ?: attachmentFile
Log.w(TAG, "Range exception, file size " + transferFile.length(), e) Log.w(TAG, "[$attachmentId] Range exception, file size " + transferFile.length(), e)
if (transferFile.delete()) { if (transferFile.delete()) {
Log.i(TAG, "Deleted temp download file to recover") Log.i(TAG, "Deleted temp download file to recover")
throw RetryLaterException(e) throw RetryLaterException(e)
@@ -299,7 +309,7 @@ class RestoreAttachmentJob private constructor(
} catch (e: NonSuccessfulResponseCodeException) { } catch (e: NonSuccessfulResponseCodeException) {
if (SignalStore.backup.backsUpMedia) { if (SignalStore.backup.backsUpMedia) {
if (e.code == 404 && !forceTransitTier && attachment.remoteLocation?.isNotBlank() == true) { if (e.code == 404 && !forceTransitTier && attachment.remoteLocation?.isNotBlank() == true) {
Log.i(TAG, "Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.") Log.i(TAG, "[$attachmentId] Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.")
if (RemoteConfig.internalUser) { if (RemoteConfig.internalUser) {
postFailedToDownloadFromArchiveNotification() postFailedToDownloadFromArchiveNotification()
} }
@@ -316,18 +326,18 @@ class RestoreAttachmentJob private constructor(
} }
} }
Log.w(TAG, "Experienced exception while trying to download an attachment.", e) Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId) markFailed(attachmentId)
} catch (e: MmsException) { } catch (e: MmsException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e) Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId) markFailed(attachmentId)
} catch (e: MissingConfigurationException) { } catch (e: MissingConfigurationException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e) Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId) markFailed(attachmentId)
} catch (e: InvalidMessageException) { } catch (e: InvalidMessageException) {
Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e) Log.w(TAG, "[$attachmentId] Experienced an InvalidMessageException while trying to download an attachment.", e)
if (e.cause is InvalidMacException) { if (e.cause is InvalidMacException) {
Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.") Log.w(TAG, "[$attachmentId] Detected an invalid mac. Treating as a permanent failure.")
markPermanentlyFailed(attachmentId) markPermanentlyFailed(attachmentId)
} else { } else {
markFailed(attachmentId) markFailed(attachmentId)

View File

@@ -120,7 +120,6 @@ class RestoreAttachmentThumbnailJob private constructor(
val maxThumbnailSize: Long = RemoteConfig.maxAttachmentReceiveSizeBytes val maxThumbnailSize: Long = RemoteConfig.maxAttachmentReceiveSizeBytes
val thumbnailTransferFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile() val thumbnailTransferFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile()
val thumbnailFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile()
val progressListener = object : SignalServiceAttachment.ProgressListener { val progressListener = object : SignalServiceAttachment.ProgressListener {
override fun onAttachmentProgress(progress: AttachmentTransferProgress) = Unit override fun onAttachmentProgress(progress: AttachmentTransferProgress) = Unit
@@ -137,7 +136,6 @@ class RestoreAttachmentThumbnailJob private constructor(
cdnCredentials, cdnCredentials,
thumbnailTransferFile, thumbnailTransferFile,
pointer, pointer,
thumbnailFile,
maxThumbnailSize, maxThumbnailSize,
progressListener progressListener
) )

View File

@@ -23,6 +23,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.mms.MmsException import org.thoughtcrime.securesms.mms.MmsException
import org.whispersystems.signalservice.api.backup.MediaName import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.StreamSupplier import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.StreamSupplier
import java.io.IOException import java.io.IOException
@@ -154,7 +155,10 @@ class RestoreLocalAttachmentJob private constructor(
streamLength = size, streamLength = size,
plaintextLength = attachment.size, plaintextLength = attachment.size,
combinedKeyMaterial = combinedKey, combinedKeyMaterial = combinedKey,
digest = attachment.remoteDigest, integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash(
encryptedDigest = attachment.remoteDigest,
plaintextHash = attachment.dataHash
),
incrementalDigest = null, incrementalDigest = null,
incrementalMacChunkSize = 0 incrementalMacChunkSize = 0
).use { input -> ).use { input ->

View File

@@ -23,6 +23,7 @@ import org.signal.core.util.Base64;
import org.whispersystems.signalservice.api.backup.MediaName; import org.whispersystems.signalservice.api.backup.MediaName;
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.signal.core.util.stream.TailerInputStream; import org.signal.core.util.stream.TailerInputStream;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
@@ -100,11 +101,13 @@ class PartDataSource implements DataSource {
long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size)); long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength); AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength);
if (attachment.remoteDigest == null) { if (attachment.remoteDigest == null && attachment.dataHash == null) {
throw new InvalidMessageException("Missing digest!"); throw new InvalidMessageException("Missing digest!");
} }
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash);
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, integrityCheck, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e); throw new IOException("Error decrypting attachment stream!", e);
} }

View File

@@ -7,14 +7,16 @@ import io.mockk.mockk
import io.mockk.mockkObject import io.mockk.mockkObject
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.signal.core.util.Base64
import org.thoughtcrime.securesms.MockCursor import org.thoughtcrime.securesms.MockCursor
import org.thoughtcrime.securesms.keyvalue.BackupValues import org.thoughtcrime.securesms.keyvalue.BackupValues
import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey import org.whispersystems.signalservice.api.backup.MediaRootBackupKey
class ArchivedMediaObjectIteratorTest { class ArchivedMediaObjectIteratorTest {
private val cursor = mockk<MockCursor>(relaxed = true) { private val cursor = mockk<MockCursor>(relaxed = true) {
every { getString(any()) } returns "A" every { getString(any()) } returns Base64.encodeWithPadding(Util.getSecretBytes(32))
every { moveToPosition(any()) } answers { callOriginal() } every { moveToPosition(any()) } answers { callOriginal() }
every { moveToNext() } answers { callOriginal() } every { moveToNext() } answers { callOriginal() }
every { position } answers { callOriginal() } every { position } answers { callOriginal() }

View File

@@ -7,10 +7,12 @@
package org.whispersystems.signalservice.api; package org.whispersystems.signalservice.api;
import org.signal.core.util.StreamUtil; import org.signal.core.util.StreamUtil;
import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream; import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener; import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
@@ -63,9 +65,9 @@ public class SignalServiceMessageReceiver {
* @throws IOException * @throws IOException
* @throws InvalidMessageException * @throws InvalidMessageException
*/ */
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes) public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck)
throws IOException, InvalidMessageException, MissingConfigurationException { throws IOException, InvalidMessageException, MissingConfigurationException {
return retrieveAttachment(pointer, destination, maxSizeBytes, null); return retrieveAttachment(pointer, destination, maxSizeBytes, integrityCheck, null);
} }
public InputStream retrieveProfileAvatar(String path, File destination, ProfileKey profileKey, long maxSizeBytes) public InputStream retrieveProfileAvatar(String path, File destination, ProfileKey profileKey, long maxSizeBytes)
@@ -96,9 +98,9 @@ public class SignalServiceMessageReceiver {
* @throws IOException * @throws IOException
* @throws InvalidMessageException * @throws InvalidMessageException
*/ */
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener) public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck, ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException { throws IOException, InvalidMessageException, MissingConfigurationException {
if (pointer.getDigest().isEmpty()) throw new InvalidMessageException("No attachment digest!"); if (integrityCheck == null) throw new InvalidMessageException("No integrity check!");
if (pointer.getKey() == null) throw new InvalidMessageException("No key!"); if (pointer.getKey() == null) throw new InvalidMessageException("No key!");
socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener); socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
@@ -112,7 +114,7 @@ public class SignalServiceMessageReceiver {
destination, destination,
pointer.getSize().orElse(0), pointer.getSize().orElse(0),
pointer.getKey(), pointer.getKey(),
pointer.getDigest().get(), integrityCheck,
null, null,
0 0
); );
@@ -126,7 +128,6 @@ public class SignalServiceMessageReceiver {
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress. * @param listener An optional listener (may be null) to receive callbacks on download progress.
* *
* @return An InputStream that streams the plaintext attachment contents. * @return An InputStream that streams the plaintext attachment contents.
@@ -136,7 +137,6 @@ public class SignalServiceMessageReceiver {
@Nonnull Map<String, String> readCredentialHeaders, @Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination, @Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer, @Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes, long maxSizeBytes,
@Nullable ProgressListener listener) @Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException throws IOException, InvalidMessageException, MissingConfigurationException
@@ -154,7 +154,7 @@ public class SignalServiceMessageReceiver {
return AttachmentCipherInputStream.createForArchivedMedia( return AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial, archivedMediaKeyMaterial,
attachmentDestination, archiveDestination,
originalCipherLength, originalCipherLength,
pointer.getSize().orElse(0), pointer.getSize().orElse(0),
pointer.getKey(), pointer.getKey(),
@@ -171,7 +171,6 @@ public class SignalServiceMessageReceiver {
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress. * @param listener An optional listener (may be null) to receive callbacks on download progress.
* *
* @return An InputStream that streams the plaintext attachment contents. * @return An InputStream that streams the plaintext attachment contents.
@@ -180,7 +179,6 @@ public class SignalServiceMessageReceiver {
@Nonnull Map<String, String> readCredentialHeaders, @Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination, @Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer, @Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes, long maxSizeBytes,
@Nullable ProgressListener listener) @Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException throws IOException, InvalidMessageException, MissingConfigurationException
@@ -198,7 +196,7 @@ public class SignalServiceMessageReceiver {
return AttachmentCipherInputStream.createForArchivedThumbnail( return AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial, archivedMediaKeyMaterial,
attachmentDestination, archiveDestination,
originalCipherLength, originalCipherLength,
pointer.getSize().orElse(0), pointer.getSize().orElse(0),
pointer.getKey() pointer.getKey()

View File

@@ -16,7 +16,7 @@ class ArchiveGetMediaItemsResponse(
@JsonProperty val mediaDir: String?, @JsonProperty val mediaDir: String?,
@JsonProperty val cursor: String? @JsonProperty val cursor: String?
) { ) {
class StoredMediaObject( data class StoredMediaObject(
@JsonProperty val cdn: Int, @JsonProperty val cdn: Int,
@JsonProperty val mediaId: String, @JsonProperty val mediaId: String,
@JsonProperty val objectLength: Long @JsonProperty val objectLength: Long

View File

@@ -13,7 +13,7 @@ import com.fasterxml.jackson.annotation.JsonProperty
class DeleteArchivedMediaRequest( class DeleteArchivedMediaRequest(
@JsonProperty val mediaToDelete: List<ArchivedMediaObject> @JsonProperty val mediaToDelete: List<ArchivedMediaObject>
) { ) {
class ArchivedMediaObject( data class ArchivedMediaObject(
@JsonProperty val cdn: Int, @JsonProperty val cdn: Int,
@JsonProperty val mediaId: String @JsonProperty val mediaId: String
) )

View File

@@ -31,10 +31,6 @@ value class MediaName(val name: String) {
return mediaRootBackupKey.deriveMediaId(this) return mediaRootBackupKey.deriveMediaId(this)
} }
fun toByteArray(): ByteArray {
return name.toByteArray()
}
override fun toString(): String { override fun toString(): String {
return name return name
} }

View File

@@ -5,6 +5,7 @@
*/ */
package org.whispersystems.signalservice.api.crypto package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.Base64
import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.stream.LimitedInputStream import org.signal.core.util.stream.LimitedInputStream
import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.InvalidMessageException
@@ -51,7 +52,7 @@ object AttachmentCipherInputStream {
file: File, file: File,
plaintextLength: Long, plaintextLength: Long,
combinedKeyMaterial: ByteArray, combinedKeyMaterial: ByteArray,
digest: ByteArray, integrityCheck: IntegrityCheck,
incrementalDigest: ByteArray?, incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int incrementalMacChunkSize: Int
): InputStream { ): InputStream {
@@ -60,11 +61,9 @@ object AttachmentCipherInputStream {
streamLength = file.length(), streamLength = file.length(),
plaintextLength = plaintextLength, plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial, combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = digest, integrityCheck = integrityCheck,
plaintextHash = null,
incrementalDigest = incrementalDigest, incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize, incrementalMacChunkSize = incrementalMacChunkSize
ignoreDigest = false
) )
} }
@@ -81,7 +80,7 @@ object AttachmentCipherInputStream {
streamLength: Long, streamLength: Long,
plaintextLength: Long, plaintextLength: Long,
combinedKeyMaterial: ByteArray, combinedKeyMaterial: ByteArray,
digest: ByteArray, integrityCheck: IntegrityCheck,
incrementalDigest: ByteArray?, incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int incrementalMacChunkSize: Int
): InputStream { ): InputStream {
@@ -90,11 +89,9 @@ object AttachmentCipherInputStream {
streamLength = streamLength, streamLength = streamLength,
plaintextLength = plaintextLength, plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial, combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = digest, integrityCheck = integrityCheck,
plaintextHash = null,
incrementalDigest = incrementalDigest, incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize, incrementalMacChunkSize = incrementalMacChunkSize
ignoreDigest = false
) )
} }
@@ -130,11 +127,9 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength, streamLength = originalCipherTextLength,
plaintextLength = plaintextLength, plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial, combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = null, integrityCheck = IntegrityCheck(plaintextHash = plaintextHash, encryptedDigest = null),
plaintextHash = plaintextHash,
incrementalDigest = incrementalDigest, incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize, incrementalMacChunkSize = incrementalMacChunkSize
ignoreDigest = true
) )
} }
@@ -159,7 +154,7 @@ object AttachmentCipherInputStream {
val mac = initMac(keyMaterial.macKey) val mac = initMac(keyMaterial.macKey)
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!") throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength")
} }
return create( return create(
@@ -167,11 +162,9 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength, streamLength = originalCipherTextLength,
plaintextLength = plaintextLength, plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial, combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = null, integrityCheck = null,
plaintextHash = null,
incrementalDigest = null, incrementalDigest = null,
incrementalMacChunkSize = 0, incrementalMacChunkSize = 0
ignoreDigest = true
) )
} }
@@ -189,7 +182,7 @@ object AttachmentCipherInputStream {
} }
ByteArrayInputStream(data).use { inputStream -> ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null) verifyMacAndMaybeEncryptedDigest(inputStream, data.size.toLong(), mac, null)
} }
val encryptedStream = ByteArrayInputStream(data) val encryptedStream = ByteArrayInputStream(data)
@@ -211,11 +204,11 @@ object AttachmentCipherInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey) val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) { if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!") throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got ${file.length()}")
} }
FileInputStream(file).use { macVerificationStream -> FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null) verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), mac, null)
} }
val encryptedStream = FileInputStream(file) val encryptedStream = FileInputStream(file)
@@ -226,6 +219,10 @@ object AttachmentCipherInputStream {
return LimitedInputStream(inputStream, originalCipherTextLength) return LimitedInputStream(inputStream, originalCipherTextLength)
} }
/**
* @param integrityCheck If null, no integrity check is performed! This is a private method, so it's assumed that care has been taken to ensure that this is
* the correct course of action. Public methods should properly enforce when integrity checks are required.
*/
@JvmStatic @JvmStatic
@Throws(InvalidMessageException::class, IOException::class) @Throws(InvalidMessageException::class, IOException::class)
private fun create( private fun create(
@@ -233,11 +230,9 @@ object AttachmentCipherInputStream {
streamLength: Long, streamLength: Long,
plaintextLength: Long, plaintextLength: Long,
combinedKeyMaterial: ByteArray, combinedKeyMaterial: ByteArray,
encryptedDigest: ByteArray?, integrityCheck: IntegrityCheck?,
plaintextHash: ByteArray?,
incrementalDigest: ByteArray?, incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int, incrementalMacChunkSize: Int
ignoreDigest: Boolean
): InputStream { ): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey) val mac = initMac(keyMaterial.macKey)
@@ -246,25 +241,16 @@ object AttachmentCipherInputStream {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
} }
if (!ignoreDigest && encryptedDigest == null) {
throw InvalidMessageException("Missing digest!")
}
val wrappedStream: InputStream val wrappedStream: InputStream
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0 val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
if (!hasIncrementalMac) { if (hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream -> if (integrityCheck == null) {
verifyMac(macVerificationStream, streamLength, mac, encryptedDigest) throw InvalidMessageException("Missing integrityCheck for incremental mac validation!")
}
wrappedStream = streamSupplier.openStream()
} else {
if (encryptedDigest == null && plaintextHash == null) {
throw InvalidMessageException("Missing data (digest or plaintextHas) for incremental mac validation!")
} }
val digestValidatingStream = if (encryptedDigest != null) { val digestValidatingStream = if (integrityCheck.encryptedDigest != null) {
DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), encryptedDigest) DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), integrityCheck.encryptedDigest)
} else { } else {
streamSupplier.openStream() streamSupplier.openStream()
} }
@@ -279,6 +265,11 @@ object AttachmentCipherInputStream {
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest incrementalDigest
) )
} else {
streamSupplier.openStream().use { macVerificationStream ->
verifyMacAndMaybeEncryptedDigest(macVerificationStream, streamLength, mac, integrityCheck?.encryptedDigest)
}
wrappedStream = streamSupplier.openStream()
} }
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength) val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
@@ -286,12 +277,12 @@ object AttachmentCipherInputStream {
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher) val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength) val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength)
return if (plaintextHash != null) { return if (integrityCheck?.plaintextHash != null) {
if (plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) { if (integrityCheck.plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) {
throw InvalidMessageException("Invalid plaintext hash size: ${plaintextHash.size}") throw InvalidMessageException("Invalid plaintext hash size: ${integrityCheck.plaintextHash.size}")
} }
DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), plaintextHash) DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), integrityCheck.plaintextHash)
} else { } else {
paddinglessDecryptingStream paddinglessDecryptingStream
} }
@@ -326,7 +317,7 @@ object AttachmentCipherInputStream {
} }
@Throws(InvalidMessageException::class) @Throws(InvalidMessageException::class)
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { private fun verifyMacAndMaybeEncryptedDigest(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
try { try {
val digest = MessageDigest.getInstance("SHA256") val digest = MessageDigest.getInstance("SHA256")
var remainingData = Util.toIntExact(length) - mac.macLength var remainingData = Util.toIntExact(length) - mac.macLength
@@ -375,4 +366,33 @@ object AttachmentCipherInputStream {
@Throws(IOException::class) @Throws(IOException::class)
fun openStream(): InputStream fun openStream(): InputStream
} }
class IntegrityCheck(
val encryptedDigest: ByteArray?,
val plaintextHash: ByteArray?
) {
init {
if (encryptedDigest == null && plaintextHash == null) {
throw IllegalArgumentException("At least one of encryptedDigest or plaintextHash must be provided")
}
}
companion object {
@JvmStatic
fun forEncryptedDigest(encryptedDigest: ByteArray): IntegrityCheck {
return IntegrityCheck(encryptedDigest, null)
}
@JvmStatic
fun forPlaintextHash(plaintextHash: ByteArray): IntegrityCheck {
return IntegrityCheck(null, plaintextHash)
}
@JvmStatic
fun forEncryptedDigestAndPlaintextHash(encryptedDigest: ByteArray?, plaintextHash: String?): IntegrityCheck {
val plaintextHashBytes = plaintextHash?.let { Base64.decode(it) }
return IntegrityCheck(encryptedDigest, plaintextHashBytes)
}
}
}
} }

View File

@@ -13,6 +13,7 @@ import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException import org.signal.libsignal.protocol.incrementalmac.InvalidMacException
import org.signal.libsignal.protocol.kdf.HKDF import org.signal.libsignal.protocol.kdf.HKDF
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.crypto.AttachmentCipherTestHelper.createMediaKeyMaterial import org.whispersystems.signalservice.api.crypto.AttachmentCipherTestHelper.createMediaKeyMaterial
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory
@@ -32,19 +33,39 @@ import java.util.Random
class AttachmentCipherTest { class AttachmentCipherTest {
@Test @Test
fun attachment_encryptDecrypt_nonIncremental() { fun attachment_encryptDecrypt_nonIncremental_encryptedDigest() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST)
} }
@Test @Test
fun attachment_encryptDecrypt_incremental() { fun attachment_encryptDecrypt_nonIncremental_plaintextHash() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH)
}
@Test
fun attachment_encryptDecrypt_nonIncremental_bothIntegrityChecks() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH)
}
@Test
fun attachment_encryptDecrypt_incremental_encryptedDigest() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST)
}
@Test
fun attachment_encryptDecrypt_incremental_plaintextHash() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH)
}
@Test
fun attachment_encryptDecrypt_incremental_bothIntegrityChecks() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH)
} }
@Test @Test
fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() { fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) { for (i in 0..99) {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), integrityCheckMode = IntegrityCheckMode.BOTH)
} }
} }
@@ -52,18 +73,44 @@ class AttachmentCipherTest {
fun attachment_encryptDecrypt_incremental_manyFileSizes() { fun attachment_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac // Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) { for (i in 0..99) {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), IntegrityCheckMode.BOTH)
} }
} }
private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int) { private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int, integrityCheckMode: IntegrityCheckMode) {
val key = Util.getSecretBytes(64) val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize) val plaintextInput = Util.getSecretBytes(fileSize)
val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput)
val encryptResult = encryptData(plaintextInput, key, incremental) val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext) val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val integrityCheck = when (integrityCheckMode) {
IntegrityCheckMode.ENCRYPTED_DIGEST -> IntegrityCheck.forEncryptedDigest(encryptResult.digest)
IntegrityCheckMode.PLAINTEXT_HASH -> IntegrityCheck.forPlaintextHash(plaintextHash)
IntegrityCheckMode.BOTH -> IntegrityCheck(
encryptedDigest = encryptResult.digest,
plaintextHash = plaintextHash
)
}
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
cipherFile.delete()
}
private fun attachment_encryptDecrypt_plaintextHash(incremental: Boolean, fileSize: Int) {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput)
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false) val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput) assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -88,7 +135,8 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental) val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext) val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false) val plaintextOutput = inputStream.readFully(autoClose = false)
Assert.assertArrayEquals(plaintextInput, plaintextOutput) Assert.assertArrayEquals(plaintextInput, plaintextOutput)
@@ -117,7 +165,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext) cipherFile = writeToFile(encryptResult.ciphertext)
val badKey = ByteArray(64) val badKey = ByteArray(64)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, encryptResult.digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, integrityCheck, null, 0)
} finally { } finally {
cipherFile?.delete() cipherFile?.delete()
} }
@@ -147,7 +196,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(badMacCiphertext) cipherFile = writeToFile(badMacCiphertext)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
// In incremental mode, we'll only check the digest after reading the whole thing // In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) { if (incremental) {
@@ -159,16 +209,26 @@ class AttachmentCipherTest {
} }
@Test(expected = InvalidMessageException::class) @Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_nonIncremental() { fun attachment_decryptFailOnBadEncryptedDigest_nonIncremental() {
attachment_decryptFailOnBadDigest(incremental = false) attachment_decryptFailOnBadEncryptedDigest(incremental = false)
} }
@Test(expected = InvalidMessageException::class) @Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_incremental() { fun attachment_decryptFailOnBadEncryptedDigest_incremental() {
attachment_decryptFailOnBadDigest(incremental = true) attachment_decryptFailOnBadEncryptedDigest(incremental = true)
} }
private fun attachment_decryptFailOnBadDigest(incremental: Boolean) { @Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadPlaintextHash_nonIncremental() {
attachment_decryptFailOnBadPlaintextHash(incremental = false)
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadPlaintextHash_incremental() {
attachment_decryptFailOnBadPlaintextHash(incremental = true)
}
private fun attachment_decryptFailOnBadEncryptedDigest(incremental: Boolean) {
var cipherFile: File? = null var cipherFile: File? = null
try { try {
@@ -180,7 +240,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext) cipherFile = writeToFile(encryptResult.ciphertext)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, badDigest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) val integrityCheck = IntegrityCheck.forEncryptedDigest(badDigest)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
// In incremental mode, we'll only check the digest after reading the whole thing // In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) { if (incremental) {
@@ -191,6 +252,29 @@ class AttachmentCipherTest {
} }
} }
private fun attachment_decryptFailOnBadPlaintextHash(incremental: Boolean) {
var cipherFile: File? = null
try {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val badPlaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput).apply {
this[0] = (this[0] + 1).toByte()
}
val encryptResult = encryptData(plaintextInput, key, incremental)
cipherFile = writeToFile(encryptResult.ciphertext)
val integrityCheck = IntegrityCheck.forPlaintextHash(badPlaintextHash)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
StreamUtil.readFully(stream)
} finally {
cipherFile?.delete()
}
}
@Test @Test
fun attachment_decryptFailOnBadIncrementalDigest() { fun attachment_decryptFailOnBadIncrementalDigest() {
var cipherFile: File? = null var cipherFile: File? = null
@@ -205,7 +289,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext) cipherFile = writeToFile(encryptResult.ciphertext)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, badDigest, encryptResult.chunkSizeChoice) val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, badDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
fail(AssertionError("Expected to fail before hitting this line")) fail(AssertionError("Expected to fail before hitting this line"))
@@ -480,7 +565,8 @@ class AttachmentCipherTest {
val combinedData = plaintextInput1 + plaintextInput2 val combinedData = plaintextInput1 + plaintextInput2
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(combinedData) assertThat(plaintextOutput).isEqualTo(combinedData)
@@ -511,7 +597,8 @@ class AttachmentCipherTest {
val combinedData = plaintextInput1 + plaintextInput2 val combinedData = plaintextInput1 + plaintextInput2
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(combinedData) assertThat(plaintextOutput).isEqualTo(combinedData)
@@ -536,7 +623,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput) assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -567,7 +655,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(expectedData) assertThat(plaintextOutput).isEqualTo(expectedData)
@@ -596,7 +685,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData) val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0) val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream) val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput) assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -677,4 +767,10 @@ class AttachmentCipherTest {
return HKDF.deriveSecrets(shortKey, "Sticker Pack".toByteArray(), 64) return HKDF.deriveSecrets(shortKey, "Sticker Pack".toByteArray(), 64)
} }
} }
enum class IntegrityCheckMode {
ENCRYPTED_DIGEST,
PLAINTEXT_HASH,
BOTH
}
} }