Skip to content

Commit

Permalink
Add Cancel() to StreamWriter (hypermodeinc#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajeetdsouza authored Nov 2, 2020
1 parent cd98408 commit a516cb8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
36 changes: 34 additions & 2 deletions stream_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ func (sw *StreamWriter) Prepare() error {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

var err error
sw.done, err = sw.db.dropAll()
done, err := sw.db.dropAll()

// Ensure that done() is never called more than once.
var once sync.Once
sw.done = func() { once.Do(done) }

return err
}

Expand Down Expand Up @@ -235,6 +239,34 @@ func (sw *StreamWriter) Flush() error {
return sw.db.lc.validate()
}

// Cancel signals all goroutines to exit. Calling defer sw.Cancel() immediately after creating a new StreamWriter
// ensures that writes are unblocked even upon early return. Note that dropAll() is not called here, so any
// partially written data will not be erased until a new StreamWriter is initialized.
func (sw *StreamWriter) Cancel() {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

for _, writer := range sw.writers {
if writer != nil {
writer.closer.Signal()
}
}
for _, writer := range sw.writers {
if writer != nil {
writer.closer.Wait()
}
}

if err := sw.throttle.Finish(); err != nil {
sw.db.opt.Errorf("error in throttle.Finish: %+v", err)
}

// Handle Cancel() being called before Prepare().
if sw.done != nil {
sw.done()
}
}

type sortedWriter struct {
db *DB
throttle *y.Throttle
Expand Down
30 changes: 30 additions & 0 deletions stream_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,36 @@ func TestStreamWriter6(t *testing.T) {
})
}

// This test uses a StreamWriter without calling Flush() at the end.
func TestStreamWriterCancel(t *testing.T) {
runBadgerTest(t, nil, func(t *testing.T, db *DB) {
list := &pb.KVList{}
str := []string{"a", "a", "b", "b", "c", "c"}
ver := 1
for i := range str {
kv := &pb.KV{
Key: bytes.Repeat([]byte(str[i]), int(db.opt.MaxTableSize)),
Value: []byte("val"),
Version: uint64(ver),
}
list.Kv = append(list.Kv, kv)
ver = (ver + 1) % 2
}

sw := db.NewStreamWriter()
require.NoError(t, sw.Prepare(), "sw.Prepare() failed")
require.NoError(t, sw.Write(list), "sw.Write() failed")
sw.Cancel()

// Use the API incorrectly.
sw1 := db.NewStreamWriter()
defer sw1.Cancel()
require.NoError(t, sw1.Prepare())
defer sw1.Cancel()
sw1.Flush()
})
}

func TestStreamDone(t *testing.T) {
runBadgerTest(t, nil, func(t *testing.T, db *DB) {
sw := db.NewStreamWriter()
Expand Down

0 comments on commit a516cb8

Please sign in to comment.