diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStream.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStream.java index cf4e4bee2b..f34ff2ead9 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStream.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStream.java @@ -6,6 +6,7 @@ import org.whispersystems.signalservice.internal.util.Util; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; public class PaddingInputStream extends FilterInputStream { @@ -36,6 +37,7 @@ public class PaddingInputStream extends FilterInputStream { if (paddingRemaining > 0) { length = Math.min(length, Util.toIntExact(paddingRemaining)); + Arrays.fill(buffer, offset, length, (byte) 0x00); paddingRemaining -= length; return length; } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStreamTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStreamTest.kt new file mode 100644 index 0000000000..fe15bed961 --- /dev/null +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/crypto/PaddingInputStreamTest.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.signalservice.internal.crypto + +import org.hamcrest.MatcherAssert.assertThat +import org.hamcrest.core.Is.`is` +import org.junit.Test +import org.signal.core.util.StreamUtil +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream + +class PaddingInputStreamTest { + + /** + * Small stress test to confirm padding input only returns the source stream data + * followed strictly by zeros. + */ + @Test + fun stressTest() { + (0..2048).forEach { length -> + val source = ByteArray(length).apply { fill(42) } + val sourceInput = ByteArrayInputStream(source) + val paddingInput = PaddingInputStream(sourceInput, length.toLong()) + + val paddedData = ByteArrayOutputStream().let { + StreamUtil.copy(paddingInput, it) + it.toByteArray() + } + + paddedData.forEachIndexed { index, byte -> + if (index < length) { + assertThat(byte, `is`(source[index])) + } else { + assertThat(byte, `is`(0x00)) + } + } + } + } +}