Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions flate/deflate.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ type compressionLevel struct {
// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/
var levels = []compressionLevel{
{}, // 0
// Level 1-4 uses specialized algorithm - values not used
// Level 1-6 uses specialized algorithm - values not used
{0, 0, 0, 0, 0, 1},
{0, 0, 0, 0, 0, 2},
{0, 0, 0, 0, 0, 3},
{0, 0, 0, 0, 0, 4},
// For levels 5-6 we don't bother trying with lazy matches.
// Lazy matching is at least 30% slower, with 1.5% increase.
{6, 0, 12, 8, 12, 5},
{8, 0, 24, 16, 16, 6},
{0, 0, 0, 0, 0, 5},
{0, 0, 0, 0, 0, 6},
// Levels 7-9 use increasingly more lazy matching
// and increasingly stringent conditions for "good enough".
{8, 8, 24, 16, skipNever, 7},
Expand Down Expand Up @@ -203,9 +201,8 @@ func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error {
// This is much faster than doing a full encode.
// Should only be used after a start/reset.
func (d *compressor) fillWindow(b []byte) {
// Do not fill window if we are in store-only mode,
// use constant or Snappy compression.
if d.level == 0 {
// Do not fill window if we are in store-only or huffman mode.
if d.level <= 0 {
return
}
if d.fast != nil {
Expand Down Expand Up @@ -667,6 +664,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
default:
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
}
d.level = level
return nil
}

Expand Down Expand Up @@ -720,6 +718,7 @@ func (d *compressor) close() error {
return d.w.err
}
d.w.flush()
d.w.reset(nil)
return d.w.err
}

Expand Down Expand Up @@ -750,8 +749,7 @@ func NewWriter(w io.Writer, level int) (*Writer, error) {
// can only be decompressed by a Reader initialized with the
// same dictionary.
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
dw := &dictWriter{w}
zw, err := NewWriter(dw, level)
zw, err := NewWriter(w, level)
if err != nil {
return nil, err
}
Expand All @@ -760,14 +758,6 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
return zw, err
}

type dictWriter struct {
w io.Writer
}

func (w *dictWriter) Write(b []byte) (n int, err error) {
return w.w.Write(b)
}

// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
Expand Down Expand Up @@ -805,11 +795,12 @@ func (w *Writer) Close() error {
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.writer.(*dictWriter); ok {
if len(w.dict) > 0 {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
w.d.fillWindow(w.dict)
w.d.reset(dst)
if dst != nil {
w.d.fillWindow(w.dict)
}
} else {
// w was created with NewWriter
w.d.reset(dst)
Expand Down
95 changes: 53 additions & 42 deletions flate/deflate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,54 +516,65 @@ func TestWriterReset(t *testing.T) {
t.Errorf("level %d Writer not reset after Reset", level)
}
}
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, ConstantCompression) })
dict := []byte("we are the world")
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, ConstantCompression, dict) })
}

func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.Bytes()

buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("level-", i), func(w io.Writer) (*Writer, error) { return NewWriter(w, i) })
}
w.Close()
out2 := buf2.Bytes()

if len(out1) != len(out2) {
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
dict := []byte(strings.Repeat("we are the world - how are you?", 3))
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("dict-level-", i), func(w io.Writer) (*Writer, error) { return NewWriterDict(w, i, dict) })
}
if bytes.Compare(out1, out2) != 0 {
mm := 0
for i, b := range out1[:len(out2)] {
if b != out2[i] {
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
for i := HuffmanOnly; i <= BestCompression; i++ {
testResetOutput(t, fmt.Sprint("dict-reset-level-", i), func(w io.Writer) (*Writer, error) {
w2, err := NewWriter(nil, i)
if err != nil {
return w2, err
}
mm++
if mm == 10 {
t.Fatal("Stopping")
w2.ResetDict(w, dict)
return w2, nil
})
}
}

func testResetOutput(t *testing.T, name string, newWriter func(w io.Writer) (*Writer, error)) {
t.Run(name, func(t *testing.T) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world - how are you doing?")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.Bytes()

buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out2 := buf2.Bytes()

if len(out1) != len(out2) {
t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
}
if bytes.Compare(out1, out2) != 0 {
mm := 0
for i, b := range out1[:len(out2)] {
if b != out2[i] {
t.Errorf("mismatch index %d: %02x, expected %02x", i, out2[i], b)
}
mm++
if mm == 10 {
t.Fatal("Stopping")
}
}
}
}
t.Logf("got %d bytes", len(out1))
t.Logf("got %d bytes", len(out1))
})
}

// TestBestSpeed tests that round-tripping through deflate and then inflate
Expand Down