diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt index ec4efe7922..d72ce25d9f 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/database/AttachmentTableTest.kt @@ -1,15 +1,23 @@ package org.thoughtcrime.securesms.database +import android.content.Context import android.net.Uri import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.filters.FlakyTest +import androidx.test.platform.app.InstrumentationRegistry +import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertNotEquals import org.junit.Before import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith +import org.signal.core.util.copyTo +import org.signal.core.util.readFully +import org.signal.core.util.stream.NullOutputStream +import org.thoughtcrime.securesms.attachments.Attachment import org.thoughtcrime.securesms.attachments.AttachmentId +import org.thoughtcrime.securesms.attachments.PointerAttachment import org.thoughtcrime.securesms.attachments.UriAttachment import org.thoughtcrime.securesms.mms.MediaStream import org.thoughtcrime.securesms.mms.SentMediaQuality @@ -17,6 +25,15 @@ import org.thoughtcrime.securesms.providers.BlobProvider import org.thoughtcrime.securesms.testing.assertIs import org.thoughtcrime.securesms.testing.assertIsNot import org.thoughtcrime.securesms.util.MediaUtil +import org.thoughtcrime.securesms.util.Util +import org.whispersystems.signalservice.api.crypto.AttachmentCipherInputStream +import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream +import org.whispersystems.signalservice.api.crypto.NoCipherOutputStream +import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentPointer +import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId +import org.whispersystems.signalservice.internal.crypto.PaddingInputStream +import java.io.ByteArrayOutputStream +import java.io.File import java.util.Optional @RunWith(AndroidJUnit4::class) @@ -163,6 +180,91 @@ class AttachmentTableTest { highInfo.file.exists() assertIs true } + @Test + fun finalizeAttachmentAfterDownload_fixDigestOnNonZeroPadding() { + // Insert attachment metadata for badly-padded attachment + val plaintext = byteArrayOf(1, 2, 3, 4) + val key = Util.getSecretBytes(64) + val iv = Util.getSecretBytes(16) + + val badlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully().also { it[it.size - 1] = 0x42 } + val badlyPaddedCiphertext = encryptPrePaddedBytes(badlyPaddedPlaintext, key, iv) + val badlyPaddedDigest = getDigest(badlyPaddedCiphertext) + + val cipherFile = getTempFile() + cipherFile.writeBytes(badlyPaddedCiphertext) + + val mmsId = -1L + val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, badlyPaddedDigest, plaintext.size)), emptyList()).values.first() + + // Give data to attachment table + val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, badlyPaddedDigest, null, 4, false) + SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) + + // Verify the digest has been updated to the properly padded one + val properlyPaddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully() + val properlyPaddedCiphertext = encryptPrePaddedBytes(properlyPaddedPlaintext, key, iv) + val properlyPaddedDigest = getDigest(properlyPaddedCiphertext) + + val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!! + + assertArrayEquals(properlyPaddedDigest, newDigest) + } + + @Test + fun finalizeAttachmentAfterDownload_leaveDigestAloneForAllZeroPadding() { + // Insert attachment metadata for properly-padded attachment + val plaintext = byteArrayOf(1, 2, 3, 4) + val key = Util.getSecretBytes(64) + val iv = Util.getSecretBytes(16) + + val paddedPlaintext = PaddingInputStream(plaintext.inputStream(), plaintext.size.toLong()).readFully() + val ciphertext = encryptPrePaddedBytes(paddedPlaintext, key, iv) + val digest = getDigest(ciphertext) + + val cipherFile = getTempFile() + cipherFile.writeBytes(ciphertext) + + val mmsId = -1L + val attachmentId = SignalDatabase.attachments.insertAttachmentsForMessage(mmsId, listOf(createAttachmentPointer(key, digest, plaintext.size)), emptyList()).values.first() + + // Give data to attachment table + val cipherInputStream = AttachmentCipherInputStream.createForAttachment(cipherFile, plaintext.size.toLong(), key, digest, null, 4, false) + SignalDatabase.attachments.finalizeAttachmentAfterDownload(mmsId, attachmentId, cipherInputStream, iv) + + // Verify the digest hasn't changed + val newDigest = SignalDatabase.attachments.getAttachment(attachmentId)!!.remoteDigest!! + assertArrayEquals(digest, newDigest) + } + + private fun createAttachmentPointer(key: ByteArray, digest: ByteArray, size: Int): Attachment { + return PointerAttachment.forPointer( + pointer = Optional.of( + SignalServiceAttachmentPointer( + cdnNumber = 3, + remoteId = SignalServiceAttachmentRemoteId.V4("asdf"), + contentType = MediaUtil.IMAGE_JPEG, + key = key, + size = Optional.of(size), + preview = Optional.empty(), + width = 2, + height = 2, + digest = Optional.of(digest), + incrementalDigest = Optional.empty(), + incrementalMacChunkSize = 0, + fileName = Optional.of("file.jpg"), + voiceNote = false, + isBorderless = false, + isGif = false, + caption = Optional.empty(), + blurHash = Optional.empty(), + uploadTimestamp = 0, + uuid = null + ) + ) + ).get() + } + private fun createAttachment(id: Long, uri: Uri, transformProperties: AttachmentTable.TransformProperties): UriAttachment { return UriAttachmentBuilder.build( id, @@ -179,4 +281,24 @@ class AttachmentTableTest { private fun createMediaStream(byteArray: ByteArray): MediaStream { return MediaStream(byteArray.inputStream(), MediaUtil.IMAGE_JPEG, 2, 2) } + + private fun getDigest(ciphertext: ByteArray): ByteArray { + val digestStream = NoCipherOutputStream(NullOutputStream) + ciphertext.inputStream().copyTo(digestStream) + return digestStream.transmittedDigest + } + + private fun encryptPrePaddedBytes(plaintext: ByteArray, key: ByteArray, iv: ByteArray): ByteArray { + val outputStream = ByteArrayOutputStream() + val cipherStream = AttachmentCipherOutputStream(key, iv, outputStream) + plaintext.inputStream().copyTo(cipherStream) + + return outputStream.toByteArray() + } + + private fun getTempFile(): File { + val dir = InstrumentationRegistry.getInstrumentation().targetContext.getDir("temp", Context.MODE_PRIVATE) + dir.mkdir() + return File.createTempFile("transfer", ".mms", dir) + } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt index 716c9f8998..0b43ce68c4 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/database/AttachmentTable.kt @@ -36,6 +36,8 @@ import org.signal.core.util.Base64 import org.signal.core.util.SqlUtil import org.signal.core.util.StreamUtil import org.signal.core.util.ThreadUtil +import org.signal.core.util.allMatch +import org.signal.core.util.copyTo import org.signal.core.util.count import org.signal.core.util.delete import org.signal.core.util.deleteAll @@ -59,6 +61,8 @@ import org.signal.core.util.requireNonNullString import org.signal.core.util.requireObject import org.signal.core.util.requireString import org.signal.core.util.select +import org.signal.core.util.stream.LimitedInputStream +import org.signal.core.util.stream.NullOutputStream import org.signal.core.util.toInt import org.signal.core.util.update import org.signal.core.util.withinTransaction @@ -94,7 +98,9 @@ import org.thoughtcrime.securesms.util.StorageUtil import org.thoughtcrime.securesms.util.Util import org.thoughtcrime.securesms.video.EncryptedMediaDataSource import org.whispersystems.signalservice.api.attachment.AttachmentUploadResult +import org.whispersystems.signalservice.api.crypto.AttachmentCipherOutputStream import org.whispersystems.signalservice.api.util.UuidUtil +import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.util.JsonUtil import java.io.File import java.io.FileNotFoundException @@ -963,14 +969,32 @@ class AttachmentTable( * that the content of the attachment will never change. */ @Throws(MmsException::class) - fun finalizeAttachmentAfterDownload(mmsId: Long, attachmentId: AttachmentId, inputStream: InputStream, iv: ByteArray?) { + fun finalizeAttachmentAfterDownload(mmsId: Long, attachmentId: AttachmentId, inputStream: LimitedInputStream, iv: ByteArray?) { Log.i(TAG, "[finalizeAttachmentAfterDownload] Finalizing downloaded data for $attachmentId. (MessageId: $mmsId, $attachmentId)") val existingPlaceholder: DatabaseAttachment = getAttachment(attachmentId) ?: throw MmsException("No attachment found for id: $attachmentId") - val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty()) + val fileWriteResult: DataFileWriteResult = writeToDataFile(newDataFile(context), inputStream, TransformProperties.empty(), closeInputStream = false) val transferFile: File? = getTransferFile(databaseHelper.signalReadableDatabase, attachmentId) + val paddingAllZeroes = inputStream.use { limitStream -> + limitStream.leftoverStream().allMatch { it == 0x00.toByte() } + } + + val digest = if (paddingAllZeroes) { + Log.d(TAG, "[finalizeAttachmentAfterDownload] $attachmentId has all-zero padding. Digest is good.") + existingPlaceholder.remoteDigest!! + } else { + Log.w(TAG, "[finalizeAttachmentAfterDownload] $attachmentId has non-zero padding bytes. Recomputing digest.") + + val stream = PaddingInputStream(getDataStream(fileWriteResult.file, fileWriteResult.random, 0), fileWriteResult.length) + val key = Base64.decode(existingPlaceholder.remoteKey!!) + val cipherOutputStream = AttachmentCipherOutputStream(key, iv, NullOutputStream) + + StreamUtil.copy(stream, cipherOutputStream) + cipherOutputStream.transmittedDigest + } + val foundDuplicate = writableDatabase.withinTransaction { db -> // We can look and see if we have any exact matches on hash_ends and dedupe the file if we see one. // We don't look at hash_start here because that could result in us matching on a file that got compressed down to something smaller, effectively lowering @@ -1013,6 +1037,7 @@ class AttachmentTable( values.put(TRANSFORM_PROPERTIES, TransformProperties.forSkipTransform().serialize()) values.put(ARCHIVE_TRANSFER_FILE, null as String?) values.put(REMOTE_IV, iv) + values.put(REMOTE_DIGEST, digest) db.update(TABLE_NAME) .values(values) @@ -1878,7 +1903,7 @@ class AttachmentTable( * Reads the entire stream and saves to disk and returns a bunch of metadat about the write. */ @Throws(MmsException::class, IllegalStateException::class) - private fun writeToDataFile(destination: File, inputStream: InputStream, transformProperties: TransformProperties): DataFileWriteResult { + private fun writeToDataFile(destination: File, inputStream: InputStream, transformProperties: TransformProperties, closeInputStream: Boolean = true): DataFileWriteResult { return try { // Sometimes the destination is a file that's already in use, sometimes it's not. // To avoid writing to a file while it's in-use, we write to a temp file and then rename it to the destination file at the end. @@ -1890,7 +1915,7 @@ class AttachmentTable( val random = encryptingStreamData.first val encryptingOutputStream = encryptingStreamData.second - val length = StreamUtil.copy(digestInputStream, encryptingOutputStream) + val length = digestInputStream.copyTo(encryptingOutputStream, closeInputStream) val hash = Base64.encodeWithPadding(digestInputStream.messageDigest.digest()) if (!tempFile.renameTo(destination)) { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt index da6edae410..77ebcdac38 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentDownloadJob.kt @@ -12,6 +12,7 @@ import org.greenrobot.eventbus.EventBus import org.signal.core.util.Base64 import org.signal.core.util.Hex import org.signal.core.util.logging.Log +import org.signal.core.util.stream.LimitedInputStream import org.signal.libsignal.protocol.InvalidMacException import org.signal.libsignal.protocol.InvalidMessageException import org.thoughtcrime.securesms.attachments.Attachment @@ -415,7 +416,12 @@ class AttachmentDownloadJob private constructor( if (body.contentLength() > RemoteConfig.maxAttachmentReceiveSizeBytes) { throw MmsException("Attachment too large, failing download") } - SignalDatabase.attachments.finalizeAttachmentAfterDownload(messageId, attachmentId, (body.source() as Source).buffer().inputStream(), iv = null) + SignalDatabase.attachments.finalizeAttachmentAfterDownload( + messageId, + attachmentId, + LimitedInputStream.withoutLimits((body.source() as Source).buffer().inputStream()), + iv = null + ) } } } catch (e: MmsException) { diff --git a/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt b/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt index fc7a0eb7f2..645e48de72 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/InputStreamExtensions.kt @@ -5,9 +5,11 @@ package org.signal.core.util +import org.signal.core.util.stream.LimitedInputStream import java.io.ByteArrayOutputStream import java.io.IOException import java.io.InputStream +import java.io.OutputStream import kotlin.math.min /** @@ -112,3 +114,37 @@ fun InputStream.readLength(): Long { fun InputStream.drain() { this.readLength() } + +/** + * Returns a [LimitedInputStream] that will limit the number of bytes read from this stream to [limit]. + */ +fun InputStream.limit(limit: Long): LimitedInputStream { + return LimitedInputStream(this, limit) +} + +/** + * Copies the contents of this stream to the [outputStream]. + * + * @param closeInputStream If true, the input stream will be closed after the copy is complete. + */ +fun InputStream.copyTo(outputStream: OutputStream, closeInputStream: Boolean = true): Long { + return StreamUtil.copy(this, outputStream, closeInputStream) +} + +/** + * Returns true if every byte in this stream matches the predicate, otherwise false. + */ +fun InputStream.allMatch(predicate: (Byte) -> Boolean): Boolean { + val buffer = ByteArray(4096) + + var readCount: Int + while (this.read(buffer).also { readCount = it } != -1) { + for (i in 0 until readCount) { + if (!predicate(buffer[i])) { + return false + } + } + } + + return true +} diff --git a/core-util-jvm/src/main/java/org/signal/core/util/StreamUtil.java b/core-util-jvm/src/main/java/org/signal/core/util/StreamUtil.java index 466b57919a..5c9c60617c 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/StreamUtil.java +++ b/core-util-jvm/src/main/java/org/signal/core/util/StreamUtil.java @@ -96,6 +96,10 @@ public final class StreamUtil { } public static long copy(InputStream in, OutputStream out) throws IOException { + return copy(in, out, true); + } + + public static long copy(InputStream in, OutputStream out, boolean closeInputStream) throws IOException { byte[] buffer = new byte[64 * 1024]; int read; long total = 0; @@ -105,7 +109,10 @@ public final class StreamUtil { total += read; } - in.close(); + if (closeInputStream) { + in.close(); + } + out.flush(); out.close(); diff --git a/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt index b3a4d64b10..67d6f91826 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt @@ -5,8 +5,6 @@ package org.signal.core.util.stream -import org.signal.core.util.readAtMostNBytes -import org.signal.core.util.readFully import java.io.FilterInputStream import java.io.InputStream import java.lang.UnsupportedOperationException @@ -22,8 +20,21 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: private var totalBytesRead: Long = 0 private var lastMark = -1L + companion object { + + private const val UNLIMITED = -1L + + /** + * Returns a [LimitedInputStream] that doesn't limit the stream at all -- it'll allow reading the full thing. + */ + @JvmStatic + fun withoutLimits(wrapped: InputStream): LimitedInputStream { + return LimitedInputStream(wrapped = wrapped, maxBytes = UNLIMITED) + } + } + override fun read(): Int { - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return wrapped.read() } @@ -44,7 +55,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun read(destination: ByteArray, offset: Int, length: Int): Int { - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return wrapped.read(destination, offset, length) } @@ -64,7 +75,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun skip(requestedSkipCount: Long): Long { - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return wrapped.skip(requestedSkipCount) } @@ -78,7 +89,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun available(): Int { - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return wrapped.available() } @@ -97,7 +108,7 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: wrapped.mark(readlimit) - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return } @@ -109,13 +120,13 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: throw UnsupportedOperationException("Mark not supported") } - if (lastMark == -1L) { + if (lastMark == UNLIMITED) { throw UnsupportedOperationException("Mark not set") } wrapped.reset() - if (maxBytes == -1L) { + if (maxBytes == UNLIMITED) { return } @@ -123,24 +134,18 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } /** - * If the stream has been fully read, this will return all bytes that were truncated from the stream. - * If the stream was setup with no limit, this will always return an empty array. - * - * @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit. + * If the stream has been fully read, this will return a stream that contains the remaining bytes that were truncated. + * If the stream was setup with no limit, this will always return an empty stream. */ - fun readTruncatedBytes(byteLimit: Int = -1): ByteArray { - if (maxBytes == -1L) { - return ByteArray(0) + fun leftoverStream(): InputStream { + if (maxBytes == UNLIMITED) { + return ByteArray(0).inputStream() } if (totalBytesRead < maxBytes) { throw IllegalStateException("Stream has not been fully read") } - return if (byteLimit < 0) { - wrapped.readFully() - } else { - wrapped.readAtMostNBytes(byteLimit) - } + return wrapped } } diff --git a/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt index fbb6fd614a..35d338c5f5 100644 --- a/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt +++ b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt @@ -22,7 +22,7 @@ class LimitedInputStreamTest { @Test fun `when I fully read the stream via a buffer with no limit, I should get all bytes`() { - val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream()) val data = inputStream.readFully() assertEquals(100, data.size) @@ -44,7 +44,7 @@ class LimitedInputStreamTest { @Test fun `when I fully read the stream one byte at a time with no limit, I should only get maxBytes`() { - val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream()) var count = 0 var lastRead = inputStream.read() @@ -88,35 +88,26 @@ class LimitedInputStreamTest { } @Test - fun `when I finish reading the stream, getTruncatedBytes gives me the rest`() { + fun `when I finish reading the stream, leftoverStream gives me the rest`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) inputStream.readFully() - val truncatedBytes = inputStream.readTruncatedBytes() + val truncatedBytes = inputStream.leftoverStream().readFully() assertEquals(25, truncatedBytes.size) } - @Test - fun `when I finish reading the stream, getTruncatedBytes gives me the rest, respecting the byte limit`() { - val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) - inputStream.readFully() - - val truncatedBytes = inputStream.readTruncatedBytes(byteLimit = 10) - assertEquals(10, truncatedBytes.size) - } - @Test(expected = IllegalStateException::class) - fun `if I have not finished reading the stream, getTruncatedBytes throws IllegalStateException`() { + fun `if I have not finished reading the stream, leftoverStream throws IllegalStateException`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) - inputStream.readTruncatedBytes() + inputStream.leftoverStream() } @Test - fun `when call getTruncatedBytes on a stream with no limit, it returns an empty array`() { - val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + fun `when call leftoverStream on a stream with no limit, it returns an empty array`() { + val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream()) inputStream.readFully() - val truncatedBytes = inputStream.readTruncatedBytes() + val truncatedBytes = inputStream.leftoverStream().readFully() assertEquals(0, truncatedBytes.size) } @@ -130,7 +121,7 @@ class LimitedInputStreamTest { @Test fun `when I call available with no limit, it should return the full length`() { - val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream()) val available = inputStream.available() assertEquals(100, available) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java index 74842e67d5..222dd729db 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageReceiver.java @@ -10,6 +10,7 @@ import org.signal.core.util.StreamUtil; import org.signal.core.util.concurrent.FutureTransformers; import org.signal.core.util.concurrent.ListenableFuture; import org.signal.core.util.concurrent.SettableFuture; +import org.signal.core.util.stream.LimitedInputStream; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.zkgroup.profiles.ClientZkProfileOperations; import org.signal.libsignal.zkgroup.profiles.ProfileKey; @@ -220,7 +221,7 @@ public class SignalServiceMessageReceiver { StreamUtil.readFully(tempStream, iv); } - InputStream dataStream = AttachmentCipherInputStream.createForAttachment( + LimitedInputStream dataStream = AttachmentCipherInputStream.createForAttachment( attachmentDestination, pointer.getSize().orElse(0), pointer.getKey(), diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt index a7d0a17c6f..850ea7c7f1 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentDownloadResult.kt @@ -5,12 +5,12 @@ package org.whispersystems.signalservice.api.attachment -import java.io.InputStream +import org.signal.core.util.stream.LimitedInputStream /** * Holds the result of an attachment download. */ class AttachmentDownloadResult( - val dataStream: InputStream, + val dataStream: LimitedInputStream, val iv: ByteArray ) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java index 6c85a6abe5..2c279d7d70 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/crypto/AttachmentCipherInputStream.java @@ -59,7 +59,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { /** * Passing in a null incrementalDigest and/or 0 for the chunk size at the call site disables incremental mac validation. */ - public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) + public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) throws InvalidMessageException, IOException { return createForAttachment(file, plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, false); } @@ -69,7 +69,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { * * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST */ - public static InputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) + public static LimitedInputStream createForAttachment(File file, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) throws InvalidMessageException, IOException { return createForAttachment(() -> new FileInputStream(file), file.length(), plaintextLength, combinedKeyMaterial, digest, incrementalDigest, incrementalMacChunkSize, ignoreDigest); @@ -80,7 +80,7 @@ public class AttachmentCipherInputStream extends FilterInputStream { * * Passing in true for ignoreDigest DOES NOT VERIFY THE DIGEST */ - public static InputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) + public static LimitedInputStream createForAttachment(StreamSupplier streamSupplier, long streamLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize, boolean ignoreDigest) throws InvalidMessageException, IOException { byte[][] parts = Util.split(combinedKeyMaterial, CIPHER_KEY_SIZE, MAC_KEY_SIZE); @@ -117,16 +117,16 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], streamLength - BLOCK_SIZE - mac.getMacLength()); if (plaintextLength != 0) { - inputStream = new LimitedInputStream(inputStream, plaintextLength); + return new LimitedInputStream(inputStream, plaintextLength); + } else { + return LimitedInputStream.withoutLimits(inputStream); } - - return inputStream; } /** * Decrypt archived media to it's original attachment encrypted blob. */ - public static InputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength) + public static LimitedInputStream createForArchivedMedia(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength) throws InvalidMessageException, IOException { Mac mac = initMac(archivedMediaKeyMaterial.getMacKey()); @@ -142,13 +142,13 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(new FileInputStream(file), archivedMediaKeyMaterial.getCipherKey(), file.length() - BLOCK_SIZE - mac.getMacLength()); if (originalCipherTextLength != 0) { - inputStream = new LimitedInputStream(inputStream, originalCipherTextLength); + return new LimitedInputStream(inputStream, originalCipherTextLength); + } else { + return LimitedInputStream.withoutLimits(inputStream); } - - return inputStream; } - public static InputStream createStreamingForArchivedAttachment(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) + public static LimitedInputStream createStreamingForArchivedAttachment(BackupKey.MediaKeyMaterial archivedMediaKeyMaterial, File file, long originalCipherTextLength, long plaintextLength, byte[] combinedKeyMaterial, byte[] digest, byte[] incrementalDigest, int incrementalMacChunkSize) throws InvalidMessageException, IOException { final InputStream archiveStream = createForArchivedMedia(archivedMediaKeyMaterial, file, originalCipherTextLength); @@ -179,10 +179,11 @@ public class AttachmentCipherInputStream extends FilterInputStream { InputStream inputStream = new AttachmentCipherInputStream(wrappedStream, parts[0], file.length() - BLOCK_SIZE - mac.getMacLength()); if (plaintextLength != 0) { - inputStream = new LimitedInputStream(inputStream, plaintextLength); + return new LimitedInputStream(inputStream, plaintextLength); + } else { + return LimitedInputStream.withoutLimits(inputStream); } - return inputStream; } public static InputStream createForStickerData(byte[] data, byte[] packKey)