diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/backups/BackupsSettingsFragment.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/backups/BackupsSettingsFragment.kt index 2320fcef0b..2adeb6a92a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/backups/BackupsSettingsFragment.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/backups/BackupsSettingsFragment.kt @@ -131,7 +131,7 @@ private fun BackupsSettingsContent( item { Column(modifier = Modifier.padding(horizontal = dimensionResource(id = org.signal.core.ui.R.dimen.gutter))) { Text( - text = "INTERNAL ONLY", + text = "ALPHA ONLY", style = MaterialTheme.typography.titleMedium ) Text( diff --git a/core-util-jvm/src/main/java/org/signal/core/util/stream/TrimmingInputStream.kt b/core-util-jvm/src/main/java/org/signal/core/util/stream/TrimmingInputStream.kt new file mode 100644 index 0000000000..82b49c4add --- /dev/null +++ b/core-util-jvm/src/main/java/org/signal/core/util/stream/TrimmingInputStream.kt @@ -0,0 +1,148 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.core.util.stream + +import org.signal.core.util.drain +import java.io.FilterInputStream +import java.io.IOException +import java.io.InputStream +import kotlin.math.min + +/** + * An input stream that will read all but the last [trimSize] bytes of the stream. + * + * Important: we have to keep a buffer of size [trimSize] to ensure that we can avoid reading it. + * That means you should avoid using this for very large values of [trimSize]. + * + * @param drain If true, the stream will be drained when it reaches the end (but bytes won't be returned). This is useful for ensuring that the underlying + * stream is fully consumed. + */ +class TrimmingInputStream( + private val inputStream: InputStream, + private val trimSize: Int, + private val drain: Boolean = false +) : FilterInputStream(inputStream) { + + private val trimBuffer = ByteArray(trimSize) + private var trimBufferSize: Int = 0 + private var streamEnded = false + private var hasDrained = false + + private var internalBuffer = ByteArray(4096) + private var internalBufferPosition: Int = 0 + private var internalBufferSize: Int = 0 + + @Throws(IOException::class) + override fun read(): Int { + val singleByteBuffer = ByteArray(1) + val bytesRead = read(singleByteBuffer, 0, 1) + return if (bytesRead == -1) { + -1 + } else { + singleByteBuffer[0].toInt() and 0xFF + } + } + + @Throws(IOException::class) + override fun read(b: ByteArray): Int { + return read(b, 0, b.size) + } + + /** + * The general strategy is that we do bulk reads into an internal buffer (just for perf reasons), and then when new bytes are requested, + * we fill up a buffer of size [trimSize] with the most recent bytes, and then return the oldest byte from that buffer. + * + * This ensures that the last [trimSize] bytes are never returned, while still returning the rest of the bytes. + * + * When we hit the end of the stream, we stop returning bytes. + */ + @Throws(IOException::class) + override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int { + if (streamEnded) { + return -1 + } + + if (trimSize == 0) { + return super.read(outputBuffer, outputOffset, readLength) + } + + var outputCount = 0 + + while (outputCount < readLength) { + val nextByte = readNextByte() + + if (nextByte == -1) { + streamEnded = true + drainIfNecessary() + break + } + + if (trimBufferSize < trimSize) { + // Still filling the buffer - can't output anything yet + trimBuffer[trimBufferSize] = nextByte.toByte() + trimBufferSize++ + } else { + // Buffer is full - output the oldest byte and add the new one + outputBuffer[outputOffset + outputCount] = trimBuffer[0] + outputCount++ + + // Shift buffer left and add new byte at the end. In practice, this is a tiny array and copies should be fast. + System.arraycopy(trimBuffer, 1, trimBuffer, 0, trimSize - 1) + trimBuffer[trimSize - 1] = nextByte.toByte() + } + } + + return if (outputCount == 0) { + drainIfNecessary() + -1 + } else { + outputCount + } + } + + @Throws(IOException::class) + override fun skip(skipCount: Long): Long { + if (skipCount <= 0) return 0 + + var totalSkipped = 0L + val buffer = ByteArray(8192) + + while (totalSkipped < skipCount) { + val toRead = min((skipCount - totalSkipped).toInt(), buffer.size) + val bytesRead = read(buffer, 0, toRead) + if (bytesRead == -1) { + break + } + totalSkipped += bytesRead + } + + return totalSkipped + } + + private fun readNextByte(): Int { + val hitEndOfStream = if (internalBufferPosition >= internalBufferSize) { + internalBufferPosition = 0 + internalBufferSize = super.read(internalBuffer, 0, internalBuffer.size) + internalBufferSize <= 0 + } else { + false + } + + if (hitEndOfStream) { + drainIfNecessary() + return -1 + } + + return internalBuffer[internalBufferPosition++].toInt() and 0xFF + } + + private fun drainIfNecessary() { + if (drain && !hasDrained) { + inputStream.drain() + hasDrained = true + } + } +} diff --git a/core-util-jvm/src/test/java/org/signal/core/util/stream/TrimmingInputStreamTest.kt b/core-util-jvm/src/test/java/org/signal/core/util/stream/TrimmingInputStreamTest.kt new file mode 100644 index 0000000000..b8f9ca36cc --- /dev/null +++ b/core-util-jvm/src/test/java/org/signal/core/util/stream/TrimmingInputStreamTest.kt @@ -0,0 +1,140 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.signal.core.util.stream + +import assertk.assertThat +import assertk.assertions.isEqualTo +import org.junit.Test +import org.signal.core.util.readFully +import kotlin.math.min +import kotlin.random.Random + +class TrimmingInputStreamTest { + + @Test + fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes`() { + val initialData = testData(100) + val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25) + val data = inputStream.readFully() + + assertThat(data.size).isEqualTo(75) + assertThat(data).isEqualTo(initialData.copyOfRange(0, 75)) + } + + @Test + fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes - many sizes`() { + for (i in 1..100) { + val arraySize = Random.nextInt(1024, 2 * 1024 * 1024) + val trimSize = min(arraySize, Random.nextInt(1024)) + + val initialData = testData(arraySize) + val innerStream = initialData.inputStream() + val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize) + val data = inputStream.readFully() + + assertThat(data.size).isEqualTo(arraySize - trimSize) + assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize)) + } + } + + @Test + fun `when I fully read the stream via a buffer with drain set, I should exclude the last trimSize bytes but still drain the remaining stream - many sizes`() { + for (i in 1..100) { + val arraySize = Random.nextInt(1024, 2 * 1024 * 1024) + val trimSize = min(arraySize, Random.nextInt(1024)) + + val initialData = testData(arraySize) + val innerStream = initialData.inputStream() + val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize, drain = true) + val data = inputStream.readFully() + + assertThat(data.size).isEqualTo(arraySize - trimSize) + assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize)) + assertThat(innerStream.available()).isEqualTo(0) + } + } + + @Test + fun `when I fully read the stream and the trimSize is greater than the stream length, I should get zero bytes`() { + val initialData = testData(100) + val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 200) + val data = inputStream.readFully() + + assertThat(data.size).isEqualTo(0) + } + + @Test + fun `when I fully read the stream via a buffer with no trimSize, I should get all bytes`() { + val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0) + val data = inputStream.readFully() + + assertThat(data.size).isEqualTo(100) + } + + @Test + fun `when I fully read the stream one byte at a time, I should exclude the last trimSize bytes`() { + val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25) + + var count = 0 + var lastRead = inputStream.read() + while (lastRead != -1) { + count++ + lastRead = inputStream.read() + } + + assertThat(count).isEqualTo(75) + } + + @Test + fun `when I fully read the stream one byte at a time with no trimSize, I should get all bytes`() { + val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0) + + var count = 0 + var lastRead = inputStream.read() + while (lastRead != -1) { + count++ + lastRead = inputStream.read() + } + + assertThat(count).isEqualTo(100) + } + + @Test + fun `when I skip past the the trimSize, I should get -1`() { + val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25) + + val skipCount = inputStream.skip(100) + val read = inputStream.read() + + assertThat(skipCount).isEqualTo(75) + assertThat(read).isEqualTo(-1) + } + + @Test + fun `when I skip, I should still truncate correctly afterwards`() { + val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25) + + val skipCount = inputStream.skip(50) + val data = inputStream.readFully() + + assertThat(skipCount).isEqualTo(50) + assertThat(data.size).isEqualTo(25) + } + + @Test + fun `when I skip more than the remaining bytes, I still respect trimSize`() { + val initialData = testData(100) + val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25) + + val skipCount = inputStream.skip(100) + + assertThat(skipCount).isEqualTo(75) + } + + private fun testData(length: Int): ByteArray { + return ByteArray(length) { (it % 0xFF).toByte() } + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 785c494211..52ab2be071 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -189,16 +189,9 @@ public class SignalServiceMessageReceiver { socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener); - long originalCipherLength = pointer.getSize() - .filter(s -> s > 0) - .map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s))) - .orElse(0L); - return AttachmentCipherInputStream.createForArchivedThumbnail( archivedMediaKeyMaterial, archiveDestination, - originalCipherLength, - pointer.getSize().orElse(0), pointer.getKey() ); } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt index 4fe85d127d..cfc8256d21 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt @@ -8,6 +8,7 @@ package org.whispersystems.signalservice.api.crypto import org.signal.core.util.Base64 import org.signal.core.util.readNBytesOrThrow import org.signal.core.util.stream.LimitedInputStream +import org.signal.core.util.stream.TrimmingInputStream import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream @@ -134,38 +135,45 @@ object AttachmentCipherInputStream { } /** - * When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption. + * When you archive an attachment thumbnail, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption. * * This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it. * - * @param incrementalDigest If null, incremental mac validation is disabled. - * @param incrementalMacChunkSize If 0, incremental mac validation is disabled. + * Archive thumbnails are also special in that we: + * - don't know how long they are (meaning you'll get them back with padding at the end, image viewers are ok with this) + * - don't care about external integrity checks (we still validate the MACs) + * + * So there's some code duplication here just to avoid mucking up the reusable functions with special cases. */ @JvmStatic @Throws(InvalidMessageException::class, IOException::class) fun createForArchivedThumbnail( archivedMediaKeyMaterial: MediaKeyMaterial, file: File, - originalCipherTextLength: Long, - plaintextLength: Long, - combinedKeyMaterial: ByteArray + innerCombinedKeyMaterial: ByteArray ): InputStream { - val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial) - val mac = initMac(keyMaterial.macKey) + val outerMac = initMac(archivedMediaKeyMaterial.macKey) - if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { - throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength") + if (file.length() <= BLOCK_SIZE + outerMac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + outerMac.macLength} bytes, got ${file.length()}") } - return create( - streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) }, - streamLength = originalCipherTextLength, - plaintextLength = plaintextLength, - combinedKeyMaterial = combinedKeyMaterial, - integrityCheck = null, - incrementalDigest = null, - incrementalMacChunkSize = 0 - ) + FileInputStream(file).use { macVerificationStream -> + verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), outerMac, null) + } + + val outerEncryptedStreamExcludingMac = LimitedInputStream(FileInputStream(file), maxBytes = file.length() - outerMac.macLength) + val outerCipher = createCipher(outerEncryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey) + val innerEncryptedStream = BetterCipherInputStream(outerEncryptedStreamExcludingMac, outerCipher) + + val innerKeyMaterial = CombinedKeyMaterial.from(innerCombinedKeyMaterial) + val innerMac = initMac(innerKeyMaterial.macKey) + + val innerEncryptedStreamWithMac = MacValidatingInputStream(innerEncryptedStream, innerMac) + val innerEncryptedStreamExcludingMac = TrimmingInputStream(innerEncryptedStreamWithMac, trimSize = innerMac.macLength, drain = true) + val innerCipher = createCipher(innerEncryptedStreamExcludingMac, innerKeyMaterial.aesKey) + + return BetterCipherInputStream(innerEncryptedStreamExcludingMac, innerCipher) } /** diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStream.kt new file mode 100644 index 0000000000..7dd773a67c --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStream.kt @@ -0,0 +1,140 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.crypto + +import org.jetbrains.annotations.VisibleForTesting +import org.signal.core.util.stream.LimitedInputStream +import org.signal.core.util.stream.TrimmingInputStream +import org.signal.libsignal.protocol.InvalidMessageException +import java.io.FilterInputStream +import java.io.IOException +import java.io.InputStream +import java.security.MessageDigest +import javax.crypto.Mac + +/** + * An InputStream that validates a MAC appended to the end of the stream data. + * This stream will not exclude the MAC from the data it reads, meaning that you may want to pair this with a [LimitedInputStream] or a [TrimmingInputStream] + * if you don't want to read that data to be a part of it. + * + * Important: The MAC is only validated once the stream has been fully read. + * + * @param inputStream The underlying InputStream to read from + * @param mac The Mac instance to use for validation + */ +class MacValidatingInputStream( + inputStream: InputStream, + private val mac: Mac +) : FilterInputStream(inputStream) { + + private val macBuffer = ByteArray(mac.macLength) + private val macLength = mac.macLength + private var macBufferPosition = 0 + private var streamEnded = false + + @VisibleForTesting + var validationAttempted = false + private set + + @Throws(IOException::class) + override fun read(): Int { + val singleByteBuffer = ByteArray(1) + val bytesRead = read(singleByteBuffer, 0, 1) + return if (bytesRead == -1) -1 else singleByteBuffer[0].toInt() and 0xFF + } + + @Throws(IOException::class) + override fun read(b: ByteArray): Int { + return read(b, 0, b.size) + } + + @Throws(IOException::class) + override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int { + if (streamEnded) { + return -1 + } + + val bytesRead = super.read(outputBuffer, outputOffset, readLength) + + if (bytesRead == -1) { + // End of stream - check if we have enough data for MAC validation + if (macBufferPosition < macLength) { + throw InvalidMessageException("Stream ended before MAC could be read. Expected $macLength bytes, got $macBufferPosition") + } + validateMacAndMarkStreamEnded() + return -1 + } + + // If we've read more than `macLength` bytes, we can just snag the last `macLength` bytes and digest the rest + if (bytesRead >= macLength) { + // Before replacing the macBuffer, process any pre-existing data + if (macBufferPosition > 0) { + mac.update(macBuffer, 0, macBufferPosition) + macBufferPosition = 0 + } + + // Copy the last `macLength` bytes into the macBuffer + outputBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = outputOffset + bytesRead - macLength, endIndex = outputOffset + bytesRead) + macBufferPosition = macLength + + // Update the mac with the bytes that are not part of the MAC + if (bytesRead > macLength) { + mac.update(outputBuffer, outputOffset, bytesRead - macLength) + } + } else { + val totalBytesAvailable = macBufferPosition + bytesRead + + // If the new bytes we've read don't overflow the buffer, we can just append them, and none of them will be digested + if (totalBytesAvailable <= macLength) { + outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset, endIndex = outputOffset + bytesRead) + macBufferPosition = totalBytesAvailable + } else { + // If we have more bytes than we can hold in the buffer, keep the last `macLength` bytes and digest the rest + + // We know that `bytesRead` is less than `macLength`, so we know all of `bytesRead` should go into the buffer + // And we know that the buffer usage + `bytesRead` is greater than `macLength`, so we're guaranteed to be able to digest the first chunk of the buffer. + // We also know that there can't possibly be 0 bytes in the buffer because of how the math of those conditions works out. + + val bytesToDigest = totalBytesAvailable - macLength + + val bytesOfBufferToDigest = minOf(macBufferPosition, bytesToDigest) + val bytesOfReadToDigest = bytesToDigest - bytesOfBufferToDigest + + mac.update(macBuffer, 0, bytesOfBufferToDigest) + macBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = bytesOfBufferToDigest, endIndex = macBufferPosition) + macBufferPosition -= bytesOfBufferToDigest + + if (bytesOfReadToDigest > 0) { + mac.update(outputBuffer, outputOffset, bytesOfReadToDigest) + } + + val bytesOfReadRemaining = bytesRead - bytesOfReadToDigest + if (bytesOfReadRemaining > 0) { + outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset + bytesOfReadToDigest, endIndex = outputOffset + bytesRead) + macBufferPosition += bytesOfReadRemaining + } + } + } + + return bytesRead + } + + @Throws(InvalidMessageException::class) + private fun validateMacAndMarkStreamEnded() { + if (validationAttempted) { + return + } + validationAttempted = true + streamEnded = true + + val calculatedMac = mac.doFinal() + if (!MessageDigest.isEqual(calculatedMac, macBuffer)) { + throw InvalidMessageException("MAC validation failed!") + } + } + + private fun minOf(a: Int, b: Int): Int = if (a < b) a else b +} diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt index eda1409164..fa1645341d 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt @@ -94,7 +94,7 @@ class AttachmentCipherTest { ) } val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - val plaintextOutput = inputStream.readFully(autoClose = false) + val plaintextOutput = inputStream.readFully() assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -111,7 +111,7 @@ class AttachmentCipherTest { val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash) val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - val plaintextOutput = inputStream.readFully(autoClose = false) + val plaintextOutput = inputStream.readFully() assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -137,7 +137,7 @@ class AttachmentCipherTest { val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest) val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - val plaintextOutput = inputStream.readFully(autoClose = false) + val plaintextOutput = inputStream.readFully() Assert.assertArrayEquals(plaintextInput, plaintextOutput) @@ -327,7 +327,7 @@ class AttachmentCipherTest { incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) - val plaintextOutput = decryptedStream.readFully(autoClose = false) + val plaintextOutput = decryptedStream.readFully() assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -382,37 +382,36 @@ class AttachmentCipherTest { @Test fun archive_encryptDecrypt_nonIncremental() { - archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) + archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) } @Test fun archive_encryptDecrypt_incremental() { - archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) + archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) } @Test fun archive_encryptDecrypt_nonIncremental_manyFileSizes() { for (i in 0..99) { - archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) } } @Test - fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() { + fun archive_encryptDecrypt_incremental_manyFileSizes() { // Designed to stress the various boundary conditions of reading the final mac for (i in 0..99) { - archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) } } - private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) { + private fun archive_encryptDecrypt(incremental: Boolean, fileSize: Int) { val innerKey = Util.getSecretBytes(64) val plaintextInput = Util.getSecretBytes(fileSize) - val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental) - val outerKey = Util.getSecretBytes(64) - val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false) + val outerKey = Util.getSecretBytes(64) + val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad val cipherFile = writeToFile(outerEncryptResult.ciphertext) val keyMaterial = createMediaKeyMaterial(outerKey) @@ -426,7 +425,7 @@ class AttachmentCipherTest { incrementalDigest = innerEncryptResult.incrementalDigest, incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice ) - val plaintextOutput = decryptedStream.readFully(autoClose = false) + val plaintextOutput = decryptedStream.readFully() assertThat(plaintextOutput).isEqualTo(plaintextInput) @@ -434,26 +433,56 @@ class AttachmentCipherTest { } @Test - fun archiveEncryptDecrypt_decryptFailOnBadMac() { + fun archiveThumbnail_encryptDecrypt_manyFileSizes() { + for (i in 0..99) { + archiveThumbnail_encryptDecrypt(fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + } + } + + private fun archiveThumbnail_encryptDecrypt(fileSize: Int) { + val innerKey = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(fileSize) + val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = false) + + val outerKey = Util.getSecretBytes(64) + val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad + val cipherFile = writeToFile(outerEncryptResult.ciphertext) + + val keyMaterial = createMediaKeyMaterial(outerKey) + val decryptedStream = AttachmentCipherInputStream.createForArchivedThumbnail( + archivedMediaKeyMaterial = keyMaterial, + file = cipherFile, + innerCombinedKeyMaterial = innerKey + ) + val plaintextOutput = decryptedStream.readFully() + + // We knowingly keep padding on thumbnails, so for the test to work, we strip the padding off, but check to make sure the sizes match up beforehand + assertThat(plaintextOutput.size).isEqualTo(PaddingInputStream.getPaddedSize(plaintextInput.size.toLong()).toInt()) + assertThat(plaintextOutput.copyOfRange(fromIndex = 0, toIndex = plaintextInput.size)).isEqualTo(plaintextInput) + + cipherFile.delete() + } + + @Test + fun archiveEncryptDecrypt_decryptFailOnInnerBadMac() { var cipherFile: File? = null var hitCorrectException = false try { val innerKey = Util.getSecretBytes(64) - val badInnerKey = Util.getSecretBytes(64) val plaintextInput = Util.getSecretBytes(MEBIBYTE) - val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true) + val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad + val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also { + it[it.size - 1] = (it[it.size - 1] + 1).toByte() + } + val outerKey = Util.getSecretBytes(64) + val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false) - val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false) - val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf(outerEncryptResult.ciphertext.size) + cipherFile = writeToFile(outerEncryptResult.ciphertext) - badMacOuterCiphertext[badMacOuterCiphertext.size - 1] = (badMacOuterCiphertext[badMacOuterCiphertext.size - 1] + 1).toByte() - - cipherFile = writeToFile(badMacOuterCiphertext) - - val keyMaterial = createMediaKeyMaterial(badInnerKey) + val keyMaterial = createMediaKeyMaterial(innerKey) AttachmentCipherInputStream.createForArchivedMedia( archivedMediaKeyMaterial = keyMaterial, @@ -476,6 +505,122 @@ class AttachmentCipherTest { Assert.assertTrue(hitCorrectException) } + @Test + fun archiveEncryptDecrypt_decryptFailOnOuterMac() { + var cipherFile: File? = null + var hitCorrectException = false + + try { + val innerKey = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(MEBIBYTE) + + val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad + val outerKey = Util.getSecretBytes(64) + + val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false) + val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also { + it[it.size - 1] = (it[it.size - 1] + 1).toByte() + } + + cipherFile = writeToFile(badMacOuterCiphertext) + + val keyMaterial = createMediaKeyMaterial(innerKey) + + AttachmentCipherInputStream.createForArchivedMedia( + archivedMediaKeyMaterial = keyMaterial, + file = cipherFile, + originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(), + plaintextLength = plaintextInput.size.toLong(), + combinedKeyMaterial = innerKey, + plaintextHash = innerEncryptResult.digest, + incrementalDigest = innerEncryptResult.incrementalDigest, + incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice + ) + + Assert.fail() + } catch (e: InvalidMessageException) { + hitCorrectException = true + } finally { + cipherFile?.delete() + } + + Assert.assertTrue(hitCorrectException) + } + + @Test + fun archiveThumbnailEncryptDecrypt_decryptFailOnInnerBadMac() { + var cipherFile: File? = null + var hitCorrectException = false + + try { + val innerKey = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(MEBIBYTE) + + val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad + val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also { + it[it.size - 1] = (it[it.size - 1] + 1).toByte() + } + + val outerKey = Util.getSecretBytes(64) + val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false) + + cipherFile = writeToFile(outerEncryptResult.ciphertext) + + val keyMaterial = createMediaKeyMaterial(innerKey) + + AttachmentCipherInputStream.createForArchivedThumbnail( + archivedMediaKeyMaterial = keyMaterial, + file = cipherFile, + innerCombinedKeyMaterial = innerKey + ).readFully() + + Assert.fail() + } catch (e: InvalidMessageException) { + hitCorrectException = true + } finally { + cipherFile?.delete() + } + + Assert.assertTrue(hitCorrectException) + } + + @Test + fun archiveThumbnailEncryptDecrypt_decryptFailOnOuterMac() { + var cipherFile: File? = null + var hitCorrectException = false + + try { + val innerKey = Util.getSecretBytes(64) + val plaintextInput = Util.getSecretBytes(MEBIBYTE) + + val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad + val outerKey = Util.getSecretBytes(64) + + val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false) + val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also { + it[it.size - 1] = (it[it.size - 1] + 1).toByte() + } + + cipherFile = writeToFile(badMacOuterCiphertext) + + val keyMaterial = createMediaKeyMaterial(innerKey) + + AttachmentCipherInputStream.createForArchivedThumbnail( + archivedMediaKeyMaterial = keyMaterial, + file = cipherFile, + innerCombinedKeyMaterial = innerKey + ) + + Assert.fail() + } catch (e: InvalidMessageException) { + hitCorrectException = true + } finally { + cipherFile?.delete() + } + + Assert.assertTrue(hitCorrectException) + } + @Test fun sticker_encryptDecrypt() { LibSignalLibraryUtil.assumeLibSignalSupportedOnOS() diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStreamTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStreamTest.kt new file mode 100644 index 0000000000..5d64ced390 --- /dev/null +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/MacValidatingInputStreamTest.kt @@ -0,0 +1,188 @@ +package org.whispersystems.signalservice.api.crypto + +import assertk.assertThat +import assertk.assertions.isEqualTo +import assertk.assertions.isTrue +import assertk.fail +import org.junit.Test +import org.signal.core.util.kibiBytes +import org.signal.core.util.mebiBytes +import org.signal.core.util.readFully +import org.signal.libsignal.protocol.InvalidMessageException +import org.whispersystems.signalservice.internal.util.Util +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec +import kotlin.random.Random + +class MacValidatingInputStreamTest { + + @Test + fun `success - simple byte array read`() { + val data = "Hello, World!".toByteArray() + val key = Util.getSecretBytes(32) + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + val result = macValidatingStream.readFully() + + assertThat(result).isEqualTo(dataWithMac) + macValidatingStream.close() + } + + @Test + fun `success - byte by byte`() { + val data = "Hello, World!".toByteArray() + val key = Util.getSecretBytes(32) + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + val out = ByteArrayOutputStream() + var read = -1 + while (macValidatingStream.read().also { read = it } != -1) { + out.write(read) + } + val result = out.toByteArray() + + assertThat(result).isEqualTo(dataWithMac) + macValidatingStream.close() + } + + @Test + fun `success - many different sizes`() { + for (i in 1..100) { + val data = Util.getSecretBytes(Random.nextLong(from = 256.kibiBytes.bytes, until = 2.mebiBytes.bytes).toInt()) + val key = Util.getSecretBytes(32) + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + val result = macValidatingStream.readFully() + + assertThat(result).isEqualTo(dataWithMac) + assertThat(macValidatingStream.validationAttempted).isTrue() + macValidatingStream.close() + } + } + + @Test + fun `success - empty data`() { + val data = ByteArray(0) + val key = Util.getSecretBytes(32) + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + val result = macValidatingStream.readFully() + + assertThat(result).isEqualTo(dataWithMac) + assertThat(macValidatingStream.validationAttempted).isTrue() + macValidatingStream.close() + } + + @Test + fun `success - data exactly MAC length`() { + val key = Util.getSecretBytes(32) + val mac = createMac(key) + val macLength = mac.macLength + val data = ByteArray(macLength) { (it % 256).toByte() } // Data same size as MAC + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac2 = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac2) + + val result = macValidatingStream.readFully() + + assertThat(result).isEqualTo(dataWithMac) + assertThat(macValidatingStream.validationAttempted).isTrue() + macValidatingStream.close() + } + + @Test + fun `success - multiple reads after end of stream`() { + val data = "Test multiple reads after EOF".toByteArray() + val key = Util.getSecretBytes(32) + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + val result = macValidatingStream.readFully() + + // Multiple calls to read() after EOF should return -1 + assertThat(macValidatingStream.read()).isEqualTo(-1) + assertThat(macValidatingStream.read()).isEqualTo(-1) + assertThat(macValidatingStream.read()).isEqualTo(-1) + + assertThat(result).isEqualTo(dataWithMac) + assertThat(macValidatingStream.validationAttempted).isTrue() + macValidatingStream.close() + } + + @Test + fun `failure - invalid MAC`() { + val data = "Hello, World!".toByteArray() + val key = Util.getSecretBytes(32) + val wrongKey = ByteArray(32) { 24 } + val dataWithMac = createDataWithMac(data, key) + + val inputStream = ByteArrayInputStream(dataWithMac) + val mac = createMac(wrongKey) // Wrong key + val macValidatingStream = MacValidatingInputStream(inputStream, mac) + + try { + macValidatingStream.readFully() + fail("Expected InvalidMessageException to be thrown") + } catch (e: InvalidMessageException) { + assertThat(e.message).isEqualTo("MAC validation failed!") + } finally { + macValidatingStream.close() + } + } + + @Test + fun `failure - insufficient data for MAC`() { + val key = Util.getSecretBytes(32) + val mac = createMac(key) + val macLength = mac.macLength + val insufficientData = ByteArray(macLength - 1) { 5 } // Less than MAC length + + val inputStream = ByteArrayInputStream(insufficientData) + val mac2 = createMac(key) + val macValidatingStream = MacValidatingInputStream(inputStream, mac2) + + try { + macValidatingStream.readFully() + fail("Expected InvalidMessageException to be thrown") + } catch (e: InvalidMessageException) { + assertThat(e.message).isEqualTo("Stream ended before MAC could be read. Expected $macLength bytes, got ${insufficientData.size}") + } finally { + macValidatingStream.close() + } + } + + private fun createMac(key: ByteArray): Mac { + val mac = Mac.getInstance("HmacSHA256") + mac.init(SecretKeySpec(key, "HmacSHA256")) + return mac + } + + private fun createDataWithMac(data: ByteArray, key: ByteArray): ByteArray { + val mac = createMac(key) + val macBytes = mac.doFinal(data) + return data + macBytes + } +}