Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Fix reuse of huff0 when data hard to compress #173

Merged
merged 2 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions huff0/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)
canReuse = s.canUseTable(s.prevTable)
}

// We want the output size to be less than this:
wantSize := len(in)
if s.WantLogLess > 0 {
wantSize -= wantSize >> s.WantLogLess
}

// Reset for next run.
s.clearCount = true
s.maxCount = 0
Expand All @@ -77,7 +83,7 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)
s.cTable = s.prevTable
s.Out, err = compressor(in)
s.cTable = keepTable
if err == nil && len(s.Out) < len(in) {
if err == nil && len(s.Out) < wantSize {
s.OutData = s.Out
return s.Out, true, nil
}
Expand All @@ -100,16 +106,16 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)
hSize := len(s.Out)
oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
if oldSize <= hSize+newSize || hSize+12 >= len(in) {
if oldSize <= hSize+newSize || hSize+12 >= wantSize {
// Retain cTable even if we re-use.
keepTable := s.cTable
s.cTable = s.prevTable
s.Out, err = compressor(in)
s.cTable = keepTable
if err != nil {
return nil, false, err
}
s.cTable = keepTable
if len(s.Out) >= len(in) {
if len(s.Out) >= wantSize {
return nil, false, ErrIncompressible
}
s.OutData = s.Out
Expand All @@ -131,7 +137,7 @@ func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)
s.OutTable = nil
return nil, false, err
}
if len(s.Out) >= len(in) {
if len(s.Out) >= wantSize {
s.OutTable = nil
return nil, false, ErrIncompressible
}
Expand Down
49 changes: 33 additions & 16 deletions huff0/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,51 +91,60 @@ func init() {
func TestCompressRegression(t *testing.T) {
// Match the fuzz function
var testInput = func(data []byte) int {
var sc Scratch
comp, _, err := Compress1X(data, &sc)
var enc Scratch
enc.WantLogLess = 5
comp, _, err := Compress1X(data, &enc)
if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig {
return 0
}
if err != nil {
panic(err)
}
s, remain, err := ReadTable(comp, nil)
if len(comp) >= len(data)-len(data)>>enc.WantLogLess {
panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess))
}

dec, remain, err := ReadTable(comp, nil)
if err != nil {
panic(err)
}
out, err := s.Decompress1X(remain)
out, err := dec.Decompress1X(remain)
if err != nil {
panic(err)
}
if !bytes.Equal(out, data) {
panic("decompression 1x mismatch")
}
// Reuse as 4X
sc.Reuse = ReusePolicyAllow
comp, reUsed, err := Compress4X(data, &sc)
enc.Reuse = ReusePolicyAllow
comp, reUsed, err := Compress4X(data, &enc)
if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig {
return 0
}
if err != nil {
panic(err)
}
if len(comp) >= len(data)-len(data)>>enc.WantLogLess {
panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess))
}

remain = comp
if !reUsed {
s, remain, err = ReadTable(comp, s)
dec, remain, err = ReadTable(comp, dec)
if err != nil {
panic(err)
}
}
out, err = s.Decompress4X(remain, len(data))
out, err = dec.Decompress4X(remain, len(data))
if err != nil {
panic(err)
}
if !bytes.Equal(out, data) {
panic("decompression 4x with reuse mismatch")
}

s.Reuse = ReusePolicyNone
comp, reUsed, err = Compress4X(data, s)
enc.Reuse = ReusePolicyNone
comp, reUsed, err = Compress4X(data, &enc)
if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig {
return 0
}
Expand All @@ -145,11 +154,15 @@ func TestCompressRegression(t *testing.T) {
if reUsed {
panic("reused when asked not to")
}
s, remain, err = ReadTable(comp, nil)
if len(comp) >= len(data)-len(data)>>enc.WantLogLess {
panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess))
}

