From be668c1de6625edd9d16b1bb4c45b5ead6cbf452 Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Thu, 27 Jan 2022 19:15:51 +0000 Subject: [PATCH] Replace sync.Map with mutex locked map --- dot/state/service_test.go | 4 +- dot/state/storage.go | 35 ++++--- dot/state/storage_test.go | 22 ++--- dot/state/tries.go | 60 ++++++++++++ dot/state/tries_test.go | 195 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 279 insertions(+), 37 deletions(-) create mode 100644 dot/state/tries.go create mode 100644 dot/state/tries_test.go diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 11652ff8a6..466cc06593 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -293,8 +293,8 @@ func TestService_PruneStorage(t *testing.T) { time.Sleep(1 * time.Second) for _, v := range prunedArr { - _, has := serv.Storage.tries.Load(v.hash) - require.Equal(t, false, has) + tr := serv.Storage.tries.get(v.hash) + require.Nil(t, tr) } } diff --git a/dot/state/storage.go b/dot/state/storage.go index 56da2fc0b5..94772187df 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -30,7 +30,7 @@ func errTrieDoesNotExist(hash common.Hash) error { // StorageState is the struct that holds the trie, db and lock type StorageState struct { blockState *BlockState - tries *sync.Map // map[common.Hash]*trie.Trie // map of root -> trie + tries *tries db chaindb.Database sync.RWMutex @@ -52,8 +52,7 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, return nil, fmt.Errorf("cannot have nil trie") } - tries := new(sync.Map) - tries.Store(t.MustHash(), t) + tries := newTries(t) storageTable := chaindb.NewTable(db, storagePrefix) @@ -79,14 +78,14 @@ func NewStorageState(db chaindb.Database, blockState *BlockState, func (s *StorageState) pruneKey(keyHeader *types.Header) { logger.Tracef("pruning trie, number=%d hash=%s", keyHeader.Number, keyHeader.Hash()) - s.tries.Delete(keyHeader.StateRoot) + s.tries.delete(keyHeader.StateRoot) } // StoreTrie stores the given trie in the StorageState and writes it to the database func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header) error { root := ts.MustRoot() - _, _ = s.tries.LoadOrStore(root, ts.Trie()) + s.tries.softSet(root, ts.Trie()) if _, ok := s.pruner.(*pruner.FullNode); header == nil && ok { return fmt.Errorf("block cannot be empty for Full node pruner") @@ -127,20 +126,16 @@ func (s *StorageState) TrieState(root *common.Hash) (*rtstorage.TrieState, error root = &sr } - st, has := s.tries.Load(*root) - if !has { + t := s.tries.get(*root) + if t == nil { var err error - st, err = s.LoadFromDB(*root) + t, err = s.LoadFromDB(*root) if err != nil { return nil, err } - _, _ = s.tries.LoadOrStore(*root, st) - } - - t := st.(*trie.Trie) - - if has && t.MustHash() != *root { + s.tries.softSet(*root, t) + } else if t.MustHash() != *root { panic("trie does not have expected root") } @@ -162,7 +157,7 @@ func (s *StorageState) LoadFromDB(root common.Hash) (*trie.Trie, error) { return nil, err } - _, _ = s.tries.LoadOrStore(t.MustHash(), t) + s.tries.softSet(t.MustHash(), t) return t, nil } @@ -175,8 +170,9 @@ func (s *StorageState) loadTrie(root *common.Hash) (*trie.Trie, error) { root = &sr } - if t, has := s.tries.Load(*root); has && t != nil { - return t.(*trie.Trie), nil + t := s.tries.get(*root) + if t != nil { + return t, nil } tr, err := s.LoadFromDB(*root) @@ -205,8 +201,9 @@ func (s *StorageState) GetStorage(root *common.Hash, key []byte) ([]byte, error) root = &sr } - if t, has := s.tries.Load(*root); has { - val := t.(*trie.Trie).Get(key) + t := s.tries.get(*root) + if t != nil { + val := t.Get(key) return val, nil } diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index 7679329ad1..6954af5b32 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -5,7 +5,6 @@ package state import ( "math/big" - "sync" "testing" "time" @@ -99,7 +98,7 @@ func TestStorage_TrieState(t *testing.T) { time.Sleep(time.Millisecond * 100) // get trie from db - storage.tries.Delete(root) + storage.tries.delete(root) ts3, err := storage.TrieState(&root) require.NoError(t, err) require.Equal(t, ts.Trie().MustHash(), ts3.Trie().MustHash()) @@ -131,34 +130,25 @@ func TestStorage_LoadFromDB(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.Delete(root) + storage.tries.delete(root) data, err := storage.GetStorage(&root, trieKV[0].key) require.NoError(t, err) require.Equal(t, trieKV[0].value, data) - storage.tries.Delete(root) + storage.tries.delete(root) prefixKeys, err := storage.GetKeysWithPrefix(&root, []byte("ke")) require.NoError(t, err) require.Equal(t, 2, len(prefixKeys)) - storage.tries.Delete(root) + storage.tries.delete(root) entries, err := storage.Entries(&root) require.NoError(t, err) require.Equal(t, 3, len(entries)) } -func syncMapLen(m *sync.Map) int { - l := 0 - m.Range(func(_, _ interface{}) bool { - l++ - return true - }) - return l -} - func TestStorage_StoreTrie_NotSyncing(t *testing.T) { storage := newTestStorageState(t) ts, err := storage.TrieState(&trie.EmptyHash) @@ -170,7 +160,7 @@ func TestStorage_StoreTrie_NotSyncing(t *testing.T) { err = storage.StoreTrie(ts, nil) require.NoError(t, err) - require.Equal(t, 2, syncMapLen(storage.tries)) + require.Equal(t, 2, storage.tries.len()) } func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { @@ -217,7 +207,7 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { require.NoError(t, err) // Clear trie from cache and fetch data from disk. - storage.tries.Delete(rootHash) + storage.tries.delete(rootHash) _, err = storage.GetStorageChild(&rootHash, []byte("keyToChild")) require.NoError(t, err) diff --git a/dot/state/tries.go b/dot/state/tries.go new file mode 100644 index 0000000000..e7afd3dbb1 --- /dev/null +++ b/dot/state/tries.go @@ -0,0 +1,60 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package state + +import ( + "sync" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" +) + +type tries struct { + rootToTrie map[common.Hash]*trie.Trie + mapMutex sync.RWMutex +} + +func newTries(t *trie.Trie) *tries { + return &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + t.MustHash(): t, + }, + } +} + +// softSet sets the given trie at the given root hash +// in the memory map only if it is not already set. +func (t *tries) softSet(root common.Hash, trie *trie.Trie) { + t.mapMutex.Lock() + defer t.mapMutex.Unlock() + + _, has := t.rootToTrie[root] + if has { + return + } + + t.rootToTrie[root] = trie +} + +func (t *tries) delete(root common.Hash) { + t.mapMutex.Lock() + defer t.mapMutex.Unlock() + delete(t.rootToTrie, root) +} + +// get retrieves the trie corresponding to the root hash given +// from the in-memory thread safe map. +func (t *tries) get(root common.Hash) (tr *trie.Trie) { + t.mapMutex.RLock() + defer t.mapMutex.RUnlock() + return t.rootToTrie[root] +} + +// len returns the current numbers of tries +// stored in the in-memory map. +func (t *tries) len() int { + t.mapMutex.RLock() + defer t.mapMutex.RUnlock() + return len(t.rootToTrie) +} diff --git a/dot/state/tries_test.go b/dot/state/tries_test.go new file mode 100644 index 0000000000..0a0bc1d865 --- /dev/null +++ b/dot/state/tries_test.go @@ -0,0 +1,195 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package state + +import ( + "testing" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie" + "github.com/stretchr/testify/assert" +) + +func Test_newTries(t *testing.T) { + t.Parallel() + + tr := trie.NewEmptyTrie() + + rootToTrie := newTries(tr) + + expectedTries := &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + tr.MustHash(): tr, + }, + } + + assert.Equal(t, expectedTries, rootToTrie) +} + +func Test_tries_softSet(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + tries *tries + root common.Hash + trie *trie.Trie + expectedTries *tries + }{ + "set new in map": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + root: common.Hash{1, 2, 3}, + trie: trie.NewEmptyTrie(), + expectedTries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: trie.NewEmptyTrie(), + }, + }, + }, + "do not override in map": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: {}, + }, + }, + root: common.Hash{1, 2, 3}, + trie: trie.NewEmptyTrie(), + expectedTries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: {}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.tries.softSet(testCase.root, testCase.trie) + + assert.Equal(t, testCase.expectedTries, testCase.tries) + }) + } +} + +func Test_tries_delete(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + tries *tries + root common.Hash + expectedTries *tries + }{ + "not found": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + root: common.Hash{1, 2, 3}, + expectedTries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + }, + "deleted": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: {}, + }, + }, + root: common.Hash{1, 2, 3}, + expectedTries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.tries.delete(testCase.root) + + assert.Equal(t, testCase.expectedTries, testCase.tries) + }) + } +} +func Test_tries_get(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + tries *tries + root common.Hash + trie *trie.Trie + }{ + "found in map": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: trie.NewTrie(&node.Leaf{ + Key: []byte{1, 2, 3}, + }), + }, + }, + root: common.Hash{1, 2, 3}, + trie: trie.NewTrie(&node.Leaf{ + Key: []byte{1, 2, 3}, + }), + }, + "not found in map": { + // similar to not found in database + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + root: common.Hash{1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + trieFound := testCase.tries.get(testCase.root) + + assert.Equal(t, testCase.trie, trieFound) + }) + } +} + +func Test_tries_len(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + tries *tries + length int + }{ + "empty map": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{}, + }, + }, + "non empty map": { + tries: &tries{ + rootToTrie: map[common.Hash]*trie.Trie{ + {1, 2, 3}: {}, + }, + }, + length: 1, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + length := testCase.tries.len() + + assert.Equal(t, testCase.length, length) + }) + } +}