From a51e5ab6e31a0aea3ed8ea96aab92a5d40a3b1a7 Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Tue, 28 Sep 2021 17:07:04 -0400 Subject: [PATCH] chore: Integrate scale package into lib/trie (#1804) * remove unused Encode/Decode * update scale integration * newHasher --- lib/trie/encode.go | 136 ---------------------------------------- lib/trie/encode_test.go | 59 ----------------- lib/trie/hash.go | 37 +++++------ lib/trie/hash_test.go | 8 +-- lib/trie/node.go | 4 +- lib/trie/node_test.go | 14 ++--- lib/trie/print.go | 4 +- lib/trie/trie.go | 2 +- lib/trie/trie_test.go | 8 --- 9 files changed, 32 insertions(+), 240 deletions(-) delete mode 100644 lib/trie/encode.go delete mode 100644 lib/trie/encode_test.go diff --git a/lib/trie/encode.go b/lib/trie/encode.go deleted file mode 100644 index 09f1eec99d..0000000000 --- a/lib/trie/encode.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2019 ChainSafe Systems (ON) Corp. -// This file is part of gossamer. -// -// The gossamer library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The gossamer library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the gossamer library. If not, see . - -package trie - -import ( - "bytes" - "fmt" - "io" - - "github.com/ChainSafe/gossamer/lib/scale" -) - -// Encode traverses the trie recursively, encodes each node, SCALE encodes the encoded node, and appends them all together -func (t *Trie) Encode() ([]byte, error) { - return encodeRecursive(t.root, []byte{}) -} - -func encodeRecursive(n node, enc []byte) ([]byte, error) { - if n == nil { - return []byte{}, nil - } - - hasher := NewHasher(false) - defer hasher.returnToPool() - nenc, err := hasher.encode(n) - if err != nil { - return enc, err - } - - scnenc, err := scale.Encode(nenc) - if err != nil { - return nil, err - } - - enc = append(enc, scnenc...) - - if n, ok := n.(*branch); ok { - for _, child := range n.children { - if child != nil { - enc, err = encodeRecursive(child, enc) - if err != nil { - return enc, err - } - } - } - } - - return enc, nil -} - -// Decode decodes a trie from the DB and sets the receiver to it -// The encoded trie must have been encoded with t.Encode -func (t *Trie) Decode(enc []byte) error { - if bytes.Equal(enc, []byte{}) { - return nil - } - - r := &bytes.Buffer{} - _, err := r.Write(enc) - if err != nil { - return err - } - - sd := &scale.Decoder{Reader: r} - scroot, err := sd.Decode([]byte{}) - if err != nil { - return err - } - - n := &bytes.Buffer{} - _, err = n.Write(scroot.([]byte)) - if err != nil { - return err - } - - t.root, err = decode(n) - if err != nil { - return err - } - - return decodeRecursive(r, t.root) -} - -func decodeRecursive(r io.Reader, prev node) error { - sd := &scale.Decoder{Reader: r} - - if b, ok := prev.(*branch); ok { - for i, child := range b.children { - if child != nil { - // there's supposed to be a child here, decode the next node and place it - // when we decode a branch node, we only know if a child is supposed to exist at a certain index (due to the - // bitmap). we also have the hashes of the children, but we can't reconstruct the children from that. so - // instead, we put an empty leaf node where the child should be, so when we reconstruct it in this function, - // we can see that it's non-nil and we should decode the next node from the reader and place it here - scnode, err := sd.Decode([]byte{}) - if err != nil { - return err - } - - n := &bytes.Buffer{} - _, err = n.Write(scnode.([]byte)) - if err != nil { - return err - } - - b.children[i], err = decode(n) - if err != nil { - return fmt.Errorf("could not decode child at %d: %s", i, err) - } - - b.children[i].setDirty(true) - - err = decodeRecursive(r, b.children[i]) - if err != nil { - return err - } - } - } - } - - return nil -} diff --git a/lib/trie/encode_test.go b/lib/trie/encode_test.go deleted file mode 100644 index c4755c5716..0000000000 --- a/lib/trie/encode_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 ChainSafe Systems (ON) Corp. -// This file is part of gossamer. -// -// The gossamer library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The gossamer library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the gossamer library. If not, see . - -package trie - -import ( - "bytes" - "testing" -) - -func TestEncodeAndDecode(t *testing.T) { - trie := &Trie{} - - tests := []Test{ - {key: []byte{0x01, 0x35}, value: []byte("pen")}, - {key: []byte{0x01, 0x35, 0x79}, value: []byte("penguin")}, - {key: []byte{0x01, 0x35, 0x7}, value: []byte("g")}, - {key: []byte{0xf2}, value: []byte("feather")}, - {key: []byte{0xf2, 0x3}, value: []byte("f")}, - {key: []byte{0x09, 0xd3}, value: []byte("noot")}, - {key: []byte{0x07}, value: []byte("ramen")}, - {key: []byte{0}, value: nil}, - } - - for _, test := range tests { - trie.Put(test.key, test.value) - } - - enc, err := trie.Encode() - if err != nil { - t.Fatal(err) - } - - testTrie := &Trie{} - err = testTrie.Decode(enc) - if err != nil { - testTrie.Print() - t.Fatal(err) - } - - expected := trie.MustHash() - res := testTrie.MustHash() - if !bytes.Equal(expected[:], res[:]) { - t.Errorf("Fail: got\n %s expected\n %s", testTrie.String(), trie.String()) - } -} diff --git a/lib/trie/hash.go b/lib/trie/hash.go index 8b9d8374b8..4b31bade6c 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -23,13 +23,13 @@ import ( "sync" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/scale" + "github.com/ChainSafe/gossamer/pkg/scale" "golang.org/x/crypto/blake2b" "golang.org/x/sync/errgroup" ) // Hasher is a wrapper around a hash function -type Hasher struct { +type hasher struct { hash hash.Hash tmp bytes.Buffer parallel bool // Whether to use parallel threads when hashing @@ -43,7 +43,7 @@ var hasherPool = sync.Pool{ // This allocation will be helpful for encoding keys. This is the min buffer size. buf.Grow(700) - return &Hasher{ + return &hasher{ tmp: buf, hash: h, } @@ -51,20 +51,20 @@ var hasherPool = sync.Pool{ } // NewHasher create new Hasher instance -func NewHasher(parallel bool) *Hasher { - h := hasherPool.Get().(*Hasher) +func newHasher(parallel bool) *hasher { + h := hasherPool.Get().(*hasher) h.parallel = parallel return h } -func (h *Hasher) returnToPool() { +func (h *hasher) returnToPool() { h.tmp.Reset() h.hash.Reset() hasherPool.Put(h) } // Hash encodes the node and then hashes it if its encoded length is > 32 bytes -func (h *Hasher) Hash(n node) (res []byte, err error) { +func (h *hasher) Hash(n node) (res []byte, err error) { encNode, err := h.encode(n) if err != nil { return nil, err @@ -88,7 +88,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) { // encode is the high-level function wrapping the encoding for different node types // encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value -func (h *Hasher) encode(n node) ([]byte, error) { +func (h *hasher) encode(n node) ([]byte, error) { switch n := n.(type) { case *branch: return h.encodeBranch(n) @@ -102,7 +102,7 @@ func (h *Hasher) encode(n node) ([]byte, error) { } func encodeAndHash(n node) ([]byte, error) { - h := NewHasher(false) + h := newHasher(false) defer h.returnToPool() encChild, err := h.Hash(n) @@ -110,7 +110,7 @@ func encodeAndHash(n node) ([]byte, error) { return nil, err } - scEncChild, err := scale.Encode(encChild) + scEncChild, err := scale.Marshal(encChild) if err != nil { return nil, err } @@ -118,7 +118,7 @@ func encodeAndHash(n node) ([]byte, error) { } // encodeBranch encodes a branch with the encoding specified at the top of this package -func (h *Hasher) encodeBranch(b *branch) ([]byte, error) { +func (h *hasher) encodeBranch(b *branch) ([]byte, error) { if !b.dirty && b.encoding != nil { return b.encoding, nil } @@ -134,13 +134,11 @@ func (h *Hasher) encodeBranch(b *branch) ([]byte, error) { h.tmp.Write(common.Uint16ToBytes(b.childrenBitmap())) if b.value != nil { - buffer := bytes.Buffer{} - se := scale.Encoder{Writer: &buffer} - _, err = se.Encode(b.value) + bytes, err := scale.Marshal(b.value) if err != nil { return nil, err } - h.tmp.Write(buffer.Bytes()) + h.tmp.Write(bytes) } if h.parallel { @@ -188,7 +186,7 @@ func (h *Hasher) encodeBranch(b *branch) ([]byte, error) { } // encodeLeaf encodes a leaf with the encoding specified at the top of this package -func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) { +func (h *hasher) encodeLeaf(l *leaf) ([]byte, error) { if !l.dirty && l.encoding != nil { return l.encoding, nil } @@ -203,15 +201,12 @@ func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) { h.tmp.Write(nibblesToKeyLE(l.key)) - buffer := bytes.Buffer{} - se := scale.Encoder{Writer: &buffer} - - _, err = se.Encode(l.value) + bytes, err := scale.Marshal(l.value) if err != nil { return nil, err } - h.tmp.Write(buffer.Bytes()) + h.tmp.Write(bytes) l.encoding = h.tmp.Bytes() return h.tmp.Bytes(), nil } diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go index 94c43e1adf..78ba43fa6a 100644 --- a/lib/trie/hash_test.go +++ b/lib/trie/hash_test.go @@ -41,7 +41,7 @@ func generateRand(size int) [][]byte { } func TestNewHasher(t *testing.T) { - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() _, err := hasher.hash.Write([]byte("noot")) @@ -58,7 +58,7 @@ func TestNewHasher(t *testing.T) { } func TestHashLeaf(t *testing.T) { - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)} @@ -71,7 +71,7 @@ func TestHashLeaf(t *testing.T) { } func TestHashBranch(t *testing.T) { - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)} @@ -85,7 +85,7 @@ func TestHashBranch(t *testing.T) { } func TestHashShort(t *testing.T) { - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)} diff --git a/lib/trie/node.go b/lib/trie/node.go index 3e1f091f02..f99901f0ee 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -229,7 +229,7 @@ func (b *branch) encodeAndHash() ([]byte, []byte, error) { return b.encoding, b.hash, nil } - hasher := NewHasher(false) + hasher := newHasher(false) enc, err := hasher.encodeBranch(b) if err != nil { return nil, nil, err @@ -255,7 +255,7 @@ func (l *leaf) encodeAndHash() ([]byte, []byte, error) { if !l.isDirty() && l.encoding != nil && l.hash != nil { return l.encoding, l.hash, nil } - hasher := NewHasher(false) + hasher := newHasher(false) enc, err := hasher.encodeLeaf(l) if err != nil { diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index b894bde450..aff5b9a54b 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -171,7 +171,7 @@ func TestBranchEncode(t *testing.T) { for _, child := range b.children { if child != nil { - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() encChild, er := hasher.Hash(child) if er != nil { @@ -181,7 +181,7 @@ func TestBranchEncode(t *testing.T) { } } - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() res, err := hasher.encodeBranch(b) if !bytes.Equal(res, expected) { @@ -216,7 +216,7 @@ func TestLeafEncode(t *testing.T) { expected = append(expected, buf.Bytes()...) - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() res, err := hasher.encodeLeaf(l) if !bytes.Equal(res, expected) { @@ -240,7 +240,7 @@ func TestEncodeRoot(t *testing.T) { t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, val) } - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() _, err := hasher.encode(trie.root) if err != nil { @@ -267,7 +267,7 @@ func TestBranchDecode(t *testing.T) { {key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, } - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() for _, test := range tests { enc, err := hasher.encodeBranch(test) @@ -298,7 +298,7 @@ func TestLeafDecode(t *testing.T) { {key: byteArray(573), value: []byte{0x01}, dirty: true}, } - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() for _, test := range tests { enc, err := hasher.encodeLeaf(test) @@ -337,7 +337,7 @@ func TestDecode(t *testing.T) { &leaf{key: byteArray(573), value: []byte{0x01}}, } - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() for _, test := range tests { enc, err := hasher.encode(test) diff --git a/lib/trie/print.go b/lib/trie/print.go index 73df501d43..dd55906078 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -38,7 +38,7 @@ func (t *Trie) String() string { func (t *Trie) string(tree gotree.Tree, curr node, idx int) { switch c := curr.(type) { case *branch: - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() c.encoding, _ = hasher.encode(c) var bstr string @@ -54,7 +54,7 @@ func (t *Trie) string(tree gotree.Tree, curr node, idx int) { } } case *leaf: - hasher := NewHasher(false) + hasher := newHasher(false) defer hasher.returnToPool() c.encoding, _ = hasher.encode(c) var bstr string diff --git a/lib/trie/trie.go b/lib/trie/trie.go index aeedbeb984..bd33b73ac1 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -120,7 +120,7 @@ func (t *Trie) RootNode() node { //nolint // EncodeRoot returns the encoded root of the trie func (t *Trie) EncodeRoot() ([]byte, error) { - h := NewHasher(t.parallel) + h := newHasher(t.parallel) defer h.returnToPool() return h.encode(t.RootNode()) } diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index ce6b739d97..612c55228b 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -531,14 +531,6 @@ func TestTrieDiff(t *testing.T) { dbTrie := NewEmptyTrie() err = dbTrie.Load(storageDB, common.BytesToHash(newTrie.root.getHash())) require.NoError(t, err) - - enc, err := dbTrie.Encode() - require.NoError(t, err) - - newEnc, err := newTrie.Encode() - require.NoError(t, err) - - require.Equal(t, enc, newEnc) } func TestDelete(t *testing.T) {