diff --git a/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt index 47b240a455..b3a4d64b10 100644 --- a/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt +++ b/core-util-jvm/src/main/java/org/signal/core/util/stream/LimitedInputStream.kt @@ -14,6 +14,8 @@ import kotlin.math.min /** * An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes. + * + * @param maxBytes The maximum number of bytes to read from the stream. If set to -1, there will be no limit. */ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) { @@ -21,6 +23,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: private var lastMark = -1L override fun read(): Int { + if (maxBytes == -1L) { + return wrapped.read() + } + if (totalBytesRead >= maxBytes) { return -1 } @@ -38,6 +44,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun read(destination: ByteArray, offset: Int, length: Int): Int { + if (maxBytes == -1L) { + return wrapped.read(destination, offset, length) + } + if (totalBytesRead >= maxBytes) { return -1 } @@ -54,6 +64,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun skip(requestedSkipCount: Long): Long { + if (maxBytes == -1L) { + return wrapped.skip(requestedSkipCount) + } + val bytesRemaining: Long = maxBytes - totalBytesRead val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount) val skipCount = super.skip(bytesToSkip) @@ -64,6 +78,10 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } override fun available(): Int { + if (maxBytes == -1L) { + return wrapped.available() + } + val bytesRemaining = Math.toIntExact(maxBytes - totalBytesRead) return min(bytesRemaining, wrapped.available()) } @@ -78,6 +96,11 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } wrapped.mark(readlimit) + + if (maxBytes == -1L) { + return + } + lastMark = totalBytesRead } @@ -91,15 +114,25 @@ class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: } wrapped.reset() + + if (maxBytes == -1L) { + return + } + totalBytesRead = lastMark } /** * If the stream has been fully read, this will return all bytes that were truncated from the stream. + * If the stream was setup with no limit, this will always return an empty array. * * @param byteLimit The maximum number of truncated bytes to read. Defaults to no limit. */ fun readTruncatedBytes(byteLimit: Int = -1): ByteArray { + if (maxBytes == -1L) { + return ByteArray(0) + } + if (totalBytesRead < maxBytes) { throw IllegalStateException("Stream has not been fully read") } diff --git a/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt index 7879463749..fbb6fd614a 100644 --- a/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt +++ b/core-util-jvm/src/test/java/org/signal/core/util/stream/LimitedInputStreamTest.kt @@ -20,6 +20,14 @@ class LimitedInputStreamTest { assertEquals(75, data.size) } + @Test + fun `when I fully read the stream via a buffer with no limit, I should get all bytes`() { + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + val data = inputStream.readFully() + + assertEquals(100, data.size) + } + @Test fun `when I fully read the stream one byte at a time, I should only get maxBytes`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) @@ -34,6 +42,20 @@ class LimitedInputStreamTest { assertEquals(75, count) } + @Test + fun `when I fully read the stream one byte at a time with no limit, I should only get maxBytes`() { + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + + var count = 0 + var lastRead = inputStream.read() + while (lastRead != -1) { + count++ + lastRead = inputStream.read() + } + + assertEquals(100, count) + } + @Test fun `when I skip past the maxBytes, I should get -1`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) @@ -89,6 +111,15 @@ class LimitedInputStreamTest { inputStream.readTruncatedBytes() } + @Test + fun `when call getTruncatedBytes on a stream with no limit, it returns an empty array`() { + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + inputStream.readFully() + + val truncatedBytes = inputStream.readTruncatedBytes() + assertEquals(0, truncatedBytes.size) + } + @Test fun `when I call available, it should respect the maxBytes`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75) @@ -97,6 +128,14 @@ class LimitedInputStreamTest { assertEquals(75, available) } + @Test + fun `when I call available with no limit, it should return the full length`() { + val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = -1) + val available = inputStream.available() + + assertEquals(100, available) + } + @Test fun `when I call available after reading some bytes, it should respect the maxBytes`() { val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)