diff --git a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java index 2fd334a2e0..55135c0562 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java +++ b/app/src/main/java/org/thoughtcrime/securesms/video/exo/PartDataSource.java @@ -90,6 +90,11 @@ class PartDataSource implements DataSource { try { long streamLength = AttachmentCipherStreamUtil.getCiphertextLength(PaddingInputStream.getPaddedSize(attachment.size)); AttachmentCipherInputStream.StreamSupplier streamSupplier = () -> new TailerInputStream(() -> new FileInputStream(transferFile), streamLength); + + if (attachment.remoteDigest == null) { + throw new InvalidMessageException("Missing digest!"); + } + this.inputStream = AttachmentCipherInputStream.createForAttachment(streamSupplier, streamLength, attachment.size, decode, attachment.remoteDigest, attachment.getIncrementalDigest(), attachment.incrementalMacChunkSize, false); } catch (InvalidMessageException e) { throw new IOException("Error decrypting attachment stream!", e); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java deleted file mode 100644 index 5d099238c8..0000000000 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java +++ /dev/null @@ -1,399 +0,0 @@ -/* - * Copyright (C) 2014-2017 Open Whisper Systems - * - * Licensed according to the LICENSE file in this repository. - */ - -package org.whispersystems.signalservice.api.crypto; - -import org.signal.core.util.stream.LimitedInputStream; -import org.signal.libsignal.protocol.InvalidMessageException; -import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice; -import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream; -import org.signal.libsignal.protocol.kdf.HKDF; -import org.whispersystems.signalservice.api.backup.MediaRootBackupKey; -import org.whispersystems.signalservice.internal.util.Util; - -import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FilterInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.security.InvalidAlgorithmParameterException; -import java.security.InvalidKeyException; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import javax.crypto.BadPaddingException; -import javax.crypto.Cipher; -import javax.crypto.IllegalBlockSizeException; -import javax.crypto.Mac; -import javax.crypto.NoSuchPaddingException; -import javax.crypto.ShortBufferException; -import javax.crypto.spec.IvParameterSpec; -import javax.crypto.spec.SecretKeySpec; - -/** - * Class for streaming an encrypted push attachment off disk. - * - * @author Moxie Marlinspike - */ - -public class AttachmentCipherInputStream extends FilterInputStream { - - private static final int BLOCK_SIZE = 16; - private static final int CIPHER_KEY_SIZE = 32; - private static final int MAC_KEY_SIZE = 32; - - private final Cipher cipher; - private final long totalDataSize; - - private boolean done; - private long totalRead; - private byte[] overflowBuffer; - - /** - * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. - */ - public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) - throws InvalidMessageException, IOException { - return createForAttachment(file, plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, false); - } - - /** - * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. - * - * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST - */ - public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) - throws InvalidMessageException, IOException - { - return createForAttachment(() -> new FileInputStream(file), file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest); - } - - /** - * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. - * - * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST - */ - public static LimitedInputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) - throws InvalidMessageException, IOException - { - byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); - Mac mac = initMac(parts[1]); - - if (streamLength <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead! length: " + streamLength); - } - - if (!ignoreDigest && digest == null) { - throw new InvalidMessageException("Missing digest!"); - } - - final InputStream wrappedStream; - final boolean hasIncrementalMac = incrementalDigest != null && incrementalDigest.length > 0 && incrementalMacChunkSize > 0; - - if (!hasIncrementalMac) { - try (InputStream macVerificationStream = streamSupplier.openStream()) { - verifyMac(macVerificationStream, streamLength, mac, digest); - } - wrappedStream = streamSupplier.openStream(); - } else { - wrappedStream = new IncrementalMacInputStream( - new IncrementalMacAdditionalValidationsInputStream( - streamSupplier.openStream(), - streamLength, - mac, - digest - ), - parts[1], - ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), - incrementalDigest); - } - InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength()); - - if (plaintextLength != 0) { - return new LimitedInputStream(inputStream, plaintextLength); - } else { - return LimitedInputStream.withoutLimits(inputStream); - } - } - - /** - * Decrypt archived media to it's original attachment encrypted blob. - */ - public static LimitedInputStream createForArchivedMedia(MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength) - throws InvalidMessageException, IOException - { - Mac mac = initMac(archivedMediaKeyMaterial.getMacKey()); - - if (file.length() <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - try (FileInputStream macVerificationStream = new FileInputStream(file)) { - verifyMac(macVerificationStream, file.length(), mac, null); - } - - InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getAesKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); - - if (originalCipherTextLength != 0) { - return new LimitedInputStream(inputStream, originalCipherTextLength); - } else { - return LimitedInputStream.withoutLimits(inputStream); - } - } - - public static LimitedInputStream createStreamingForArchivedAttachment(MediaRootBackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) - throws InvalidMessageException, IOException - { - final InputStream archiveStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength); - - byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); - Mac mac = initMac(parts[1]); - - if (originalCipherTextLength <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - if (digest == null) { - throw new InvalidMessageException("Missing digest!"); - } - - final InputStream wrappedStream; - wrappedStream = new IncrementalMacInputStream( - new IncrementalMacAdditionalValidationsInputStream( - archiveStream, - file.length(), - mac, - digest - ), - parts[1], - ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), - incrementalDigest); - - InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); - - if (plaintextLength != 0) { - return new LimitedInputStream(inputStream, plaintextLength); - } else { - return LimitedInputStream.withoutLimits(inputStream); - } - - } - - public static InputStream createForStickerData(byte[] data, byte[] packKey) - throws InvalidMessageException, IOException - { - byte[] combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".getBytes(), 64); - byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); - Mac mac = initMac(parts[1]); - - if (data.length <= BLOCK_SIZE + mac.getMacLength()) { - throw new InvalidMessageException("Message shorter than crypto overhead!"); - } - - try (InputStream inputStream = new ByteArrayInputStream(data)) { - verifyMac(inputStream, data.length, mac, null); - } - - return new AttachmentCipherInputStream(new ByteArrayInputStream(data), parts[0], data.length - BLOCK_SIZE - mac.getMacLength()); - } - - private AttachmentCipherInputStream(InputStream inputStream, byte[] aesKey, long totalDataSize) - throws IOException - { - super(inputStream); - - try { - byte[] iv = new byte[BLOCK_SIZE]; - readFully(iv); - - this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); - this.cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(aesKey, "AES"), new IvParameterSpec(iv)); - - this.done = false; - this.totalRead = 0; - this.totalDataSize = totalDataSize; - } catch (NoSuchAlgorithmException | InvalidKeyException | NoSuchPaddingException | InvalidAlgorithmParameterException e) { - throw new AssertionError(e); - } - } - - @Override - public int read() throws IOException { - byte[] buffer = new byte[1]; - int read; - - //noinspection StatementWithEmptyBody - while ((read = read(buffer)) == 0) ; - - return (read == -1) ? -1 : ((int) buffer[0]) & 0xFF; - } - - @Override - public int read(@Nonnull byte[] buffer) throws IOException { - return read(buffer, 0, buffer.length); - } - - @Override - public int read(@Nonnull byte[] buffer, int offset, int length) throws IOException { - if (totalRead != totalDataSize) { - return readIncremental(buffer, offset, length); - } else if (!done) { - return readFinal(buffer, offset, length); - } else { - return -1; - } - } - - @Override - public boolean markSupported() { - return false; - } - - @Override - public long skip(long byteCount) throws IOException { - long skipped = 0L; - while (skipped < byteCount) { - byte[] buf = new byte[Math.min(4096, (int) (byteCount - skipped))]; - int read = read(buf); - - skipped += read; - } - - return skipped; - } - - private int readFinal(byte[] buffer, int offset, int length) throws IOException { - try { - byte[] internal = new byte[buffer.length]; - int actualLength = Math.min(length, cipher.doFinal(internal, 0)); - System.arraycopy(internal, 0, buffer, offset, actualLength); - - done = true; - return actualLength; - } catch (IllegalBlockSizeException | BadPaddingException | ShortBufferException e) { - throw new IOException(e); - } - } - - private int readIncremental(byte[] buffer, int offset, int length) throws IOException { - int readLength = 0; - if (null != overflowBuffer) { - if (overflowBuffer.length > length) { - System.arraycopy(overflowBuffer, 0, buffer, offset, length); - overflowBuffer = Arrays.copyOfRange(overflowBuffer, length, overflowBuffer.length); - return length; - } else if (overflowBuffer.length == length) { - System.arraycopy(overflowBuffer, 0, buffer, offset, length); - overflowBuffer = null; - return length; - } else { - System.arraycopy(overflowBuffer, 0, buffer, offset, overflowBuffer.length); - readLength += overflowBuffer.length; - offset += readLength; - length -= readLength; - overflowBuffer = null; - } - } - - if (length + totalRead > totalDataSize) - length = (int) (totalDataSize - totalRead); - - byte[] internalBuffer = new byte[length]; - int read = super.read(internalBuffer, 0, internalBuffer.length <= cipher.getBlockSize() ? internalBuffer.length : internalBuffer.length - cipher.getBlockSize()); - totalRead += read; - - try { - int outputLen = cipher.getOutputSize(read); - - if (outputLen <= length) { - readLength += cipher.update(internalBuffer, 0, read, buffer, offset); - return readLength; - } - - byte[] transientBuffer = new byte[outputLen]; - outputLen = cipher.update(internalBuffer, 0, read, transientBuffer, 0); - if (outputLen <= length) { - System.arraycopy(transientBuffer, 0, buffer, offset, outputLen); - readLength += outputLen; - } else { - System.arraycopy(transientBuffer, 0, buffer, offset, length); - overflowBuffer = Arrays.copyOfRange(transientBuffer, length, outputLen); - readLength += length; - } - return readLength; - } catch (ShortBufferException e) { - throw new AssertionError(e); - } - } - - private static Mac initMac(byte[] key) { - try { - Mac mac = Mac.getInstance("HmacSHA256"); - mac.init(new SecretKeySpec(key, "HmacSHA256")); - return mac; - } catch (NoSuchAlgorithmException | InvalidKeyException e) { - throw new AssertionError(e); - } - } - - private static void verifyMac(@Nonnull InputStream inputStream, long length, @Nonnull Mac mac, @Nullable byte[] theirDigest) - throws InvalidMessageException - { - try { - MessageDigest digest = MessageDigest.getInstance("SHA256"); - int remainingData = Util.toIntExact(length) - mac.getMacLength(); - byte[] buffer = new byte[4096]; - - while (remainingData > 0) { - int read = inputStream.read(buffer, 0, Math.min(buffer.length, remainingData)); - mac.update(buffer, 0, read); - digest.update(buffer, 0, read); - remainingData -= read; - } - - byte[] ourMac = mac.doFinal(); - byte[] theirMac = new byte[mac.getMacLength()]; - Util.readFully(inputStream, theirMac); - - if (!MessageDigest.isEqual(ourMac, theirMac)) { - throw new InvalidMessageException("MAC doesn't match!"); - } - - byte[] ourDigest = digest.digest(theirMac); - - if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) { - throw new InvalidMessageException("Digest doesn't match!"); - } - - } catch (IOException | ArithmeticException e1) { - throw new InvalidMessageException(e1); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } - } - - private void readFully(byte[] buffer) throws IOException { - int offset = 0; - - for (; ; ) { - int read = super.read(buffer, offset, buffer.length - offset); - - if (read + offset < buffer.length) { - offset += read; - } else { - return; - } - } - } - - public interface StreamSupplier { - @Nonnull InputStream openStream() throws IOException; - } -} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt new file mode 100644 index 0000000000..94ad69b961 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.kt @@ -0,0 +1,410 @@ +/* + * Copyright (C) 2014-2017 Open Whisper Systems + * + * Licensed according to the LICENSE file in this repository. + */ +package org.whispersystems.signalservice.api.crypto + +import org.signal.core.util.stream.LimitedInputStream +import org.signal.core.util.stream.LimitedInputStream.Companion.withoutLimits +import org.signal.libsignal.protocol.InvalidMessageException +import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice +import org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream +import org.signal.libsignal.protocol.kdf.HKDF +import org.whispersystems.signalservice.api.backup.MediaRootBackupKey.MediaKeyMaterial +import org.whispersystems.signalservice.internal.util.Util +import java.io.ByteArrayInputStream +import java.io.File +import java.io.FileInputStream +import java.io.FilterInputStream +import java.io.IOException +import java.io.InputStream +import java.security.InvalidKeyException +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException +import javax.annotation.Nonnull +import javax.crypto.BadPaddingException +import javax.crypto.Cipher +import javax.crypto.IllegalBlockSizeException +import javax.crypto.Mac +import javax.crypto.ShortBufferException +import javax.crypto.spec.IvParameterSpec +import javax.crypto.spec.SecretKeySpec +import kotlin.math.min + +/** + * Class for streaming an encrypted push attachment off disk. + * + * @author Moxie Marlinspike + */ +class AttachmentCipherInputStream private constructor( + inputStream: InputStream, + aesKey: ByteArray, + private val totalDataSize: Long +) : FilterInputStream(inputStream) { + + private val cipher: Cipher + + private var done = false + private var totalRead: Long = 0 + private var overflowBuffer: ByteArray? = null + + init { + val iv = ByteArray(BLOCK_SIZE) + readFullyWithoutDecrypting(iv) + + this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") + cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(aesKey, "AES"), IvParameterSpec(iv)) + } + + @Throws(IOException::class) + override fun read(): Int { + val buffer = ByteArray(1) + var read: Int = read(buffer) + while (read == 0) { + read = read(buffer) + } + + if (read == -1) { + return read + } + + return buffer[0].toInt() and 0xFF + } + + @Throws(IOException::class) + override fun read(@Nonnull buffer: ByteArray): Int { + return read(buffer, 0, buffer.size) + } + + @Throws(IOException::class) + override fun read(@Nonnull buffer: ByteArray, offset: Int, length: Int): Int { + return if (totalRead != totalDataSize) { + readIncremental(buffer, offset, length) + } else if (!done) { + readFinal(buffer, offset, length) + } else { + -1 + } + } + + override fun markSupported(): Boolean = false + + @Throws(IOException::class) + override fun skip(byteCount: Long): Long { + var skipped = 0L + while (skipped < byteCount) { + val remaining = byteCount - skipped + val buffer = ByteArray(min(4096, remaining.toInt())) + val read = read(buffer) + + skipped += read.toLong() + } + + return skipped + } + + @Throws(IOException::class) + private fun readIncremental(outputBuffer: ByteArray, originalOffset: Int, originalLength: Int): Int { + var offset = originalOffset + var length = originalLength + var readLength = 0 + + overflowBuffer?.let { overflow -> + if (overflow.size > length) { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) + overflowBuffer = overflow.copyOfRange(fromIndex = length, toIndex = overflow.size) + return length + } else if (overflow.size == length) { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset) + overflowBuffer = null + return length + } else { + overflow.copyInto(destination = outputBuffer, destinationOffset = offset) + readLength += overflow.size + offset += readLength + length -= readLength + overflowBuffer = null + } + } + + if (length + totalRead > totalDataSize) { + length = (totalDataSize - totalRead).toInt() + } + + val ciphertextBuffer = ByteArray(length) + val ciphertextReadLength = if (ciphertextBuffer.size <= cipher.blockSize) { + ciphertextBuffer.size + } else { + // Ensure we leave the final block for readFinal() + ciphertextBuffer.size - cipher.blockSize + } + val ciphertextRead = super.read(ciphertextBuffer, 0, ciphertextReadLength) + totalRead += ciphertextRead.toLong() + + try { + var plaintextLength = cipher.getOutputSize(ciphertextRead) + + if (plaintextLength <= length) { + readLength += cipher.update(ciphertextBuffer, 0, ciphertextRead, outputBuffer, offset) + return readLength + } + + val plaintextBuffer = ByteArray(plaintextLength) + plaintextLength = cipher.update(ciphertextBuffer, 0, ciphertextRead, plaintextBuffer, 0) + if (plaintextLength <= length) { + plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = plaintextLength) + readLength += plaintextLength + } else { + plaintextBuffer.copyInto(destination = outputBuffer, destinationOffset = offset, endIndex = length) + overflowBuffer = plaintextBuffer.copyOfRange(fromIndex = length, toIndex = plaintextLength) + readLength += length + } + return readLength + } catch (e: ShortBufferException) { + throw AssertionError(e) + } + } + + @Throws(IOException::class) + private fun readFinal(buffer: ByteArray, offset: Int, length: Int): Int { + try { + val internal = ByteArray(buffer.size) + val actualLength = min(length, cipher.doFinal(internal, 0)) + internal.copyInto(destination = buffer, destinationOffset = offset, endIndex = actualLength) + + done = true + return actualLength + } catch (e: IllegalBlockSizeException) { + throw IOException(e) + } catch (e: BadPaddingException) { + throw IOException(e) + } catch (e: ShortBufferException) { + throw IOException(e) + } + } + + @Throws(IOException::class) + private fun readFullyWithoutDecrypting(buffer: ByteArray) { + var offset = 0 + + while (true) { + val read = super.read(buffer, offset, buffer.size - offset) + + if (read + offset < buffer.size) { + offset += read + } else { + return + } + } + } + + fun interface StreamSupplier { + @Nonnull + @Throws(IOException::class) + fun openStream(): InputStream + } + + companion object { + private const val BLOCK_SIZE = 16 + private const val CIPHER_KEY_SIZE = 32 + private const val MAC_KEY_SIZE = 32 + + /** + * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. + * + * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST + */ + @JvmStatic + @JvmOverloads + @Throws(InvalidMessageException::class, IOException::class) + fun createForAttachment(file: File, plaintextLength: Long, combinedKeyMaterial: ByteArray?, digest: ByteArray?, incrementalDigest: ByteArray?, incrementalMacChunkSize: Int, ignoreDigest: Boolean = false): LimitedInputStream { + return createForAttachment({ FileInputStream(file) }, file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest) + } + + /** + * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. + * + * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForAttachment( + streamSupplier: StreamSupplier, + streamLength: Long, + plaintextLength: Long, + combinedKeyMaterial: ByteArray?, + digest: ByteArray?, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int, + ignoreDigest: Boolean + ): LimitedInputStream { + val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) + val mac = initMac(parts[1]) + + if (streamLength <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead! length: $streamLength") + } + + if (!ignoreDigest && digest == null) { + throw InvalidMessageException("Missing digest!") + } + + val wrappedStream: InputStream + val hasIncrementalMac = incrementalDigest != null && incrementalDigest.isNotEmpty() && incrementalMacChunkSize > 0 + + if (!hasIncrementalMac) { + streamSupplier.openStream().use { macVerificationStream -> + verifyMac(macVerificationStream, streamLength, mac, digest) + } + wrappedStream = streamSupplier.openStream() + } else { + wrappedStream = IncrementalMacInputStream( + IncrementalMacAdditionalValidationsInputStream( + streamSupplier.openStream(), + streamLength, + mac, + digest!! + ), + parts[1], + ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), + incrementalDigest + ) + } + val inputStream: InputStream = AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.macLength) + + return LimitedInputStream(inputStream, plaintextLength) + } + + /** + * Decrypt archived media to it's original attachment encrypted blob. + */ + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForArchivedMedia(archivedMediaKeyMaterial: MediaKeyMaterial, file: File, originalCipherTextLength: Long): LimitedInputStream { + val mac = initMac(archivedMediaKeyMaterial.macKey) + + if (file.length() <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") + } + + FileInputStream(file).use { macVerificationStream -> + verifyMac(macVerificationStream, file.length(), mac, null) + } + val inputStream: InputStream = AttachmentCipherInputStream(FileInputStream(file), archivedMediaKeyMaterial.aesKey, file.length() - BLOCK_SIZE - mac.macLength) + + return LimitedInputStream(inputStream, originalCipherTextLength) + } + + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createStreamingForArchivedAttachment( + archivedMediaKeyMaterial: MediaKeyMaterial, + file: File, + originalCipherTextLength: Long, + plaintextLength: Long, + combinedKeyMaterial: ByteArray?, + digest: ByteArray, + incrementalDigest: ByteArray?, + incrementalMacChunkSize: Int + ): LimitedInputStream { + val archiveStream: InputStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength) + + val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) + val mac = initMac(parts[1]) + + if (originalCipherTextLength <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") + } + + val wrappedStream: InputStream = IncrementalMacInputStream( + IncrementalMacAdditionalValidationsInputStream( + wrapped = archiveStream, + fileLength = file.length(), + mac = mac, + theirDigest = digest + ), + parts[1], + ChunkSizeChoice.everyNthByte(incrementalMacChunkSize), + incrementalDigest + ) + + val inputStream: InputStream = AttachmentCipherInputStream( + inputStream = wrappedStream, + aesKey = parts[0], + totalDataSize = file.length() - BLOCK_SIZE - mac.macLength + ) + + return if (plaintextLength != 0L) { + LimitedInputStream(inputStream, plaintextLength) + } else { + withoutLimits(inputStream) + } + } + + @JvmStatic + @Throws(InvalidMessageException::class, IOException::class) + fun createForStickerData(data: ByteArray, packKey: ByteArray?): InputStream { + val combinedKeyMaterial = HKDF.deriveSecrets(packKey, "Sticker Pack".toByteArray(), 64) + val parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE) + val mac = initMac(parts[1]) + + if (data.size <= BLOCK_SIZE + mac.macLength) { + throw InvalidMessageException("Message shorter than crypto overhead!") + } + + ByteArrayInputStream(data).use { inputStream -> + verifyMac(inputStream, data.size.toLong(), mac, null) + } + return AttachmentCipherInputStream(ByteArrayInputStream(data), parts[0], (data.size - BLOCK_SIZE - mac.macLength).toLong()) + } + + private fun initMac(key: ByteArray): Mac { + try { + val mac = Mac.getInstance("HmacSHA256") + mac.init(SecretKeySpec(key, "HmacSHA256")) + return mac + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) + } catch (e: InvalidKeyException) { + throw AssertionError(e) + } + } + + @Throws(InvalidMessageException::class) + private fun verifyMac(@Nonnull inputStream: InputStream, length: Long, @Nonnull mac: Mac, theirDigest: ByteArray?) { + try { + val digest = MessageDigest.getInstance("SHA256") + var remainingData = Util.toIntExact(length) - mac.macLength + val buffer = ByteArray(4096) + + while (remainingData > 0) { + val read = inputStream.read(buffer, 0, min(buffer.size, remainingData)) + mac.update(buffer, 0, read) + digest.update(buffer, 0, read) + remainingData -= read + } + + val ourMac = mac.doFinal() + val theirMac = ByteArray(mac.macLength) + Util.readFully(inputStream, theirMac) + + if (!MessageDigest.isEqual(ourMac, theirMac)) { + throw InvalidMessageException("MAC doesn't match!") + } + + val ourDigest = digest.digest(theirMac) + + if (theirDigest != null && !MessageDigest.isEqual(ourDigest, theirDigest)) { + throw InvalidMessageException("Digest doesn't match!") + } + } catch (e: IOException) { + throw InvalidMessageException(e) + } catch (e: ArithmeticException) { + throw InvalidMessageException(e) + } catch (e: NoSuchAlgorithmException) { + throw AssertionError(e) + } + } + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt index 4e55b1ff50..f24dd6cfe6 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/IncrementalMacAdditionalValidationsInputStream.kt @@ -17,7 +17,7 @@ import kotlin.math.max * This is meant as a helper stream to go along with [org.signal.libsignal.protocol.incrementalmac.IncrementalMacInputStream]. * That class does not validate the overall digest, nor the overall MAC. This class does that for us. * - * To use, wrap the IncremtalMacInputStream around this class, and then this class should wrap the lowest-level data stream. + * To use, wrap the IncrementalMacInputStream around this class, and then this class should wrap the lowest-level data stream. */ class IncrementalMacAdditionalValidationsInputStream( wrapped: InputStream, diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/http/DigestingRequestBody.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/http/DigestingRequestBody.kt index fe4930f5de..c582d9108b 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/http/DigestingRequestBody.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/http/DigestingRequestBody.kt @@ -74,6 +74,7 @@ class DigestingRequestBody( digestStream.close() digestStream.toByteArray() } else { + outputStream.close() null } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt index 3b7630c2d6..2af0614ba6 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherTest.kt @@ -8,7 +8,10 @@ import org.conscrypt.Conscrypt import org.junit.Assert import org.junit.Test import org.signal.core.util.StreamUtil +import org.signal.core.util.allMatch import org.signal.core.util.copyTo +import org.signal.core.util.readFully +import org.signal.core.util.stream.LimitedInputStream import org.signal.libsignal.protocol.InvalidMessageException import org.signal.libsignal.protocol.incrementalmac.ChunkSizeChoice import org.signal.libsignal.protocol.incrementalmac.InvalidMacException @@ -22,7 +25,6 @@ import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.File import java.io.FileOutputStream -import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.lang.AssertionError @@ -31,19 +33,23 @@ import java.util.Random class AttachmentCipherTest { @Test - @Throws(IOException::class, InvalidMessageException::class) fun attachment_encryptDecrypt_nonIncremental() { attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE) } @Test - @Throws(IOException::class, InvalidMessageException::class) fun attachment_encryptDecrypt_incremental() { attachment_encryptDecrypt(incremental = true, fileSize = MEBIBYTE) } @Test - @Throws(IOException::class, InvalidMessageException::class) + fun attachment_encryptDecrypt_nonIncremental_manyFileSizes() { + for (i in 0..99) { + attachment_encryptDecrypt(incremental = false, fileSize = MEBIBYTE + Random().nextInt(1, 64 * 1024)) + } + } + + @Test fun attachment_encryptDecrypt_incremental_manyFileSizes() { // Designed to stress the various boundary conditions of reading the final mac for (i in 0..99) { @@ -51,7 +57,6 @@ class AttachmentCipherTest { } } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_encryptDecrypt(incremental: Boolean, fileSize: Int) { val key = Util.getSecretBytes(64) val plaintextInput = Util.getSecretBytes(fileSize) @@ -59,27 +64,25 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - val plaintextOutput = readInputStreamFully(inputStream) + val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val plaintextOutput = inputStream.readFully(autoClose = false) assertThat(plaintextOutput).isEqualTo(plaintextInput) + assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @Test - @Throws(IOException::class, InvalidMessageException::class) fun attachment_encryptDecryptEmpty_nonIncremental() { attachment_encryptDecryptEmpty(incremental = false) } @Test - @Throws(IOException::class, InvalidMessageException::class) fun attachment_encryptDecryptEmpty_incremental() { attachment_encryptDecryptEmpty(incremental = true) } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_encryptDecryptEmpty(incremental: Boolean) { val key = Util.getSecretBytes(64) val plaintextInput = "".toByteArray() @@ -87,27 +90,25 @@ class AttachmentCipherTest { val encryptResult = encryptData(plaintextInput, key, incremental) val cipherFile = writeToFile(encryptResult.ciphertext) - val inputStream: InputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) - val plaintextOutput = readInputStreamFully(inputStream) + val inputStream: LimitedInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintextInput.size.toLong(), key, encryptResult.digest, encryptResult.incrementalDigest, encryptResult.chunkSizeChoice) + val plaintextOutput = inputStream.readFully(autoClose = false) Assert.assertArrayEquals(plaintextInput, plaintextOutput) + assertThat(inputStream.leftoverStream().allMatch { it == 0.toByte() }).isTrue() cipherFile.delete() } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadKey_nonIncremental() { attachment_decryptFailOnBadKey(incremental = false) } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadKey_incremental() { attachment_decryptFailOnBadKey(incremental = true) } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_decryptFailOnBadKey(incremental: Boolean) { var cipherFile: File? = null @@ -126,18 +127,15 @@ class AttachmentCipherTest { } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadMac_nonIncremental() { attachment_decryptFailOnBadMac(incremental = false) } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadMac_incremental() { attachment_decryptFailOnBadMac(incremental = true) } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_decryptFailOnBadMac(incremental: Boolean) { var cipherFile: File? = null @@ -164,18 +162,15 @@ class AttachmentCipherTest { } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnNullDigest_nonIncremental() { attachment_decryptFailOnNullDigest(incremental = false) } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnNullDigest_incremental() { attachment_decryptFailOnNullDigest(incremental = true) } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_decryptFailOnNullDigest(incremental: Boolean) { var cipherFile: File? = null @@ -193,18 +188,15 @@ class AttachmentCipherTest { } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadDigest_nonIncremental() { attachment_decryptFailOnBadDigest(incremental = false) } @Test(expected = InvalidMessageException::class) - @Throws(IOException::class, InvalidMessageException::class) fun attachment_decryptFailOnBadDigest_incremental() { attachment_decryptFailOnBadDigest(incremental = true) } - @Throws(IOException::class, InvalidMessageException::class) private fun attachment_decryptFailOnBadDigest(incremental: Boolean) { var cipherFile: File? = null @@ -229,7 +221,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun attachment_decryptFailOnBadIncrementalDigest() { var cipherFile: File? = null var hitCorrectException = false @@ -259,7 +250,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun attachment_encryptDecryptPaddedContent() { val lengths = intArrayOf(531, 600, 724, 1019, 1024) @@ -295,7 +285,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun archive_encryptDecrypt() { val key = Util.getSecretBytes(64) val keyMaterial = createMediaKeyMaterial(key) @@ -313,7 +302,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun archive_encryptDecryptEmpty() { val key = Util.getSecretBytes(64) val keyMaterial = createMediaKeyMaterial(key) @@ -331,7 +319,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun archive_decryptFailOnBadKey() { var cipherFile: File? = null var hitCorrectException = false @@ -356,7 +343,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun archive_encryptDecryptPaddedContent() { val lengths = intArrayOf(531, 600, 724, 1019, 1024) @@ -392,7 +378,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun archive_decryptFailOnBadMac() { var cipherFile: File? = null var hitCorrectException = false @@ -420,13 +405,12 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun sticker_encryptDecrypt() { LibSignalLibraryUtil.assumeLibSignalSupportedOnOS() val packKey = Util.getSecretBytes(32) val plaintextInput = Util.getSecretBytes(MEBIBYTE) - val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true) + val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), withIncremental = false, padded = false) val inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey) val plaintextOutput = readInputStreamFully(inputStream) @@ -434,13 +418,12 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class, InvalidMessageException::class) fun sticker_encryptDecryptEmpty() { LibSignalLibraryUtil.assumeLibSignalSupportedOnOS() val packKey = Util.getSecretBytes(32) val plaintextInput = "".toByteArray() - val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), true) + val encryptResult = encryptData(plaintextInput, expandPackKey(packKey), withIncremental = false, padded = false) val inputStream = AttachmentCipherInputStream.createForStickerData(encryptResult.ciphertext, packKey) val plaintextOutput = readInputStreamFully(inputStream) @@ -448,7 +431,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun sticker_decryptFailOnBadKey() { LibSignalLibraryUtil.assumeLibSignalSupportedOnOS() @@ -469,7 +451,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun sticker_decryptFailOnBadMac() { LibSignalLibraryUtil.assumeLibSignalSupportedOnOS() @@ -492,7 +473,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun outputStream_writeAfterFlush() { val key = Util.getSecretBytes(64) val iv = Util.getSecretBytes(16) @@ -521,7 +501,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun outputStream_flushMultipleTimes() { val key = Util.getSecretBytes(64) val iv = Util.getSecretBytes(16) @@ -553,7 +532,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun outputStream_singleByteWrite() { val key = Util.getSecretBytes(64) val iv = Util.getSecretBytes(16) @@ -579,7 +557,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun outputStream_mixedSingleByteAndArrayWrites() { val key = Util.getSecretBytes(64) val iv = Util.getSecretBytes(16) @@ -611,7 +588,6 @@ class AttachmentCipherTest { } @Test - @Throws(IOException::class) fun outputStream_singleByteWriteWithFlushes() { val key = Util.getSecretBytes(64) val iv = Util.getSecretBytes(16) @@ -651,22 +627,27 @@ class AttachmentCipherTest { private const val MEBIBYTE = 1024 * 1024 - @Throws(IOException::class) - private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean): EncryptResult { + private fun encryptData(data: ByteArray, keyMaterial: ByteArray, withIncremental: Boolean, padded: Boolean = true): EncryptResult { + val actualData = if (padded) { + PaddingInputStream(ByteArrayInputStream(data), data.size.toLong()).readFully() + } else { + data + } + val outputStream = ByteArrayOutputStream() val incrementalDigestOut = ByteArrayOutputStream() val iv = Util.getSecretBytes(16) val factory = AttachmentCipherOutputStreamFactory(keyMaterial, iv) val encryptStream: DigestingOutputStream - val sizeChoice = ChunkSizeChoice.inferChunkSize(data.size) + val sizeChoice = ChunkSizeChoice.inferChunkSize(actualData.size) encryptStream = if (withIncremental) { - factory.createIncrementalFor(outputStream, data.size.toLong(), sizeChoice, incrementalDigestOut) + factory.createIncrementalFor(outputStream, actualData.size.toLong(), sizeChoice, incrementalDigestOut) } else { factory.createFor(outputStream) } - encryptStream.write(data) + encryptStream.write(actualData) encryptStream.flush() encryptStream.close() incrementalDigestOut.close() @@ -674,7 +655,6 @@ class AttachmentCipherTest { return EncryptResult(outputStream.toByteArray(), encryptStream.transmittedDigest, incrementalDigestOut.toByteArray(), sizeChoice.sizeInBytes) } - @Throws(IOException::class) private fun writeToFile(data: ByteArray): File { val file = File.createTempFile("temp", ".data") val outputStream: OutputStream = FileOutputStream(file) @@ -685,7 +665,6 @@ class AttachmentCipherTest { return file } - @Throws(IOException::class) private fun readInputStreamFully(inputStream: InputStream): ByteArray { return Util.readFullyAsBytes(inputStream) }