Skip to content

Commit

Permalink
feat: Support concurrency for IAVL and fix Racing conditions (#805)
Browse files Browse the repository at this point in the history
Co-authored-by: Marko <marko@baricevic.me>
(cherry picked from commit ba6beb1)
  • Loading branch information
mattverse authored and mergify[bot] committed Aug 23, 2023
1 parent 9d71b8a commit 78805ea
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 73 deletions.
21 changes: 15 additions & 6 deletions iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package iavl
import (
"math/rand"
"sort"
"sync"
"testing"

log "cosmossdk.io/log"
dbm "github.com/cosmos/cosmos-db"
"github.com/cosmos/iavl/fastnode"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -37,7 +37,7 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) {
})

t.Run("Unsaved Fast Iterator", func(t *testing.T) {
itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*fastnode.Node{}, map[string]interface{}{})
itr := NewUnsavedFastIterator(start, end, ascending, nil, &sync.Map{}, &sync.Map{})
performTest(t, itr)
require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error())
})
Expand Down Expand Up @@ -292,14 +292,14 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite
require.NoError(t, err)

// No unsaved additions or removals should be present after saving
require.Equal(t, 0, len(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeAdditions))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Ensure that there are unsaved additions and removals present
secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree)

require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, len(tree.unsavedFastNodeRemovals))
require.True(t, syncMapCount(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror))
require.Equal(t, 0, syncMapCount(tree.unsavedFastNodeRemovals))

