Skip to content

Commit

Permalink
Fix error propagation.
Browse files Browse the repository at this point in the history
Errors from the underlying writer was not being forwarded/returned, so that is now done, and tests have been added to test that functionality

Fixes #15
  • Loading branch information
klauspost committed Nov 17, 2015
1 parent 8717c82 commit a419316
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 18 deletions.
53 changes: 35 additions & 18 deletions flate/deflate.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,7 @@ Loop:
}
d.hash = newH
}

d.index = newIndex

} else {
// For matches this long, we don't bother inserting each individual
// item into the table.
Expand Down Expand Up @@ -480,15 +478,14 @@ func (d *compressor) deflateNoSkip() {
d.hash = d.hasher(d.window[d.index:d.index+minMatchLength]) & hashMask
}

Loop:
for {
if sanity && d.index > d.windowEnd {
panic("index > windowEnd")
}
lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength {
if !d.sync {
break Loop
return
}
if sanity && d.index > d.windowEnd {
panic("index > windowEnd")
Expand All @@ -507,7 +504,7 @@ Loop:
}
d.tokens.n = 0
}
break Loop
return
}
}
if d.index < d.maxInsertIndex {
Expand Down Expand Up @@ -538,6 +535,7 @@ Loop:
// not better. Output the previous match.
d.tokens.tokens[d.tokens.n] = matchToken(uint32(prevLength-3), uint32(prevOffset-minOffsetSize))
d.tokens.n++

// Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash
Expand Down Expand Up @@ -573,7 +571,6 @@ Loop:
}

d.index = newIndex

d.byteAvailable = false
d.length = minMatchLength - 1
if d.tokens.n == maxFlateBlockTokens {
Expand All @@ -588,13 +585,13 @@ Loop:
if d.length >= minMatchLength {
d.ii = 0
}
// We have a byte waiting. Emit it.
if d.byteAvailable {
d.ii++
i := d.index - 1
d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[i]))
d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1]))
d.tokens.n++
if d.tokens.n == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return
}
d.tokens.n = 0
Expand All @@ -604,29 +601,38 @@ Loop:
// If we have a long run of no matches, skip additional bytes
// Resets when d.ii overflows after 64KB.
if d.ii > 31 {
n := int(d.ii >> 5)
n := int(d.ii >> 6)
for j := 0; j < n; j++ {
i := d.index - 1
if d.index >= d.windowEnd-1 {
break
}

d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[i]))
d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1]))
d.tokens.n++
if d.tokens.n == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return
}
d.tokens.n = 0
}
d.index++
}
// Flush last byte
d.tokens.tokens[d.tokens.n] = literalToken(uint32(d.window[d.index-1]))
d.tokens.n++
d.byteAvailable = false
// d.length = minMatchLength - 1 // not needed, since d.ii is reset above, so it should never be > minMatchLength
if d.tokens.n == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return
}
d.tokens.n = 0
}
}

} else {
d.index++
d.byteAvailable = true
}
d.byteAvailable = true
}
}
}
Expand Down Expand Up @@ -659,6 +665,7 @@ func (d *compressor) storeHuff() {
return
}
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
d.err = d.w.err
d.windowEnd = 0
}

Expand All @@ -672,22 +679,31 @@ func (d *compressor) storeSnappy() {
}
snappyEncode(&d.tokens, d.window[:d.windowEnd])
d.w.writeBlock(d.tokens, false, d.window[:d.windowEnd])
d.err = d.w.err
d.tokens.n = 0
d.windowEnd = 0
}

func (d *compressor) write(b []byte) (n int, err error) {
if d.err != nil {
return 0, d.err
}
n = len(b)
b = b[d.fill(d, b):]
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
if d.err != nil {
return 0, d.err
}
}
return n, d.err
}

func (d *compressor) syncFlush() error {
d.sync = true
if d.err != nil {
return d.err
}
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
Expand Down Expand Up @@ -733,9 +749,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
return nil
}

var zeroes [64]int
var hzeroes [256]hashid
var bzeroes [256]byte

func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
Expand Down Expand Up @@ -769,6 +783,9 @@ func (d *compressor) reset(w io.Writer) {
}

func (d *compressor) close() error {
if d.err != nil {
return d.err
}
d.sync = true
d.step(d)
if d.err != nil {
Expand Down
66 changes: 66 additions & 0 deletions flate/deflate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ func TestWriterReset(t *testing.T) {
w.d.hasher, wref.d.hasher = nil, nil
w.d.bulkHasher, wref.d.bulkHasher = nil, nil
w.d.matcher, wref.d.matcher = nil, nil
// hashMatch is always overwritten when used.
copy(w.d.hashMatch[:], wref.d.hashMatch[:])
if w.d.tokens.n != 0 {
t.Errorf("level %d Writer not reset after Reset. %d tokens were present", level, w.d.tokens.n)
}
Expand Down Expand Up @@ -601,3 +603,67 @@ func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error))
}
t.Logf("got %d bytes", len(out1))
}

// A writer that fails after N writes.
type errorWriter struct {
N int
}

func (e *errorWriter) Write(b []byte) (int, error) {
if e.N <= 0 {
return 0, io.ErrClosedPipe
}
e.N--
return len(b), nil
}

// Test if errors from the underlying writer is passed upwards.
func TestWriteError(t *testing.T) {
buf := new(bytes.Buffer)
for i := 0; i < 1024*1024; i++ {
buf.WriteString(fmt.Sprintf("asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i))
}
in := buf.Bytes()
for l := -2; l < 10; l++ {
for fail := 1; fail <= 512; fail *= 2 {
// Fail after 2 writes
ew := &errorWriter{N: fail}
w, err := NewWriter(ew, l)
if err != nil {
t.Errorf("NewWriter: level %d: %v", l, err)
}
n, err := io.Copy(w, bytes.NewBuffer(in))
if err == nil {
t.Errorf("Level %d: Expected an error, writer was %#v", l, ew)
}
n2, err := w.Write([]byte{1, 2, 2, 3, 4, 5})
if n2 != 0 {
t.Error("Level", l, "Expected 0 length write, got", n)
}
if err == nil {
t.Error("Level", l, "Expected an error")
}
err = w.Flush()
if err == nil {
t.Error("Level", l, "Expected an error on close")
}
err = w.Close()
if err == nil {
t.Error("Level", l, "Expected an error on close")
}

w.Reset(ioutil.Discard)
n2, err = w.Write([]byte{1, 2, 3, 4, 5, 6})
if err != nil {
t.Error("Level", l, "Got unexpected error after reset:", err)
}
if n2 == 0 {
t.Error("Level", l, "Got 0 length write, expected > 0")
}
if testing.Short() {
return
}
}
}

}

0 comments on commit a419316

Please sign in to comment.