Skip to content

Commit

Permalink
Speed up huff0 table decode (#184)
Browse files Browse the repository at this point in the history
* Speed up huff0 table decode

Big speedup on small blocks.

```
name                                   old time/op    new time/op    delta
Decompress4XTable/digits-12               243µs ± 0%     228µs ± 1%   -5.87%          (p=0.008 n=5+5)
Decompress4XTable/gettysburg-12          5.81µs ± 0%    5.42µs ± 0%   -6.79%          (p=0.008 n=5+5)
Decompress4XTable/twain-12                735µs ± 1%     696µs ± 1%   -5.35%          (p=0.008 n=5+5)
Decompress4XTable/low-ent.10k-12         73.1µs ± 1%    69.0µs ± 0%   -5.64%          (p=0.008 n=5+5)
Decompress4XTable/superlow-ent-10k-12    20.4µs ± 0%    19.2µs ± 1%   -5.49%          (p=0.008 n=5+5)
Decompress4XTable/case1-12               2.51µs ± 1%    2.16µs ± 1%  -14.01%          (p=0.008 n=5+5)
Decompress4XTable/case2-12               2.47µs ± 0%    2.13µs ± 0%  -13.74%          (p=0.016 n=5+4)
Decompress4XTable/case3-12               2.50µs ± 0%    2.15µs ± 0%  -14.13%          (p=0.008 n=5+5)
Decompress4XTable/pngdata.001-12         98.9µs ± 1%    94.9µs ± 2%   -4.08%          (p=0.008 n=5+5)
Decompress4XTable/normcount2-12          1.60µs ± 0%    1.51µs ± 0%   -5.55%          (p=0.008 n=5+5)

name                                   old speed      new speed      delta
Decompress4XTable/digits-12             412MB/s ± 0%   438MB/s ± 1%   +6.24%          (p=0.008 n=5+5)
Decompress4XTable/gettysburg-12         266MB/s ± 0%   286MB/s ± 0%   +7.28%          (p=0.008 n=5+5)
Decompress4XTable/twain-12              356MB/s ± 1%   377MB/s ± 1%   +5.65%          (p=0.008 n=5+5)
Decompress4XTable/low-ent.10k-12        547MB/s ± 1%   580MB/s ± 0%   +5.98%          (p=0.008 n=5+5)
Decompress4XTable/superlow-ent-10k-12   516MB/s ± 0%   546MB/s ± 1%   +5.81%          (p=0.008 n=5+5)
Decompress4XTable/case1-12             21.9MB/s ± 1%  25.5MB/s ± 1%  +16.31%          (p=0.008 n=5+5)
Decompress4XTable/case2-12             18.2MB/s ± 0%  21.1MB/s ± 0%  +15.94%          (p=0.016 n=5+4)
Decompress4XTable/case3-12             19.2MB/s ± 0%  22.3MB/s ± 0%  +16.48%          (p=0.008 n=5+5)
Decompress4XTable/pngdata.001-12        518MB/s ± 1%   540MB/s ± 2%   +4.26%          (p=0.008 n=5+5)
Decompress4XTable/normcount2-12        54.3MB/s ± 0%  57.5MB/s ± 0%   +5.85%          (p=0.008 n=5+5)
```

* Small speedup by unrolling
  • Loading branch information
klauspost authored Nov 22, 2019
1 parent c791a01 commit 7892b3d
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 35 deletions.
119 changes: 84 additions & 35 deletions huff0/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ type dTable struct {

// single-symbols decoding
type dEntrySingle struct {
byte uint8
nBits uint8
entry uint16
}

// double-symbols decoding
Expand Down Expand Up @@ -76,14 +75,15 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
}

// collect weight stats
var rankStats [tableLogMax + 1]uint32
var rankStats [16]uint32
weightTotal := uint32(0)
for _, v := range s.huffWeight[:s.symbolLen] {
if v > tableLogMax {
return s, nil, errors.New("corrupt input: weight too large")
}
rankStats[v]++
weightTotal += (1 << (v & 15)) >> 1
v2 := v & 15
rankStats[v2]++
weightTotal += (1 << v2) >> 1
}
if weightTotal == 0 {
return s, nil, errors.New("corrupt input: weights zero")
Expand Down Expand Up @@ -134,15 +134,17 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
if len(s.dt.single) != tSize {
s.dt.single = make([]dEntrySingle, tSize)
}

for n, w := range s.huffWeight[:s.symbolLen] {
if w == 0 {
continue
}
length := (uint32(1) << w) >> 1
d := dEntrySingle{
byte: uint8(n),
nBits: s.actualTableLog + 1 - w,
entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
}
for u := rankStats[w]; u < rankStats[w]+length; u++ {
s.dt.single[u] = d
single := s.dt.single[rankStats[w] : rankStats[w]+length]
for i := range single {
single[i] = d
}
rankStats[w] += length
}
Expand All @@ -167,12 +169,12 @@ func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
decode := func() byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := s.dt.single[val]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}
hasDec := func(v dEntrySingle) byte {
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Avoid bounds check by always having full sized table.
Expand Down Expand Up @@ -269,8 +271,8 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
decode := func(br *bitReader) byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := single[val&tlMask]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Use temp table to avoid bound checks/append penalty.
Expand All @@ -283,20 +285,67 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
bigloop:
for {
for i := range br {
if br[i].off < 4 {
br := &br[i]
if br.off < 4 {
break bigloop
}
br[i].fillFast()
br.fillFast()
}

{
const stream = 0
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 1
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 2
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}
tmp[off] = decode(&br[0])
tmp[off+bufoff] = decode(&br[1])
tmp[off+bufoff*2] = decode(&br[2])
tmp[off+bufoff*3] = decode(&br[3])
tmp[off+1] = decode(&br[0])
tmp[off+1+bufoff] = decode(&br[1])
tmp[off+1+bufoff*2] = decode(&br[2])
tmp[off+1+bufoff*3] = decode(&br[3])

{
const stream = 3
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

off += 2

if off == bufoff {
if bufoff > dstEvery {
return nil, errors.New("corruption detected: stream overrun 1")
Expand Down Expand Up @@ -367,7 +416,7 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
broken++
if enc.nBits == 0 {
for _, dec := range dt {
if dec.byte == byte(sym) {
if uint8(dec.entry>>8) == byte(sym) {
fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
errs++
break
Expand All @@ -383,12 +432,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
top := enc.val << ub
// decoder looks at top bits.
dec := dt[top]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
errs++
}
if errs > 0 {
Expand All @@ -399,12 +448,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
for i := uint16(0); i < (1 << ub); i++ {
vval := top | i
dec := dt[vval]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
errs++
}
if errs > 20 {
Expand Down
38 changes: 38 additions & 0 deletions huff0/decompress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,41 @@ func BenchmarkDecompress4XNoTable(b *testing.B) {
})
}
}

func BenchmarkDecompress4XTable(b *testing.B) {
for _, tt := range testfiles {
test := tt
if test.err4X != nil {
continue
}
b.Run(test.name, func(b *testing.B) {
var s = &Scratch{}
s.Reuse = ReusePolicyNone
buf0, err := test.fn()
if err != nil {
b.Fatal(err)
}
if len(buf0) > BlockSizeMax {
buf0 = buf0[:BlockSizeMax]
}
compressed, _, err := Compress4X(buf0, s)
if err != test.err1X {
b.Fatal("unexpected error:", err)
}
s.Out = nil
b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(len(buf0)))
for i := 0; i < b.N; i++ {
s, remain, err := ReadTable(compressed, s)
if err != nil {
b.Fatal(err)
}
_, err = s.Decompress4X(remain, len(buf0))
if err != nil {
b.Fatal(err)
}
}
})
}
}
Binary file modified zstd/testdata/benchdecoder.zip
Binary file not shown.

0 comments on commit 7892b3d

Please sign in to comment.