Convert AttachmentCipherInputStream to kotlin.

This commit is contained in:
Greyson Parrelli
2025-06-16 17:13:23 -04:00
committed by Michelle Tang
parent 2e79e257a3
commit ee0ee98cb6
6 changed files with 445 additions and 449 deletions

View File

@@ -90,6 +90,11 @@ class PartDataSource implements DataSource {
try {
long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size));
AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength);
if (attachment.remoteDigest == null) {
throw new InvalidMessageException("Missing digest!");
}
this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false);
} catch (InvalidMessageException e) {
throw new IOException("Error decrypting attachment stream!", e);

View File

@@ -1,399 +0,0 @@
/*
* Copyright (C) 2014-2017 Open Whisper Systems
*
* Licensed according to the LICENSE file in this repository.
*/
package org.whispersystems.signalservice.api.crypto;
import org.signal.core.util.stream.LimitedInputStream;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice;
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream;
import org.signal.libsignal.protocol.kdf.HKDF;
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey;
import org.whispersystems.signalservice.internal.util.Util;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
/**
* Class for streaming an encrypted push attachment off disk.
*
* @author Moxie Marlinspike
*/
public class AttachmentCipherInputStream extends FilterInputStream {
private static final int BLOCK_SIZE = 16;
private static final int CIPHER_KEY_SIZE = 32;
private static final int MAC_KEY_SIZE = 32;
private final Cipher cipher;
private final long totalDataSize;
private boolean done;
private long totalRead;
private byte[] overflowBuffer;
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*/
public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
throws InvalidMessageException, IOException {
return createForAttachment(file, plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, false);
}
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
throws InvalidMessageException, IOException
{
return createForAttachment(() -> new FileInputStream(file), file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest);
}
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
public static LimitedInputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest)
throws InvalidMessageException, IOException
{
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = initMac(parts[1]);
if (streamLength <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead! length: " + streamLength);
}
if (!ignoreDigest && digest == null) {
throw new InvalidMessageException("Missing digest!");
}
final InputStream wrappedStream;
final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0;
if (!hasIncrementalMac) {
try (InputStream macVerificationStream = streamSupplier.openStream()) {
verifyMac(macVerificationStream, streamLength, mac, digest);
}
wrappedStream = streamSupplier.openStream();
} else {
wrappedStream = new IncrementalMacInputStream(
new IncrementalMacAdditionalValidationsInputStream(
streamSupplier.openStream(),
streamLength,
mac,
digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest);
}
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) {
return new LimitedInputStream(inputStream, plaintextLength);
} else {
return LimitedInputStream.withoutLimits(inputStream);
}
}
/**
* Decrypt archived media to it's original attachment encrypted blob.
*/
public static LimitedInputStream createForArchivedMedia(MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength)
throws InvalidMessageException, IOException
{
Mac mac = initMac(archivedMediaKeyMaterial.getMacKey());
if (file.length() <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (FileInputStream macVerificationStream = new FileInputStream(file)) {
verifyMac(macVerificationStream, file.length(), mac, null);
}
InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getAesKey(), file.length() - BLOCK_SIZE - mac.getMacLength());
if (originalCipherTextLength != 0) {
return new LimitedInputStream(inputStream, originalCipherTextLength);
} else {
return LimitedInputStream.withoutLimits(inputStream);
}
}
public static LimitedInputStream createStreamingForArchivedAttachment(MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize)
throws InvalidMessageException, IOException
{
final InputStream archiveStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength);
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = initMac(parts[1]);
if (originalCipherTextLength <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
if (digest == null) {
throw new InvalidMessageException("Missing digest!");
}
final InputStream wrappedStream;
wrappedStream = new IncrementalMacInputStream(
new IncrementalMacAdditionalValidationsInputStream(
archiveStream,
file.length(),
mac,
digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest);
InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength());
if (plaintextLength != 0) {
return new LimitedInputStream(inputStream, plaintextLength);
} else {
return LimitedInputStream.withoutLimits(inputStream);
}
}
public static InputStream createForStickerData(byte[] data, byte[] packKey)
throws InvalidMessageException, IOException
{
byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64);
byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE);
Mac mac = initMac(parts[1]);
if (data.length <= BLOCK_SIZE + mac.getMacLength()) {
throw new InvalidMessageException("Message shorter than crypto overhead!");
}
try (InputStream inputStream = new ByteArrayInputStream(data)) {
verifyMac(inputStream, data.length, mac, null);
}
return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength());
}
private AttachmentCipherInputStream(InputStream inputStream, byte[] aesKey, long totalDataSize)
throws IOException
{
super(inputStream);
try {
byte[] iv = new byte[BLOCK_SIZE];
readFully(iv);
this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
this.cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(aesKey, "AES"), new IvParameterSpec(iv));
this.done = false;
this.totalRead = 0;
this.totalDataSize = totalDataSize;
} catch (NoSuchAlgorithmException | InvalidKeyException | NoSuchPaddingException | InvalidAlgorithmParameterException e) {
throw new AssertionError(e);
}
}
@Override
public int read() throws IOException {
byte[] buffer = new byte[1];
int read;
//noinspection StatementWithEmptyBody
while ((read = read(buffer)) == 0) ;
return (read == -1) ? -1 : ((int) buffer[0]) & 0xFF;
}
@Override
public int read(@Nonnull byte[] buffer) throws IOException {
return read(buffer, 0, buffer.length);
}
@Override
public int read(@Nonnull byte[] buffer, int offset, int length) throws IOException {
if (totalRead != totalDataSize) {
return readIncremental(buffer, offset, length);
} else if (!done) {
return readFinal(buffer, offset, length);
} else {
return -1;
}
}
@Override
public boolean markSupported() {
return false;
}
@Override
public long skip(long byteCount) throws IOException {
long skipped = 0L;
while (skipped < byteCount) {
byte[] buf = new byte[Math.min(4096, (int) (byteCount - skipped))];
int read = read(buf);
skipped += read;
}
return skipped;
}
private int readFinal(byte[] buffer, int offset, int length) throws IOException {
try {
byte[] internal = new byte[buffer.length];
int actualLength = Math.min(length, cipher.doFinal(internal, 0));
System.arraycopy(internal, 0, buffer, offset, actualLength);
done = true;
return actualLength;
} catch (IllegalBlockSizeException | BadPaddingException | ShortBufferException e) {
throw new IOException(e);
}
}
private int readIncremental(byte[] buffer, int offset, int length) throws IOException {
int readLength = 0;
if (null != overflowBuffer) {
if (overflowBuffer.length > length) {
System.arraycopy(overflowBuffer, 0, buffer, offset, length);
overflowBuffer = Arrays.copyOfRange(overflowBuffer, length, overflowBuffer.length);
return length;
} else if (overflowBuffer.length == length) {
System.arraycopy(overflowBuffer, 0, buffer, offset, length);
overflowBuffer = null;
return length;
} else {
System.arraycopy(overflowBuffer, 0, buffer, offset, overflowBuffer.length);
readLength += overflowBuffer.length;
offset += readLength;
length -= readLength;
overflowBuffer = null;
}
}
if (length + totalRead > totalDataSize)
length = (int) (totalDataSize - totalRead);
byte[] internalBuffer = new byte[length];
int read = super.read(internalBuffer, 0, internalBuffer.length <= cipher.getBlockSize() ? internalBuffer.length : internalBuffer.length - cipher.getBlockSize());
totalRead += read;
try {
int outputLen = cipher.getOutputSize(read);
if (outputLen <= length) {
readLength += cipher.update(internalBuffer, 0, read, buffer, offset);
return readLength;
}
byte[] transientBuffer = new byte[outputLen];
outputLen = cipher.update(internalBuffer, 0, read, transientBuffer, 0);
if (outputLen <= length) {
System.arraycopy(transientBuffer, 0, buffer, offset, outputLen);
readLength += outputLen;
} else {
System.arraycopy(transientBuffer, 0, buffer, offset, length);
overflowBuffer = Arrays.copyOfRange(transientBuffer, length, outputLen);
readLength += length;
}
return readLength;
} catch (ShortBufferException e) {
throw new AssertionError(e);
}
}
private static Mac initMac(byte[] key) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(key, "HmacSHA256"));
return mac;
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
}
}
private static void verifyMac(@Nonnull InputStream inputStream, long length, @Nonnull Mac mac, @Nullable byte[] theirDigest)
throws InvalidMessageException
{
try {
MessageDigest digest = MessageDigest.getInstance("SHA256");
int remainingData = Util.toIntExact(length) - mac.getMacLength();
byte[] buffer = new byte[4096];
while (remainingData > 0) {
int read = inputStream.read(buffer, 0, Math.min(buffer.length, remainingData));
mac.update(buffer, 0, read);
digest.update(buffer, 0, read);
remainingData -= read;
}
byte[] ourMac = mac.doFinal();
byte[] theirMac = new byte[mac.getMacLength()];
Util.readFully(inputStream, theirMac);
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw new InvalidMessageException("MAC doesn't match!");
}
byte[] ourDigest = digest.digest(theirMac);
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
throw new InvalidMessageException("Digest doesn't match!");
}
} catch (IOException | ArithmeticException e1) {
throw new InvalidMessageException(e1);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
private void readFully(byte[] buffer) throws IOException {
int offset = 0;
for (; ; ) {
int read = super.read(buffer, offset, buffer.length - offset);
if (read + offset < buffer.length) {
offset += read;
} else {
return;
}
}
}
public interface StreamSupplier {
@Nonnull InputStream openStream() throws IOException;
}
}