// Merge the two halves
if config.ascending {
Expand Down Expand Up @@ -371,3 +371,12 @@ func TestNodeIterator_WithEmptyRoot(t *testing.T) {
require.NoError(t, err)
require.False(t, itr.Valid())
}

func syncMapCount(m *sync.Map) int {
count := 0
m.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
73 changes: 45 additions & 28 deletions mutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ type Option func(*Options)
type MutableTree struct {
logger log.Logger

*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
unsavedFastNodeAdditions map[string]*fastnode.Node // FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk
*ImmutableTree // The current, working tree.
lastSaved *ImmutableTree // The most recently saved tree.
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode FastNodes that have not yet been saved to disk
unsavedFastNodeRemovals *sync.Map // map[string]interface{} FastNodes that have not yet been removed from disk
ndb *nodeDB
skipFastStorageUpgrade bool // If true, the tree will work like no fast storage and always not upgrade fast storage

Expand All @@ -62,8 +62,8 @@ func NewMutableTree(db dbm.DB, cacheSize int, skipFastStorageUpgrade bool, lg lo
logger: lg,
ImmutableTree: head,
lastSaved: head.clone(),
unsavedFastNodeAdditions: make(map[string]*fastnode.Node),
unsavedFastNodeRemovals: make(map[string]interface{}),
unsavedFastNodeAdditions: &sync.Map{},
unsavedFastNodeRemovals: &sync.Map{},
ndb: ndb,
skipFastStorageUpgrade: skipFastStorageUpgrade,
}
Expand Down Expand Up @@ -176,11 +176,11 @@ func (tree *MutableTree) Get(key []byte) ([]byte, error) {
}

if !tree.skipFastStorageUpgrade {
if fastNode, ok := tree.unsavedFastNodeAdditions[string(key)]; ok {
return fastNode.GetValue(), nil
if fastNode, ok := tree.unsavedFastNodeAdditions.Load(ibytes.UnsafeBytesToStr(key)); ok {
return fastNode.(*fastnode.Node).GetValue(), nil
}
// check if node was deleted
if _, ok := tree.unsavedFastNodeRemovals[string(key)]; ok {
if _, ok := tree.unsavedFastNodeRemovals.Load(string(key)); ok {
return nil, nil
}
}
Expand Down Expand Up @@ -659,8 +659,8 @@ func (tree *MutableTree) Rollback() {
}
}
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = map[string]*fastnode.Node{}
tree.unsavedFastNodeRemovals = map[string]interface{}{}
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}
}

Expand Down Expand Up @@ -778,8 +778,8 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) {
tree.ImmutableTree = tree.ImmutableTree.clone()
tree.lastSaved = tree.ImmutableTree.clone()
if !tree.skipFastStorageUpgrade {
tree.unsavedFastNodeAdditions = make(map[string]*fastnode.Node)
tree.unsavedFastNodeRemovals = make(map[string]interface{})
tree.unsavedFastNodeAdditions = &sync.Map{}
tree.unsavedFastNodeRemovals = &sync.Map{}
}

hash := tree.Hash()
Expand All @@ -797,30 +797,45 @@ func (tree *MutableTree) saveFastNodeVersion(isGenesis bool) error {
return tree.ndb.setFastStorageVersionToBatch()
}

// nolint: unused

Check failure on line 800 in mutable_tree.go

View workflow job for this annotation

GitHub Actions / golangci-lint

directive `// nolint: unused` is unused for linter "unused" (nolintlint)
func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*fastnode.Node {
return tree.unsavedFastNodeAdditions
additions := make(map[string]*fastnode.Node)
tree.unsavedFastNodeAdditions.Range(func(key, value interface{}) bool {
additions[key.(string)] = value.(*fastnode.Node)
return true
})
return additions
}

// getUnsavedFastNodeRemovals returns unsaved FastNodes to remove

func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} {
return tree.unsavedFastNodeRemovals
removals := make(map[string]interface{})
tree.unsavedFastNodeRemovals.Range(func(key, value interface{}) bool {
removals[key.(string)] = value
return true
})
return removals
}

// addUnsavedAddition stores an addition into the unsaved additions map
func (tree *MutableTree) addUnsavedAddition(key []byte, node *fastnode.Node) {
delete(tree.unsavedFastNodeRemovals, ibytes.UnsafeBytesToStr(key))
tree.unsavedFastNodeAdditions[string(key)] = node
skey := ibytes.UnsafeBytesToStr(key)
tree.unsavedFastNodeRemovals.Delete(skey)
tree.unsavedFastNodeAdditions.Store(skey, node)
}

func (tree *MutableTree) saveFastNodeAdditions(batchCommmit bool) error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions))
for key := range tree.unsavedFastNodeAdditions {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil {
val, _ := tree.unsavedFastNodeAdditions.Load(key)
if err := tree.ndb.SaveFastNode(val.(*fastnode.Node)); err != nil {
return err
}
if batchCommmit {
Expand All @@ -832,17 +847,19 @@ func (tree *MutableTree) saveFastNodeAdditions(batchCommmit bool) error {
return nil
}

// addUnsavedRemoval adds a removal to the unsaved removals map
func (tree *MutableTree) addUnsavedRemoval(key []byte) {
skey := ibytes.UnsafeBytesToStr(key)
delete(tree.unsavedFastNodeAdditions, skey)
tree.unsavedFastNodeRemovals[skey] = true
tree.unsavedFastNodeAdditions.Delete(skey)
tree.unsavedFastNodeRemovals.Store(skey, true)
}

func (tree *MutableTree) saveFastNodeRemovals() error {
keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals))
for key := range tree.unsavedFastNodeRemovals {
keysToSort = append(keysToSort, key)
}
keysToSort := make([]string, 0)
tree.unsavedFastNodeRemovals.Range(func(k, v interface{}) bool {
keysToSort = append(keysToSort, k.(string))
return true
})
sort.Strings(keysToSort)

for _, key := range keysToSort {
Expand Down
86 changes: 47 additions & 39 deletions unsaved_fast_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bytes"
"errors"
"sort"
"sync"

dbm "github.com/cosmos/cosmos-db"

"github.com/cosmos/iavl/fastnode"
ibytes "github.com/cosmos/iavl/internal/bytes"
)

var (
Expand All @@ -29,14 +32,14 @@ type UnsavedFastIterator struct {
fastIterator dbm.Iterator

nextUnsavedNodeIdx int
unsavedFastNodeAdditions map[string]*fastnode.Node
unsavedFastNodeRemovals map[string]interface{}
unsavedFastNodesToSort [][]byte
unsavedFastNodeAdditions *sync.Map // map[string]*FastNode
unsavedFastNodeRemovals *sync.Map // map[string]interface{}
unsavedFastNodesToSort []string
}

var _ dbm.Iterator = (*UnsavedFastIterator)(nil)

func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*fastnode.Node, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator {
func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions, unsavedFastNodeRemovals *sync.Map) *UnsavedFastIterator {
iter := &UnsavedFastIterator{
start: start,
end: end,
Expand All @@ -50,29 +53,6 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
fastIterator: NewFastIterator(start, end, ascending, ndb),
}

// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
for _, fastNode := range unsavedFastNodeAdditions {
if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
continue
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
continue
}

iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, fastNode.GetKey())
}

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
cmp := bytes.Compare(iter.unsavedFastNodesToSort[i], iter.unsavedFastNodesToSort[j])
if ascending {
return cmp < 0
}
return cmp > 0
})

