Verify multiple APNG lengths to prevent bad input from crashing.

This commit is contained in:
Cody Henthorne
2026-06-09 16:40:28 -04:00
parent 029b91066f
commit e468156c4c
5 changed files with 166 additions and 22 deletions
@@ -25,7 +25,7 @@ class ApngInputStreamFactoryResourceDecoder : ResourceDecoder<InputStreamFactory
@Throws(IOException::class)
override fun decode(source: InputStreamFactory, width: Int, height: Int, options: Options): Resource<ApngDecoder>? {
val decoder = ApngDecoder.create { source.create() }
val decoder = ApngDecoder.create(contentLength = source.length()) { source.create() }
return ApngResource(decoder)
}
}
@@ -39,7 +39,7 @@ internal class EncryptedApngCacheDecoder(private val secret: ByteArray) : Encryp
@Throws(IOException::class)
override fun decode(source: File, width: Int, height: Int, options: Options): Resource<ApngDecoder>? {
val decoder = ApngDecoder.create { createEncryptedInputStream(secret, source) }
val decoder = ApngDecoder.create(contentLength = source.length()) { createEncryptedInputStream(secret, source) }
return ApngResource(decoder)
}
}
@@ -41,13 +41,16 @@ class ApngDecoder private constructor(
val metadata: Metadata,
val frames: List<Frame>,
private val ihdr: Chunk.IHDR,
private val prefixChunks: List<Chunk.ArbitraryChunk>
private val prefixChunks: List<Chunk.ArbitraryChunk>,
private val maxFrameDataSize: Long
) : Closeable {
companion object {
private val PNG_MAGIC = byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A)
private const val MAX_DIMENSION: UInt = 4096u
private const val MAX_CHUNK_LENGTH: UInt = 52_428_800u // 50 MiB
@Throws(IOException::class)
fun isApng(inputStream: InputStream): Boolean {
val magic = inputStream.readNBytesOrThrow(8)
@@ -78,12 +81,14 @@ class ApngDecoder private constructor(
/**
* Scans the stream to build metadata, then closes it. No frame image data is retained.
*
* @param contentLength An upper bound on the number of bytes the stream will yield, if known
*/
@Throws(IOException::class)
fun create(streamFactory: () -> InputStream): ApngDecoder {
fun create(contentLength: Long? = null, streamFactory: () -> InputStream): ApngDecoder {
val inputStream = streamFactory()
try {
return scanMetadata(inputStream, streamFactory)
return scanMetadata(inputStream, streamFactory, contentLength)
} finally {
inputStream.close()
}
@@ -123,8 +128,9 @@ class ApngDecoder private constructor(
* Unlike the old approach (which read all frame data into memory), this method only records byte offsets into the stream where frame data lives.
* The actual frame data is read on demand in [decodeFrame].
*/
private fun scanMetadata(inputStream: InputStream, streamFactory: () -> InputStream): ApngDecoder {
val scanner = StreamScanner(inputStream)
private fun scanMetadata(inputStream: InputStream, streamFactory: () -> InputStream, contentLength: Long?): ApngDecoder {
val maxChunkLength: Long = contentLength?.takeIf { it > 0 }?.coerceAtMost(MAX_CHUNK_LENGTH.toLong()) ?: MAX_CHUNK_LENGTH.toLong()
val scanner = StreamScanner(inputStream, maxChunkLength)
// Read the magic bytes to verify that this is a PNG
val magic = scanner.readBytes(8)
@@ -134,11 +140,15 @@ class ApngDecoder private constructor(
// The IHDR chunk is the first chunk in a PNG file and contains metadata about the image.
// Per spec it must appear first, so if it's missing the file is invalid.
val ihdrLength = scanner.readUInt()
val ihdrLength = scanner.readChunkLength()
val ihdrType = scanner.readBytes(4).toString(Charsets.US_ASCII)
if (ihdrType != "IHDR") {
throw IOException("First chunk is not IHDR!")
}
if (ihdrLength != Chunk.IHDR.LENGTH) {
throw IOException("IHDR length ($ihdrLength) is not the expected ${Chunk.IHDR.LENGTH} bytes!")
}
val ihdrData = scanner.readBytes(ihdrLength.toInt())
scanner.skipBytes(4) // CRC
@@ -152,6 +162,10 @@ class ApngDecoder private constructor(
interlaceMethod = ihdrData[12]
)
if (ihdr.width !in 1u..MAX_DIMENSION || ihdr.height !in 1u..MAX_DIMENSION) {
throw IOException("IHDR canvas dimensions out of bounds: ${ihdr.width}x${ihdr.height}")
}
// Next, we want to read all of the chunks up to the first IDAT chunk.
// The first IDAT chunk represents the default image, and possibly the first frame of the animation (depending on the presence of an fcTL chunk).
// In order for this to be a valid APNG, there _must_ be an acTL chunk before the first IDAT chunk.
@@ -163,7 +177,7 @@ class ApngDecoder private constructor(
var chunkType: String
while (true) {
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
if (chunkType == "IDAT") {
@@ -174,6 +188,9 @@ class ApngDecoder private constructor(
"acTL" -> {
val data = scanner.readBytes(chunkLength.toInt())
scanner.skipBytes(4) // CRC
if (data.size < Chunk.acTL.LENGTH.toInt()) {
throw IOException("acTL chunk is too short: ${data.size} bytes")
}
earlyActl = Chunk.acTL(
numFrames = data.sliceArray(0 until 4).toUInt(),
numPlays = data.sliceArray(4 until 8).toUInt()
@@ -211,12 +228,12 @@ class ApngDecoder private constructor(
scanner.skipBytes(chunkLength.toLong() + 4) // data + CRC
// Collect more consecutive IDATs
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
while (chunkType == "IDAT") {
idatRegions += DataRegion(streamOffset = scanner.position, length = chunkLength.toLong())
scanner.skipBytes(chunkLength.toLong() + 4) // data + CRC
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
}
@@ -230,7 +247,7 @@ class ApngDecoder private constructor(
// Scan forward to the next fcTL
while (chunkType != "fcTL") {
scanner.skipBytes(chunkLength.toLong() + 4) // data + CRC
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
if (chunkType == "IEND") break
}
@@ -245,17 +262,19 @@ class ApngDecoder private constructor(
// Collect all consecutive fdAT data regions -- frames can span multiple fdATs per the spec
val fdatRegions = mutableListOf<DataRegion>()
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
while (chunkType == "fdAT") {
// fdAT data starts with 4-byte sequence number, then the actual image data
if (chunkLength < 4u) {
throw IOException("fdAT chunk is too short: $chunkLength bytes")
}
scanner.skipBytes(4) // sequence number
val imageDataLength = chunkLength.toLong() - 4
fdatRegions += DataRegion(streamOffset = scanner.position, length = imageDataLength)
scanner.skipBytes(imageDataLength + 4) // image data + CRC
chunkLength = scanner.readUInt()
chunkLength = scanner.readChunkLength()
chunkType = scanner.readBytes(4).toString(Charsets.US_ASCII)
}
@@ -269,18 +288,23 @@ class ApngDecoder private constructor(
metadata = metadata,
frames = frames,
ihdr = ihdr,
prefixChunks = framePrefixChunks
prefixChunks = framePrefixChunks,
maxFrameDataSize = maxChunkLength
)
}
private fun isValidFrame(fctl: Chunk.fcTL, ihdr: Chunk.IHDR): Boolean {
return fctl.width in 1u..MAX_DIMENSION &&
fctl.height in 1u..MAX_DIMENSION &&
fctl.xOffset + fctl.width <= ihdr.width &&
fctl.yOffset + fctl.height <= ihdr.height
fctl.xOffset.toLong() + fctl.width.toLong() <= ihdr.width.toLong() &&
fctl.yOffset.toLong() + fctl.height.toLong() <= ihdr.height.toLong()
}
private fun parseFctl(data: ByteArray): Chunk.fcTL {
if (data.size < Chunk.fcTL.LENGTH.toInt()) {
throw IOException("fcTL chunk is too short: ${data.size} bytes")
}
return Chunk.fcTL(
sequenceNumber = data.sliceArray(0 until 4).toUInt(),
width = data.sliceArray(4 until 8).toUInt(),
@@ -340,8 +364,14 @@ class ApngDecoder private constructor(
currentStreamPos = targetOffset
}
// Read all data regions for this frame
val totalDataSize = regions.sumOf { it.length.toInt() }
// Read all data regions for this frame. Each region's length was bounded by maxFrameDataSize during the metadata
// scan, but a frame can span many chunks, so we also bound the combined total before allocating to avoid a malicious
// file driving an unbounded allocation here.
val totalDataSizeLong = regions.sumOf { it.length }
if (totalDataSizeLong > maxFrameDataSize) {
throw IOException("Frame $index data size ($totalDataSizeLong) exceeds the maximum allowed size of $maxFrameDataSize bytes!")
}
val totalDataSize = totalDataSizeLong.toInt()
val frameData = ByteArray(totalDataSize)
var writeOffset = 0
@@ -431,7 +461,7 @@ class ApngDecoder private constructor(
/**
* Tracks position while reading through a stream during the metadata scan.
*/
private class StreamScanner(private val inputStream: InputStream) {
private class StreamScanner(private val inputStream: InputStream, private val maxChunkLength: Long) {
var position: Long = 0
private set
@@ -445,6 +475,18 @@ class ApngDecoder private constructor(
return readBytes(4).toUInt()
}
/**
* Reads a chunk length and verifies it is within [maxChunkLength] before it can be used to size an allocation or
* skip.
*/
fun readChunkLength(): UInt {
val length = readUInt()
if (length.toLong() > maxChunkLength) {
throw IOException("Declared chunk length ($length) exceeds the maximum allowed size of $maxChunkLength bytes!")
}
return length
}
fun skipBytes(n: Long) {
inputStream.skipNBytesOrThrow(n)
position += n
@@ -497,7 +539,11 @@ class ApngDecoder private constructor(
class acTL(
val numFrames: UInt,
val numPlays: UInt
) : Chunk()
) : Chunk() {
companion object {
val LENGTH: UInt = 8.toUInt()
}
}
/**
* Contains metadata about a single frame of the animation. Appears before each fdAT chunk.
@@ -513,6 +559,10 @@ class ApngDecoder private constructor(
val disposeOp: DisposeOp,
val blendOp: BlendOp
) : Chunk() {
companion object {
val LENGTH: UInt = 26.toUInt()
}
/**
* Describes how you should dispose of this frame before rendering the next one. That means that in order to render the current frame, you need to know
* the [disposeOp] of the _previous_ frame.
@@ -12,6 +12,9 @@ import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
@RunWith(RobolectricTestRunner::class)
@@ -382,6 +385,63 @@ class ApngDecoderTest {
result.frames.forEachIndexed { i, _ -> assertNotNull(result.decoder.decodeFrame(i)) }
}
// -- Malicious / malformed input --
@Test(expected = IOException::class)
fun `create rejects chunk with oversized declared length instead of allocating`() {
val malicious = ByteArrayOutputStream().apply {
write(byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A)) // PNG magic
write(byteArrayOf(0x7F, 0xFF.toByte(), 0xFF.toByte(), 0xFF.toByte())) // declared chunk length = 0x7FFFFFFF
write("IHDR".toByteArray(Charsets.US_ASCII))
}.toByteArray()
ApngDecoder.create { ByteArrayInputStream(malicious) }
}
@Test(expected = IOException::class)
fun `create rejects chunk longer than the known content length even when under the absolute ceiling`() {
val malicious = ByteArrayOutputStream().apply {
write(byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A)) // PNG magic
write(byteArrayOf(0x00, 0x0F, 0x42, 0x40)) // declared chunk length = 1,000,000
write("IHDR".toByteArray(Charsets.US_ASCII))
}.toByteArray()
ApngDecoder.create(contentLength = 1024) { ByteArrayInputStream(malicious) }
}
// -- Bounds checking --
@Test(expected = IOException::class)
fun `create rejects APNG with oversized IHDR canvas dimensions`() {
ApngDecoder.create { ByteArrayInputStream(apngWithIhdrDimensions(0x7FFFFFFF, 0x7FFFFFFF)) }
}
@Test(expected = IOException::class)
fun `create rejects IHDR with a declared length other than 13`() {
val malicious = ByteArrayOutputStream().apply {
write(byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A)) // PNG magic
write(byteArrayOf(0x00, 0x00, 0x00, 0x09)) // declared IHDR length = 9 (should be 13)
write("IHDR".toByteArray(Charsets.US_ASCII))
write(ByteArray(9))
write(byteArrayOf(0x00, 0x00, 0x00, 0x00)) // CRC
}.toByteArray()
ApngDecoder.create { ByteArrayInputStream(malicious) }
}
@Test(expected = IOException::class)
fun `create rejects acTL that is too short`() {
val malicious = ByteArrayOutputStream().apply {
write(apngWithIhdrDimensions(1, 1)) // PNG magic + valid 1x1 IHDR
write(byteArrayOf(0x00, 0x00, 0x00, 0x04)) // declared acTL length = 4 (should be 8)
write("acTL".toByteArray(Charsets.US_ASCII))
write(ByteArray(4))
write(byteArrayOf(0x00, 0x00, 0x00, 0x00)) // CRC
}.toByteArray()
ApngDecoder.create { ByteArrayInputStream(malicious) }
}
// -- Helpers --
private fun open(filename: String): InputStream {
@@ -394,6 +454,33 @@ class ApngDecoderTest {
return DecodeResult(decoder, decoder.metadata, decoder.frames)
}
/**
* Builds the leading bytes of an APNG (PNG magic + a single IHDR chunk) with the given canvas dimensions.
* Enough for the decoder to read and validate the IHDR before anything else.
*/
private fun apngWithIhdrDimensions(width: Int, height: Int): ByteArray {
val ihdrData = ByteArray(13).apply {
this[0] = (width ushr 24).toByte()
this[1] = (width ushr 16).toByte()
this[2] = (width ushr 8).toByte()
this[3] = width.toByte()
this[4] = (height ushr 24).toByte()
this[5] = (height ushr 16).toByte()
this[6] = (height ushr 8).toByte()
this[7] = height.toByte()
this[8] = 8 // bitDepth
this[9] = 6 // colorType
}
return ByteArrayOutputStream().apply {
write(byteArrayOf(0x89.toByte(), 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A)) // PNG magic
write(byteArrayOf(0x00, 0x00, 0x00, 0x0D)) // IHDR length = 13
write("IHDR".toByteArray(Charsets.US_ASCII))
write(ihdrData)
write(byteArrayOf(0x00, 0x00, 0x00, 0x00)) // CRC (not validated by the decoder)
}.toByteArray()
}
private val ApngDecoder.Frame.delayMs: Long
get() {
val delayNumerator = fcTL.delayNum.toInt()
@@ -26,6 +26,11 @@ interface InputStreamFactory {
fun create(): InputStream
fun createRecyclable(byteArrayPool: ArrayPool): InputStream = RecyclableBufferedInputStream(create(), byteArrayPool)
/**
* An upper bound on the number of bytes [create] will yield, if cheaply knowable, else null.
*/
fun length(): Long? = null
}
/**
@@ -46,4 +51,6 @@ class FileInputStreamFactory(
throw e
}
}
override fun length(): Long? = file.length().takeIf { it > 0 }
}