View File

@@ -0,0 +1,410 @@
/*
* Copyright (C) 2014-2017 Open Whisper Systems
*
* Licensed according to the LICENSE file in this repository.
*/
package org.whispersystems.signalservice.api.crypto
import org.signal.core.util.stream.LimitedInputStream
import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits
import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream
import org.signal.libsignal.protocol.kdf.HKDF
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey.MediaKeyMaterial
import org.whispersystems.signalservice.internal.util.Util
import java.io.ByteArrayInputStream
import java.io.File
import java.io.FileInputStream
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import java.security.InvalidKeyException
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import javax.annotation.Nonnull
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.IllegalBlockSizeException
import javax.crypto.Mac
import javax.crypto.ShortBufferException
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import kotlin.math.min
/**
* Class for streaming an encrypted push attachment off disk.
*
* @author Moxie Marlinspike
*/
class AttachmentCipherInputStream private constructor(
inputStream: InputStream,
aesKey: ByteArray,
private val totalDataSize: Long
) : FilterInputStream(inputStream) {
private val cipher: Cipher
private var done = false
private var totalRead: Long = 0
private var overflowBuffer: ByteArray? = null
init {
val iv = ByteArray(BLOCK_SIZE)
readFullyWithoutDecrypting(iv)
this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv))
}
@Throws(IOException::class)
override fun read(): Int {
val buffer = ByteArray(1)
var read: Int = read(buffer)
while (read == 0) {
read = read(buffer)
}
if (read == -1) {
return read
}
return buffer[0].toInt() and 0xFF
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray): Int {
return read(buffer, 0, buffer.size)
}
@Throws(IOException::class)
override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int {
return if (totalRead != totalDataSize) {
readIncremental(buffer, offset, length)
} else if (!done) {
readFinal(buffer, offset, length)
} else {
-1
}
}
override fun markSupported(): Boolean = false
@Throws(IOException::class)
override fun skip(byteCount: Long): Long {
var skipped = 0L
while (skipped < byteCount) {
val remaining = byteCount - skipped
val buffer = ByteArray(min(4096, remaining.toInt()))
val read = read(buffer)
skipped += read.toLong()
}
return skipped
}
@Throws(IOException::class)
private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int {
var offset = originalOffset
var length = originalLength
var readLength = 0
overflowBuffer?.let { overflow ->
if (overflow.size > length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size)
return length
} else if (overflow.size == length) {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
overflowBuffer = null
return length
} else {
overflow.copyInto(destination = outputBuffer, destinationOffset = offset)
readLength += overflow.size
offset += readLength
length -= readLength
overflowBuffer = null
}
}
if (length + totalRead > totalDataSize) {
length = (totalDataSize - totalRead).toInt()
}
val ciphertextBuffer = ByteArray(length)
val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) {
ciphertextBuffer.size
} else {
// Ensure we leave the final block for readFinal()
ciphertextBuffer.size - cipher.blockSize
}
val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength)
totalRead += ciphertextRead.toLong()
try {
var plaintextLength = cipher.getOutputSize(ciphertextRead)
if (plaintextLength <= length) {
readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset)
return readLength
}
val plaintextBuffer = ByteArray(plaintextLength)
plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0)
if (plaintextLength <= length) {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength)
readLength += plaintextLength
} else {
plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length)
overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength)
readLength += length
}
return readLength
} catch (e: ShortBufferException) {
throw AssertionError(e)
}
}
@Throws(IOException::class)
private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int {
try {
val internal = ByteArray(buffer.size)
val actualLength = min(length, cipher.doFinal(internal, 0))
internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength)
done = true
return actualLength
} catch (e: IllegalBlockSizeException) {
throw IOException(e)
} catch (e: BadPaddingException) {
throw IOException(e)
} catch (e: ShortBufferException) {
throw IOException(e)
}
}
@Throws(IOException::class)
private fun readFullyWithoutDecrypting(buffer: ByteArray) {
var offset = 0
while (true) {
val read = super.read(buffer, offset, buffer.size - offset)
if (read + offset < buffer.size) {
offset += read
} else {
return
}
}
}
fun interface StreamSupplier {
@Nonnull
@Throws(IOException::class)
fun openStream(): InputStream
}
companion object {
private const val BLOCK_SIZE = 16
private const val CIPHER_KEY_SIZE = 32
private const val MAC_KEY_SIZE = 32
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
@JvmStatic
@JvmOverloads
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream {
return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest)
}
/**
* Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation.
*
* Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForAttachment(
streamSupplier: StreamSupplier,
streamLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
digest: ByteArray?,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int,
ignoreDigest: Boolean
): LimitedInputStream {
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (streamLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength")
}
if (!ignoreDigest && digest == null) {
throw InvalidMessageException("Missing digest!")
}
val wrappedStream: InputStream
val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0
if (!hasIncrementalMac) {
streamSupplier.openStream().use { macVerificationStream ->
verifyMac(macVerificationStream, streamLength, mac, digest)
}
wrappedStream = streamSupplier.openStream()
} else {
wrappedStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
streamSupplier.openStream(),
streamLength,
mac,
digest!!
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
}
val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, plaintextLength)
}
/**
* Decrypt archived media to it's original attachment encrypted blob.
*/
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream {
val mac = initMac(archivedMediaKeyMaterial.macKey)
if (file.length() <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
FileInputStream(file).use { macVerificationStream ->
verifyMac(macVerificationStream, file.length(), mac, null)
}
val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength)
return LimitedInputStream(inputStream, originalCipherTextLength)
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createStreamingForArchivedAttachment(
archivedMediaKeyMaterial: MediaKeyMaterial,
file: File,
originalCipherTextLength: Long,
plaintextLength: Long,
combinedKeyMaterial: ByteArray?,
digest: ByteArray,
incrementalDigest: ByteArray?,
incrementalMacChunkSize: Int
): LimitedInputStream {
val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
val wrappedStream: InputStream = IncrementalMacInputStream(
IncrementalMacAdditionalValidationsInputStream(
wrapped = archiveStream,
fileLength = file.length(),
mac = mac,
theirDigest = digest
),
parts[1],
ChunkSizeChoice.everyNthByte(incrementalMacChunkSize),
incrementalDigest
)
val inputStream: InputStream = AttachmentCipherInputStream(
inputStream = wrappedStream,
aesKey = parts[0],
totalDataSize = file.length() - BLOCK_SIZE - mac.macLength
)
return if (plaintextLength != 0L) {
LimitedInputStream(inputStream, plaintextLength)
} else {
withoutLimits(inputStream)
}
}
@JvmStatic
@Throws(InvalidMessageException::class, IOException::class)
fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream {
val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64)
val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE)
val mac = initMac(parts[1])
if (data.size <= BLOCK_SIZE + mac.macLength) {
throw InvalidMessageException("Message shorter than crypto overhead!")
}
ByteArrayInputStream(data).use { inputStream ->
verifyMac(inputStream, data.size.toLong(), mac, null)
}
return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong())
}
private fun initMac(key: ByteArray): Mac {
try {
val mac = Mac.getInstance("HmacSHA256")
mac.init(SecretKeySpec(key, "HmacSHA256"))
return mac
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
} catch (e: InvalidKeyException) {
throw AssertionError(e)
}
}
@Throws(InvalidMessageException::class)
private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) {
try {
val digest = MessageDigest.getInstance("SHA256")
var remainingData = Util.toIntExact(length) - mac.macLength
val buffer = ByteArray(4096)
while (remainingData > 0) {
val read = inputStream.read(buffer, 0, min(buffer.size, remainingData))
mac.update(buffer, 0, read)
digest.update(buffer, 0, read)
remainingData -= read
}
val ourMac = mac.doFinal()
val theirMac = ByteArray(mac.macLength)
Util.readFully(inputStream, theirMac)
if (!MessageDigest.isEqual(ourMac, theirMac)) {
throw InvalidMessageException("MAC doesn't match!")
}
val ourDigest = digest.digest(theirMac)
if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) {
throw InvalidMessageException("Digest doesn't match!")
}
} catch (e: IOException) {
throw InvalidMessageException(e)
} catch (e: ArithmeticException) {
throw InvalidMessageException(e)
} catch (e: NoSuchAlgorithmException) {
throw AssertionError(e)
}
}
}
}

