mirror of
https://github.com/signalapp/Signal-Android.git
synced 2025-12-24 13:08:46 +00:00
Fix archive thumbnail decryption.
This commit is contained in:
committed by
Cody Henthorne
parent
b1063f69f9
commit
c0340be3ce
@@ -131,7 +131,7 @@ private fun BackupsSettingsContent(
|
||||
item {
|
||||
Column(modifier = Modifier.padding(horizontal = dimensionResource(id = org.signal.core.ui.R.dimen.gutter))) {
|
||||
Text(
|
||||
text = "INTERNAL ONLY",
|
||||
text = "ALPHA ONLY",
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
)
|
||||
Text(
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.signal.core.util.stream
|
||||
|
||||
import org.signal.core.util.drain
|
||||
import java.io.FilterInputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import kotlin.math.min
|
||||
|
||||
/**
|
||||
* An input stream that will read all but the last [trimSize] bytes of the stream.
|
||||
*
|
||||
* Important: we have to keep a buffer of size [trimSize] to ensure that we can avoid reading it.
|
||||
* That means you should avoid using this for very large values of [trimSize].
|
||||
*
|
||||
* @param drain If true, the stream will be drained when it reaches the end (but bytes won't be returned). This is useful for ensuring that the underlying
|
||||
* stream is fully consumed.
|
||||
*/
|
||||
class TrimmingInputStream(
|
||||
private val inputStream: InputStream,
|
||||
private val trimSize: Int,
|
||||
private val drain: Boolean = false
|
||||
) : FilterInputStream(inputStream) {
|
||||
|
||||
private val trimBuffer = ByteArray(trimSize)
|
||||
private var trimBufferSize: Int = 0
|
||||
private var streamEnded = false
|
||||
private var hasDrained = false
|
||||
|
||||
private var internalBuffer = ByteArray(4096)
|
||||
private var internalBufferPosition: Int = 0
|
||||
private var internalBufferSize: Int = 0
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(): Int {
|
||||
val singleByteBuffer = ByteArray(1)
|
||||
val bytesRead = read(singleByteBuffer, 0, 1)
|
||||
return if (bytesRead == -1) {
|
||||
-1
|
||||
} else {
|
||||
singleByteBuffer[0].toInt() and 0xFF
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(b: ByteArray): Int {
|
||||
return read(b, 0, b.size)
|
||||
}
|
||||
|
||||
/**
|
||||
* The general strategy is that we do bulk reads into an internal buffer (just for perf reasons), and then when new bytes are requested,
|
||||
* we fill up a buffer of size [trimSize] with the most recent bytes, and then return the oldest byte from that buffer.
|
||||
*
|
||||
* This ensures that the last [trimSize] bytes are never returned, while still returning the rest of the bytes.
|
||||
*
|
||||
* When we hit the end of the stream, we stop returning bytes.
|
||||
*/
|
||||
@Throws(IOException::class)
|
||||
override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int {
|
||||
if (streamEnded) {
|
||||
return -1
|
||||
}
|
||||
|
||||
if (trimSize == 0) {
|
||||
return super.read(outputBuffer, outputOffset, readLength)
|
||||
}
|
||||
|
||||
var outputCount = 0
|
||||
|
||||
while (outputCount < readLength) {
|
||||
val nextByte = readNextByte()
|
||||
|
||||
if (nextByte == -1) {
|
||||
streamEnded = true
|
||||
drainIfNecessary()
|
||||
break
|
||||
}
|
||||
|
||||
if (trimBufferSize < trimSize) {
|
||||
// Still filling the buffer - can't output anything yet
|
||||
trimBuffer[trimBufferSize] = nextByte.toByte()
|
||||
trimBufferSize++
|
||||
} else {
|
||||
// Buffer is full - output the oldest byte and add the new one
|
||||
outputBuffer[outputOffset + outputCount] = trimBuffer[0]
|
||||
outputCount++
|
||||
|
||||
// Shift buffer left and add new byte at the end. In practice, this is a tiny array and copies should be fast.
|
||||
System.arraycopy(trimBuffer, 1, trimBuffer, 0, trimSize - 1)
|
||||
trimBuffer[trimSize - 1] = nextByte.toByte()
|
||||
}
|
||||
}
|
||||
|
||||
return if (outputCount == 0) {
|
||||
drainIfNecessary()
|
||||
-1
|
||||
} else {
|
||||
outputCount
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun skip(skipCount: Long): Long {
|
||||
if (skipCount <= 0) return 0
|
||||
|
||||
var totalSkipped = 0L
|
||||
val buffer = ByteArray(8192)
|
||||
|
||||
while (totalSkipped < skipCount) {
|
||||
val toRead = min((skipCount - totalSkipped).toInt(), buffer.size)
|
||||
val bytesRead = read(buffer, 0, toRead)
|
||||
if (bytesRead == -1) {
|
||||
break
|
||||
}
|
||||
totalSkipped += bytesRead
|
||||
}
|
||||
|
||||
return totalSkipped
|
||||
}
|
||||
|
||||
private fun readNextByte(): Int {
|
||||
val hitEndOfStream = if (internalBufferPosition >= internalBufferSize) {
|
||||
internalBufferPosition = 0
|
||||
internalBufferSize = super.read(internalBuffer, 0, internalBuffer.size)
|
||||
internalBufferSize <= 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
if (hitEndOfStream) {
|
||||
drainIfNecessary()
|
||||
return -1
|
||||
}
|
||||
|
||||
return internalBuffer[internalBufferPosition++].toInt() and 0xFF
|
||||
}
|
||||
|
||||
private fun drainIfNecessary() {
|
||||
if (drain && !hasDrained) {
|
||||
inputStream.drain()
|
||||
hasDrained = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.signal.core.util.stream
|
||||
|
||||
import assertk.assertThat
|
||||
import assertk.assertions.isEqualTo
|
||||
import org.junit.Test
|
||||
import org.signal.core.util.readFully
|
||||
import kotlin.math.min
|
||||
import kotlin.random.Random
|
||||
|
||||
class TrimmingInputStreamTest {
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes`() {
|
||||
val initialData = testData(100)
|
||||
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(data.size).isEqualTo(75)
|
||||
assertThat(data).isEqualTo(initialData.copyOfRange(0, 75))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes - many sizes`() {
|
||||
for (i in 1..100) {
|
||||
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
|
||||
val trimSize = min(arraySize, Random.nextInt(1024))
|
||||
|
||||
val initialData = testData(arraySize)
|
||||
val innerStream = initialData.inputStream()
|
||||
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(data.size).isEqualTo(arraySize - trimSize)
|
||||
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream via a buffer with drain set, I should exclude the last trimSize bytes but still drain the remaining stream - many sizes`() {
|
||||
for (i in 1..100) {
|
||||
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
|
||||
val trimSize = min(arraySize, Random.nextInt(1024))
|
||||
|
||||
val initialData = testData(arraySize)
|
||||
val innerStream = initialData.inputStream()
|
||||
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize, drain = true)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(data.size).isEqualTo(arraySize - trimSize)
|
||||
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
|
||||
assertThat(innerStream.available()).isEqualTo(0)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream and the trimSize is greater than the stream length, I should get zero bytes`() {
|
||||
val initialData = testData(100)
|
||||
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 200)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(data.size).isEqualTo(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream via a buffer with no trimSize, I should get all bytes`() {
|
||||
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(data.size).isEqualTo(100)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream one byte at a time, I should exclude the last trimSize bytes`() {
|
||||
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
|
||||
|
||||
var count = 0
|
||||
var lastRead = inputStream.read()
|
||||
while (lastRead != -1) {
|
||||
count++
|
||||
lastRead = inputStream.read()
|
||||
}
|
||||
|
||||
assertThat(count).isEqualTo(75)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I fully read the stream one byte at a time with no trimSize, I should get all bytes`() {
|
||||
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
|
||||
|
||||
var count = 0
|
||||
var lastRead = inputStream.read()
|
||||
while (lastRead != -1) {
|
||||
count++
|
||||
lastRead = inputStream.read()
|
||||
}
|
||||
|
||||
assertThat(count).isEqualTo(100)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I skip past the the trimSize, I should get -1`() {
|
||||
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
|
||||
|
||||
val skipCount = inputStream.skip(100)
|
||||
val read = inputStream.read()
|
||||
|
||||
assertThat(skipCount).isEqualTo(75)
|
||||
assertThat(read).isEqualTo(-1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I skip, I should still truncate correctly afterwards`() {
|
||||
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
|
||||
|
||||
val skipCount = inputStream.skip(50)
|
||||
val data = inputStream.readFully()
|
||||
|
||||
assertThat(skipCount).isEqualTo(50)
|
||||
assertThat(data.size).isEqualTo(25)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `when I skip more than the remaining bytes, I still respect trimSize`() {
|
||||
val initialData = testData(100)
|
||||
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
|
||||
|
||||
val skipCount = inputStream.skip(100)
|
||||
|
||||
assertThat(skipCount).isEqualTo(75)
|
||||
}
|
||||
|
||||
private fun testData(length: Int): ByteArray {
|
||||
return ByteArray(length) { (it % 0xFF).toByte() }
|
||||
}
|
||||
}
|
||||
@@ -189,16 +189,9 @@ public class SignalServiceMessageReceiver {
|
||||
|
||||
socket.retrieveAttachment(pointer.getCdnNumber(), readCredentialHeaders, pointer.getRemoteId(), archiveDestination, maxSizeBytes, listener);
|
||||
|
||||
long originalCipherLength = pointer.getSize()
|
||||
.filter(s -> s > 0)
|
||||
.map(s -> AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(s)))
|
||||
.orElse(0L);
|
||||
|
||||
return AttachmentCipherInputStream.createForArchivedThumbnail(
|
||||
archivedMediaKeyMaterial,
|
||||
archiveDestination,
|
||||
originalCipherLength,
|
||||
pointer.getSize().orElse(0),
|
||||
pointer.getKey()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package org.whispersystems.signalservice.api.crypto
|
||||
import org.signal.core.util.Base64
|
||||
import org.signal.core.util.readNBytesOrThrow
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import org.signal.core.util.stream.TrimmingInputStream
|
||||
import org.signal.libsignal.protocol.InvalidMessageException
|
||||
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
|
||||
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
|
||||
@@ -134,38 +135,45 @@ object AttachmentCipherInputStream {
|
||||
}
|
||||
|
||||
/**
|
||||
* When you archive an attachment, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
|
||||
* When you archive an attachment thumbnail, you give the server an encrypted attachment, and the server wraps it in *another* layer of encryption.
|
||||
*
|
||||
* This creates a stream decrypt both the inner and outer layers of an archived attachment at the same time by basically double-decrypting it.
|
||||
*
|
||||
* @param incrementalDigest If null, incremental mac validation is disabled.
|
||||
* @param incrementalMacChunkSize If 0, incremental mac validation is disabled.
|
||||
* Archive thumbnails are also special in that we:
|
||||
* - don't know how long they are (meaning you'll get them back with padding at the end, image viewers are ok with this)
|
||||
* - don't care about external integrity checks (we still validate the MACs)
|
||||
*
|
||||
* So there's some code duplication here just to avoid mucking up the reusable functions with special cases.
|
||||
*/
|
||||
@JvmStatic
|
||||
@Throws(InvalidMessageException::class, IOException::class)
|
||||
fun createForArchivedThumbnail(
|
||||
archivedMediaKeyMaterial: MediaKeyMaterial,
|
||||
file: File,
|
||||
originalCipherTextLength: Long,
|
||||
plaintextLength: Long,
|
||||
combinedKeyMaterial: ByteArray
|
||||
innerCombinedKeyMaterial: ByteArray
|
||||
): InputStream {
|
||||
val keyMaterial = CombinedKeyMaterial.from(combinedKeyMaterial)
|
||||
val mac = initMac(keyMaterial.macKey)
|
||||
val outerMac = initMac(archivedMediaKeyMaterial.macKey)
|
||||
|
||||
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + mac.macLength} bytes, got $originalCipherTextLength")
|
||||
if (file.length() <= BLOCK_SIZE + outerMac.macLength) {
|
||||
throw InvalidMessageException("Message shorter than crypto overhead! Expected at least ${BLOCK_SIZE + outerMac.macLength} bytes, got ${file.length()}")
|
||||
}
|
||||
|
||||
return create(
|
||||
streamSupplier = { createForArchivedMediaOuterLayer(archivedMediaKeyMaterial, file, originalCipherTextLength) },
|
||||
streamLength = originalCipherTextLength,
|
||||
plaintextLength = plaintextLength,
|
||||
combinedKeyMaterial = combinedKeyMaterial,
|
||||
integrityCheck = null,
|
||||
incrementalDigest = null,
|
||||
incrementalMacChunkSize = 0
|
||||
)
|
||||
FileInputStream(file).use { macVerificationStream ->
|
||||
verifyMacAndMaybeEncryptedDigest(macVerificationStream, file.length(), outerMac, null)
|
||||
}
|
||||
|
||||
val outerEncryptedStreamExcludingMac = LimitedInputStream(FileInputStream(file), maxBytes = file.length() - outerMac.macLength)
|
||||
val outerCipher = createCipher(outerEncryptedStreamExcludingMac, archivedMediaKeyMaterial.aesKey)
|
||||
val innerEncryptedStream = BetterCipherInputStream(outerEncryptedStreamExcludingMac, outerCipher)
|
||||
|
||||
val innerKeyMaterial = CombinedKeyMaterial.from(innerCombinedKeyMaterial)
|
||||
val innerMac = initMac(innerKeyMaterial.macKey)
|
||||
|
||||
val innerEncryptedStreamWithMac = MacValidatingInputStream(innerEncryptedStream, innerMac)
|
||||
val innerEncryptedStreamExcludingMac = TrimmingInputStream(innerEncryptedStreamWithMac, trimSize = innerMac.macLength, drain = true)
|
||||
val innerCipher = createCipher(innerEncryptedStreamExcludingMac, innerKeyMaterial.aesKey)
|
||||
|
||||
return BetterCipherInputStream(innerEncryptedStreamExcludingMac, innerCipher)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.signalservice.api.crypto
|
||||
|
||||
import org.jetbrains.annotations.VisibleForTesting
|
||||
import org.signal.core.util.stream.LimitedInputStream
|
||||
import org.signal.core.util.stream.TrimmingInputStream
|
||||
import org.signal.libsignal.protocol.InvalidMessageException
|
||||
import java.io.FilterInputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.security.MessageDigest
|
||||
import javax.crypto.Mac
|
||||
|
||||
/**
|
||||
* An InputStream that validates a MAC appended to the end of the stream data.
|
||||
* This stream will not exclude the MAC from the data it reads, meaning that you may want to pair this with a [LimitedInputStream] or a [TrimmingInputStream]
|
||||
* if you don't want to read that data to be a part of it.
|
||||
*
|
||||
* Important: The MAC is only validated once the stream has been fully read.
|
||||
*
|
||||
* @param inputStream The underlying InputStream to read from
|
||||
* @param mac The Mac instance to use for validation
|
||||
*/
|
||||
class MacValidatingInputStream(
|
||||
inputStream: InputStream,
|
||||
private val mac: Mac
|
||||
) : FilterInputStream(inputStream) {
|
||||
|
||||
private val macBuffer = ByteArray(mac.macLength)
|
||||
private val macLength = mac.macLength
|
||||
private var macBufferPosition = 0
|
||||
private var streamEnded = false
|
||||
|
||||
@VisibleForTesting
|
||||
var validationAttempted = false
|
||||
private set
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(): Int {
|
||||
val singleByteBuffer = ByteArray(1)
|
||||
val bytesRead = read(singleByteBuffer, 0, 1)
|
||||
return if (bytesRead == -1) -1 else singleByteBuffer[0].toInt() and 0xFF
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(b: ByteArray): Int {
|
||||
return read(b, 0, b.size)
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int {
|
||||
if (streamEnded) {
|
||||
return -1
|
||||
}
|
||||
|
||||
val bytesRead = super.read(outputBuffer, outputOffset, readLength)
|
||||
|
||||
if (bytesRead == -1) {
|
||||
// End of stream - check if we have enough data for MAC validation
|
||||
if (macBufferPosition < macLength) {
|
||||
throw InvalidMessageException("Stream ended before MAC could be read. Expected $macLength bytes, got $macBufferPosition")
|
||||
}
|
||||
validateMacAndMarkStreamEnded()
|
||||
return -1
|
||||
}
|
||||
|
||||
// If we've read more than `macLength` bytes, we can just snag the last `macLength` bytes and digest the rest
|
||||
if (bytesRead >= macLength) {
|
||||
// Before replacing the macBuffer, process any pre-existing data
|
||||
if (macBufferPosition > 0) {
|
||||
mac.update(macBuffer, 0, macBufferPosition)
|
||||
macBufferPosition = 0
|
||||
}
|
||||
|
||||
// Copy the last `macLength` bytes into the macBuffer
|
||||
outputBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = outputOffset + bytesRead - macLength, endIndex = outputOffset + bytesRead)
|
||||
macBufferPosition = macLength
|
||||
|
||||
// Update the mac with the bytes that are not part of the MAC
|
||||
if (bytesRead > macLength) {
|
||||
mac.update(outputBuffer, outputOffset, bytesRead - macLength)
|
||||
}
|
||||
} else {
|
||||
val totalBytesAvailable = macBufferPosition + bytesRead
|
||||
|
||||
// If the new bytes we've read don't overflow the buffer, we can just append them, and none of them will be digested
|
||||
if (totalBytesAvailable <= macLength) {
|
||||
outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset, endIndex = outputOffset + bytesRead)
|
||||
macBufferPosition = totalBytesAvailable
|
||||
} else {
|
||||
// If we have more bytes than we can hold in the buffer, keep the last `macLength` bytes and digest the rest
|
||||
|
||||
// We know that `bytesRead` is less than `macLength`, so we know all of `bytesRead` should go into the buffer
|
||||
// And we know that the buffer usage + `bytesRead` is greater than `macLength`, so we're guaranteed to be able to digest the first chunk of the buffer.
|
||||
// We also know that there can't possibly be 0 bytes in the buffer because of how the math of those conditions works out.
|
||||
|
||||
val bytesToDigest = totalBytesAvailable - macLength
|
||||
|
||||
val bytesOfBufferToDigest = minOf(macBufferPosition, bytesToDigest)
|
||||
val bytesOfReadToDigest = bytesToDigest - bytesOfBufferToDigest
|
||||
|
||||
mac.update(macBuffer, 0, bytesOfBufferToDigest)
|
||||
macBuffer.copyInto(destination = macBuffer, destinationOffset = 0, startIndex = bytesOfBufferToDigest, endIndex = macBufferPosition)
|
||||
macBufferPosition -= bytesOfBufferToDigest
|
||||
|
||||
if (bytesOfReadToDigest > 0) {
|
||||
mac.update(outputBuffer, outputOffset, bytesOfReadToDigest)
|
||||
}
|
||||
|
||||
val bytesOfReadRemaining = bytesRead - bytesOfReadToDigest
|
||||
if (bytesOfReadRemaining > 0) {
|
||||
outputBuffer.copyInto(destination = macBuffer, destinationOffset = macBufferPosition, startIndex = outputOffset + bytesOfReadToDigest, endIndex = outputOffset + bytesRead)
|
||||
macBufferPosition += bytesOfReadRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead
|
||||
}
|
||||
|
||||
@Throws(InvalidMessageException::class)
|
||||
private fun validateMacAndMarkStreamEnded() {
|
||||
if (validationAttempted) {
|
||||
return
|
||||
}
|
||||
validationAttempted = true
|
||||
streamEnded = true
|
||||
|
||||
val calculatedMac = mac.doFinal()
|
||||
if (!MessageDigest.isEqual(calculatedMac, macBuffer)) {
|
||||
throw InvalidMessageException("MAC validation failed!")
|
||||
}
|
||||
}
|
||||
|
||||
private fun minOf(a: Int, b: Int): Int = if (a < b) a else b
|
||||
}
|
||||
@@ -94,7 +94,7 @@ class AttachmentCipherTest {
|
||||
)
|
||||
}
|
||||
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
|
||||
val plaintextOutput = inputStream.readFully(autoClose = false)
|
||||
val plaintextOutput = inputStream.readFully()
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
|
||||
@@ -111,7 +111,7 @@ class AttachmentCipherTest {
|
||||
|
||||
val integrityCheck = IntegrityCheck.forPlaintextHash(plaintextHash)
|
||||
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
|
||||
val plaintextOutput = inputStream.readFully(autoClose = false)
|
||||
val plaintextOutput = inputStream.readFully()
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
|
||||
@@ -137,7 +137,7 @@ class AttachmentCipherTest {
|
||||
|
||||
val integrityCheck = IntegrityCheck.forEncryptedDigest(encryptResult.digest)
|
||||
val inputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, integrityCheck, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
|
||||
val plaintextOutput = inputStream.readFully(autoClose = false)
|
||||
val plaintextOutput = inputStream.readFully()
|
||||
|
||||
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
|
||||
|
||||
@@ -327,7 +327,7 @@ class AttachmentCipherTest {
|
||||
incrementalDigest = innerEncryptResult.incrementalDigest,
|
||||
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
|
||||
)
|
||||
val plaintextOutput = decryptedStream.readFully(autoClose = false)
|
||||
val plaintextOutput = decryptedStream.readFully()
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
|
||||
@@ -382,37 +382,36 @@ class AttachmentCipherTest {
|
||||
|
||||
@Test
|
||||
fun archive_encryptDecrypt_nonIncremental() {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
|
||||
archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archive_encryptDecrypt_incremental() {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
|
||||
archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archive_encryptDecrypt_nonIncremental_manyFileSizes() {
|
||||
for (i in 0..99) {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
archive_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveInnerAndOuter_encryptDecrypt_incremental_manyFileSizes() {
|
||||
fun archive_encryptDecrypt_incremental_manyFileSizes() {
|
||||
// Designed to stress the various boundary conditions of reading the final mac
|
||||
for (i in 0..99) {
|
||||
archiveInnerAndOuter_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
archive_encryptDecrypt(incremental = true, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
}
|
||||
}
|
||||
|
||||
private fun archiveInnerAndOuter_encryptDecrypt(incremental: Boolean, fileSize: Int) {
|
||||
private fun archive_encryptDecrypt(incremental: Boolean, fileSize: Int) {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(fileSize)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, incremental)
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad
|
||||
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(outerKey)
|
||||
@@ -426,7 +425,7 @@ class AttachmentCipherTest {
|
||||
incrementalDigest = innerEncryptResult.incrementalDigest,
|
||||
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
|
||||
)
|
||||
val plaintextOutput = decryptedStream.readFully(autoClose = false)
|
||||
val plaintextOutput = decryptedStream.readFully()
|
||||
|
||||
assertThat(plaintextOutput).isEqualTo(plaintextInput)
|
||||
|
||||
@@ -434,26 +433,56 @@ class AttachmentCipherTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveEncryptDecrypt_decryptFailOnBadMac() {
|
||||
fun archiveThumbnail_encryptDecrypt_manyFileSizes() {
|
||||
for (i in 0..99) {
|
||||
archiveThumbnail_encryptDecrypt(fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
|
||||
}
|
||||
}
|
||||
|
||||
private fun archiveThumbnail_encryptDecrypt(fileSize: Int) {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(fileSize)
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = false)
|
||||
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, withIncremental = false, padded = false) // Server doesn't pad
|
||||
val cipherFile = writeToFile(outerEncryptResult.ciphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(outerKey)
|
||||
val decryptedStream = AttachmentCipherInputStream.createForArchivedThumbnail(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
file = cipherFile,
|
||||
innerCombinedKeyMaterial = innerKey
|
||||
)
|
||||
val plaintextOutput = decryptedStream.readFully()
|
||||
|
||||
// We knowingly keep padding on thumbnails, so for the test to work, we strip the padding off, but check to make sure the sizes match up beforehand
|
||||
assertThat(plaintextOutput.size).isEqualTo(PaddingInputStream.getPaddedSize(plaintextInput.size.toLong()).toInt())
|
||||
assertThat(plaintextOutput.copyOfRange(fromIndex = 0, toIndex = plaintextInput.size)).isEqualTo(plaintextInput)
|
||||
|
||||
cipherFile.delete()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveEncryptDecrypt_decryptFailOnInnerBadMac() {
|
||||
var cipherFile: File? = null
|
||||
var hitCorrectException = false
|
||||
|
||||
try {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val badInnerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true)
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
|
||||
val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also {
|
||||
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
|
||||
}
|
||||
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false)
|
||||
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
|
||||
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf(outerEncryptResult.ciphertext.size)
|
||||
cipherFile = writeToFile(outerEncryptResult.ciphertext)
|
||||
|
||||
badMacOuterCiphertext[badMacOuterCiphertext.size - 1] = (badMacOuterCiphertext[badMacOuterCiphertext.size - 1] + 1).toByte()
|
||||
|
||||
cipherFile = writeToFile(badMacOuterCiphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(badInnerKey)
|
||||
val keyMaterial = createMediaKeyMaterial(innerKey)
|
||||
|
||||
AttachmentCipherInputStream.createForArchivedMedia(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
@@ -476,6 +505,122 @@ class AttachmentCipherTest {
|
||||
Assert.assertTrue(hitCorrectException)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveEncryptDecrypt_decryptFailOnOuterMac() {
|
||||
var cipherFile: File? = null
|
||||
var hitCorrectException = false
|
||||
|
||||
try {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
|
||||
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also {
|
||||
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
|
||||
}
|
||||
|
||||
cipherFile = writeToFile(badMacOuterCiphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(innerKey)
|
||||
|
||||
AttachmentCipherInputStream.createForArchivedMedia(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
file = cipherFile,
|
||||
originalCipherTextLength = innerEncryptResult.ciphertext.size.toLong(),
|
||||
plaintextLength = plaintextInput.size.toLong(),
|
||||
combinedKeyMaterial = innerKey,
|
||||
plaintextHash = innerEncryptResult.digest,
|
||||
incrementalDigest = innerEncryptResult.incrementalDigest,
|
||||
incrementalMacChunkSize = innerEncryptResult.chunkSizeChoice
|
||||
)
|
||||
|
||||
Assert.fail()
|
||||
} catch (e: InvalidMessageException) {
|
||||
hitCorrectException = true
|
||||
} finally {
|
||||
cipherFile?.delete()
|
||||
}
|
||||
|
||||
Assert.assertTrue(hitCorrectException)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveThumbnailEncryptDecrypt_decryptFailOnInnerBadMac() {
|
||||
var cipherFile: File? = null
|
||||
var hitCorrectException = false
|
||||
|
||||
try {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
|
||||
val badMacInnerCipherText = innerEncryptResult.ciphertext.copyOf().also {
|
||||
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
|
||||
}
|
||||
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
val outerEncryptResult = encryptData(badMacInnerCipherText, outerKey, false)
|
||||
|
||||
cipherFile = writeToFile(outerEncryptResult.ciphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(innerKey)
|
||||
|
||||
AttachmentCipherInputStream.createForArchivedThumbnail(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
file = cipherFile,
|
||||
innerCombinedKeyMaterial = innerKey
|
||||
).readFully()
|
||||
|
||||
Assert.fail()
|
||||
} catch (e: InvalidMessageException) {
|
||||
hitCorrectException = true
|
||||
} finally {
|
||||
cipherFile?.delete()
|
||||
}
|
||||
|
||||
Assert.assertTrue(hitCorrectException)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun archiveThumbnailEncryptDecrypt_decryptFailOnOuterMac() {
|
||||
var cipherFile: File? = null
|
||||
var hitCorrectException = false
|
||||
|
||||
try {
|
||||
val innerKey = Util.getSecretBytes(64)
|
||||
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
|
||||
|
||||
val innerEncryptResult = encryptData(plaintextInput, innerKey, withIncremental = true, padded = false) // Server doesn't pad
|
||||
val outerKey = Util.getSecretBytes(64)
|
||||
|
||||
val outerEncryptResult = encryptData(innerEncryptResult.ciphertext, outerKey, false)
|
||||
val badMacOuterCiphertext = outerEncryptResult.ciphertext.copyOf().also {
|
||||
it[it.size - 1] = (it[it.size - 1] + 1).toByte()
|
||||
}
|
||||
|
||||
cipherFile = writeToFile(badMacOuterCiphertext)
|
||||
|
||||
val keyMaterial = createMediaKeyMaterial(innerKey)
|
||||
|
||||
AttachmentCipherInputStream.createForArchivedThumbnail(
|
||||
archivedMediaKeyMaterial = keyMaterial,
|
||||
file = cipherFile,
|
||||
innerCombinedKeyMaterial = innerKey
|
||||
)
|
||||
|
||||
Assert.fail()
|
||||
} catch (e: InvalidMessageException) {
|
||||
hitCorrectException = true
|
||||
} finally {
|
||||
cipherFile?.delete()
|
||||
}
|
||||
|
||||
Assert.assertTrue(hitCorrectException)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun sticker_encryptDecrypt() {
|
||||
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
package org.whispersystems.signalservice.api.crypto
|
||||
|
||||
import assertk.assertThat
|
||||
import assertk.assertions.isEqualTo
|
||||
import assertk.assertions.isTrue
|
||||
import assertk.fail
|
||||
import org.junit.Test
|
||||
import org.signal.core.util.kibiBytes
|
||||
import org.signal.core.util.mebiBytes
|
||||
import org.signal.core.util.readFully
|
||||
import org.signal.libsignal.protocol.InvalidMessageException
|
||||
import org.whispersystems.signalservice.internal.util.Util
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import javax.crypto.Mac
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
import kotlin.random.Random
|
||||
|
||||
class MacValidatingInputStreamTest {
|
||||
|
||||
@Test
|
||||
fun `success - simple byte array read`() {
|
||||
val data = "Hello, World!".toByteArray()
|
||||
val key = Util.getSecretBytes(32)
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
val result = macValidatingStream.readFully()
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
macValidatingStream.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `success - byte by byte`() {
|
||||
val data = "Hello, World!".toByteArray()
|
||||
val key = Util.getSecretBytes(32)
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
val out = ByteArrayOutputStream()
|
||||
var read = -1
|
||||
while (macValidatingStream.read().also { read = it } != -1) {
|
||||
out.write(read)
|
||||
}
|
||||
val result = out.toByteArray()
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
macValidatingStream.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `success - many different sizes`() {
|
||||
for (i in 1..100) {
|
||||
val data = Util.getSecretBytes(Random.nextLong(from = 256.kibiBytes.bytes, until = 2.mebiBytes.bytes).toInt())
|
||||
val key = Util.getSecretBytes(32)
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
val result = macValidatingStream.readFully()
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
assertThat(macValidatingStream.validationAttempted).isTrue()
|
||||
macValidatingStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `success - empty data`() {
|
||||
val data = ByteArray(0)
|
||||
val key = Util.getSecretBytes(32)
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
val result = macValidatingStream.readFully()
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
assertThat(macValidatingStream.validationAttempted).isTrue()
|
||||
macValidatingStream.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `success - data exactly MAC length`() {
|
||||
val key = Util.getSecretBytes(32)
|
||||
val mac = createMac(key)
|
||||
val macLength = mac.macLength
|
||||
val data = ByteArray(macLength) { (it % 256).toByte() } // Data same size as MAC
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac2 = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac2)
|
||||
|
||||
val result = macValidatingStream.readFully()
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
assertThat(macValidatingStream.validationAttempted).isTrue()
|
||||
macValidatingStream.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `success - multiple reads after end of stream`() {
|
||||
val data = "Test multiple reads after EOF".toByteArray()
|
||||
val key = Util.getSecretBytes(32)
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
val result = macValidatingStream.readFully()
|
||||
|
||||
// Multiple calls to read() after EOF should return -1
|
||||
assertThat(macValidatingStream.read()).isEqualTo(-1)
|
||||
assertThat(macValidatingStream.read()).isEqualTo(-1)
|
||||
assertThat(macValidatingStream.read()).isEqualTo(-1)
|
||||
|
||||
assertThat(result).isEqualTo(dataWithMac)
|
||||
assertThat(macValidatingStream.validationAttempted).isTrue()
|
||||
macValidatingStream.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `failure - invalid MAC`() {
|
||||
val data = "Hello, World!".toByteArray()
|
||||
val key = Util.getSecretBytes(32)
|
||||
val wrongKey = ByteArray(32) { 24 }
|
||||
val dataWithMac = createDataWithMac(data, key)
|
||||
|
||||
val inputStream = ByteArrayInputStream(dataWithMac)
|
||||
val mac = createMac(wrongKey) // Wrong key
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac)
|
||||
|
||||
try {
|
||||
macValidatingStream.readFully()
|
||||
fail("Expected InvalidMessageException to be thrown")
|
||||
} catch (e: InvalidMessageException) {
|
||||
assertThat(e.message).isEqualTo("MAC validation failed!")
|
||||
} finally {
|
||||
macValidatingStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `failure - insufficient data for MAC`() {
|
||||
val key = Util.getSecretBytes(32)
|
||||
val mac = createMac(key)
|
||||
val macLength = mac.macLength
|
||||
val insufficientData = ByteArray(macLength - 1) { 5 } // Less than MAC length
|
||||
|
||||
val inputStream = ByteArrayInputStream(insufficientData)
|
||||
val mac2 = createMac(key)
|
||||
val macValidatingStream = MacValidatingInputStream(inputStream, mac2)
|
||||
|
||||
try {
|
||||
macValidatingStream.readFully()
|
||||
fail("Expected InvalidMessageException to be thrown")
|
||||
} catch (e: InvalidMessageException) {
|
||||
assertThat(e.message).isEqualTo("Stream ended before MAC could be read. Expected $macLength bytes, got ${insufficientData.size}")
|
||||
} finally {
|
||||
macValidatingStream.close()
|
||||
}
|
||||
}
|
||||
|
||||
private fun createMac(key: ByteArray): Mac {
|
||||
val mac = Mac.getInstance("HmacSHA256")
|
||||
mac.init(SecretKeySpec(key, "HmacSHA256"))
|
||||
return mac
|
||||
}
|
||||
|
||||
private fun createDataWithMac(data: ByteArray, key: ByteArray): ByteArray {
|
||||
val mac = createMac(key)
|
||||
val macBytes = mac.doFinal(data)
|
||||
return data + macBytes
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user