Allow normal attachments to be validated with plaintextHashes.

This commit is contained in:
Greyson Parrelli
2025-06-23 12:13:30 -04:00
committed by Cody Henthorne
parent 607b83d65b
commit ec5452744d
23 changed files with 319 additions and 185 deletions

View File

@@ -8,7 +8,6 @@ import androidx.test.platform.app.InstrumentationRegistry
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNotEqualTo
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotEquals
import org.junit.Before
@@ -17,7 +16,6 @@ import org.junit.Test
import org.junit.runner.RunWith
import org.signal.core.util.Base64.decodeBase64OrThrow
import org.signal.core.util.copyTo
import org.signal.core.util.readFully
import org.signal.core.util.stream.NullOutputStream
import org.thoughtcrime.securesms.attachments.Attachment
import org.thoughtcrime.securesms.attachments.AttachmentId
@@ -27,13 +25,10 @@ import org.thoughtcrime.securesms.mms.MediaStream
import org.thoughtcrime.securesms.mms.SentMediaQuality
import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.util.MediaUtil
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream
import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import java.io.ByteArrayOutputStream
import java.io.File
import java.util.Optional
@@ -182,32 +177,6 @@ class AttachmentTableTest {
assertThat(highInfo.file.exists()).isEqualTo(true)
}
@Test
fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() {
// Insert attachment metadata for properly-padded attachment
val plaintext = byteArrayOf(1, 2, 3, 4)
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
val paddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully()
val ciphertext = encryptPrePaddedBytes(paddedPlaintext, key, iv)
val digest = getDigest(ciphertext)
val cipherFile = getTempFile()
cipherFile.writeBytes(ciphertext)
val mmsId = -1L
val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first()
// Give data to attachment table
val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4)
SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream)
// Verify the digest hasn't changed
val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!!
assertArrayEquals(digest, newDigest)
}
@Test
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey_singleMatch() {
// Given an attachment with some plaintextHash+remoteKey

View File

@@ -110,6 +110,7 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.thoughtcrime.securesms.util.Util;
import org.thoughtcrime.securesms.util.VersionTracker;
import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper;
import org.whispersystems.signalservice.api.backup.MediaName;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import java.io.InterruptedIOException;

View File

@@ -19,6 +19,7 @@ import okio.ByteString
import okio.ByteString.Companion.toByteString
import org.greenrobot.eventbus.EventBus
import org.signal.core.util.Base64
import org.signal.core.util.Base64.decodeBase64OrThrow
import org.signal.core.util.ByteSize
import org.signal.core.util.CursorUtil
import org.signal.core.util.EventTimer
@@ -37,7 +38,7 @@ import org.signal.core.util.getForeignKeyViolations
import org.signal.core.util.logging.Log
import org.signal.core.util.money.FiatMoney
import org.signal.core.util.requireIntOrNull
import org.signal.core.util.requireNonNullBlob
import org.signal.core.util.requireNonNullString
import org.signal.core.util.stream.NonClosingOutputStream
import org.signal.core.util.urlEncode
import org.signal.core.util.withinTransaction
@@ -1290,11 +1291,11 @@ object BackupRepository {
return initBackupAndFetchAuth()
.then { credential ->
SignalNetwork.archive.getMessageBackupUploadForm(SignalStore.account.requireAci(), credential.messageBackupAccess)
.also { Log.i(TAG, "UploadFormResult: $it") }
.also { Log.i(TAG, "UploadFormResult: ${it::class.simpleName}") }
}
.then { form ->
SignalNetwork.archive.getBackupResumableUploadUrl(form)
.also { Log.i(TAG, "ResumableUploadUrlResult: $it") }
.also { Log.i(TAG, "ResumableUploadUrlResult: ${it::class.simpleName}") }
.map { ResumableMessagesBackupUploadSpec(attachmentUploadForm = form, resumableUri = it) }
}
}
@@ -1307,7 +1308,7 @@ object BackupRepository {
): NetworkResult<Unit> {
val (form, resumableUploadUrl) = resumableSpec
return SignalNetwork.archive.uploadBackupFile(form, resumableUploadUrl, backupStream, backupStreamLength, progressListener)
.also { Log.i(TAG, "UploadBackupFileResult: $it") }
.also { Log.i(TAG, "UploadBackupFileResult: ${it::class.simpleName}") }
}
fun downloadBackupFile(destination: File, listener: ProgressListener? = null): NetworkResult<Unit> {
@@ -1395,7 +1396,7 @@ object BackupRepository {
.map { response ->
SignalDatabase.attachments.setArchiveCdn(attachmentId = attachment.attachmentId, archiveCdn = response.cdn)
}
.also { Log.i(TAG, "archiveMediaResult: $it") }
.also { Log.i(TAG, "archiveMediaResult: ${it::class.simpleName}") }
}
fun deleteAbandonedMediaObjects(mediaObjects: Collection<ArchivedMediaObject>): NetworkResult<Unit> {
@@ -1421,7 +1422,7 @@ object BackupRepository {
mediaToDelete = mediaToDelete
)
}
.also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: $it") }
.also { Log.i(TAG, "deleteAbandonedMediaObjectsResult: ${it::class.simpleName}") }
}
fun deleteBackup(): NetworkResult<Unit> {
@@ -1477,7 +1478,7 @@ object BackupRepository {
.map {
SignalDatabase.attachments.clearAllArchiveData()
}
.also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: $it") }
.also { Log.i(TAG, "debugDeleteAllArchivedMediaResult: ${it::class.simpleName}") }
}
/**
@@ -1512,7 +1513,7 @@ object BackupRepository {
credentialStore.cdnReadCredentials = it.result
}
}
.also { Log.i(TAG, "getCdnReadCredentialsResult: $it") }
.also { Log.i(TAG, "getCdnReadCredentialsResult: ${it::class.simpleName}") }
}
fun restoreBackupTier(aci: ACI): MessageBackupTier? {
@@ -1965,8 +1966,8 @@ class ArchiveMediaItemIterator(private val cursor: Cursor) : Iterator<ArchiveMed
override fun hasNext(): Boolean = !cursor.isAfterLast
override fun next(): ArchiveMediaItem {
val plaintextHash = cursor.requireNonNullBlob(AttachmentTable.DATA_HASH_END)
val remoteKey = cursor.requireNonNullBlob(AttachmentTable.REMOTE_KEY)
val plaintextHash = cursor.requireNonNullString(AttachmentTable.DATA_HASH_END).decodeBase64OrThrow()
val remoteKey = cursor.requireNonNullString(AttachmentTable.REMOTE_KEY).decodeBase64OrThrow()
val cdn = cursor.requireIntOrNull(AttachmentTable.ARCHIVE_CDN)
val mediaId = MediaName.fromPlaintextHashAndRemoteKey(plaintextHash, remoteKey).toMediaId(SignalStore.backup.mediaRootBackupKey).encode()

View File

@@ -25,7 +25,6 @@ import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.RadioButton
import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Surface
@@ -77,6 +76,8 @@ import org.thoughtcrime.securesms.backup.v2.ui.BackupAlertBottomSheet
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.DialogState
import org.thoughtcrime.securesms.components.settings.app.internal.backup.InternalBackupPlaygroundViewModel.ScreenState
import org.thoughtcrime.securesms.compose.ComposeFragment
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobs.ArchiveAttachmentReconciliationJob
import org.thoughtcrime.securesms.jobs.LocalBackupJob
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.util.Util
@@ -149,9 +150,9 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
mainContent = {
Screen(
state = state,
onBackupTierSelected = { tier -> viewModel.onBackupTierSelected(tier) },
onCheckRemoteBackupStateClicked = { viewModel.checkRemoteBackupState() },
onEnqueueRemoteBackupClicked = { viewModel.triggerBackupJob() },
onEnqueueReconciliationClicked = { AppDependencies.jobManager.add(ArchiveAttachmentReconciliationJob(forced = true)) },
onHaltAllBackupJobsClicked = { viewModel.haltAllJobs() },
onValidateBackupClicked = { viewModel.validateBackup() },
onSaveEncryptedBackupToDiskClicked = {
@@ -222,7 +223,7 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
onDeleteRemoteBackup = {
MaterialAlertDialogBuilder(context)
.setTitle("Are you sure?")
.setMessage("This will delete all of your remote backup data?")
.setMessage("This will delete all of your remote backup data!")
.setPositiveButton("Delete remote data") { _, _ ->
lifecycleScope.launch {
val success = viewModel.deleteRemoteBackupData()
@@ -234,6 +235,21 @@ class InternalBackupPlaygroundFragment : ComposeFragment() {
.setNegativeButton("Cancel", null)
.show()
},
onClearLocalMediaBackupState = {
MaterialAlertDialogBuilder(context)
.setTitle("Are you sure?")
.setMessage("This will cause you to have to re-upload all of your media!")
.setPositiveButton("Clear local media state") { _, _ ->
lifecycleScope.launch {
viewModel.clearLocalMediaBackupState()
withContext(Dispatchers.Main) {
Toast.makeText(requireContext(), "Done!", Toast.LENGTH_SHORT).show()
}
}
}
.setNegativeButton("Cancel", null)
.show()
},
onDisplayInitialBackupFailureSheet = {
BackupRepository.displayInitialBackupFailureNotification()
BackupAlertBottomSheet
@@ -312,8 +328,8 @@ fun Screen(
onImportNewStyleLocalBackupClicked: () -> Unit = {},
onCheckRemoteBackupStateClicked: () -> Unit = {},
onEnqueueRemoteBackupClicked: () -> Unit = {},
onEnqueueReconciliationClicked: () -> Unit = {},
onWipeDataAndRestoreFromRemoteClicked: () -> Unit = {},
onBackupTierSelected: (MessageBackupTier?) -> Unit = {},
onHaltAllBackupJobsClicked: () -> Unit = {},
onSavePlaintextCopyOfRemoteBackupClicked: () -> Unit = {},
onValidateBackupClicked: () -> Unit = {},
@@ -322,6 +338,7 @@ fun Screen(
onImportEncryptedBackupFromDiskClicked: () -> Unit = {},
onImportEncryptedBackupFromDiskDismissed: () -> Unit = {},
onImportEncryptedBackupFromDiskConfirmed: (aci: String, backupKey: String) -> Unit = { _, _ -> },
onClearLocalMediaBackupState: () -> Unit = {},
onDeleteRemoteBackup: () -> Unit = {},
onDisplayInitialBackupFailureSheet: () -> Unit = {}
) {
@@ -353,21 +370,6 @@ fun Screen(
.fillMaxSize()
.verticalScroll(scrollState)
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Text("Tier", fontWeight = FontWeight.Bold)
options.forEach { option ->
Row(verticalAlignment = Alignment.CenterVertically) {
RadioButton(
selected = option.value == state.backupTier,
onClick = { onBackupTierSelected(option.value) }
)
Text(option.key)
}
}
}
Dividers.Default()
Rows.TextRow(
text = {
Text(
@@ -392,6 +394,12 @@ fun Screen(
onClick = onEnqueueRemoteBackupClicked
)
Rows.TextRow(
text = "Enqueue reconciliation job",
label = "Schedules a job that will ensure local and remote media state are in sync.",
onClick = onEnqueueReconciliationClicked
)
Rows.TextRow(
text = "Halt all backup jobs",
label = "Stops all backup-related jobs to the best of our ability.",
@@ -513,6 +521,12 @@ fun Screen(
onClick = onDeleteRemoteBackup
)
Rows.TextRow(
text = "Clear local media backup state",
label = "Resets local state tracking so you think you haven't uploaded any media. The media still exists on the server.",
onClick = onClearLocalMediaBackupState
)
Dividers.Default()
Rows.TextRow(

View File

@@ -292,11 +292,6 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
}
}
fun onBackupTierSelected(backupTier: MessageBackupTier?) {
SignalStore.backup.backupTier = backupTier
_state.value = _state.value.copy(backupTier = backupTier)
}
fun onImportSelected() {
_state.value = _state.value.copy(dialog = DialogState.ImportCredentials)
}
@@ -398,6 +393,10 @@ class InternalBackupPlaygroundViewModel : ViewModel() {
return@withContext false
}
suspend fun clearLocalMediaBackupState() = withContext(Dispatchers.IO) {
SignalDatabase.attachments.clearAllArchiveData()
}
override fun onCleared() {
disposables.clear()
}

View File

@@ -55,7 +55,7 @@ fun InternalBackupStatsTab(stats: InternalBackupPlaygroundViewModel.StatsState,
Spacer(modifier = Modifier.size(16.dp))
Text(text = "Unique/archived data files: ${stats.attachmentStats.attachmentFileCount}/${stats.attachmentStats.finishedAttachmentFileCount}")
Text(text = "Unique/archived verified digest count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")
Text(text = "Unique/archived verified plaintextHash count: ${stats.attachmentStats.attachmentPlaintextHashAndKeyCount}/${stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")
Text(text = "Unique/expected thumbnail files: ${stats.attachmentStats.thumbnailFileCount}/${stats.attachmentStats.estimatedThumbnailCount}")
Text(text = "Local Total: ${stats.attachmentStats.attachmentFileCount + stats.attachmentStats.thumbnailFileCount}")
Text(text = "Expected remote total: ${stats.attachmentStats.estimatedThumbnailCount + stats.attachmentStats.finishedAttachmentPlaintextHashAndKeyCount}")

View File

@@ -679,15 +679,15 @@ class AttachmentTable(
/**
* Sets the archive transfer state for the given attachment by digest.
*/
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray) {
writableDatabase
fun resetArchiveTransferStateByPlaintextHashAndRemoteKey(plaintextHash: ByteArray, remoteKey: ByteArray): Boolean {
return writableDatabase
.update(TABLE_NAME)
.values(
ARCHIVE_TRANSFER_STATE to ArchiveTransferState.NONE.value,
ARCHIVE_CDN to null
)
.where("$DATA_HASH_END = ? AND $REMOTE_KEY = ?", Base64.encodeWithPadding(plaintextHash), Base64.encodeWithPadding(remoteKey))
.run()
.run() > 0
}
/**
@@ -2593,11 +2593,11 @@ class AttachmentTable(
.associate { it to readableDatabase.count().from(TABLE_NAME).where("$ARCHIVE_TRANSFER_STATE = ${it.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").run().readToSingleLong(-1L) }
val attachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL").readToSingleLong(-1L)
val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val finishedAttachmentFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $DATA_FILE) FROM $TABLE_NAME WHERE $DATA_FILE NOT NULL AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value}").readToSingleLong(-1L)
val attachmentPlaintextHashAndKeyCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $TRANSFER_STATE in ($TRANSFER_PROGRESS_DONE, $TRANSFER_RESTORE_OFFLOADED, $TRANSFER_RESTORE_IN_PROGRESS, $TRANSFER_NEEDS_RESTORE))").readToSingleLong(-1L)
val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L)
val finishedAttachmentDigestCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value})").readToSingleLong(-1L)
val thumbnailFileCount = readableDatabase.query("SELECT COUNT(DISTINCT $THUMBNAIL_FILE) FROM $TABLE_NAME WHERE $THUMBNAIL_FILE IS NOT NULL").readToSingleLong(-1L)
val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY) FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L)
val estimatedThumbnailCount = readableDatabase.query("SELECT COUNT(*) FROM (SELECT DISTINCT $DATA_HASH_END, $REMOTE_KEY FROM $TABLE_NAME WHERE $ARCHIVE_TRANSFER_STATE = ${ArchiveTransferState.FINISHED.value} AND $DATA_HASH_END NOT NULL AND $REMOTE_KEY NOT NULL AND ($CONTENT_TYPE LIKE 'image/%' OR $CONTENT_TYPE LIKE 'video/%'))").readToSingleLong(-1L)
val pendingUploadBytes = getPendingArchiveUploadBytes()
val uploadedAttachmentBytes = readableDatabase

View File

@@ -127,7 +127,13 @@ class ArchiveAttachmentReconciliationJob private constructor(
val entry = BackupMediaSnapshotTable.MediaEntry.fromCursor(it)
// TODO [backup] Re-enqueue thumbnail uploads if necessary
if (!entry.isThumbnail) {
SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey)
val success = SignalDatabase.attachments.resetArchiveTransferStateByPlaintextHashAndRemoteKey(entry.plaintextHash, entry.remoteKey)
if (!success) {
Log.e(TAG, "Failed to reset archive transfer state by remote hash/key!")
if (RemoteConfig.internalUser) {
throw RuntimeException("Failed to reset archive transfer state by remote hash/key!")
}
}
}
}

View File

@@ -85,8 +85,8 @@ class ArchiveThumbnailUploadJob private constructor(
return Result.success()
}
if (attachment.remoteDigest == null) {
Log.w(TAG, "$attachmentId was never uploaded! Cannot proceed.")
if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "$attachmentId has no integrity check! Cannot proceed.")
return Result.success()
}
@@ -153,7 +153,7 @@ class ArchiveThumbnailUploadJob private constructor(
data = thumbnailResult.data
)
Log.d(TAG, "Successfully archived thumbnail for $attachmentId mediaName=${attachment.requireThumbnailMediaName()}")
Log.d(TAG, "Successfully archived thumbnail for $attachmentId")
Result.success()
}

View File

@@ -35,6 +35,7 @@ import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.AttachmentUtil
import org.thoughtcrime.securesms.util.RemoteConfig
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
@@ -277,12 +278,18 @@ class AttachmentDownloadJob private constructor(
}
}
if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "Attachment has no integrity check!")
throw InvalidAttachmentException("Attachment has no integrity check!")
}
val decryptingStream = AppDependencies
.signalServiceMessageReceiver
.retrieveAttachment(
pointer,
attachmentFile,
maxReceiveSize,
IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash),
progressListener
)

View File

@@ -16,6 +16,8 @@ import org.thoughtcrime.securesms.jobmanager.impl.NetworkConstraint;
import org.thoughtcrime.securesms.profiles.AvatarHelper;
import org.signal.core.util.Hex;
import org.whispersystems.signalservice.api.SignalServiceMessageReceiver;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId;
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException;
@@ -84,9 +86,16 @@ public final class AvatarGroupsV1DownloadJob extends BaseJob {
attachment = File.createTempFile("avatar", "tmp", context.getCacheDir());
attachment.deleteOnExit();
SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null);
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE);
SignalServiceMessageReceiver receiver = AppDependencies.getSignalServiceMessageReceiver();
SignalServiceAttachmentPointer pointer = new SignalServiceAttachmentPointer(0, new SignalServiceAttachmentRemoteId.V2(avatarId), contentType, key, Optional.of(0), Optional.empty(), 0, 0, digest, Optional.empty(), 0, fileName, false, false, false, Optional.empty(), Optional.empty(), System.currentTimeMillis(), null);
if (pointer.getDigest().isEmpty()) {
throw new InvalidMessageException("Missing digest!");
}
IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigest(pointer.getDigest().get());
InputStream inputStream = receiver.retrieveAttachment(pointer, attachment, AvatarHelper.AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE, integrityCheck);
AvatarHelper.setAvatar(context, record.get().getRecipientId(), inputStream);
SignalDatabase.groups().onAvatarUpdated(groupId, true);

View File

@@ -12,6 +12,7 @@ import org.thoughtcrime.securesms.net.NotPushRegisteredException
import org.thoughtcrime.securesms.profiles.AvatarHelper
import org.thoughtcrime.securesms.providers.BlobProvider
import org.thoughtcrime.securesms.recipients.Recipient
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer
import org.whispersystems.signalservice.api.messages.multidevice.DeviceContact
import org.whispersystems.signalservice.api.messages.multidevice.DeviceContactsInputStream
@@ -59,7 +60,7 @@ class MultiDeviceContactSyncJob(parameters: Parameters, private val attachmentPo
try {
val contactsFile: File = BlobProvider.getInstance().forNonAutoEncryptingSingleSessionOnDisk(context)
AppDependencies.signalServiceMessageReceiver
.retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE)
.retrieveAttachment(contactAttachment, contactsFile, MAX_ATTACHMENT_SIZE, IntegrityCheck.forEncryptedDigest(contactAttachment.digest.get()))
.use(this::processContactFile)
} catch (e: MissingConfigurationException) {
throw IOException(e)

View File

@@ -40,6 +40,7 @@ import org.thoughtcrime.securesms.notifications.NotificationChannels
import org.thoughtcrime.securesms.notifications.NotificationIds
import org.thoughtcrime.securesms.transport.RetryLaterException
import org.thoughtcrime.securesms.util.RemoteConfig
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.messages.AttachmentTransferProgress
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment
import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException
@@ -49,6 +50,7 @@ import org.whispersystems.signalservice.api.push.exceptions.RangeException
import java.io.File
import java.io.IOException
import java.util.concurrent.TimeUnit
import kotlin.jvm.optionals.getOrNull
import kotlin.math.max
import kotlin.math.pow
import kotlin.time.Duration.Companion.days
@@ -172,12 +174,12 @@ class RestoreAttachmentJob private constructor(
val attachment = SignalDatabase.attachments.getAttachment(attachmentId)
if (attachment == null) {
Log.w(TAG, "attachment no longer exists.")
Log.w(TAG, "[$attachmentId] Attachment no longer exists.")
return
}
if (attachment.isPermanentlyFailed) {
Log.w(TAG, "Attachment was marked as a permanent failure. Refusing to download.")
Log.w(TAG, "[$attachmentId] Attachment was marked as a permanent failure. Refusing to download.")
return
}
@@ -186,7 +188,7 @@ class RestoreAttachmentJob private constructor(
attachment.transferState != AttachmentTable.TRANSFER_PROGRESS_FAILED &&
attachment.transferState != AttachmentTable.TRANSFER_RESTORE_OFFLOADED
) {
Log.w(TAG, "Attachment does not need to be restored. Current state: ${attachment.transferState}")
Log.w(TAG, "[$attachmentId] Attachment does not need to be restored. Current state: ${attachment.transferState}")
return
}
@@ -231,14 +233,20 @@ class RestoreAttachmentJob private constructor(
var archiveFile: File? = null
var useArchiveCdn = false
if (attachment.remoteDigest == null && attachment.dataHash == null) {
Log.w(TAG, "[$attachmentId] Attachment has no integrity check! Cannot proceed.")
markPermanentlyFailed(attachmentId)
return
}
try {
if (attachment.size > maxReceiveSize) {
throw MmsException("Attachment too large, failing download")
throw MmsException("[$attachmentId] Attachment too large, failing download")
}
useArchiveCdn = if (SignalStore.backup.backsUpMedia && !forceTransitTier) {
if (attachment.archiveTransferState != AttachmentTable.ArchiveTransferState.FINISHED) {
throw InvalidAttachmentException("Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}")
throw InvalidAttachmentException("[$attachmentId] Invalid attachment configuration! backsUpMedia: ${SignalStore.backup.backsUpMedia}, forceTransitTier: $forceTransitTier, archiveTransferState: ${attachment.archiveTransferState}")
}
true
} else {
@@ -259,7 +267,9 @@ class RestoreAttachmentJob private constructor(
}
val decryptingStream = if (useArchiveCdn) {
archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
// TODO next PR: remove archive transfer file and just use the regular attachment file
archiveFile = attachmentFile
// archiveFile = SignalDatabase.attachments.getOrCreateArchiveTransferFile(attachmentId)
val cdnCredentials = BackupRepository.getCdnReadCredentials(BackupRepository.CredentialType.MEDIA, attachment.archiveCdn ?: RemoteConfig.backupFallbackArchiveCdn).successOrThrow().headers
messageReceiver
@@ -269,7 +279,6 @@ class RestoreAttachmentJob private constructor(
cdnCredentials,
archiveFile,
pointer,
attachmentFile,
maxReceiveSize,
progressListener
)
@@ -279,6 +288,7 @@ class RestoreAttachmentJob private constructor(
pointer,
attachmentFile,
maxReceiveSize,
IntegrityCheck.forEncryptedDigestAndPlaintextHash(pointer.digest.getOrNull(), attachment.dataHash),
progressListener
)
}
@@ -286,7 +296,7 @@ class RestoreAttachmentJob private constructor(
SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, decryptingStream, if (manual) System.currentTimeMillis().milliseconds else null)
} catch (e: RangeException) {
val transferFile = archiveFile ?: attachmentFile
Log.w(TAG, "Range exception, file size " + transferFile.length(), e)
Log.w(TAG, "[$attachmentId] Range exception, file size " + transferFile.length(), e)
if (transferFile.delete()) {
Log.i(TAG, "Deleted temp download file to recover")
throw RetryLaterException(e)
@@ -299,7 +309,7 @@ class RestoreAttachmentJob private constructor(
} catch (e: NonSuccessfulResponseCodeException) {
if (SignalStore.backup.backsUpMedia) {
if (e.code == 404 && !forceTransitTier && attachment.remoteLocation?.isNotBlank() == true) {
Log.i(TAG, "Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.")
Log.i(TAG, "[$attachmentId] Failed to download attachment from archive! Should only happen for recent attachments in a narrow window. Retrying download from transit CDN.")
if (RemoteConfig.internalUser) {
postFailedToDownloadFromArchiveNotification()
}
@@ -316,18 +326,18 @@ class RestoreAttachmentJob private constructor(
}
}
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId)
} catch (e: MmsException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId)
} catch (e: MissingConfigurationException) {
Log.w(TAG, "Experienced exception while trying to download an attachment.", e)
Log.w(TAG, "[$attachmentId] Experienced exception while trying to download an attachment.", e)
markFailed(attachmentId)
} catch (e: InvalidMessageException) {
Log.w(TAG, "Experienced an InvalidMessageException while trying to download an attachment.", e)
Log.w(TAG, "[$attachmentId] Experienced an InvalidMessageException while trying to download an attachment.", e)
if (e.cause is InvalidMacException) {
Log.w(TAG, "Detected an invalid mac. Treating as a permanent failure.")
Log.w(TAG, "[$attachmentId] Detected an invalid mac. Treating as a permanent failure.")
markPermanentlyFailed(attachmentId)
} else {
markFailed(attachmentId)

View File

@@ -120,7 +120,6 @@ class RestoreAttachmentThumbnailJob private constructor(
val maxThumbnailSize: Long = RemoteConfig.maxAttachmentReceiveSizeBytes
val thumbnailTransferFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile()
val thumbnailFile: File = SignalDatabase.attachments.createArchiveThumbnailTransferFile()
val progressListener = object : SignalServiceAttachment.ProgressListener {
override fun onAttachmentProgress(progress: AttachmentTransferProgress) = Unit
@@ -137,7 +136,6 @@ class RestoreAttachmentThumbnailJob private constructor(
cdnCredentials,
thumbnailTransferFile,
pointer,
thumbnailFile,
maxThumbnailSize,
progressListener
)

View File

@@ -23,6 +23,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.mms.MmsException
import org.whispersystems.signalservice.api.backup.MediaName
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.StreamSupplier
import java.io.IOException
@@ -154,7 +155,10 @@ class RestoreLocalAttachmentJob private constructor(
streamLength = size,
plaintextLength = attachment.size,
combinedKeyMaterial = combinedKey,
digest = attachment.remoteDigest,
integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash(
encryptedDigest = attachment.remoteDigest,
plaintextHash = attachment.dataHash
),
incrementalDigest = null,
incrementalMacChunkSize = 0
).use { input ->

View File

@@ -23,6 +23,7 @@ import org.signal.core.util.Base64;
import org.whispersystems.signalservice.api.backup.MediaName;
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.signal.core.util.stream.TailerInputStream;
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream;
@@ -100,11 +101,13 @@ class PartDataSource implements DataSource {
long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength);
if (attachment.remoteDigest == null) {
if (attachment.remoteDigest == null && attachment.dataHash == null) {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
IntegrityCheck integrityCheck = IntegrityCheck.forEncryptedDigestAndPlaintextHash(attachment.remoteDigest, attachment.dataHash);
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, integrityCheck, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize);
} catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e);
}

View File

@@ -7,14 +7,16 @@ import io.mockk.mockk
import io.mockk.mockkObject
import org.junit.Before
import org.junit.Test
import org.signal.core.util.Base64
import org.thoughtcrime.securesms.MockCursor
import org.thoughtcrime.securesms.keyvalue.BackupValues
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.util.Util
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey
class ArchivedMediaObjectIteratorTest {
private val cursor = mockk<MockCursor>(relaxed = true) {
every { getString(any()) } returns "A"
every { getString(any()) } returns Base64.encodeWithPadding(Util.getSecretBytes(32))
every { moveToPosition(any()) } answers { callOriginal() }
every { moveToNext() } answers { callOriginal() }
every { position } answers { callOriginal() }

View File

@@ -7,10 +7,12 @@
package org.whispersystems.signalservice.api;
import org.signal.core.util.StreamUtil;
import org.signal.core.util.logging.Log;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck;
import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil;
import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream;
import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener;
@@ -63,9 +65,9 @@ public class SignalServiceMessageReceiver {
* @throws IOException
* @throws InvalidMessageException
*/
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes)
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck)
throws IOException, InvalidMessageException, MissingConfigurationException {
return retrieveAttachment(pointer, destination, maxSizeBytes, null);
return retrieveAttachment(pointer, destination, maxSizeBytes, integrityCheck, null);
}
public InputStream retrieveProfileAvatar(String path, File destination, ProfileKey profileKey, long maxSizeBytes)
@@ -96,9 +98,9 @@ public class SignalServiceMessageReceiver {
* @throws IOException
* @throws InvalidMessageException
*/
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener)
public InputStream retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, IntegrityCheck integrityCheck, ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException {
if (pointer.getDigest().isEmpty()) throw new InvalidMessageException("No attachment digest!");
if (integrityCheck == null) throw new InvalidMessageException("No integrity check!");
if (pointer.getKey() == null) throw new InvalidMessageException("No key!");
socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener);
@@ -112,7 +114,7 @@ public class SignalServiceMessageReceiver {
destination,
pointer.getSize().orElse(0),
pointer.getKey(),
pointer.getDigest().get(),
integrityCheck,
null,
0
);
@@ -126,7 +128,6 @@ public class SignalServiceMessageReceiver {
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress.
*
* @return An InputStream that streams the plaintext attachment contents.
@@ -136,7 +137,6 @@ public class SignalServiceMessageReceiver {
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
@@ -154,7 +154,7 @@ public class SignalServiceMessageReceiver {
return AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial,
attachmentDestination,
archiveDestination,
originalCipherLength,
pointer.getSize().orElse(0),
pointer.getKey(),
@@ -171,7 +171,6 @@ public class SignalServiceMessageReceiver {
* @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download
* @param archiveDestination The download destination for archived attachment. If this file exists, download will resume.
* @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}.
* @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed.
* @param listener An optional listener (may be null) to receive callbacks on download progress.
*
* @return An InputStream that streams the plaintext attachment contents.
@@ -180,7 +179,6 @@ public class SignalServiceMessageReceiver {
@Nonnull Map<String, String> readCredentialHeaders,
@Nonnull File archiveDestination,
@Nonnull SignalServiceAttachmentPointer pointer,
@Nonnull File attachmentDestination,
long maxSizeBytes,
@Nullable ProgressListener listener)
throws IOException, InvalidMessageException, MissingConfigurationException
@@ -198,7 +196,7 @@ public class SignalServiceMessageReceiver {
return AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial,
attachmentDestination,
archiveDestination,
originalCipherLength,
pointer.getSize().orElse(0),
pointer.getKey()

View File

@@ -16,7 +16,7 @@ class ArchiveGetMediaItemsResponse(
@JsonProperty val mediaDir: String?,
@JsonProperty val cursor: String?
) {
class StoredMediaObject(
data class StoredMediaObject(
@JsonProperty val cdn: Int,
@JsonProperty val mediaId: String,
@JsonProperty val objectLength: Long

View File

@@ -13,7 +13,7 @@ import com.fasterxml.jackson.annotation.JsonProperty
class DeleteArchivedMediaRequest(
@JsonProperty val mediaToDelete: List<ArchivedMediaObject>
) {
class ArchivedMediaObject(
data class ArchivedMediaObject(
@JsonProperty val cdn: Int,
@JsonProperty val mediaId: String
)

View File

@@ -31,10 +31,6 @@ value class MediaName(val name: String) {
return mediaRootBackupKey.deriveMediaId(this)
}
fun toByteArray(): ByteArray {
return name.toByteArray()
}
override fun toString(): String {
return name
}

View File

@@ -5,6 +5,7 @@
*/
package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.Base64
import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.stream.LimitedInputStream
import org.signal.libsignal.protocol.InvalidMessageException
@@ -51,7 +52,7 @@ object AttachmentCipherInputStream {
file: File,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
integrityCheck: IntegrityCheck,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): InputStream {
@@ -60,11 +61,9 @@ object AttachmentCipherInputStream {
streamLength = file.length(),
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = digest,
plaintextHash = null,
integrityCheck = integrityCheck,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
incrementalMacChunkSize = incrementalMacChunkSize
)
}
@@ -81,7 +80,7 @@ object AttachmentCipherInputStream {
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
digest: ByteArray,
integrityCheck: IntegrityCheck,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): InputStream {
@@ -90,11 +89,9 @@ object AttachmentCipherInputStream {
streamLength = streamLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = digest,
plaintextHash = null,
integrityCheck = integrityCheck,
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = false
incrementalMacChunkSize = incrementalMacChunkSize
)
}
@@ -130,11 +127,9 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = null,
plaintextHash = plaintextHash,
integrityCheck = IntegrityCheck(plaintextHash = plaintextHash, encryptedDigest = null),
incrementalDigest = incrementalDigest,
incrementalMacChunkSize = incrementalMacChunkSize,
ignoreDigest = true
incrementalMacChunkSize = incrementalMacChunkSize
)
}
@@ -159,7 +154,7 @@ object AttachmentCipherInputStream {
val mac = initMac(keyMaterial.macKey)
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength")
}
return create(
@@ -167,11 +162,9 @@ object AttachmentCipherInputStream {
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
encryptedDigest = null,
plaintextHash = null,
integrityCheck = null,
incrementalDigest = null,
incrementalMacChunkSize = 0,
ignoreDigest = true
incrementalMacChunkSize = 0
)
}
@@ -189,7 +182,7 @@ object AttachmentCipherInputStream {
}
ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null)
verifyMacAndMaybeEncryptedDigest(inputStream, data.size.toLong(), mac, null)
}
val encryptedStream = ByteArrayInputStream(data)
@@ -211,11 +204,11 @@ object AttachmentCipherInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got ${file.length()}")
}
FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null)
verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), mac, null)
}
val encryptedStream = FileInputStream(file)
@@ -226,6 +219,10 @@ object AttachmentCipherInputStream {
return LimitedInputStream(inputStream, originalCipherTextLength)
}
/**
* @param integrityCheck If null, no integrity check is performed! This is a private method, so it's assumed that care has been taken to ensure that this is
* the correct course of action. Public methods should properly enforce when integrity checks are required.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
private fun create(
@@ -233,11 +230,9 @@ object AttachmentCipherInputStream {
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray,
encryptedDigest: ByteArray?,
plaintextHash: ByteArray?,
integrityCheck: IntegrityCheck?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
incrementalMacChunkSize: Int
): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
@@ -246,25 +241,16 @@ object AttachmentCipherInputStream {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
}
if (!ignoreDigest && encryptedDigest == null) {
throw InvalidMessageException("Missing digest!")
}
val wrappedStream: InputStream
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
if (!hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream ->
verifyMac(macVerificationStream, streamLength, mac, encryptedDigest)
}
wrappedStream = streamSupplier.openStream()
} else {
if (encryptedDigest == null && plaintextHash == null) {
throw InvalidMessageException("Missing data (digest or plaintextHas) for incremental mac validation!")
if (hasIncrementalMac) {
if (integrityCheck == null) {
throw InvalidMessageException("Missing integrityCheck for incremental mac validation!")
}
val digestValidatingStream = if (encryptedDigest != null) {
DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), encryptedDigest)
val digestValidatingStream = if (integrityCheck.encryptedDigest != null) {
DigestValidatingInputStream(streamSupplier.openStream(), sha256Digest(), integrityCheck.encryptedDigest)
} else {
streamSupplier.openStream()
}
@@ -279,6 +265,11 @@ object AttachmentCipherInputStream {
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
} else {
streamSupplier.openStream().use { macVerificationStream ->
verifyMacAndMaybeEncryptedDigest(macVerificationStream, streamLength, mac, integrityCheck?.encryptedDigest)
}
wrappedStream = streamSupplier.openStream()
}
val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength)
@@ -286,12 +277,12 @@ object AttachmentCipherInputStream {
val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher)
val paddinglessDecryptingStream = LimitedInputStream(decryptingStream, plaintextLength)
return if (plaintextHash != null) {
if (plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) {
throw InvalidMessageException("Invalid plaintext hash size: ${plaintextHash.size}")
return if (integrityCheck?.plaintextHash != null) {
if (integrityCheck.plaintextHash.size != MessageDigest.getInstance("SHA-256").digestLength) {
throw InvalidMessageException("Invalid plaintext hash size: ${integrityCheck.plaintextHash.size}")
}
DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), plaintextHash)
DigestValidatingInputStream(paddinglessDecryptingStream, sha256Digest(), integrityCheck.plaintextHash)
} else {
paddinglessDecryptingStream
}
@@ -326,7 +317,7 @@ object AttachmentCipherInputStream {
}
@Throws(InvalidMessageException::class)
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
private fun verifyMacAndMaybeEncryptedDigest(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
try {
val digest = MessageDigest.getInstance("SHA256")
var remainingData = Util.toIntExact(length) - mac.macLength
@@ -375,4 +366,33 @@ object AttachmentCipherInputStream {
@Throws(IOException::class)
fun openStream(): InputStream
}
class IntegrityCheck(
val encryptedDigest: ByteArray?,
val plaintextHash: ByteArray?
) {
init {
if (encryptedDigest == null && plaintextHash == null) {
throw IllegalArgumentException("At least one of encryptedDigest or plaintextHash must be provided")
}
}
companion object {
@JvmStatic
fun forEncryptedDigest(encryptedDigest: ByteArray): IntegrityCheck {
return IntegrityCheck(encryptedDigest, null)
}
@JvmStatic
fun forPlaintextHash(plaintextHash: ByteArray): IntegrityCheck {
return IntegrityCheck(null, plaintextHash)
}
@JvmStatic
fun forEncryptedDigestAndPlaintextHash(encryptedDigest: ByteArray?, plaintextHash: String?): IntegrityCheck {
val plaintextHashBytes = plaintextHash?.let { Base64.decode(it) }
return IntegrityCheck(encryptedDigest, plaintextHashBytes)
}
}
}
}

View File

@@ -13,6 +13,7 @@ import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException
import org.signal.libsignal.protocol.kdf.HKDF
import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream.IntegrityCheck
import org.whispersystems.signalservice.api.crypto.AttachmentCipherTestHelper.createMediaKeyMaterial
import org.whispersystems.signalservice.internal.crypto.PaddingInputStream
import org.whispersystems.signalservice.internal.push.http.AttachmentCipherOutputStreamFactory
@@ -32,19 +33,39 @@ import java.util.Random
class AttachmentCipherTest {
@Test
fun attachment_encryptDecrypt_nonIncremental() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
fun attachment_encryptDecrypt_nonIncremental_encryptedDigest() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST)
}
@Test
fun attachment_encryptDecrypt_incremental() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
fun attachment_encryptDecrypt_nonIncremental_plaintextHash() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH)
}
@Test
fun attachment_encryptDecrypt_nonIncremental_bothIntegrityChecks() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH)
}
@Test
fun attachment_encryptDecrypt_incremental_encryptedDigest() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.ENCRYPTED_DIGEST)
}
@Test
fun attachment_encryptDecrypt_incremental_plaintextHash() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.PLAINTEXT_HASH)
}
@Test
fun attachment_encryptDecrypt_incremental_bothIntegrityChecks() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE, integrityCheckMode = IntegrityCheckMode.BOTH)
}
@Test
fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), integrityCheckMode = IntegrityCheckMode.BOTH)
}
}
@@ -52,18 +73,44 @@ class AttachmentCipherTest {
fun attachment_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024), IntegrityCheckMode.BOTH)
}
}
private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int) {
private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int, integrityCheckMode: IntegrityCheckMode) {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput)
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val integrityCheck = when (integrityCheckMode) {
IntegrityCheckMode.ENCRYPTED_DIGEST -> IntegrityCheck.forEncryptedDigest(encryptResult.digest)
IntegrityCheckMode.PLAINTEXT_HASH -> IntegrityCheck.forPlaintextHash(plaintextHash)
IntegrityCheckMode.BOTH -> IntegrityCheck(
encryptedDigest = encryptResult.digest,
plaintextHash = plaintextHash
)
}
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
cipherFile.delete()
}
private fun attachment_encryptDecrypt_plaintextHash(incremental: Boolean, fileSize: Int) {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val plaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput)
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -88,7 +135,8 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
@@ -117,7 +165,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext)
val badKey = ByteArray(64)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, encryptResult.digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), badKey, integrityCheck, null, 0)
} finally {
cipherFile?.delete()
}
@@ -147,7 +196,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(badMacCiphertext)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
// In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) {
@@ -159,16 +209,26 @@ class AttachmentCipherTest {
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_nonIncremental() {
attachment_decryptFailOnBadDigest(incremental = false)
fun attachment_decryptFailOnBadEncryptedDigest_nonIncremental() {
attachment_decryptFailOnBadEncryptedDigest(incremental = false)
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_incremental() {
attachment_decryptFailOnBadDigest(incremental = true)
fun attachment_decryptFailOnBadEncryptedDigest_incremental() {
attachment_decryptFailOnBadEncryptedDigest(incremental = true)
}
private fun attachment_decryptFailOnBadDigest(incremental: Boolean) {
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadPlaintextHash_nonIncremental() {
attachment_decryptFailOnBadPlaintextHash(incremental = false)
}
@Test(expected = InvalidMessageException::class)
fun attachment_decryptFailOnBadPlaintextHash_incremental() {
attachment_decryptFailOnBadPlaintextHash(incremental = true)
}
private fun attachment_decryptFailOnBadEncryptedDigest(incremental: Boolean) {
var cipherFile: File? = null
try {
@@ -180,7 +240,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, badDigest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val integrityCheck = IntegrityCheck.forEncryptedDigest(badDigest)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
// In incremental mode, we'll only check the digest after reading the whole thing
if (incremental) {
@@ -191,6 +252,29 @@ class AttachmentCipherTest {
}
}
private fun attachment_decryptFailOnBadPlaintextHash(incremental: Boolean) {
var cipherFile: File? = null
try {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val badPlaintextHash = MessageDigest.getInstance("SHA-256").digest(plaintextInput).apply {
this[0] = (this[0] + 1).toByte()
}
val encryptResult = encryptData(plaintextInput, key, incremental)
cipherFile = writeToFile(encryptResult.ciphertext)
val integrityCheck = IntegrityCheck.forPlaintextHash(badPlaintextHash)
val stream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
StreamUtil.readFully(stream)
} finally {
cipherFile?.delete()
}
}
@Test
fun attachment_decryptFailOnBadIncrementalDigest() {
var cipherFile: File? = null
@@ -205,7 +289,8 @@ class AttachmentCipherTest {
cipherFile = writeToFile(encryptResult.ciphertext)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, badDigest, encryptResult.chunkSizeChoice)
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val decryptedStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, badDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = readInputStreamFully(decryptedStream)
fail(AssertionError("Expected to fail before hitting this line"))
@@ -480,7 +565,8 @@ class AttachmentCipherTest {
val combinedData = plaintextInput1 + plaintextInput2
val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(combinedData)
@@ -511,7 +597,8 @@ class AttachmentCipherTest {
val combinedData = plaintextInput1 + plaintextInput2
val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, combinedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(combinedData)
@@ -536,7 +623,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -567,7 +655,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, expectedData.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(expectedData)
@@ -596,7 +685,8 @@ class AttachmentCipherTest {
val digest = encryptingOutputStream.transmittedDigest
val cipherFile = writeToFile(encryptedData)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, digest, null, 0)
val integrityCheck = IntegrityCheck.forEncryptedDigest(digest)
val decryptedStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, null, 0)
val plaintextOutput = readInputStreamFully(decryptedStream)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -677,4 +767,10 @@ class AttachmentCipherTest {
return HKDF.deriveSecrets(shortKey, "Sticker Pack".toByteArray(), 64)
}
}
enum class IntegrityCheckMode {
ENCRYPTED_DIGEST,
PLAINTEXT_HASH,
BOTH
}
}