dec, remain, err = ReadTable(comp, dec)
if err != nil {
panic(err)
}
out, err = s.Decompress4X(remain, len(data))
out, err = dec.Decompress4X(remain, len(data))
if err != nil {
panic(err)
}
Expand All @@ -158,22 +171,26 @@ func TestCompressRegression(t *testing.T) {
}

// Reuse as 1X
s.Reuse = ReusePolicyAllow
comp, reUsed, err = Compress1X(data, &sc)
dec.Reuse = ReusePolicyAllow
comp, reUsed, err = Compress1X(data, &enc)
if err == ErrIncompressible || err == ErrUseRLE || err == ErrTooBig {
return 0
}
if err != nil {
panic(err)
}
if len(comp) >= len(data)-len(data)>>enc.WantLogLess {
panic(fmt.Errorf("too large output provided. got %d, but should be < %d", len(comp), len(data)-len(data)>>enc.WantLogLess))
}

remain = comp
if !reUsed {
s, remain, err = ReadTable(comp, s)
dec, remain, err = ReadTable(comp, dec)
if err != nil {
panic(err)
}
}
out, err = s.Decompress1X(remain)
out, err = dec.Decompress1X(remain)
if err != nil {
panic(err)
}
Expand Down
6 changes: 6 additions & 0 deletions huff0/huff0.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ type Scratch struct {
// Reuse will specify the reuse policy
Reuse ReusePolicy

// WantLogLess allows to specify a log 2 reduction that should at least be achieved,
// otherwise the block will be returned as incompressible.
// The reduction should then at least be (input size >> WantLogLess)
// If WantLogLess == 0 any improvement will do.
WantLogLess uint8

// MaxDecodedSize will set the maximum allowed output size.
// This value will automatically be set to BlockSizeMax if not set.
// Decoders will return ErrMaxDecodedSizeExceeded is this limit is exceeded.
Expand Down
10 changes: 2 additions & 8 deletions zstd/blockenc.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (b *blockEnc) init() {
b.coders.llEnc = &fseEncoder{}
b.coders.llPrev = &fseEncoder{}
}
b.litEnc = &huff0.Scratch{}
b.litEnc = &huff0.Scratch{WantLogLess: 4}
b.reset(nil)
}

Expand Down Expand Up @@ -415,16 +415,10 @@ func (b *blockEnc) encode() error {
if len(b.literals) >= 1024 {
// Use 4 Streams.
out, reUsed, err = huff0.Compress4X(b.literals, b.litEnc)
if len(out) > len(b.literals)-len(b.literals)>>4 {
err = huff0.ErrIncompressible
}
} else if len(b.literals) > 32 {
// Use 1 stream
single = true
out, reUsed, err = huff0.Compress1X(b.literals, b.litEnc)
if len(out) > len(b.literals)-len(b.literals)>>4 {
err = huff0.ErrIncompressible
}
} else {
err = huff0.ErrIncompressible
}
Expand Down Expand Up @@ -711,7 +705,7 @@ func (b *blockEnc) encode() error {
return nil
}

var errIncompressible = errors.New("uncompressible")
var errIncompressible = errors.New("incompressible")

func (b *blockEnc) genCodes() {
if len(b.sequences) == 0 {
Expand Down
6 changes: 3 additions & 3 deletions zstd/enc_dfast.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ encodeLoop:
if debug && s-t > e.maxMatchOff {
panic("s - t >e.maxMatchOff")
}
if debug {
if debugMatches {
println("long match")
}
break
Expand All @@ -259,7 +259,7 @@ encodeLoop:
// but the likelihood of both the first 4 bytes and the hash matching should be enough.
t = candidateL.offset - e.cur
s += checkAt
if debug {
if debugMatches {
println("long match (after short)")
}
break
Expand All @@ -275,7 +275,7 @@ encodeLoop:
if debug && t < 0 {
panic("t<0")
}
if debug {
if debugMatches {
println("short match")
}
break
Expand Down
1 change: 1 addition & 0 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

const debug = false
const debugSequences = false
const debugMatches = false

// force encoder to use predefined tables.
const forcePreDef = false
Expand Down