Fix archive thumbnail decryption.

This commit is contained in:
Greyson Parrelli
2025-06-25 13:54:28 -04:00
committed by Cody Henthorne
parent b1063f69f9
commit c0340be3ce
8 changed files with 813 additions and 51 deletions

View File

@@ -131,7 +131,7 @@ private fun BackupsSettingsContent(
item {
Column(modifier = Modifier.padding(horizontal = dimensionResource(id = org.signal.core.ui.R.dimen.gutter))) {
Text(
text = "INTERNAL ONLY",
text = "ALPHA ONLY",
style = MaterialTheme.typography.titleMedium
)
Text(

View File

@@ -0,0 +1,148 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.signal.core.util.drain
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import kotlin.math.min
/**
* An input stream that will read all but the last [trimSize] bytes of the stream.
*
* Important: we have to keep a buffer of size [trimSize] to ensure that we can avoid reading it.
* That means you should avoid using this for very large values of [trimSize].
*
* @param drain If true, the stream will be drained when it reaches the end (but bytes won't be returned). This is useful for ensuring that the underlying
* stream is fully consumed.
*/
class TrimmingInputStream(
private val inputStream: InputStream,
private val trimSize: Int,
private val drain: Boolean = false
) : FilterInputStream(inputStream) {
private val trimBuffer = ByteArray(trimSize)
private var trimBufferSize: Int = 0
private var streamEnded = false
private var hasDrained = false
private var internalBuffer = ByteArray(4096)
private var internalBufferPosition: Int = 0
private var internalBufferSize: Int = 0
@Throws(IOException::class)
override fun read(): Int {
val singleByteBuffer = ByteArray(1)
val bytesRead = read(singleByteBuffer, 0, 1)
return if (bytesRead == -1) {
-1
} else {
singleByteBuffer[0].toInt() and 0xFF
}
}
@Throws(IOException::class)
override fun read(b: ByteArray): Int {
return read(b, 0, b.size)
}
/**
* The general strategy is that we do bulk reads into an internal buffer (just for perf reasons), and then when new bytes are requested,
* we fill up a buffer of size [trimSize] with the most recent bytes, and then return the oldest byte from that buffer.
*
* This ensures that the last [trimSize] bytes are never returned, while still returning the rest of the bytes.
*
* When we hit the end of the stream, we stop returning bytes.
*/
@Throws(IOException::class)
override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int {
if (streamEnded) {
return -1
}
if (trimSize == 0) {
return super.read(outputBuffer, outputOffset, readLength)
}
var outputCount = 0
while (outputCount < readLength) {
val nextByte = readNextByte()
if (nextByte == -1) {
streamEnded = true
drainIfNecessary()
break
}
if (trimBufferSize < trimSize) {
// Still filling the buffer - can't output anything yet
trimBuffer[trimBufferSize] = nextByte.toByte()
trimBufferSize++
} else {
// Buffer is full - output the oldest byte and add the new one
outputBuffer[outputOffset + outputCount] = trimBuffer[0]
outputCount++
// Shift buffer left and add new byte at the end. In practice, this is a tiny array and copies should be fast.
System.arraycopy(trimBuffer, 1, trimBuffer, 0, trimSize - 1)
trimBuffer[trimSize - 1] = nextByte.toByte()
}
}
return if (outputCount == 0) {
drainIfNecessary()
-1
} else {
outputCount
}
}
@Throws(IOException::class)
override fun skip(skipCount: Long): Long {
if (skipCount <= 0) return 0
var totalSkipped = 0L
val buffer = ByteArray(8192)
while (totalSkipped < skipCount) {
val toRead = min((skipCount - totalSkipped).toInt(), buffer.size)
val bytesRead = read(buffer, 0, toRead)
if (bytesRead == -1) {
break
}
totalSkipped += bytesRead
}
return totalSkipped
}
private fun readNextByte(): Int {
val hitEndOfStream = if (internalBufferPosition >= internalBufferSize) {
internalBufferPosition = 0
internalBufferSize = super.read(internalBuffer, 0, internalBuffer.size)
internalBufferSize <= 0
} else {
false
}
if (hitEndOfStream) {
drainIfNecessary()
return -1
}
return internalBuffer[internalBufferPosition++].toInt() and 0xFF
}
private fun drainIfNecessary() {
if (drain && !hasDrained) {
inputStream.drain()
hasDrained = true
}
}
}

View File

@@ -0,0 +1,140 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import assertk.assertThat
import assertk.assertions.isEqualTo
import org.junit.Test
import org.signal.core.util.readFully
import kotlin.math.min
import kotlin.random.Random
class TrimmingInputStreamTest {
@Test
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(75)
assertThat(data).isEqualTo(initialData.copyOfRange(0, 75))
}
@Test
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes - many sizes`() {
for (i in 1..100) {
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
val trimSize = min(arraySize, Random.nextInt(1024))
val initialData = testData(arraySize)
val innerStream = initialData.inputStream()
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(arraySize - trimSize)
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
}
}
@Test
fun `when I fully read the stream via a buffer with drain set, I should exclude the last trimSize bytes but still drain the remaining stream - many sizes`() {
for (i in 1..100) {
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
val trimSize = min(arraySize, Random.nextInt(1024))
val initialData = testData(arraySize)
val innerStream = initialData.inputStream()
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize, drain = true)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(arraySize - trimSize)
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
assertThat(innerStream.available()).isEqualTo(0)
}
}
@Test
fun `when I fully read the stream and the trimSize is greater than the stream length, I should get zero bytes`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 200)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(0)
}
@Test
fun `when I fully read the stream via a buffer with no trimSize, I should get all bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(100)
}
@Test
fun `when I fully read the stream one byte at a time, I should exclude the last trimSize bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertThat(count).isEqualTo(75)
}
@Test
fun `when I fully read the stream one byte at a time with no trimSize, I should get all bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertThat(count).isEqualTo(100)
}
@Test
fun `when I skip past the the trimSize, I should get -1`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
val skipCount = inputStream.skip(100)
val read = inputStream.read()
assertThat(skipCount).isEqualTo(75)
assertThat(read).isEqualTo(-1)
}
@Test
fun `when I skip, I should still truncate correctly afterwards`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
val skipCount = inputStream.skip(50)
val data = inputStream.readFully()
assertThat(skipCount).isEqualTo(50)
assertThat(data.size).isEqualTo(25)
}
@Test
fun `when I skip more than the remaining bytes, I still respect trimSize`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
val skipCount = inputStream.skip(100)
assertThat(skipCount).isEqualTo(75)
}
private fun testData(length: Int): ByteArray {
return ByteArray(length) { (it % 0xFF).toByte() }
}
}

