Skip to content

Commit

Permalink
Refactor zstd decoder (#498)
Browse files Browse the repository at this point in the history
# TLDR;

* Streams can now be decoded without goroutines using `WithDecoderConcurrency(1)`.

* `WithDecoderConcurrency(4)` is now default. If you need more concurrent `DecodeAll` operations, use `WithDecoderConcurrency(0)`.

Goroutines exit when streams have finished reading (either error or EOF).

Designed and tested to be compatible, but test before committing upgrade.

# Changes

Goroutines will now only be created on demand, and `WithDecoderConcurrency(1)` is now strictly synchronized.

Decompression will typically be about 2x faster when using multiple goroutines, and will prepare input for the upstream reader async to reads. This can lead to ~3x faster input in total than using no goroutines.

New default is now `WithDecoderConcurrency(4)` (or less, if GOMAXPROCS is less). Beyond 4, there is little benefit for streaming decompression.

* No goroutines created, unless streaming, and auto-closed at error/EOF.
* Synchronous stream decoding with  `WithDecoderConcurrency(1)`.
* Split sequence decoding/execution for streams up to 50% faster.
* Simplified error flow.
* Speedup on streams.
* More consistent error reporting.
* Improved error detection/compliance with reference decoder.
* Improved test coverage.

Fixes #477
  • Loading branch information
klauspost authored Feb 24, 2022
1 parent 0f44aaa commit 308a751
Show file tree
Hide file tree
Showing 18 changed files with 1,447 additions and 890 deletions.
121 changes: 15 additions & 106 deletions huff0/bitreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,115 +8,10 @@ package huff0
import (
"encoding/binary"
"errors"
"fmt"
"io"
)

// bitReader reads a bitstream in reverse.
// The last set bit indicates the start of the stream and is used
// for aligning the input.
type bitReader struct {
in []byte
off uint // next byte to read is at in[off - 1]
value uint64
bitsRead uint8
}

// init initializes and resets the bit reader.
func (b *bitReader) init(in []byte) error {
if len(in) < 1 {
return errors.New("corrupt stream: too short")
}
b.in = in
b.off = uint(len(in))
// The highest bit of the last byte indicates where to start
v := in[len(in)-1]
if v == 0 {
return errors.New("corrupt stream, did not find end of stream")
}
b.bitsRead = 64
b.value = 0
if len(in) >= 8 {
b.fillFastStart()
} else {
b.fill()
b.fill()
}
b.bitsRead += 8 - uint8(highBit32(uint32(v)))
return nil
}

// peekBitsFast requires that at least one bit is requested every time.
// There are no checks if the buffer is filled.
func (b *bitReader) peekBitsFast(n uint8) uint16 {
const regMask = 64 - 1
v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
return v
}

// fillFast() will make sure at least 32 bits are available.
// There must be at least 4 bytes available.
func (b *bitReader) fillFast() {
if b.bitsRead < 32 {
return
}

// 2 bounds checks.
v := b.in[b.off-4 : b.off]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.bitsRead -= 32
b.off -= 4
}

func (b *bitReader) advance(n uint8) {
b.bitsRead += n
}

// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
func (b *bitReader) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.bitsRead = 0
b.off -= 8
}

// fill() will make sure at least 32 bits are available.
func (b *bitReader) fill() {
if b.bitsRead < 32 {
return
}
if b.off > 4 {
v := b.in[b.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.bitsRead -= 32
b.off -= 4
return
}
for b.off > 0 {
b.value = (b.value << 8) | uint64(b.in[b.off-1])
b.bitsRead -= 8
b.off--
}
}

// finished returns true if all bits have been read from the bit stream.
func (b *bitReader) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReader) close() error {
// Release reference.
b.in = nil
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
return nil
}

// bitReader reads a bitstream in reverse.
// The last set bit indicates the start of the stream and is used
// for aligning the input.
Expand Down Expand Up @@ -213,10 +108,17 @@ func (b *bitReaderBytes) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

func (b *bitReaderBytes) remaining() uint {
return b.off*8 + uint(64-b.bitsRead)
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReaderBytes) close() error {
// Release reference.
b.in = nil
if b.remaining() > 0 {
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down Expand Up @@ -318,10 +220,17 @@ func (b *bitReaderShifted) finished() bool {
return b.off == 0 && b.bitsRead >= 64
}

func (b *bitReaderShifted) remaining() uint {
return b.off*8 + uint(64-b.bitsRead)
}

// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReaderShifted) close() error {
// Release reference.
b.in = nil
if b.remaining() > 0 {
return fmt.Errorf("corrupt input: %d bits remain on stream", b.remaining())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down
63 changes: 38 additions & 25 deletions huff0/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
}

var br [4]bitReaderShifted
// Decode "jump table"
start := 6
for i := 0; i < 3; i++ {
length := int(src[i*2]) | (int(src[i*2+1]) << 8)
Expand Down Expand Up @@ -865,30 +866,18 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := br.off*8 + uint(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
br.fill()
if false && br.bitsRead >= 32 {
if br.off >= 4 {
v := br.in[br.off-4:]
v = v[:4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
br.value = (br.value << 32) | uint64(low)
br.bitsRead -= 32
br.off -= 4
} else {
for br.off > 0 {
br.value = (br.value << 8) | uint64(br.in[br.off-1])
br.bitsRead -= 8
br.off--
}
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -902,6 +891,10 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down Expand Up @@ -1091,10 +1084,16 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
if br.finished() {
d.bufs.Put(buf)
Expand All @@ -1117,7 +1116,7 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -1126,10 +1125,14 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
v := single[uint8(br.value>>shift)].entry
nBits := uint8(v)
br.advance(nBits)
bitsLeft -= int(nBits)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down Expand Up @@ -1315,10 +1318,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
}

// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
}
br := &br[i]
bitsLeft := int(br.off*8) + int(64-br.bitsRead)
bitsLeft := br.remaining()
for bitsLeft > 0 {
if br.finished() {
d.bufs.Put(buf)
Expand All @@ -1341,7 +1349,7 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
}
}
// end inline...
if offset >= len(out) {
if offset >= endsAt {
d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 4")
}
Expand All @@ -1350,10 +1358,15 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
v := single[br.peekByteFast()].entry
nBits := uint8(v)
br.advance(nBits)
bitsLeft -= int(nBits)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
offset++
}
if offset != endsAt {
d.bufs.Put(buf)
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
}

decoded += offset - dstEvery*i
err = br.close()
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions zstd/bitreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package zstd
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math/bits"
)
Expand Down Expand Up @@ -132,6 +133,9 @@ func (b *bitReader) remain() uint {
func (b *bitReader) close() error {
// Release reference.
b.in = nil
if !b.finished() {
return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
}
if b.bitsRead > 64 {
return io.ErrUnexpectedEOF
}
Expand Down
Loading

0 comments on commit 308a751

Please sign in to comment.