Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Fix amd64 not always detecting corrupt data #785

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/fuzz/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func AddFromZip(f *testing.F, filename string, t InputType, short bool) {
t = TypeRaw // Fallback
if len(b) >= 4 {
sz := binary.BigEndian.Uint32(b)
if sz == uint32(len(b))-4 {
f.Add(b[4:])
if sz <= uint32(len(b))-4 {
f.Add(b[4 : 4+sz])
continue
}
}
Expand Down
22 changes: 18 additions & 4 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const errorNotEnoughLiterals = 4
// error reported when capacity of `out` is too small
const errorNotEnoughSpace = 5

// error reported when bits are overread.
const errorOverread = 6

const maxMatchLen = 131074

// size of struct seqVals
Expand Down Expand Up @@ -247,8 +250,9 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
{
brPointer := GP64()
MOVQ(brPointerStash, brPointer)

Comment("Fill bitreader to have enough for the offset and match length.")
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer)
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))

Comment("Update offset")
// Up to 32 extra bits
Expand All @@ -261,7 +265,7 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
// If we need more than 56 in total, we must refill here.
if !o.fiftysix {
Comment("Fill bitreader to have enough for the remaining")
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer)
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))
}

Comment("Update literal length")
Expand Down Expand Up @@ -502,6 +506,12 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
o.returnWithCode(errorNotEnoughLiterals)
}

Comment("Return with overread error")
{
Label("error_overread")
o.returnWithCode(errorOverread)
}

if !o.useSeqs {
Comment("Return with not enough output space error")
Label("error_not_enough_space")
Expand Down Expand Up @@ -529,7 +539,7 @@ func (o options) returnWithCode(returnCode uint32) {
}

// bitreaderFill will make sure at least 56 bits are available.
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual) {
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual, overread LabelRef) {
// bitreader_fill begin
CMPQ(brOffset, U8(8)) // b.off >= 8
JL(LabelRef(name + "_byte_by_byte"))
Expand All @@ -545,7 +555,7 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi

Label(name + "_byte_by_byte")
CMPQ(brOffset, U8(0)) /* for b.off > 0 */
JLE(LabelRef(name + "_end"))
JLE(LabelRef(name + "_check_overread"))

CMPQ(brBitsRead, U8(7)) /* for brBitsRead > 7 */
JLE(LabelRef(name + "_end"))
Expand All @@ -565,6 +575,10 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi
}
JMP(LabelRef(name + "_byte_by_byte"))

Label(name + "_check_overread")
CMPQ(brBitsRead, U8(64))
JA(overread)

Label(name + "_end")
}

Expand Down
12 changes: 10 additions & 2 deletions zstd/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func FuzzDecAllNoBMI2(f *testing.F) {
func FuzzDecoder(f *testing.F) {
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
//fuzz.AddFromZip(f, "testdata/fuzz/decode-oss.zip", fuzz.TypeOSSFuzz, false)

brLow := newBytesReader(nil)
brHi := newBytesReader(nil)
Expand All @@ -92,18 +93,25 @@ func FuzzDecoder(f *testing.F) {
}
defer decHi.Close()

if debugDecoder {
fmt.Println("LOW CONCURRENT")
}
b1, err1 := io.ReadAll(decLow)

if debugDecoder {
fmt.Println("HI NOT CONCURRENT")
}
b2, err2 := io.ReadAll(decHi)
if err1 != err2 {
if (err1 == nil) != (err2 == nil) {
t.Errorf("err low: %v, hi: %v", err1, err2)
t.Errorf("err low concurrent: %v, hi: %v", err1, err2)
}
}
if err1 != nil {
b1, b2 = b1[:0], b2[:0]
}
if !bytes.Equal(b1, b2) {
t.Fatalf("Output mismatch, low: %v, hi: %v", err1, err2)
t.Fatalf("Output mismatch, low concurrent: %v, hi: %v", err1, err2)
}
})
}
Expand Down
5 changes: 4 additions & 1 deletion zstd/seqdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,12 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
maxBlockSize = s.windowSize
}

if debugDecoder {
println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream")
}
for i := seqs - 1; i >= 0; i-- {
if br.overread() {
printf("reading sequence %d, exceeded available data\n", seqs-i)
printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain())
return io.ErrUnexpectedEOF
}
var ll, mo, ml int
Expand Down
16 changes: 16 additions & 0 deletions zstd/seqdec_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package zstd

import (
"fmt"
"io"

"github.com/klauspost/compress/internal/cpuinfo"
)
Expand Down Expand Up @@ -134,6 +135,9 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
ctx.ll, ctx.litRemain+ctx.ll)

case errorOverread:
return true, io.ErrUnexpectedEOF

case errorNotEnoughSpace:
size := ctx.outPosition + ctx.ll + ctx.ml
if debugDecoder {
Expand Down Expand Up @@ -202,6 +206,9 @@ const errorNotEnoughLiterals = 4
// error reported when capacity of `out` is too small
const errorNotEnoughSpace = 5

// error reported when bits are overread.
const errorOverread = 6

// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
//
// Please refer to seqdec_generic.go for the reference implementation.
Expand Down Expand Up @@ -247,6 +254,10 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
litRemain: len(s.literals),
}

if debugDecoder {
println("decode: decoding", len(seqs), "sequences", br.remain(), "bits remain on stream")
}

s.seqSize = 0
lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
var errCode int
Expand Down Expand Up @@ -277,6 +288,8 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
case errorNotEnoughLiterals:
ll := ctx.seqs[i].ll
return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
case errorOverread:
return io.ErrUnexpectedEOF
}

return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
Expand All @@ -291,6 +304,9 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
if s.seqSize > maxBlockSize {
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
}
if debugDecoder {
println("decode: ", br.remain(), "bits remain on stream. code:", errCode)
}
err := br.close()
if err != nil {
printf("Closing sequences: %v, %+v\n", err, *br)
Expand Down
Loading