View File

@@ -189,16 +189,9 @@ public class SignalServiceMessageReceiver {
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
long originalCipherLength = pointer.getSize()
.filter(s -> s > 0)
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
.orElse(0L);
return AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial,
archiveDestination,
originalCipherLength,
pointer.getSize().orElse(0),
pointer.getKey()
);
}

View File

@@ -8,6 +8,7 @@ package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.Base64
import org.signal.core.util.readNBytesOrThrow
import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.TrimmingInputStream
import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
@@ -134,38 +135,45 @@ object AttachmentCipherInputStream {
}
/**
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
* When you archive an attachment thumbnail, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
*
* This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it.
*
* @param incrementalDigest If null, incremental mac validation is disabled.
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
* Archive thumbnails are also special in that we:
* - don't know how long they are (meaning you'll get them back with padding at the end, image viewers are ok with this)
* - don't care about external integrity checks (we still validate the MACs)
*
* So there's some code duplication here just to avoid mucking up the reusable functions with special cases.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedThumbnail(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray
innerCombinedKeyMaterial: ByteArray
): InputStream {
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
val mac = initMac(keyMaterial.macKey)
val outerMac = initMac(archivedMediaKeyMaterial.macKey)
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength")
if (file.length() <= BLOCK_SIZE + outerMac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + outerMac.macLength} bytes, got ${file.length()}")
}
return create(
streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) },
streamLength = originalCipherTextLength,
plaintextLength = plaintextLength,
combinedKeyMaterial = combinedKeyMaterial,
integrityCheck = null,
incrementalDigest = null,
incrementalMacChunkSize = 0
)
FileInputStream(file).use { macVerificationStream ->
verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), outerMac, null)
}
val outerEncryptedStreamExcludingMac = LimitedInputStream(FileInputStream(file), maxBytes = file.length() - outerMac.macLength)
val outerCipher = createCipher(outerEncryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey)
val innerEncryptedStream = BetterCipherInputStream(outerEncryptedStreamExcludingMac, outerCipher)
val innerKeyMaterial = CombinedKeyMaterial.from(innerCombinedKeyMaterial)
val innerMac = initMac(innerKeyMaterial.macKey)
val innerEncryptedStreamWithMac = MacValidatingInputStream(innerEncryptedStream, innerMac)
val innerEncryptedStreamExcludingMac = TrimmingInputStream(innerEncryptedStreamWithMac, trimSize = innerMac.macLength, drain = true)
val innerCipher = createCipher(innerEncryptedStreamExcludingMac, innerKeyMaterial.aesKey)
return BetterCipherInputStream(innerEncryptedStreamExcludingMac, innerCipher)
}
/**

View File

@@ -0,0 +1,140 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.api.crypto
import org.jetbrains.annotations.VisibleForTesting
import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.TrimmingInputStream
import org.signal.libsignal.protocol.InvalidMessageException
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import java.security.MessageDigest
import javax.crypto.Mac
/**
* An InputStream that validates a MAC appended to the end of the stream data.
* This stream will not exclude the MAC from the data it reads, meaning that you may want to pair this with a [LimitedInputStream] or a [TrimmingInputStream]
* if you don't want to read that data to be a part of it.
*
* Important: The MAC is only validated once the stream has been fully read.
*
* @param inputStream The underlying InputStream to read from
* @param mac The Mac instance to use for validation
*/
class MacValidatingInputStream(
inputStream: InputStream,
private val mac: Mac
) : FilterInputStream(inputStream) {
private val macBuffer = ByteArray(mac.macLength)
private val macLength = mac.macLength
private var macBufferPosition = 0
private var streamEnded = false
@VisibleForTesting
var validationAttempted = false
private set
@Throws(IOException::class)
override fun read(): Int {
val singleByteBuffer = ByteArray(1)
val bytesRead = read(singleByteBuffer, 0, 1)
return if (bytesRead == -1) -1 else singleByteBuffer[0].toInt() and 0xFF
}
@Throws(IOException::class)
override fun read(b: ByteArray): Int {
return read(b, 0, b.size)
}
@Throws(IOException::class)
override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int {
if (streamEnded) {
return -1
}
val bytesRead = super.read(outputBuffer, outputOffset, readLength)
if (bytesRead == -1) {
// End of stream - check if we have enough data for MAC validation
if (macBufferPosition < macLength) {
throw InvalidMessageException("Stream ended before MAC could be read. Expected $macLength bytes, got $macBufferPosition")
}
validateMacAndMarkStreamEnded()
return -1
}
// If we've read more than `macLength` bytes, we can just snag the last `macLength` bytes and digest the rest
if (bytesRead >= macLength) {
// Before replacing the macBuffer, process any pre-existing data
if (macBufferPosition > 0) {
mac.update(macBuffer, 0, macBufferPosition)
macBufferPosition = 0
}
// Copy the last `macLength` bytes into the macBuffer
outputBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = outputOffset + bytesRead - macLength, endIndex = outputOffset + bytesRead)
macBufferPosition = macLength
// Update the mac with the bytes that are not part of the MAC
if (bytesRead > macLength) {
mac.update(outputBuffer, outputOffset, bytesRead - macLength)
}
} else {
val totalBytesAvailable = macBufferPosition + bytesRead
// If the new bytes we've read don't overflow the buffer, we can just append them, and none of them will be digested
if (totalBytesAvailable <= macLength) {
outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset, endIndex = outputOffset + bytesRead)
macBufferPosition = totalBytesAvailable
} else {
// If we have more bytes than we can hold in the buffer, keep the last `macLength` bytes and digest the rest
// We know that `bytesRead` is less than `macLength`, so we know all of `bytesRead` should go into the buffer
// And we know that the buffer usage + `bytesRead` is greater than `macLength`, so we're guaranteed to be able to digest the first chunk of the buffer.
// We also know that there can't possibly be 0 bytes in the buffer because of how the math of those conditions works out.
val bytesToDigest = totalBytesAvailable - macLength
val bytesOfBufferToDigest = minOf(macBufferPosition, bytesToDigest)
val bytesOfReadToDigest = bytesToDigest - bytesOfBufferToDigest
mac.update(macBuffer, 0, bytesOfBufferToDigest)
macBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = bytesOfBufferToDigest, endIndex = macBufferPosition)
macBufferPosition -= bytesOfBufferToDigest
if (bytesOfReadToDigest > 0) {
mac.update(outputBuffer, outputOffset, bytesOfReadToDigest)
}
val bytesOfReadRemaining = bytesRead - bytesOfReadToDigest
if (bytesOfReadRemaining > 0) {
outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset + bytesOfReadToDigest, endIndex = outputOffset + bytesRead)
macBufferPosition += bytesOfReadRemaining
}
}
}
return bytesRead
}
@Throws(InvalidMessageException::class)
private fun validateMacAndMarkStreamEnded() {
if (validationAttempted) {
return
}
validationAttempted = true
streamEnded = true
val calculatedMac = mac.doFinal()
if (!MessageDigest.isEqual(calculatedMac, macBuffer)) {
throw InvalidMessageException("MAC validation failed!")
}
}
private fun minOf(a: Int, b: Int): Int = if (a < b) a else b
}

