diff --git a/dot/state/roottottrie.go b/dot/state/roottottrie.go new file mode 100644 index 0000000000..6201133f3a --- /dev/null +++ b/dot/state/roottottrie.go @@ -0,0 +1,50 @@ +package state + +import ( + "sync" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +type rootToTrieMap struct { + rootToTrie map[common.Hash]*trie.Trie + sync.RWMutex +} + +func newRootToTrieMap() *rootToTrieMap { + return &rootToTrieMap{ + rootToTrie: make(map[common.Hash]*trie.Trie), + } +} + +func (r *rootToTrieMap) get(root common.Hash) (t *trie.Trie, has bool) { + r.RLock() + defer r.RUnlock() + t, has = r.rootToTrie[root] + return t, has +} + +func (r *rootToTrieMap) set(root common.Hash, trie *trie.Trie) { + r.Lock() + defer r.Unlock() + r.rootToTrie[root] = trie +} + +func (r *rootToTrieMap) setIfUnset(root common.Hash, trie *trie.Trie) { + r.Lock() + defer r.Unlock() + + _, exists := r.rootToTrie[root] + if exists { + return + } + + r.rootToTrie[root] = trie +} + +func (r *rootToTrieMap) delete(root common.Hash) { + r.Lock() + defer r.Unlock() + delete(r.rootToTrie, root) +} diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 52a6f17c7c..34c80afe86 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -310,8 +310,8 @@ func TestService_PruneStorage(t *testing.T) { time.Sleep(1 * time.Second) for _, v := range prunedArr { - _, has := serv.Storage.tries.Load(v.hash) - require.Equal(t, false, has) + _, has := serv.Storage.rootToTrie.get(v.hash) + require.False(t, has) } } diff --git a/dot/state/storage.go b/dot/state/storage.go index f9f1328412..a826637cf9 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -30,7 +30,7 @@ func errTrieDoesNotExist(hash common.Hash) error { // StorageState is the struct that holds the trie, db and lock type StorageState struct { blockState *BlockState - tries *sync.Map // map[common.Hash]*trie.Trie // map of root -> trie + rootToTrie *rootToTrieMap db chaindb.Database sync.RWMutex @@ -52,8 +52,8 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, return nil, fmt.Errorf("cannot have nil trie") } - tries := new(sync.Map) - tries.Store(t.MustHash(), t) + rootToTrie := newRootToTrieMap() + rootToTrie.set(t.MustHash(), t) storageTable := chaindb.NewTable(db, storagePrefix) @@ -70,7 +70,7 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, return &StorageState{ blockState: blockState, - tries: tries, + rootToTrie: rootToTrie, db: storageTable, observerList: []Observer{}, pruner: p, @@ -78,14 +78,14 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, } func (s *StorageState) pruneKey(keyHeader *types.Header) { - s.tries.Delete(keyHeader.StateRoot) + s.rootToTrie.delete(keyHeader.StateRoot) } // StoreTrie stores the given trie in the StorageState and writes it to the database func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { root := ts.MustRoot() - _, _ = s.tries.LoadOrStore(root, ts.Trie()) + s.rootToTrie.setIfUnset(root, ts.Trie()) if _, ok := s.pruner.(*pruner.FullNode); header == nil && ok { return fmt.Errorf("block cannot be empty for Full node pruner") @@ -126,19 +126,18 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error root = &sr } - st, has := s.tries.Load(*root) + t, has := s.rootToTrie.get(*root) if !has { var err error - st, err = s.LoadFromDB(*root) + t, err = s.LoadFromDB(*root) if err != nil { return nil, err } - _, _ = s.tries.LoadOrStore(*root, st) + s.rootToTrie.setIfUnset(*root, t) + // TODO get and setIfUnset should be atomic } - t := st.(*trie.Trie) - if has && t.MustHash() != *root { panic("trie does not have expected root") } @@ -161,7 +160,7 @@ func (s *StorageState) LoadFromDB(root common.Hash) (*trie.Trie, error) { return nil, err } - _, _ = s.tries.LoadOrStore(t.MustHash(), t) + s.rootToTrie.setIfUnset(t.MustHash(), t) return t, nil } @@ -174,8 +173,8 @@ func (s *StorageState) loadTrie(root *common.Hash) (*trie.Trie, error) { root = &sr } - if t, has := s.tries.Load(*root); has && t != nil { - return t.(*trie.Trie), nil + if t, has := s.rootToTrie.get(*root); has && t != nil { + return t, nil } tr, err := s.LoadFromDB(*root) @@ -204,8 +203,8 @@ func (s *StorageState) GetStorage(root *common.Hash, key []byte) ([]byte, error) root = &sr } - if t, has := s.tries.Load(*root); has { - val := t.(*trie.Trie).Get(key) + if t, has := s.rootToTrie.get(*root); has { + val := t.Get(key) return val, nil } diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index 7679329ad1..b733195e84 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -5,7 +5,6 @@ package state import ( "math/big" - "sync" "testing" "time" @@ -99,7 +98,7 @@ func TestStorage_TrieState(t *testing.T) { time.Sleep(time.Millisecond * 100) // get trie from db - storage.tries.Delete(root) + storage.rootToTrie.delete(root) ts3, err := storage.TrieState(&root) require.NoError(t, err) require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash()) @@ -131,34 +130,25 @@ func TestStorage_LoadFromDB(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.Delete(root) + storage.rootToTrie.delete(root) data, err := storage.GetStorage(&root, trieKV[0].key) require.NoError(t, err) require.Equal(t, trieKV[0].value, data) - storage.tries.Delete(root) + storage.rootToTrie.delete(root) prefixKeys, err := storage.GetKeysWithPrefix(&root, []byte("ke")) require.NoError(t, err) require.Equal(t, 2, len(prefixKeys)) - storage.tries.Delete(root) + storage.rootToTrie.delete(root) entries, err := storage.Entries(&root) require.NoError(t, err) require.Equal(t, 3, len(entries)) } -func syncMapLen(m *sync.Map) int { - l := 0 - m.Range(func(_, _ interface{}) bool { - l++ - return true - }) - return l -} - func TestStorage_StoreTrie_NotSyncing(t *testing.T) { storage := newTestStorageState(t) ts, err := storage.TrieState(&trie.EmptyHash) @@ -170,7 +160,7 @@ func TestStorage_StoreTrie_NotSyncing(t *testing.T) { err = storage.StoreTrie(ts, nil) require.NoError(t, err) - require.Equal(t, 2, syncMapLen(storage.tries)) + require.Equal(t, 2, len(storage.rootToTrie.rootToTrie)) } func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { @@ -217,7 +207,7 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.Delete(rootHash) + storage.rootToTrie.delete(rootHash) _, err = storage.GetStorageChild(&rootHash, []byte("keyToChild")) require.NoError(t, err)