diff --git a/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java b/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java index 60183d73..7993ad34 100644 --- a/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java +++ b/src/main/java/io/airlift/compress/lz4/Lz4RawDecompressor.java @@ -69,6 +69,9 @@ public static int decompress( } while (value == 255 && input < inputLimit - 15); } + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } // copy literal long literalEnd = input + literalLength; @@ -127,6 +130,9 @@ public static int decompress( while (value == 255); } matchLength += MIN_MATCH; // implicit length from initial 4-byte match in encoder + if (matchLength < 0) { + throw new MalformedInputException(input - inputAddress); + } long matchOutputLimit = output + matchLength; diff --git a/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java b/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java index bb30c2d1..2a8dcc4f 100644 --- a/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java +++ b/src/main/java/io/airlift/compress/lzo/LzoRawDecompressor.java @@ -248,6 +248,10 @@ else if ((command & 0b1100_0000) != 0) { } firstCommand = false; + if (matchLength < 0) { + throw new MalformedInputException(input - inputAddress); + } + // copy match if (matchLength != 0) { // lzo encodes match offset minus one @@ -316,6 +320,9 @@ else if ((command & 0b1100_0000) != 0) { } // copy literal + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } long literalOutputLimit = output + literalLength; if (literalOutputLimit > fastOutputLimit || input + literalLength > inputLimit - SIZE_OF_LONG) { if (literalOutputLimit > outputLimit) { diff --git a/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java b/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java index 7cca9afa..2871b9a0 100644 --- a/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java +++ b/src/main/java/io/airlift/compress/snappy/SnappyRawDecompressor.java @@ -116,6 +116,9 @@ private static int uncompressAll( if ((opCode & 0x3) == LITERAL) { int literalLength = length + trailer; + if (literalLength < 0) { + throw new MalformedInputException(input - inputAddress); + } // copy literal long literalOutputLimit = output + literalLength; @@ -147,6 +150,9 @@ private static int uncompressAll( // bit 8). int matchOffset = entry & 0x700; matchOffset += trailer; + if (matchOffset < 0) { + throw new MalformedInputException(input - inputAddress); + } long matchAddress = output - matchOffset; if (matchAddress < outputAddress || output + length > outputLimit) { diff --git a/src/main/java/io/airlift/compress/zstd/Huffman.java b/src/main/java/io/airlift/compress/zstd/Huffman.java index b5e24e20..29fb74e3 100644 --- a/src/main/java/io/airlift/compress/zstd/Huffman.java +++ b/src/main/java/io/airlift/compress/zstd/Huffman.java @@ -172,6 +172,8 @@ public void decode4Streams(final Object inputBase, final long inputAddress, fina long start3 = start2 + (UNSAFE.getShort(inputBase, inputAddress + 2) & 0xFFFF); long start4 = start3 + (UNSAFE.getShort(inputBase, inputAddress + 4) & 0xFFFF); + verify(start2 < start3 && start3 < start4 && start4 < inputLimit, inputAddress, "Input is corrupted"); + BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, start1, start2); initializer.initialize(); int stream1bitsConsumed = initializer.getBitsConsumed(); diff --git a/src/test/java/io/airlift/compress/lz4/TestLz4.java b/src/test/java/io/airlift/compress/lz4/TestLz4.java index 3f9a13ec..03a6964a 100644 --- a/src/test/java/io/airlift/compress/lz4/TestLz4.java +++ b/src/test/java/io/airlift/compress/lz4/TestLz4.java @@ -16,9 +16,17 @@ import io.airlift.compress.AbstractTestCompression; import io.airlift.compress.Compressor; import io.airlift.compress.Decompressor; +import io.airlift.compress.MalformedInputException; import io.airlift.compress.thirdparty.JPountzLz4Compressor; import io.airlift.compress.thirdparty.JPountzLz4Decompressor; import net.jpountz.lz4.LZ4Factory; +import org.testng.annotations.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestLz4 extends AbstractTestCompression @@ -46,4 +54,45 @@ protected Decompressor getVerifyDecompressor() { return new JPountzLz4Decompressor(LZ4Factory.fastestInstance()); } + + @Test + public void testLiteralLengthOverflow() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + buffer.write((byte) 0b1111_0000); // token + // Causes overflow for `literalLength` + byte[] literalLengthBytes = new byte[Integer.MAX_VALUE / 255 + 1]; // ~9MB + Arrays.fill(literalLengthBytes, (byte) 255); + buffer.write(literalLengthBytes); + buffer.write(1); + buffer.write(new byte[20]); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new Lz4Decompressor().decompress(data, 0, data.length, new byte[2048], 0, 2048)) + .isInstanceOf(MalformedInputException.class); + } + + @Test + public void testMatchLengthOverflow() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + buffer.write((byte) 0b0000_1111); // token + buffer.write(new byte[2]); // offset + + // Causes overflow for `matchLength` + byte[] literalLengthBytes = new byte[Integer.MAX_VALUE / 255 + 1]; // ~9MB + Arrays.fill(literalLengthBytes, (byte) 255); + buffer.write(literalLengthBytes); + buffer.write(1); + + buffer.write(new byte[10]); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new Lz4Decompressor().decompress(data, 0, data.length, new byte[2048], 0, 2048)) + .isInstanceOf(MalformedInputException.class); + } } diff --git a/src/test/java/io/airlift/compress/lzo/TestLzo.java b/src/test/java/io/airlift/compress/lzo/TestLzo.java index b7dfe6f1..9aaa1a21 100644 --- a/src/test/java/io/airlift/compress/lzo/TestLzo.java +++ b/src/test/java/io/airlift/compress/lzo/TestLzo.java @@ -17,8 +17,15 @@ import io.airlift.compress.Compressor; import io.airlift.compress.Decompressor; import io.airlift.compress.HadoopNative; +import io.airlift.compress.MalformedInputException; import io.airlift.compress.thirdparty.HadoopLzoCompressor; import io.airlift.compress.thirdparty.HadoopLzoDecompressor; +import org.testng.annotations.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestLzo extends AbstractTestCompression @@ -50,4 +57,78 @@ protected Decompressor getVerifyDecompressor() { return new HadoopLzoDecompressor(); } + + @Test + public void testLiteralLengthOverflow() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + // Command + buffer.write(0); + // Causes overflow for `literalLength` + buffer.write(new byte[Integer.MAX_VALUE / 255 + 1]); // ~9MB + buffer.write(1); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new LzoDecompressor().decompress(data, 0, data.length, new byte[20000], 0, 20000)) + .isInstanceOf(MalformedInputException.class); + } + + @Test + public void testMatchLengthOverflow1() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + // Write some data so that `matchOffset` validation later passes + // Command + buffer.write(0); + buffer.write(new byte[66]); + buffer.write(8); + buffer.write(new byte[2107 * 8]); + + // Command + buffer.write(0b001_0000); + // Causes overflow for `matchLength` + buffer.write(new byte[Integer.MAX_VALUE / 255 + 1]); // ~9MB + buffer.write(1); + // Trailer + buffer.write(0b0000_0000); + buffer.write(0b0000_0100); + + buffer.write(new byte[10]); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new LzoDecompressor().decompress(data, 0, data.length, new byte[20000], 0, 20000)) + .isInstanceOf(MalformedInputException.class); + } + + @Test + public void testMatchLengthOverflow2() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + // Write some data so that `matchOffset` validation later passes + // Command + buffer.write(0); + buffer.write(246); + buffer.write(new byte[264]); + + // Command + buffer.write(0b0010_0000); + // Causes overflow for `matchLength` + buffer.write(new byte[Integer.MAX_VALUE / 255 + 1]); // ~9MB + buffer.write(1); + // Trailer + buffer.write(0b0000_0000); + buffer.write(0b0000_0100); + + buffer.write(new byte[10]); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new LzoDecompressor().decompress(data, 0, data.length, new byte[20000], 0, 20000)) + .isInstanceOf(MalformedInputException.class); + } } diff --git a/src/test/java/io/airlift/compress/snappy/TestSnappy.java b/src/test/java/io/airlift/compress/snappy/TestSnappy.java index 91c14cb5..b2e92d79 100644 --- a/src/test/java/io/airlift/compress/snappy/TestSnappy.java +++ b/src/test/java/io/airlift/compress/snappy/TestSnappy.java @@ -16,8 +16,12 @@ import io.airlift.compress.AbstractTestCompression; import io.airlift.compress.Compressor; import io.airlift.compress.Decompressor; +import io.airlift.compress.MalformedInputException; import io.airlift.compress.thirdparty.XerialSnappyCompressor; import io.airlift.compress.thirdparty.XerialSnappyDecompressor; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestSnappy extends AbstractTestCompression @@ -45,4 +49,22 @@ protected Decompressor getVerifyDecompressor() { return new XerialSnappyDecompressor(); } + + @Test + public void testInvalidLiteralLength() + { + byte[] data = { + // Encoded uncompressed length 1024 + -128, 8, + // op-code + (byte) 252, + // Trailer value Integer.MAX_VALUE + (byte) 0b1111_1111, (byte) 0b1111_1111, (byte) 0b1111_1111, (byte) 0b0111_1111, + // Some arbitrary data + 0, 0, 0, 0, 0, 0, 0, 0 + }; + + assertThatThrownBy(() -> new SnappyDecompressor().decompress(data, 0, data.length, new byte[1024], 0, 1024)) + .isInstanceOf(MalformedInputException.class); + } } diff --git a/src/test/java/io/airlift/compress/zstd/TestZstd.java b/src/test/java/io/airlift/compress/zstd/TestZstd.java index 8c93b46a..ba3dd22f 100644 --- a/src/test/java/io/airlift/compress/zstd/TestZstd.java +++ b/src/test/java/io/airlift/compress/zstd/TestZstd.java @@ -23,6 +23,7 @@ import io.airlift.compress.thirdparty.ZstdJniDecompressor; import org.testng.annotations.Test; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; @@ -209,4 +210,54 @@ public void testDecompressIsMissingData() .matches(e -> e instanceof MalformedInputException || e instanceof UncheckedIOException) .hasMessageContaining("Not enough input bytes"); } + + @Test + public void testBadHuffmanData() + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + // Magic + buffer.write(new byte[] { + (byte) 0b0010_1000, + (byte) 0b1011_0101, + (byte) 0b0010_1111, + (byte) 0b1111_1101, + }); + // Frame header + buffer.write(0); + buffer.write(0); + // Block header COMPRESSED_BLOCK + buffer.write(new byte[] { + (byte) 0b1111_0100, + (byte) 0b0000_0000, + (byte) 0b0000_0000, + }); + // Literals header + buffer.write(new byte[] { + // literalsBlockType COMPRESSED_LITERALS_BLOCK + // + literals type + 0b0000_1010, + // ... header remainder + 0b0000_0000, + // compressedSize + 0b0011_1100, + 0b0000_0000, + }); + // Huffman inputSize + buffer.write(128); + // weight value + buffer.write(0b0001_0000); + // Bad start values + buffer.write(new byte[] {(byte) 255, (byte) 255}); + buffer.write(new byte[] {(byte) 255, (byte) 255}); + buffer.write(new byte[] {(byte) 255, (byte) 255}); + + buffer.write(new byte[10]); + + byte[] data = buffer.toByteArray(); + + assertThatThrownBy(() -> new ZstdDecompressor().decompress(data, 0, data.length, new byte[10], 0, 10)) + .isInstanceOf(MalformedInputException.class) + .hasMessageStartingWith("Input is corrupted"); + } }