Skip to content

Commit

Permalink
zstd: Forward read errors (#373)
Browse files Browse the repository at this point in the history
* zstd: Forward read errors

Forward all read errors.

Fixes #372
  • Loading branch information
klauspost authored May 11, 2021
1 parent 20ca64c commit 2748482
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 38 deletions.
11 changes: 4 additions & 7 deletions zstd/blockdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,10 @@ func newBlockDec(lowMem bool) *blockDec {
// Input must be a start of a block and will be at the end of the block when returned.
func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
b.WindowSize = windowSize
tmp := br.readSmall(3)
if tmp == nil {
if debug {
println("Reading block header:", io.ErrUnexpectedEOF)
}
return io.ErrUnexpectedEOF
tmp, err := br.readSmall(3)
if err != nil {
println("Reading block header:", err)
return err
}
bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
b.Last = bh&1 != 0
Expand Down Expand Up @@ -179,7 +177,6 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
if cap(b.dst) <= maxSize {
b.dst = make([]byte, 0, maxSize+1)
}
var err error
b.data, err = br.readBig(cSize, b.dataStorage)
if err != nil {
if debug {
Expand Down
21 changes: 12 additions & 9 deletions zstd/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (

type byteBuffer interface {
// Read up to 8 bytes.
// Returns nil if no more input is available.
readSmall(n int) []byte
// Returns io.ErrUnexpectedEOF if this cannot be satisfied.
readSmall(n int) ([]byte, error)

// Read >8 bytes.
// MAY use the destination slice.
Expand All @@ -29,17 +29,17 @@ type byteBuffer interface {
// in-memory buffer
type byteBuf []byte

func (b *byteBuf) readSmall(n int) []byte {
func (b *byteBuf) readSmall(n int) ([]byte, error) {
if debugAsserts && n > 8 {
panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
}
bb := *b
if len(bb) < n {
return nil
return nil, io.ErrUnexpectedEOF
}
r := bb[:n]
*b = bb[n:]
return r
return r, nil
}

func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
Expand Down Expand Up @@ -81,19 +81,22 @@ type readerWrapper struct {
tmp [8]byte
}

func (r *readerWrapper) readSmall(n int) []byte {
func (r *readerWrapper) readSmall(n int) ([]byte, error) {
if debugAsserts && n > 8 {
panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
}
n2, err := io.ReadFull(r.r, r.tmp[:n])
// We only really care about the actual bytes read.
if n2 != n {
if err != nil {
if err == io.EOF {
return nil, io.ErrUnexpectedEOF
}
if debug {
println("readSmall: got", n2, "want", n, "err", err)
}
return nil
return nil, err
}
return r.tmp[:n]
return r.tmp[:n], nil
}

func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
Expand Down
25 changes: 24 additions & 1 deletion zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -142,6 +143,28 @@ func TestNewReaderMismatch(t *testing.T) {
t.Log("Output matched")
}

type errorReader struct {
err error
}

func (r *errorReader) Read(p []byte) (int, error) {
return 0, r.err
}

func TestErrorReader(t *testing.T) {
wantErr := fmt.Errorf("i'm a failure")
zr, err := NewReader(&errorReader{err: wantErr})
if err != nil {
t.Fatal(err)
}
defer zr.Close()

_, err = ioutil.ReadAll(zr)
if !errors.Is(err, wantErr) {
t.Errorf("want error %v, got %v", wantErr, err)
}
}

func TestNewDecoder(t *testing.T) {
defer timeout(60 * time.Second)()
testDecoderFile(t, "testdata/decoder.zip")
Expand Down Expand Up @@ -426,7 +449,7 @@ func TestDecoderRegression(t *testing.T) {
}
defer dec.Close()
for i, tt := range zr.File {
if !strings.HasSuffix(tt.Name, "") || (testing.Short() && i > 10) {
if !strings.HasSuffix(tt.Name, "artifact (5)") || (testing.Short() && i > 10) {
continue
}
t.Run("Reader-"+tt.Name, func(t *testing.T) {
Expand Down
46 changes: 25 additions & 21 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ func (d *frameDec) reset(br byteBuffer) error {
d.WindowSize = 0
var b []byte
for {
b = br.readSmall(4)
if b == nil {
var err error
b, err = br.readSmall(4)
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return io.EOF
default:
return err
case nil:
}
if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
if debug {
Expand All @@ -92,14 +97,14 @@ func (d *frameDec) reset(br byteBuffer) error {
break
}
// Read size to skip
b = br.readSmall(4)
if b == nil {
println("Reading Frame Size EOF")
return io.ErrUnexpectedEOF
b, err = br.readSmall(4)
if err != nil {
println("Reading Frame Size", err)
return err
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
println("Skipping frame with", n, "bytes.")
err := br.skipN(int(n))
err = br.skipN(int(n))
if err != nil {
if debug {
println("Reading discarded frame", err)
Expand Down Expand Up @@ -147,12 +152,11 @@ func (d *frameDec) reset(br byteBuffer) error {
if size == 3 {
size = 4
}
b = br.readSmall(int(size))
if b == nil {
if debug {
println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
}
return io.ErrUnexpectedEOF

b, err = br.readSmall(int(size))
if err != nil {
println("Reading Dictionary_ID", err)
return err
}
var id uint32
switch size {
Expand Down Expand Up @@ -187,10 +191,10 @@ func (d *frameDec) reset(br byteBuffer) error {
}
d.FrameContentSize = 0
if fcsSize > 0 {
b := br.readSmall(fcsSize)
if b == nil {
println("Reading Frame content", io.ErrUnexpectedEOF)
return io.ErrUnexpectedEOF
b, err = br.readSmall(fcsSize)
if err != nil {
println("Reading Frame content", err)
return err
}
switch fcsSize {
case 1:
Expand Down Expand Up @@ -307,10 +311,10 @@ func (d *frameDec) checkCRC() error {
tmp[3] = byte(got >> 24)

// We can overwrite upper tmp now
want := d.rawInput.readSmall(4)
if want == nil {
println("CRC missing?")
return io.ErrUnexpectedEOF
want, err := d.rawInput.readSmall(4)
if err != nil {
println("CRC missing?", err)
return err
}

if !bytes.Equal(tmp[:], want) {
Expand Down

0 comments on commit 2748482

Please sign in to comment.