diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt index fbf8006e23..2756e0964f 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt @@ -199,7 +199,7 @@ class AttachmentTableTest { val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first() // Give data to attachment table - val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false) + val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4) SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) // Verify the digest has been updated to the properly padded one @@ -230,7 +230,7 @@ class AttachmentTableTest { val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first() // Give data to attachment table - val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false) + val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4) SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) // Verify the digest hasn't changed diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt index 8b868e0251..dfb312dc1d 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentJob.kt @@ -252,7 +252,6 @@ class RestoreAttachmentJob private constructor( pointer, attachmentFile, maxReceiveSize, - false, progressListener ) } else { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt index 935e0efdbf..d4a4d953d6 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreAttachmentThumbnailJob.kt @@ -132,14 +132,13 @@ class RestoreAttachmentThumbnailJob private constructor( Log.i(TAG, "Downloading thumbnail for $attachmentId") val downloadResult = AppDependencies.signalServiceMessageReceiver - .retrieveArchivedAttachment( + .retrieveArchivedThumbnail( SignalStore.backup.mediaRootBackupKey.deriveMediaSecrets(attachment.requireThumbnailMediaName()), cdnCredentials, thumbnailTransferFile, pointer, thumbnailFile, maxThumbnailSize, - true, progressListener ) diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt index f26422754a..0d37246ceb 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/RestoreLocalAttachmentJob.kt @@ -149,7 +149,15 @@ class RestoreLocalAttachmentJob private constructor( try { val iv = ByteArray(16) streamSupplier.openStream().use { StreamUtil.readFully(it, iv) } - AttachmentCipherInputStream.createForAttachment(streamSupplier, size, attachment.size, combinedKey, attachment.remoteDigest, null, 0, false).use { input -> + AttachmentCipherInputStream.createForAttachment( + streamSupplier = streamSupplier, + streamLength = size, + plaintextLength = attachment.size, + combinedKeyMaterial = combinedKey, + digest = attachment.remoteDigest, + incrementalDigest = null, + incrementalMacChunkSize = 0 + ).use { input -> SignalDatabase.attachments.finalizeAttachmentAfterDownload(attachment.mmsId, attachment.attachmentId, input, iv) } } catch (e: InvalidMessageException) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java b/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java index cd0c654324..39719f9e14 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java +++ b/app/src/main/java/org/thoughtcrime/securesms/mms/AttachmentStreamLocalUriFetcher.java @@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.InvalidMessageException; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import java.io.File; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Optional; @@ -37,7 +38,12 @@ class AttachmentStreamLocalUriFetcher implements DataFetcher { public void loadData(@NonNull Priority priority, @NonNull DataCallback callback) { try { if (!digest.isPresent()) throw new InvalidMessageException("No attachment digest!"); - is = AttachmentCipherInputStream.createForAttachment(attachment, plaintextLength, key, digest.get(), null, 0); + is = AttachmentCipherInputStream.createForAttachment(attachment, + plaintextLength, + key, + digest.get(), + null, + 0); callback.onDataReady(is); } catch (IOException | InvalidMessageException e) { callback.onLoadFailed(e); diff --git a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java index 55135c0562..3897e0ec68 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java +++ b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java @@ -71,7 +71,8 @@ class PartDataSource implements DataSource { final boolean hasData = attachment.hasData; if (inProgress && !hasData && hasIncrementalDigest && attachmentKey != null) { - final byte[] decode = Base64.decode(attachmentKey); + final byte[] decodedKey = Base64.decode(attachmentKey); + if (attachment.transferState == AttachmentTable.TRANSFER_RESTORE_IN_PROGRESS && attachment.archiveTransferState == AttachmentTable.ArchiveTransferState.FINISHED) { final File archiveFile = attachmentDatabase.getOrCreateArchiveTransferFile(attachment.attachmentId); try { @@ -81,7 +82,11 @@ class PartDataSource implements DataSource { MediaRootBackupKey.MediaKeyMaterial mediaKeyMaterial = SignalStore.backup().getMediaRootBackupKey().deriveMediaSecretsFromMediaId(mediaId); long originalCipherLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size)); - this.inputStream = AttachmentCipherInputStream.createStreamingForArchivedAttachment(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, attachment.remoteDigest, decode, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); + if (attachment.remoteDigest == null) { + throw new InvalidMessageException("Missing digest!"); + } + + this.inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers(mediaKeyMaterial, archiveFile, originalCipherLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); } catch (InvalidMessageException e) { throw new IOException("Error decrypting attachment stream!", e); } @@ -95,7 +100,7 @@ class PartDataSource implements DataSource { throw new InvalidMessageException("Missing digest!"); } - this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false); + this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decodedKey, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize); } catch (InvalidMessageException e) { throw new IOException("Error decrypting attachment stream!", e); } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 9424f4235b..6dbc82bb4a 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -7,9 +7,6 @@ package org.whispersystems.signalservice.api; import org.signal.core.util.StreamUtil; -import org.signal.core.util.concurrent.FutureTransformers; -import org.signal.core.util.concurrent.ListenableFuture; -import org.signal.core.util.concurrent.SettableFuture; import org.signal.core.util.stream.LimitedInputStream; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.zkgroup.profiles.ProfileKey; @@ -18,24 +15,15 @@ import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream; import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil; import org.whispersystems.signalservice.api.crypto.ProfileCipherInputStream; -import org.whispersystems.signalservice.api.crypto.SealedSenderAccess; import org.whispersystems.signalservice.api.messages.SignalServiceAttachment.ProgressListener; import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer; import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage; import org.whispersystems.signalservice.api.messages.SignalServiceStickerManifest; -import org.whispersystems.signalservice.api.profiles.ProfileAndCredential; -import org.whispersystems.signalservice.api.profiles.SignalServiceProfile; -import org.whispersystems.signalservice.api.push.ServiceId.ACI; -import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.exceptions.MissingConfigurationException; -import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.crypto.PaddingInputStream; -import org.whispersystems.signalservice.internal.push.IdentityCheckRequest; -import org.whispersystems.signalservice.internal.push.IdentityCheckResponse; import org.whispersystems.signalservice.internal.push.PushServiceSocket; import org.whispersystems.signalservice.internal.sticker.Pack; import org.whispersystems.signalservice.internal.util.Util; -import org.whispersystems.signalservice.internal.websocket.ResponseMapper; import java.io.File; import java.io.FileInputStream; @@ -46,15 +34,11 @@ import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Optional; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import io.reactivex.rxjava3.core.Single; - /** * The primary interface for receiving Signal Service messages. * @@ -118,6 +102,7 @@ public class SignalServiceMessageReceiver { public AttachmentDownloadResult retrieveAttachment(SignalServiceAttachmentPointer pointer, File destination, long maxSizeBytes, ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException { if (!pointer.getDigest().isPresent()) throw new InvalidMessageException("No attachment digest!"); + if (pointer.getKey() == null) throw new InvalidMessageException("No key!"); socket.retrieveAttachment(pointer.getCdnNumber(), Collections.emptyMap(), pointer.getRemoteId(), destination, maxSizeBytes, listener); @@ -127,7 +112,14 @@ public class SignalServiceMessageReceiver { } return new AttachmentDownloadResult( - AttachmentCipherInputStream.createForAttachment(destination, pointer.getSize().orElse(0), pointer.getKey(), pointer.getDigest().get(), null, 0), + AttachmentCipherInputStream.createForAttachment( + destination, + pointer.getSize().orElse(0), + pointer.getKey(), + pointer.getDigest().get(), + null, + 0 + ), iv ); } @@ -150,14 +142,17 @@ public class SignalServiceMessageReceiver { @Nonnull SignalServiceAttachmentPointer pointer, @Nonnull File attachmentDestination, long maxSizeBytes, - boolean ignoreDigest, @Nullable ProgressListener listener) throws IOException, InvalidMessageException, MissingConfigurationException { - if (!ignoreDigest && pointer.getDigest().isEmpty()) { + if (pointer.getDigest().isEmpty()) { throw new InvalidMessageException("No attachment digest!"); } + if (pointer.getKey() == null) { + throw new InvalidMessageException("No key!"); + } + socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener); long originalCipherLength = pointer.getSize() @@ -166,7 +161,7 @@ public class SignalServiceMessageReceiver { .orElse(0L); // There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer. - try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMedia(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) { + try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) { try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) { // TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line. StreamUtil.copy(backupDecrypted, fos); @@ -182,10 +177,63 @@ public class SignalServiceMessageReceiver { attachmentDestination, pointer.getSize().orElse(0), pointer.getKey(), - ignoreDigest ? null : pointer.getDigest().get(), + pointer.getDigest().get(), null, - 0, - ignoreDigest + 0 + ); + + return new AttachmentDownloadResult(dataStream, iv); + } + + /** + * Retrieves an archived media attachment. + * + * @param archivedMediaKeyMaterial Decryption key material for decrypting outer layer of archived media. + * @param readCredentialHeaders Headers to pass to the backup CDN to authorize the download + * @param archiveDestination The download destination for archived attachment. If this file exists, download will resume. + * @param pointer The {@link SignalServiceAttachmentPointer} received in a {@link SignalServiceDataMessage}. + * @param attachmentDestination The download destination for this attachment. If this file exists, it is assumed that this is previously-downloaded content that can be resumed. + * @param listener An optional listener (may be null) to receive callbacks on download progress. + * + * @return An InputStream that streams the plaintext attachment contents. + */ + public AttachmentDownloadResult retrieveArchivedThumbnail(@Nonnull MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, + @Nonnull Map readCredentialHeaders, + @Nonnull File archiveDestination, + @Nonnull SignalServiceAttachmentPointer pointer, + @Nonnull File attachmentDestination, + long maxSizeBytes, + @Nullable ProgressListener listener) + throws IOException, InvalidMessageException, MissingConfigurationException + { + if (pointer.getKey() == null) { + throw new InvalidMessageException("No key!"); + } + + socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener); + + long originalCipherLength = pointer.getSize() + .filter(s -> s > 0) + .map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s))) + .orElse(0L); + + // There's two layers of encryption -- one from the backup, and one from the attachment. This only strips the outermost backup encryption layer. + try (InputStream backupDecrypted = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, archiveDestination, originalCipherLength)) { + try (FileOutputStream fos = new FileOutputStream(attachmentDestination)) { + // TODO [backup] I don't think we should be doing the full copy here. This is basically doing the entire download inline in this single line. + StreamUtil.copy(backupDecrypted, fos); + } + } + + byte[] iv = new byte[16]; + try (InputStream tempStream = new FileInputStream(attachmentDestination)) { + StreamUtil.readFully(tempStream, iv); + } + + LimitedInputStream dataStream = AttachmentCipherInputStream.createForArchiveThumbnailInnerLayer( + attachmentDestination, + pointer.getSize().orElse(0), + pointer.getKey() ); return new AttachmentDownloadResult(dataStream, iv); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt index 94ad69b961..22e4417169 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt @@ -5,8 +5,8 @@ */ package org.whispersystems.signalservice.api.crypto +import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.stream.LimitedInputStream -import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream @@ -16,185 +16,316 @@ import org.whispersystems.signalservice.internal.util.Util import java.io.ByteArrayInputStream import java.io.File import java.io.FileInputStream -import java.io.FilterInputStream import java.io.IOException import java.io.InputStream import java.security.InvalidKeyException import java.security.MessageDigest import java.security.NoSuchAlgorithmException import javax.annotation.Nonnull -import javax.crypto.BadPaddingException import javax.crypto.Cipher -import javax.crypto.IllegalBlockSizeException import javax.crypto.Mac -import javax.crypto.ShortBufferException import javax.crypto.spec.IvParameterSpec import javax.crypto.spec.SecretKeySpec import kotlin.math.min /** - * Class for streaming an encrypted push attachment off disk. + * Decrypts an attachment stream that has been encrypted with AES/CBC/PKCS5Padding. * - * @author Moxie Marlinspike + * It assumes that the first 16 bytes of the stream are the IV, and that the rest of the stream is encrypted data. */ -class AttachmentCipherInputStream private constructor( - inputStream: InputStream, - aesKey: ByteArray, - private val totalDataSize: Long -) : FilterInputStream(inputStream) { +object AttachmentCipherInputStream { - private val cipher: Cipher + private const val BLOCK_SIZE = 16 + private const val CIPHER_KEY_SIZE = 32 + private const val MAC_KEY_SIZE = 32 - private var done = false - private var totalRead: Long = 0 - private var overflowBuffer: ByteArray? = null - - init { - val iv = ByteArray(BLOCK_SIZE) - readFullyWithoutDecrypting(iv) - - this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") - cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv)) + /** + * Creates a stream to decrypt a typical attachment via a [File]. + * + * @param incrementalDigest If null, incremental mac validation is disabled. + * @param incrementalMacChunkSize If 0, incremental mac validation is disabled. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForAttachment( + file: File, + plaintextLength: Long, + combinedKeyMaterial: ByteArray, + digest: ByteArray, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int + ): LimitedInputStream { + return create( + streamSupplier = { FileInputStream(file) }, + streamLength = file.length(), + plaintextLength = plaintextLength, + combinedKeyMaterial = combinedKeyMaterial, + digest = digest, + incrementalDigest = incrementalDigest, + incrementalMacChunkSize = incrementalMacChunkSize, + ignoreDigest = false + ) } - @Throws(IOException::class) - override fun read(): Int { - val buffer = ByteArray(1) - var read: Int = read(buffer) - while (read == 0) { - read = read(buffer) + /** + * Creates a stream to decrypt a typical attachment via a [StreamSupplier]. + * + * @param incrementalDigest If null, incremental mac validation is disabled. + * @param incrementalMacChunkSize If 0, incremental mac validation is disabled. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForAttachment( + streamSupplier: StreamSupplier, + streamLength: Long, + plaintextLength: Long, + combinedKeyMaterial: ByteArray, + digest: ByteArray, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int + ): LimitedInputStream { + return create( + streamSupplier = streamSupplier, + streamLength = streamLength, + plaintextLength = plaintextLength, + combinedKeyMaterial = combinedKeyMaterial, + digest = digest, + incrementalDigest = incrementalDigest, + incrementalMacChunkSize = incrementalMacChunkSize, + ignoreDigest = false + ) + } + + /** + * After removing the server layer of encryption using [createForArchivedMediaOuterLayer], use this to decrypt the inner layer of the attachment. + * Thumbnails have a special path because we don't do any additional digest/hash validation on them. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForArchiveThumbnailInnerLayer( + file: File, + plaintextLength: Long, + combinedKeyMaterial: ByteArray + ): LimitedInputStream { + return create( + streamSupplier = { FileInputStream(file) }, + streamLength = file.length(), + plaintextLength = plaintextLength, + combinedKeyMaterial = combinedKeyMaterial, + digest = null, + incrementalDigest = null, + incrementalMacChunkSize = 0, + ignoreDigest = true + ) + } + + /** + * When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption. + * This will return a stream that unwraps the server's layer of encryption, giving you a stream that contains a "normally-encrypted" attachment. + * + * Because we're validating the encryptedDigest/plaintextHash of the inner layer, there's no additional out-of-band validation of this outer layer. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForArchivedMediaOuterLayer(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream { + val mac = initMac(archivedMediaKeyMaterial.macKey) + + if (file.length() <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") } - if (read == -1) { - return read + FileInputStream(file).use { macVerificationStream -> + verifyMac(macVerificationStream, file.length(), mac, null) } - return buffer[0].toInt() and 0xFF + val encryptedStream = FileInputStream(file) + val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, file.length() - mac.macLength) + val cipher = createCipher(encryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey) + val inputStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher) + + return LimitedInputStream(inputStream, originalCipherTextLength) } - @Throws(IOException::class) - override fun read(@Nonnull buffer: ByteArray): Int { - return read(buffer, 0, buffer.size) - } + /** + * When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption. + * + * This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it. + * + * @param incrementalDigest If null, incremental mac validation is disabled. + * @param incrementalMacChunkSize If 0, incremental mac validation is disabled. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForArchivedMediaOuterAndInnerLayers( + archivedMediaKeyMaterial: MediaKeyMaterial, + file: File, + originalCipherTextLength: Long, + plaintextLength: Long, + combinedKeyMaterial: ByteArray, + digest: ByteArray, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int + ): LimitedInputStream { + val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) + val mac = initMac(keyMaterial.macKey) - @Throws(IOException::class) - override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int { - return if (totalRead != totalDataSize) { - readIncremental(buffer, offset, length) - } else if (!done) { - readFinal(buffer, offset, length) - } else { - -1 - } - } - - override fun markSupported(): Boolean = false - - @Throws(IOException::class) - override fun skip(byteCount: Long): Long { - var skipped = 0L - while (skipped < byteCount) { - val remaining = byteCount - skipped - val buffer = ByteArray(min(4096, remaining.toInt())) - val read = read(buffer) - - skipped += read.toLong() + if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") } - return skipped + return create( + streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) }, + streamLength = originalCipherTextLength, + plaintextLength = plaintextLength, + combinedKeyMaterial = combinedKeyMaterial, + digest = digest, + incrementalDigest = incrementalDigest, + incrementalMacChunkSize = incrementalMacChunkSize, + ignoreDigest = false + ) } - @Throws(IOException::class) - private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int { - var offset = originalOffset - var length = originalLength - var readLength = 0 + /** + * Creates a stream to decrypt sticker data. Stickers have a special path because the key material is derived from the pack key. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForStickerData(data: ByteArray, packKey: ByteArray): InputStream { + val keyMaterial = CombinedKeyMaterial.from(HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64)) + val mac = initMac(keyMaterial.macKey) - overflowBuffer?.let { overflow -> - if (overflow.size > length) { - overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) - overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size) - return length - } else if (overflow.size == length) { - overflow.copyInto(destination = outputBuffer, destinationOffset = offset) - overflowBuffer = null - return length - } else { - overflow.copyInto(destination = outputBuffer, destinationOffset = offset) - readLength += overflow.size - offset += readLength - length -= readLength - overflowBuffer = null + if (data.size <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") + } + + ByteArrayInputStream(data).use { inputStream -> + verifyMac(inputStream, data.size.toLong(), mac, null) + } + + val encryptedStream = ByteArrayInputStream(data) + val encryptedStreamExcludingMac = LimitedInputStream(encryptedStream, data.size.toLong() - mac.macLength) + val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey) + + return BetterCipherInputStream(encryptedStreamExcludingMac, cipher) + } + + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + private fun create( + streamSupplier: StreamSupplier, + streamLength: Long, + plaintextLength: Long, + combinedKeyMaterial: ByteArray, + digest: ByteArray?, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int, + ignoreDigest: Boolean + ): LimitedInputStream { + val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) + val mac = initMac(keyMaterial.macKey) + + if (streamLength <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") + } + + if (!ignoreDigest && digest == null) { + throw InvalidMessageException("Missing digest!") + } + + val wrappedStream: InputStream + val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0 + + if (!hasIncrementalMac) { + streamSupplier.openStream().use { macVerificationStream -> + verifyMac(macVerificationStream, streamLength, mac, digest) } - } - - if (length + totalRead > totalDataSize) { - length = (totalDataSize - totalRead).toInt() - } - - val ciphertextBuffer = ByteArray(length) - val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) { - ciphertextBuffer.size + wrappedStream = streamSupplier.openStream() } else { - // Ensure we leave the final block for readFinal() - ciphertextBuffer.size - cipher.blockSize - } - val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength) - totalRead += ciphertextRead.toLong() + if (digest == null) { + throw InvalidMessageException("Missing digest for incremental mac validation!") + } + wrappedStream = IncrementalMacInputStream( + IncrementalMacAdditionalValidationsInputStream( + wrapped = streamSupplier.openStream(), + fileLength = streamLength, + mac = mac, + theirDigest = digest + ), + keyMaterial.macKey, + ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), + incrementalDigest + ) + } + + val encryptedStreamExcludingMac = LimitedInputStream(wrappedStream, streamLength - mac.macLength) + val cipher = createCipher(encryptedStreamExcludingMac, keyMaterial.aesKey) + val decryptingStream: InputStream = BetterCipherInputStream(encryptedStreamExcludingMac, cipher) + + return LimitedInputStream(decryptingStream, plaintextLength) + } + + private fun createCipher(inputStream: InputStream, aesKey: ByteArray): Cipher { + val iv = inputStream.readNBytesOrThrow(BLOCK_SIZE) + + return Cipher.getInstance("AES/CBC/PKCS5Padding").apply { + init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv)) + } + } + + private fun initMac(key: ByteArray): Mac { try { - var plaintextLength = cipher.getOutputSize(ciphertextRead) - - if (plaintextLength <= length) { - readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset) - return readLength - } - - val plaintextBuffer = ByteArray(plaintextLength) - plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0) - if (plaintextLength <= length) { - plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength) - readLength += plaintextLength - } else { - plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) - overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength) - readLength += length - } - return readLength - } catch (e: ShortBufferException) { + val mac = Mac.getInstance("HmacSHA256") + mac.init(SecretKeySpec(key, "HmacSHA256")) + return mac + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) + } catch (e: InvalidKeyException) { throw AssertionError(e) } } - @Throws(IOException::class) - private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int { + @Throws(InvalidMessageException::class) + private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { try { - val internal = ByteArray(buffer.size) - val actualLength = min(length, cipher.doFinal(internal, 0)) - internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength) + val digest = MessageDigest.getInstance("SHA256") + var remainingData = Util.toIntExact(length) - mac.macLength + val buffer = ByteArray(4096) - done = true - return actualLength - } catch (e: IllegalBlockSizeException) { - throw IOException(e) - } catch (e: BadPaddingException) { - throw IOException(e) - } catch (e: ShortBufferException) { - throw IOException(e) + while (remainingData > 0) { + val read = inputStream.read(buffer, 0, min(buffer.size, remainingData)) + mac.update(buffer, 0, read) + digest.update(buffer, 0, read) + remainingData -= read + } + + val ourMac = mac.doFinal() + val theirMac = ByteArray(mac.macLength) + Util.readFully(inputStream, theirMac) + + if (!MessageDigest.isEqual(ourMac, theirMac)) { + throw InvalidMessageException("MAC doesn't match!") + } + + val ourDigest = digest.digest(theirMac) + + if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) { + throw InvalidMessageException("Digest doesn't match!") + } + } catch (e: IOException) { + throw InvalidMessageException(e) + } catch (e: ArithmeticException) { + throw InvalidMessageException(e) + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) } } - @Throws(IOException::class) - private fun readFullyWithoutDecrypting(buffer: ByteArray) { - var offset = 0 - - while (true) { - val read = super.read(buffer, offset, buffer.size - offset) - - if (read + offset < buffer.size) { - offset += read - } else { - return + private class CombinedKeyMaterial(val aesKey: ByteArray, val macKey: ByteArray) { + companion object { + fun from(combinedKeyMaterial: ByteArray): CombinedKeyMaterial { + val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) + return CombinedKeyMaterial(parts[0], parts[1]) } } } @@ -204,207 +335,4 @@ class AttachmentCipherInputStream private constructor( @Throws(IOException::class) fun openStream(): InputStream } - - companion object { - private const val BLOCK_SIZE = 16 - private const val CIPHER_KEY_SIZE = 32 - private const val MAC_KEY_SIZE = 32 - - /** - * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. - * - * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST - */ - @JvmStatic - @JvmOverloads - @Throws(InvalidMessageException::class, IOException::class) - fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream { - return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest) - } - - /** - * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. - * - * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST - */ - @JvmStatic - @Throws(InvalidMessageException::class, IOException::class) - fun createForAttachment( - streamSupplier: StreamSupplier, - streamLength: Long, - plaintextLength: Long, - combinedKeyMaterial: ByteArray?, - digest: ByteArray?, - incrementalDigest: ByteArray?, - incrementalMacChunkSize: Int, - ignoreDigest: Boolean - ): LimitedInputStream { - val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) - val mac = initMac(parts[1]) - - if (streamLength <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") - } - - if (!ignoreDigest && digest == null) { - throw InvalidMessageException("Missing digest!") - } - - val wrappedStream: InputStream - val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0 - - if (!hasIncrementalMac) { - streamSupplier.openStream().use { macVerificationStream -> - verifyMac(macVerificationStream, streamLength, mac, digest) - } - wrappedStream = streamSupplier.openStream() - } else { - wrappedStream = IncrementalMacInputStream( - IncrementalMacAdditionalValidationsInputStream( - streamSupplier.openStream(), - streamLength, - mac, - digest!! - ), - parts[1], - ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), - incrementalDigest - ) - } - val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength) - - return LimitedInputStream(inputStream, plaintextLength) - } - - /** - * Decrypt archived media to it's original attachment encrypted blob. - */ - @JvmStatic - @Throws(InvalidMessageException::class, IOException::class) - fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream { - val mac = initMac(archivedMediaKeyMaterial.macKey) - - if (file.length() <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead!") - } - - FileInputStream(file).use { macVerificationStream -> - verifyMac(macVerificationStream, file.length(), mac, null) - } - val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength) - - return LimitedInputStream(inputStream, originalCipherTextLength) - } - - @JvmStatic - @Throws(InvalidMessageException::class, IOException::class) - fun createStreamingForArchivedAttachment( - archivedMediaKeyMaterial: MediaKeyMaterial, - file: File, - originalCipherTextLength: Long, - plaintextLength: Long, - combinedKeyMaterial: ByteArray?, - digest: ByteArray, - incrementalDigest: ByteArray?, - incrementalMacChunkSize: Int - ): LimitedInputStream { - val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength) - - val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) - val mac = initMac(parts[1]) - - if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead!") - } - - val wrappedStream: InputStream = IncrementalMacInputStream( - IncrementalMacAdditionalValidationsInputStream( - wrapped = archiveStream, - fileLength = file.length(), - mac = mac, - theirDigest = digest - ), - parts[1], - ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), - incrementalDigest - ) - - val inputStream: InputStream = AttachmentCipherInputStream( - inputStream = wrappedStream, - aesKey = parts[0], - totalDataSize = file.length() - BLOCK_SIZE - mac.macLength - ) - - return if (plaintextLength != 0L) { - LimitedInputStream(inputStream, plaintextLength) - } else { - withoutLimits(inputStream) - } - } - - @JvmStatic - @Throws(InvalidMessageException::class, IOException::class) - fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream { - val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64) - val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) - val mac = initMac(parts[1]) - - if (data.size <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead!") - } - - ByteArrayInputStream(data).use { inputStream -> - verifyMac(inputStream, data.size.toLong(), mac, null) - } - return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong()) - } - - private fun initMac(key: ByteArray): Mac { - try { - val mac = Mac.getInstance("HmacSHA256") - mac.init(SecretKeySpec(key, "HmacSHA256")) - return mac - } catch (e: NoSuchAlgorithmException) { - throw AssertionError(e) - } catch (e: InvalidKeyException) { - throw AssertionError(e) - } - } - - @Throws(InvalidMessageException::class) - private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { - try { - val digest = MessageDigest.getInstance("SHA256") - var remainingData = Util.toIntExact(length) - mac.macLength - val buffer = ByteArray(4096) - - while (remainingData > 0) { - val read = inputStream.read(buffer, 0, min(buffer.size, remainingData)) - mac.update(buffer, 0, read) - digest.update(buffer, 0, read) - remainingData -= read - } - - val ourMac = mac.doFinal() - val theirMac = ByteArray(mac.macLength) - Util.readFully(inputStream, theirMac) - - if (!MessageDigest.isEqual(ourMac, theirMac)) { - throw InvalidMessageException("MAC doesn't match!") - } - - val ourDigest = digest.digest(theirMac) - - if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) { - throw InvalidMessageException("Digest doesn't match!") - } - } catch (e: IOException) { - throw InvalidMessageException(e) - } catch (e: ArithmeticException) { - throw InvalidMessageException(e) - } catch (e: NoSuchAlgorithmException) { - throw AssertionError(e) - } - } - } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/BetterCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/BetterCipherInputStream.kt new file mode 100644 index 0000000000..1755f23861 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/BetterCipherInputStream.kt @@ -0,0 +1,148 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.crypto + +import java.io.FilterInputStream +import java.io.IOException +import java.io.InputStream +import javax.annotation.Nonnull +import javax.crypto.BadPaddingException +import javax.crypto.Cipher +import javax.crypto.IllegalBlockSizeException +import javax.crypto.ShortBufferException +import kotlin.math.min + +/** + * This is similar to [javax.crypto.CipherInputStream], but it fixes various issues, including proper error propagation, + * and proper handling of boundary conditions. + */ +class BetterCipherInputStream( + inputStream: InputStream, + val cipher: Cipher +) : FilterInputStream(inputStream) { + + private var done = false + private var overflowBuffer: ByteArray? = null + + @Throws(IOException::class) + override fun read(): Int { + val buffer = ByteArray(1) + var read: Int = read(buffer) + while (read == 0) { + read = read(buffer) + } + + if (read == -1) { + return read + } + + return buffer[0].toInt() and 0xFF + } + + @Throws(IOException::class) + override fun read(@Nonnull buffer: ByteArray): Int { + return read(buffer, 0, buffer.size) + } + + @Throws(IOException::class) + override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int { + return if (!done) { + readIncremental(buffer, offset, length) + } else { + -1 + } + } + + override fun markSupported(): Boolean = false + + @Throws(IOException::class) + override fun skip(byteCount: Long): Long { + val buffer = ByteArray(4096) + var skipped = 0L + + while (skipped < byteCount) { + val remaining = byteCount - skipped + val read = read(buffer, 0, remaining.toInt()) + + skipped += read.toLong() + } + + return skipped + } + + @Throws(IOException::class) + private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int { + var offset = originalOffset + var length = originalLength + var readLength = 0 + + overflowBuffer?.let { overflow -> + if (overflow.size > length) { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) + overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size) + return length + } else if (overflow.size == length) { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset) + overflowBuffer = null + return length + } else { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset) + readLength += overflow.size + offset += readLength + length -= readLength + overflowBuffer = null + } + } + + val ciphertextBuffer = ByteArray(length) + val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextBuffer.size) + + if (ciphertextRead == -1) { + return readFinal(outputBuffer, offset, length) + } + + try { + var plaintextLength = cipher.getOutputSize(ciphertextRead) + + if (plaintextLength <= length) { + readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset) + return readLength + } + + val plaintextBuffer = ByteArray(plaintextLength) + plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0) + if (plaintextLength <= length) { + plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength) + readLength += plaintextLength + } else { + plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) + overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength) + readLength += length + } + return readLength + } catch (e: ShortBufferException) { + throw AssertionError(e) + } + } + + @Throws(IOException::class) + private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int { + try { + val internal = ByteArray(buffer.size) + val actualLength = min(length, cipher.doFinal(internal, 0)) + internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength) + + done = true + return actualLength + } catch (e: IllegalBlockSizeException) { + throw IOException(e) + } catch (e: BadPaddingException) { + throw IOException(e) + } catch (e: ShortBufferException) { + throw IOException(e) + } + } +} diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt index 2af0614ba6..0ac129f145 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt @@ -161,32 +161,6 @@ class AttachmentCipherTest { } } - @Test(expected = InvalidMessageException::class) - fun attachment_decryptFailOnNullDigest_nonIncremental() { - attachment_decryptFailOnNullDigest(incremental = false) - } - - @Test(expected = InvalidMessageException::class) - fun attachment_decryptFailOnNullDigest_incremental() { - attachment_decryptFailOnNullDigest(incremental = true) - } - - private fun attachment_decryptFailOnNullDigest(incremental: Boolean) { - var cipherFile: File? = null - - try { - val key = Util.getSecretBytes(64) - val plaintextInput = Util.getSecretBytes(MEBIBYTE) - val encryptResult = encryptData(plaintextInput, key, incremental) - - cipherFile = writeToFile(encryptResult.ciphertext) - - AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, null, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - } finally { - cipherFile?.delete() - } - } - @Test(expected = InvalidMessageException::class) fun attachment_decryptFailOnBadDigest_nonIncremental() { attachment_decryptFailOnBadDigest(incremental = false) @@ -293,7 +267,7 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, false) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) + val inputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong()) val plaintextOutput = readInputStreamFully(inputStream) assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -310,7 +284,7 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, false) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) + val inputStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong()) val plaintextOutput = readInputStreamFully(inputStream) assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -332,7 +306,7 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, false) cipherFile = writeToFile(encryptResult.ciphertext) - AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) + AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong()) } catch (e: InvalidMessageException) { hitCorrectException = true } finally { @@ -368,7 +342,7 @@ class AttachmentCipherTest { val cipherFile = writeToFile(encryptedData) val keyMaterial = createMediaKeyMaterial(key) - val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, length.toLong()) + val decryptedStream: InputStream = AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, length.toLong()) val plaintextOutput = readInputStreamFully(decryptedStream) Assert.assertArrayEquals(plaintextInput, plaintextOutput) @@ -377,6 +351,60 @@ class AttachmentCipherTest { } } + @Test + fun archiveInnerAndOuter_encryptDecrypt_nonIncremental() { + archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) + } + + @Test + fun archiveInnerAndOuter_encryptDecrypt_incremental() { + archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) + } + + @Test + fun archiveInnerAndOuter_encryptDecrypt_nonIncremental_manyFileSizes() { + for (i in 0..99) { + archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + } + } + + @Test + fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() { + // Designed to stress the various boundary conditions of reading the final mac + for (i in 0..99) { + archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + } + } + + private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) { + val innerKey = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(fileSize) + + val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental) + val outerKey = Util.getSecretBytes(64) + + val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false) + val cipherFile = writeToFile(outerEncryptResult.ciphertext) + + val keyMaterial = createMediaKeyMaterial(outerKey) + val decryptedStream: LimitedInputStream = AttachmentCipherInputStream.createForArchivedMediaOuterAndInnerLayers( + archivedMediaKeyMaterial = keyMaterial, + file = cipherFile, + originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), + plaintextLength = plaintextInput.size.toLong(), + combinedKeyMaterial = innerKey, + digest = innerEncryptResult.digest, + incrementalDigest = innerEncryptResult.incrementalDigest, + incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice + ) + val plaintextOutput = decryptedStream.readFully(autoClose = false) + + assertThat(plaintextOutput).isEqualTo(plaintextInput) + assertThat(decryptedStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() + + cipherFile.delete() + } + @Test fun archive_decryptFailOnBadMac() { var cipherFile: File? = null @@ -393,7 +421,7 @@ class AttachmentCipherTest { cipherFile = writeToFile(badMacCiphertext) val keyMaterial = createMediaKeyMaterial(key) - AttachmentCipherInputStream.createForArchivedMedia(keyMaterial, cipherFile, plaintextInput.size.toLong()) + AttachmentCipherInputStream.createForArchivedMediaOuterLayer(keyMaterial, cipherFile, plaintextInput.size.toLong()) Assert.fail() } catch (e: InvalidMessageException) { hitCorrectException = true