View File

@@ -17,7 +17,7 @@ import kotlin.math.max
* This is meant as a helper stream to go along with [org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream].
* That class does not validate the overall digest, nor the overall MAC. This class does that for us.
*
* To use, wrap the IncremtalMacInputStream around this class, and then this class should wrap the lowest-level data stream.
* To use, wrap the IncrementalMacInputStream around this class, and then this class should wrap the lowest-level data stream.
*/
class IncrementalMacAdditionalValidationsInputStream(
wrapped: InputStream,

View File

@@ -74,6 +74,7 @@ class DigestingRequestBody(
digestStream.close()
digestStream.toByteArray()
} else {
outputStream.close()
null
}

View File

@@ -8,7 +8,10 @@ import org.conscrypt.Conscrypt
import org.junit.Assert
import org.junit.Test
import org.signal.core.util.StreamUtil
import org.signal.core.util.allMatch
import org.signal.core.util.copyTo
import org.signal.core.util.readFully
import org.signal.core.util.stream.LimitedInputStream
import org.signal.libsignal.protocol.InvalidMessageException
import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice
import org.signal.libsignal.protocol.incrementalmac.InvalidMacException
@@ -22,7 +25,6 @@ import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.lang.AssertionError
@@ -31,19 +33,23 @@ import java.util.Random
class AttachmentCipherTest {
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecrypt_nonIncremental() {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE)
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecrypt_incremental() {
attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE)
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() {
for (i in 0..99) {
attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024))
}
}
@Test
fun attachment_encryptDecrypt_incremental_manyFileSizes() {
// Designed to stress the various boundary conditions of reading the final mac
for (i in 0..99) {
@@ -51,7 +57,6 @@ class AttachmentCipherTest {
}
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int) {
val key = Util.getSecretBytes(64)
val plaintextInput = Util.getSecretBytes(fileSize)
@@ -59,27 +64,25 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = readInputStreamFully(inputStream)
val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
assertThat(plaintextOutput).isEqualTo(plaintextInput)
assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecryptEmpty_nonIncremental() {
attachment_encryptDecryptEmpty(incremental = false)
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecryptEmpty_incremental() {
attachment_encryptDecryptEmpty(incremental = true)
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_encryptDecryptEmpty(incremental: Boolean) {
val key = Util.getSecretBytes(64)
val plaintextInput = "".toByteArray()
@@ -87,27 +90,25 @@ class AttachmentCipherTest {
val encryptResult = encryptData(plaintextInput, key, incremental)
val cipherFile = writeToFile(encryptResult.ciphertext)
val inputStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = readInputStreamFully(inputStream)
val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice)
val plaintextOutput = inputStream.readFully(autoClose = false)
Assert.assertArrayEquals(plaintextInput, plaintextOutput)
assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue()
cipherFile.delete()
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadKey_nonIncremental() {
attachment_decryptFailOnBadKey(incremental = false)
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadKey_incremental() {
attachment_decryptFailOnBadKey(incremental = true)
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_decryptFailOnBadKey(incremental: Boolean) {
var cipherFile: File? = null
@@ -126,18 +127,15 @@ class AttachmentCipherTest {
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadMac_nonIncremental() {
attachment_decryptFailOnBadMac(incremental = false)
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadMac_incremental() {
attachment_decryptFailOnBadMac(incremental = true)
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_decryptFailOnBadMac(incremental: Boolean) {
var cipherFile: File? = null
@@ -164,18 +162,15 @@ class AttachmentCipherTest {
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_nonIncremental() {
attachment_decryptFailOnNullDigest(incremental = false)
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnNullDigest_incremental() {
attachment_decryptFailOnNullDigest(incremental = true)
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_decryptFailOnNullDigest(incremental: Boolean) {
var cipherFile: File? = null
@@ -193,18 +188,15 @@ class AttachmentCipherTest {
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_nonIncremental() {
attachment_decryptFailOnBadDigest(incremental = false)
}
@Test(expected = InvalidMessageException::class)
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_decryptFailOnBadDigest_incremental() {
attachment_decryptFailOnBadDigest(incremental = true)
}
@Throws(IOException::class, InvalidMessageException::class)
private fun attachment_decryptFailOnBadDigest(incremental: Boolean) {
var cipherFile: File? = null
@@ -229,7 +221,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun attachment_decryptFailOnBadIncrementalDigest() {
var cipherFile: File? = null
var hitCorrectException = false
@@ -259,7 +250,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun attachment_encryptDecryptPaddedContent() {
val lengths = intArrayOf(531, 600, 724, 1019, 1024)
@@ -295,7 +285,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun archive_encryptDecrypt() {
val key = Util.getSecretBytes(64)
val keyMaterial = createMediaKeyMaterial(key)
@@ -313,7 +302,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun archive_encryptDecryptEmpty() {
val key = Util.getSecretBytes(64)
val keyMaterial = createMediaKeyMaterial(key)
@@ -331,7 +319,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun archive_decryptFailOnBadKey() {
var cipherFile: File? = null
var hitCorrectException = false
@@ -356,7 +343,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun archive_encryptDecryptPaddedContent() {
val lengths = intArrayOf(531, 600, 724, 1019, 1024)
@@ -392,7 +378,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun archive_decryptFailOnBadMac() {
var cipherFile: File? = null
var hitCorrectException = false
@@ -420,13 +405,12 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun sticker_encryptDecrypt() {
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()
val packKey = Util.getSecretBytes(32)
val plaintextInput = Util.getSecretBytes(MEBIBYTE)
val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true)
val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), withIncremental = false, padded = false)
val inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey)
val plaintextOutput = readInputStreamFully(inputStream)
@@ -434,13 +418,12 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class, InvalidMessageException::class)
fun sticker_encryptDecryptEmpty() {
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()
val packKey = Util.getSecretBytes(32)
val plaintextInput = "".toByteArray()
val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true)
val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), withIncremental = false, padded = false)
val inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey)
val plaintextOutput = readInputStreamFully(inputStream)
@@ -448,7 +431,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun sticker_decryptFailOnBadKey() {
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()
@@ -469,7 +451,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun sticker_decryptFailOnBadMac() {
LibSignalLibraryUtil.assumeLibSignalSupportedOnOS()
@@ -492,7 +473,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun outputStream_writeAfterFlush() {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
@@ -521,7 +501,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun outputStream_flushMultipleTimes() {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
@@ -553,7 +532,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun outputStream_singleByteWrite() {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
@@ -579,7 +557,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun outputStream_mixedSingleByteAndArrayWrites() {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
@@ -611,7 +588,6 @@ class AttachmentCipherTest {
}
@Test
@Throws(IOException::class)
fun outputStream_singleByteWriteWithFlushes() {
val key = Util.getSecretBytes(64)
val iv = Util.getSecretBytes(16)
@@ -651,22 +627,27 @@ class AttachmentCipherTest {
private const val MEBIBYTE = 1024 * 1024
@Throws(IOException::class)
private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean): EncryptResult {
private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean, padded: Boolean = true): EncryptResult {
val actualData = if (padded) {
PaddingInputStream(ByteArrayInputStream(data), data.size.toLong()).readFully()
} else {
data
}
val outputStream = ByteArrayOutputStream()
val incrementalDigestOut = ByteArrayOutputStream()
val iv = Util.getSecretBytes(16)
val factory = AttachmentCipherOutputStreamFactory(keyMaterial, iv)
val encryptStream: DigestingOutputStream
val sizeChoice = ChunkSizeChoice.inferChunkSize(data.size)
val sizeChoice = ChunkSizeChoice.inferChunkSize(actualData.size)
encryptStream = if (withIncremental) {
factory.createIncrementalFor(outputStream, data.size.toLong(), sizeChoice, incrementalDigestOut)
factory.createIncrementalFor(outputStream, actualData.size.toLong(), sizeChoice, incrementalDigestOut)
} else {
factory.createFor(outputStream)
}
encryptStream.write(data)
encryptStream.write(actualData)
encryptStream.flush()
encryptStream.close()
incrementalDigestOut.close()
@@ -674,7 +655,6 @@ class AttachmentCipherTest {
return EncryptResult(outputStream.toByteArray(), encryptStream.transmittedDigest, incrementalDigestOut.toByteArray(), sizeChoice.sizeInBytes)
}
@Throws(IOException::class)
private fun writeToFile(data: ByteArray): File {
val file = File.createTempFile("temp", ".data")
val outputStream: OutputStream = FileOutputStream(file)
@@ -685,7 +665,6 @@ class AttachmentCipherTest {
return file
}
@Throws(IOException::class)
private fun readInputStreamFully(inputStream: InputStream): ByteArray {
return Util.readFullyAsBytes(inputStream)
}