Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ResultIteration refuse unsafe operation (option 1) #86

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type FilterIterator struct {
iter ResultIterator
}

// NewFilterIterator wraps a ResultIterator. The filter function is applied
// to each value returned from a call to wrap.Next().
//
// See the documentation for ResultIterator for correct usage of FilterIterator.
func NewFilterIterator(wrap ResultIterator, filter FilterFunc) *FilterIterator {
return &FilterIterator{
filter: filter,
Expand All @@ -23,7 +27,7 @@ func NewFilterIterator(wrap ResultIterator, filter FilterFunc) *FilterIterator {
// WatchCh returns the watch channel of the wrapped iterator.
func (f *FilterIterator) WatchCh() <-chan struct{} { return f.iter.WatchCh() }

// Next returns the next non-filtered result from the wrapped iterator
// Next returns the next non-filtered result from the wrapped iterator.
func (f *FilterIterator) Next() interface{} {
for {
if value := f.iter.Next(); value == nil || !f.filter(value) {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ require (
github.com/hashicorp/go-immutable-radix v1.3.0
github.com/hashicorp/golang-lru v0.5.4 // indirect
)

replace github.com/hashicorp/go-immutable-radix => github.com/hashicorp/go-immutable-radix v1.3.1-0.20210121185740-67e10d480dcf
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a note as a reminder that this should be changed to an official release, before merging this PR, if we are interested in this approach.

6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
github.com/hashicorp/go-immutable-radix v1.3.0 h1:8exGP7ego3OmkfksihtSouGMZ+hQrhxx+FVELeXpVPE=
github.com/hashicorp/go-immutable-radix v1.3.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60=
github.com/hashicorp/go-immutable-radix v1.3.1-0.20210121184832-1f3613208fff h1:pCJq6LOTMatCem80dkrOG/po5/rQUM9EeIvzpcKU6mg=
github.com/hashicorp/go-immutable-radix v1.3.1-0.20210121184832-1f3613208fff/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60=
github.com/hashicorp/go-immutable-radix v1.3.1-0.20210121185740-67e10d480dcf h1:xxNhSdLiKRbDHbxlo7UkV0Jzp0NoTGSJKCvZVVBQ3fI=
github.com/hashicorp/go-immutable-radix v1.3.1-0.20210121185740-67e10d480dcf/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60=
github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM=
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo=
Expand Down
71 changes: 50 additions & 21 deletions txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,35 +663,54 @@ func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexS
return indexSchema, val, err
}

// ResultIterator is used to iterate over a list of results
// from a Get query on a table.
// ResultIterator is used to iterate over a list of results from a query.
//
// Once a ResultIterator has been created it is no longer safe to modify
// (Insert or Delete) the table that ResultIterator is using. If any modifications
// are made to the table, subsequent calls to Next may panic or return unexpected
// values.
//
// To safely use a ResultIterator with a write transaction always capture the
// results of the iteration in a slice or map, then perform any modifications
// after the final call to Next.
type ResultIterator interface {
WatchCh() <-chan struct{}
// Next returns the next result from the iterator. If there are no more results
// nil is returned.
// If any modification (Insert or Delete) is made to the table after the
// ResultIterator is created, Next may panic or return unexpected values.
Next() interface{}
}

// Get is used to construct a ResultIterator over all the
// rows that match the given constraints of an index.
// Get is used to construct a ResultIterator over all the rows that match the
// given constraints of an index.
// See the documentation for ResultIterator for correct usage of the returned
// ResultIterator.
func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) {
indexIter, val, err := txn.getIndexIterator(table, index, args...)
indexTxn, val, err := txn.getIndexTxn(table, index, args...)
if err != nil {
return nil, err
}

indexIter := indexTxn.Root().Iterator()

// Seek the iterator to the appropriate sub-set
watchCh := indexIter.SeekPrefixWatch(val)

// Create an iterator
iter := &radixIterator{
iter: indexIter,
watchCh: watchCh,
txn: indexTxn,
modifyIndex: indexTxn.ModifyIndex(),
iter: indexIter,
watchCh: watchCh,
}
return iter, nil
}

// GetReverse is used to construct a Reverse ResultIterator over all the
// rows that match the given constraints of an index.
// The returned ResultIterator's Next() will return the next Previous value
// The returned ResultIterator's Next() will return the Previous value.
// See the documentation for ResultIterator for correct usage of the returned
// ResultIterator.
func (txn *Txn) GetReverse(table, index string, args ...interface{}) (ResultIterator, error) {
indexIter, val, err := txn.getIndexIteratorReverse(table, index, args...)
if err != nil {
Expand All @@ -714,19 +733,25 @@ func (txn *Txn) GetReverse(table, index string, args ...interface{}) (ResultIter
// Calling this then iterating until the rows are larger than required allows
// range scans within an index. It is not possible to watch the resulting
// iterator since the radix tree doesn't efficiently allow watching on lower
// bound changes. The WatchCh returned will be nill and so will block forever.
// bound changes. The WatchCh returned will be nil and so will block forever.
//
// See the documentation for ResultIterator for correct usage of the returned
// ResultIterator.
func (txn *Txn) LowerBound(table, index string, args ...interface{}) (ResultIterator, error) {
indexIter, val, err := txn.getIndexIterator(table, index, args...)
indexTxn, val, err := txn.getIndexTxn(table, index, args...)
if err != nil {
return nil, err
}

indexIter := indexTxn.Root().Iterator()

// Seek the iterator to the appropriate sub-set
indexIter.SeekLowerBound(val)

// Create an iterator
iter := &radixIterator{
iter: indexIter,
txn: indexTxn,
modifyIndex: indexTxn.ModifyIndex(),
iter: indexIter,
}
return iter, nil
}
Expand All @@ -738,6 +763,9 @@ func (txn *Txn) LowerBound(table, index string, args ...interface{}) (ResultIter
// resulting iterator since the radix tree doesn't efficiently allow watching
// on lower bound changes. The WatchCh returned will be nill and so will block
// forever.
//
// See the documentation for ResultIterator for correct usage of the returned
// ResultIterator.
func (txn *Txn) ReverseLowerBound(table, index string, args ...interface{}) (ResultIterator, error) {
indexIter, val, err := txn.getIndexIteratorReverse(table, index, args...)
if err != nil {
Expand Down Expand Up @@ -839,7 +867,7 @@ func (txn *Txn) Changes() Changes {
return cs
}

func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*iradix.Iterator, []byte, error) {
func (txn *Txn) getIndexTxn(table, index string, args ...interface{}) (*iradix.Txn, []byte, error) {
// Get the index value to scan
indexSchema, val, err := txn.getIndexValue(table, index, args...)
if err != nil {
Expand All @@ -848,11 +876,7 @@ func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*ira

// Get the index itself
indexTxn := txn.readableIndex(table, indexSchema.Name)
indexRoot := indexTxn.Root()

// Get an interator over the index
indexIter := indexRoot.Iterator()
return indexIter, val, nil
return indexTxn, val, nil
}

func (txn *Txn) getIndexIteratorReverse(table, index string, args ...interface{}) (*iradix.ReverseIterator, []byte, error) {
Expand Down Expand Up @@ -883,15 +907,20 @@ func (txn *Txn) Defer(fn func()) {
// This is much more efficient than a sliceIterator as we are not
// materializing the entire view.
type radixIterator struct {
iter *iradix.Iterator
watchCh <-chan struct{}
txn *iradix.Txn
iter *iradix.Iterator
watchCh <-chan struct{}
modifyIndex int64
}

func (r *radixIterator) WatchCh() <-chan struct{} {
return r.watchCh
}

func (r *radixIterator) Next() interface{} {
if r.modifyIndex != r.txn.ModifyIndex() {
panic("unsafe call to ResultIterator.Next, Txn has modifications after iteration creation")
}
_, value, ok := r.iter.Next()
if !ok {
return nil
Expand Down
65 changes: 65 additions & 0 deletions txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2178,3 +2178,68 @@ func TestTxn_Changes(t *testing.T) {
})
}
}

func TestTxn_GetIterAndDelete(t *testing.T) {
schema := &DBSchema{
Tables: map[string]*TableSchema{
"main": {
Name: "main",
Indexes: map[string]*IndexSchema{
"id": {
Name: "id",
Unique: true,
Indexer: &StringFieldIndex{Field: "ID"},
},
"foo": {
Name: "foo",
Indexer: &StringFieldIndex{Field: "Foo"},
},
},
},
},
}
db, err := NewMemDB(schema)
assertNilError(t, err)

key := "aaaa"
txn := db.Txn(true)
assertNilError(t, txn.Insert("main", &TestObject{ID: "1", Foo: key}))
assertNilError(t, txn.Insert("main", &TestObject{ID: "123", Foo: key}))
assertNilError(t, txn.Insert("main", &TestObject{ID: "2", Foo: key}))
txn.Commit()

txn = db.Txn(true)
// Delete something
assertNilError(t, txn.Delete("main", &TestObject{ID: "123", Foo: key}))

iter, err := txn.Get("main", "foo", key)
assertNilError(t, err)

// Modify the table after the iterator is created.
assertNilError(t, txn.Insert("main", &TestObject{ID: "3", Foo: key}))

var panicMsg interface{}
func() {
defer func() {
panicMsg = recover()
}()
_ = iter.Next()
}()

msg, ok := panicMsg.(string)
if !ok {
t.Fatal("expected iter.Next() to panic")
}

expected := "unsafe call to ResultIterator.Next"
if !strings.HasPrefix(msg, expected) {
t.Fatalf("expected panic with message %v, got %v", expected, msg)
}
}

func assertNilError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
}