if iter.ndb == nil {
iter.err = errFastIteratorNilNdbGiven
iter.valid = false
Expand All @@ -90,8 +70,34 @@ func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsa
iter.valid = false
return iter
}
// We need to ensure that we iterate over saved and unsaved state in order.
// The strategy is to sort unsaved nodes, the fast node on disk are already sorted.
// Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently.
unsavedFastNodeAdditions.Range(func(k, v interface{}) bool {
fastNode := v.(*fastnode.Node)

if start != nil && bytes.Compare(fastNode.GetKey(), start) < 0 {
return true
}

if end != nil && bytes.Compare(fastNode.GetKey(), end) >= 0 {
return true
}

// convert key to bytes. Type conversion failure should not happen in practice
iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, k.(string))

// Move to the first elemenet
return true
})

sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool {
if ascending {
return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j]
}
return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j]
})

// Move to the first element
iter.Next()

return iter
Expand Down Expand Up @@ -134,31 +140,31 @@ func (iter *UnsavedFastIterator) Next() {
return
}

diskKeyStr := iter.fastIterator.Key()
diskKey := iter.fastIterator.Key()
diskKeyStr := ibytes.UnsafeBytesToStr(diskKey)
if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {

if iter.unsavedFastNodeRemovals[string(diskKeyStr)] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
return
}

nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[string(nextUnsavedKey)]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

var isUnsavedNext bool
cmp := bytes.Compare(diskKeyStr, nextUnsavedKey)
if iter.ascending {
isUnsavedNext = cmp >= 0
isUnsavedNext = diskKeyStr >= nextUnsavedKey
} else {
isUnsavedNext = cmp <= 0
isUnsavedNext = diskKeyStr <= nextUnsavedKey
}

if isUnsavedNext {
// Unsaved node is next

if cmp == 0 {
if diskKeyStr == nextUnsavedKey {
// Unsaved update prevails over saved copy so we skip the copy from disk
iter.fastIterator.Next()
}
Expand All @@ -179,7 +185,8 @@ func (iter *UnsavedFastIterator) Next() {

// if only nodes on disk are left, we return them
if iter.fastIterator.Valid() {
if iter.unsavedFastNodeRemovals[string(diskKeyStr)] != nil {
value, ok := iter.unsavedFastNodeRemovals.Load(diskKeyStr)
if ok && value != nil {
// If next fast node from disk is to be removed, skip it.
iter.fastIterator.Next()
iter.Next()
Expand All @@ -196,7 +203,8 @@ func (iter *UnsavedFastIterator) Next() {
// if only unsaved nodes are left, we can just iterate
if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) {
nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx]
nextUnsavedNode := iter.unsavedFastNodeAdditions[string(nextUnsavedKey)]
nextUnsavedNodeVal, _ := iter.unsavedFastNodeAdditions.Load(nextUnsavedKey)
nextUnsavedNode := nextUnsavedNodeVal.(*fastnode.Node)

iter.nextKey = nextUnsavedNode.GetKey()
iter.nextVal = nextUnsavedNode.GetValue()
Expand Down

0 comments on commit 78805ea

Please sign in to comment.