View File

@@ -94,7 +94,7 @@ class AttachmentCipherTest {
)
}
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
val plaintextOutput = inputStream.readFully()
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -111,7 +111,7 @@ class AttachmentCipherTest {
val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
val plaintextOutput = inputStream.readFully()
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -137,7 +137,7 @@ class AttachmentCipherTest {
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
val plaintextOutput = inputStream.readFully()
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
@@ -327,7 +327,7 @@ class AttachmentCipherTest {
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
val plaintextOutput = decryptedStream.readFully()
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -382,37 +382,36 @@ class AttachmentCipherTest {
@Test
fun archive_encryptDecrypt_nonIncremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
}
@Test
fun archive_encryptDecrypt_incremental() {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
}
@Test
fun archive_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
@Test
fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() {
fun archive_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) {
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) {
private fun archive_encryptDecrypt(incremental: Boolean, fileSize: Int) {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental)
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
@@ -426,7 +425,7 @@ class AttachmentCipherTest {
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
val plaintextOutput = decryptedStream.readFully(autoClose = false)
val plaintextOutput = decryptedStream.readFully()
assertThat(plaintextOutput).isEqualTo(plaintextInput)
@@ -434,26 +433,56 @@ class AttachmentCipherTest {
}
@Test
fun archiveEncryptDecrypt_decryptFailOnBadMac() {
fun archiveThumbnail_encryptDecrypt_manyFileSizes() {
for (i in 0..99) {
archiveThumbnail_encryptDecrypt(fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
private fun archiveThumbnail_encryptDecrypt(fileSize: Int) {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = false)
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(outerKey)
val decryptedStream = AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
innerCombinedKeyMaterial = innerKey
)
val plaintextOutput = decryptedStream.readFully()
// We knowingly keep padding on thumbnails, so for the test to work, we strip the padding off, but check to make sure the sizes match up beforehand
assertThat(plaintextOutput.size).isEqualTo(PaddingInputStream.getPaddedSize(plaintextInput.size.toLong()).toInt())
assertThat(plaintextOutput.copyOfRange(fromIndex = 0, toIndex = plaintextInput.size)).isEqualTo(plaintextInput)
cipherFile.delete()
}
@Test
fun archiveEncryptDecrypt_decryptFailOnInnerBadMac() {
var cipherFile: File? = null
var hitCorrectException = false
try {
val innerKey = Util.getSecretBytes(64)
val badInnerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also {
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
}
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf(outerEncryptResult.ciphertext.size)
cipherFile = writeToFile(outerEncryptResult.ciphertext)
badMacOuterCiphertext[badMacOuterCiphertext.size - 1] = (badMacOuterCiphertext[badMacOuterCiphertext.size - 1] + 1).toByte()
cipherFile = writeToFile(badMacOuterCiphertext)
val keyMaterial = createMediaKeyMaterial(badInnerKey)
val keyMaterial = createMediaKeyMaterial(innerKey)
AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial = keyMaterial,
@@ -476,6 +505,122 @@ class AttachmentCipherTest {
Assert.assertTrue(hitCorrectException)
}
@Test
fun archiveEncryptDecrypt_decryptFailOnOuterMac() {
var cipherFile: File? = null
var hitCorrectException = false
try {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also {
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
}
cipherFile = writeToFile(badMacOuterCiphertext)
val keyMaterial = createMediaKeyMaterial(innerKey)
AttachmentCipherInputStream.createForArchivedMedia(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
plaintextLength = plaintextInput.size.toLong(),
combinedKeyMaterial = innerKey,
plaintextHash = innerEncryptResult.digest,
incrementalDigest = innerEncryptResult.incrementalDigest,
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
)
Assert.fail()
} catch (e: InvalidMessageException) {
hitCorrectException = true
} finally {
cipherFile?.delete()
}
Assert.assertTrue(hitCorrectException)
}
@Test
fun archiveThumbnailEncryptDecrypt_decryptFailOnInnerBadMac() {
var cipherFile: File? = null
var hitCorrectException = false
try {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also {
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
}
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false)
cipherFile = writeToFile(outerEncryptResult.ciphertext)
val keyMaterial = createMediaKeyMaterial(innerKey)
AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
innerCombinedKeyMaterial = innerKey
).readFully()
Assert.fail()
} catch (e: InvalidMessageException) {
hitCorrectException = true
} finally {
cipherFile?.delete()
}
Assert.assertTrue(hitCorrectException)
}
@Test
fun archiveThumbnailEncryptDecrypt_decryptFailOnOuterMac() {
var cipherFile: File? = null
var hitCorrectException = false
try {
val innerKey = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
val outerKey = Util.getSecretBytes(64)
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also {
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
}
cipherFile = writeToFile(badMacOuterCiphertext)
val keyMaterial = createMediaKeyMaterial(innerKey)
AttachmentCipherInputStream.createForArchivedThumbnail(
archivedMediaKeyMaterial = keyMaterial,
file = cipherFile,
innerCombinedKeyMaterial = innerKey
)
Assert.fail()
} catch (e: InvalidMessageException) {
hitCorrectException = true
} finally {
cipherFile?.delete()
}
Assert.assertTrue(hitCorrectException)
}
@Test
fun sticker_encryptDecrypt() {
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()

View File

@@ -0,0 +1,188 @@
package org.whispersystems.signalservice.api.crypto
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isTrue
import assertk.fail
import org.junit.Test
import org.signal.core.util.kibiBytes
import org.signal.core.util.mebiBytes
import org.signal.core.util.readFully
import org.signal.libsignal.protocol.InvalidMessageException
import org.whispersystems.signalservice.internal.util.Util
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.random.Random
class MacValidatingInputStreamTest {
@Test
fun `success - simple byte array read`() {
val data = "Hello, World!".toByteArray()
val key = Util.getSecretBytes(32)
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
val result = macValidatingStream.readFully()
assertThat(result).isEqualTo(dataWithMac)
macValidatingStream.close()
}
@Test
fun `success - byte by byte`() {
val data = "Hello, World!".toByteArray()
val key = Util.getSecretBytes(32)
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
val out = ByteArrayOutputStream()
var read = -1
while (macValidatingStream.read().also { read = it } != -1) {
out.write(read)
}
val result = out.toByteArray()
assertThat(result).isEqualTo(dataWithMac)
macValidatingStream.close()
}
@Test
fun `success - many different sizes`() {
for (i in 1..100) {
val data = Util.getSecretBytes(Random.nextLong(from = 256.kibiBytes.bytes, until = 2.mebiBytes.bytes).toInt())
val key = Util.getSecretBytes(32)
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
val result = macValidatingStream.readFully()
assertThat(result).isEqualTo(dataWithMac)
assertThat(macValidatingStream.validationAttempted).isTrue()
macValidatingStream.close()
}
}
@Test
fun `success - empty data`() {
val data = ByteArray(0)
val key = Util.getSecretBytes(32)
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
val result = macValidatingStream.readFully()
assertThat(result).isEqualTo(dataWithMac)
assertThat(macValidatingStream.validationAttempted).isTrue()
macValidatingStream.close()
}
@Test
fun `success - data exactly MAC length`() {
val key = Util.getSecretBytes(32)
val mac = createMac(key)
val macLength = mac.macLength
val data = ByteArray(macLength) { (it % 256).toByte() } // Data same size as MAC
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac2 = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac2)
val result = macValidatingStream.readFully()
assertThat(result).isEqualTo(dataWithMac)
assertThat(macValidatingStream.validationAttempted).isTrue()
macValidatingStream.close()
}
@Test
fun `success - multiple reads after end of stream`() {
val data = "Test multiple reads after EOF".toByteArray()
val key = Util.getSecretBytes(32)
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
val result = macValidatingStream.readFully()
// Multiple calls to read() after EOF should return -1
assertThat(macValidatingStream.read()).isEqualTo(-1)
assertThat(macValidatingStream.read()).isEqualTo(-1)
assertThat(macValidatingStream.read()).isEqualTo(-1)
assertThat(result).isEqualTo(dataWithMac)
assertThat(macValidatingStream.validationAttempted).isTrue()
macValidatingStream.close()
}
@Test
fun `failure - invalid MAC`() {
val data = "Hello, World!".toByteArray()
val key = Util.getSecretBytes(32)
val wrongKey = ByteArray(32) { 24 }
val dataWithMac = createDataWithMac(data, key)
val inputStream = ByteArrayInputStream(dataWithMac)
val mac = createMac(wrongKey) // Wrong key
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
try {
macValidatingStream.readFully()
fail("Expected InvalidMessageException to be thrown")
} catch (e: InvalidMessageException) {
assertThat(e.message).isEqualTo("MAC validation failed!")
} finally {
macValidatingStream.close()
}
}
@Test
fun `failure - insufficient data for MAC`() {
val key = Util.getSecretBytes(32)
val mac = createMac(key)
val macLength = mac.macLength
val insufficientData = ByteArray(macLength - 1) { 5 } // Less than MAC length
val inputStream = ByteArrayInputStream(insufficientData)
val mac2 = createMac(key)
val macValidatingStream = MacValidatingInputStream(inputStream, mac2)
try {
macValidatingStream.readFully()
fail("Expected InvalidMessageException to be thrown")
} catch (e: InvalidMessageException) {
assertThat(e.message).isEqualTo("Stream ended before MAC could be read. Expected $macLength bytes, got ${insufficientData.size}")
} finally {
macValidatingStream.close()
}
}
private fun createMac(key: ByteArray): Mac {
val mac = Mac.getInstance("HmacSHA256")
mac.init(SecretKeySpec(key, "HmacSHA256"))
return mac
}
private fun createDataWithMac(data: ByteArray, key: ByteArray): ByteArray {
val mac = createMac(key)
val macBytes = mac.doFinal(data)
return data + macBytes
}
}