Skip to content

Commit

Permalink
chore: Integrate scale package into lib/trie (#1804)
Browse files Browse the repository at this point in the history
* remove unused Encode/Decode

* update scale integration

* newHasher
  • Loading branch information
timwu20 authored Sep 28, 2021
1 parent 9b2f41e commit a51e5ab
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 240 deletions.
136 changes: 0 additions & 136 deletions lib/trie/encode.go

This file was deleted.

59 changes: 0 additions & 59 deletions lib/trie/encode_test.go

This file was deleted.

37 changes: 16 additions & 21 deletions lib/trie/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import (
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"github.com/ChainSafe/gossamer/pkg/scale"
"golang.org/x/crypto/blake2b"
"golang.org/x/sync/errgroup"
)

// Hasher is a wrapper around a hash function
type Hasher struct {
type hasher struct {
hash hash.Hash
tmp bytes.Buffer
parallel bool // Whether to use parallel threads when hashing
Expand All @@ -43,28 +43,28 @@ var hasherPool = sync.Pool{
// This allocation will be helpful for encoding keys. This is the min buffer size.
buf.Grow(700)

return &Hasher{
return &hasher{
tmp: buf,
hash: h,
}
},
}

// NewHasher create new Hasher instance
func NewHasher(parallel bool) *Hasher {
h := hasherPool.Get().(*Hasher)
func newHasher(parallel bool) *hasher {
h := hasherPool.Get().(*hasher)
h.parallel = parallel
return h
}

func (h *Hasher) returnToPool() {
func (h *hasher) returnToPool() {
h.tmp.Reset()
h.hash.Reset()
hasherPool.Put(h)
}

// Hash encodes the node and then hashes it if its encoded length is > 32 bytes
func (h *Hasher) Hash(n node) (res []byte, err error) {
func (h *hasher) Hash(n node) (res []byte, err error) {
encNode, err := h.encode(n)
if err != nil {
return nil, err
Expand All @@ -88,7 +88,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {
// encode is the high-level function wrapping the encoding for different node types
// encoding has the following format:
// NodeHeader | Extra partial key length | Partial Key | Value
func (h *Hasher) encode(n node) ([]byte, error) {
func (h *hasher) encode(n node) ([]byte, error) {
switch n := n.(type) {
case *branch:
return h.encodeBranch(n)
Expand All @@ -102,23 +102,23 @@ func (h *Hasher) encode(n node) ([]byte, error) {
}

func encodeAndHash(n node) ([]byte, error) {
h := NewHasher(false)
h := newHasher(false)
defer h.returnToPool()

encChild, err := h.Hash(n)
if err != nil {
return nil, err
}

scEncChild, err := scale.Encode(encChild)
scEncChild, err := scale.Marshal(encChild)
if err != nil {
return nil, err
}
return scEncChild, nil
}

// encodeBranch encodes a branch with the encoding specified at the top of this package
func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
func (h *hasher) encodeBranch(b *branch) ([]byte, error) {
if !b.dirty && b.encoding != nil {
return b.encoding, nil
}
Expand All @@ -134,13 +134,11 @@ func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
h.tmp.Write(common.Uint16ToBytes(b.childrenBitmap()))

if b.value != nil {
buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}
_, err = se.Encode(b.value)
bytes, err := scale.Marshal(b.value)
if err != nil {
return nil, err
}
h.tmp.Write(buffer.Bytes())
h.tmp.Write(bytes)
}

if h.parallel {
Expand Down Expand Up @@ -188,7 +186,7 @@ func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
}

// encodeLeaf encodes a leaf with the encoding specified at the top of this package
func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) {
func (h *hasher) encodeLeaf(l *leaf) ([]byte, error) {
if !l.dirty && l.encoding != nil {
return l.encoding, nil
}
Expand All @@ -203,15 +201,12 @@ func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) {

h.tmp.Write(nibblesToKeyLE(l.key))

buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}

_, err = se.Encode(l.value)
bytes, err := scale.Marshal(l.value)
if err != nil {
return nil, err
}

h.tmp.Write(buffer.Bytes())
h.tmp.Write(bytes)
l.encoding = h.tmp.Bytes()
return h.tmp.Bytes(), nil
}
8 changes: 4 additions & 4 deletions lib/trie/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func generateRand(size int) [][]byte {
}

func TestNewHasher(t *testing.T) {
hasher := NewHasher(false)
hasher := newHasher(false)
defer hasher.returnToPool()

_, err := hasher.hash.Write([]byte("noot"))
Expand All @@ -58,7 +58,7 @@ func TestNewHasher(t *testing.T) {
}

func TestHashLeaf(t *testing.T) {
hasher := NewHasher(false)
hasher := newHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)}
Expand All @@ -71,7 +71,7 @@ func TestHashLeaf(t *testing.T) {
}

func TestHashBranch(t *testing.T) {
hasher := NewHasher(false)
hasher := newHasher(false)
defer hasher.returnToPool()

n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)}
Expand All @@ -85,7 +85,7 @@ func TestHashBranch(t *testing.T) {
}

func TestHashShort(t *testing.T) {
hasher := NewHasher(false)
hasher := newHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)}
Expand Down
4 changes: 2 additions & 2 deletions lib/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (b *branch) encodeAndHash() ([]byte, []byte, error) {
return b.encoding, b.hash, nil
}

hasher := NewHasher(false)
hasher := newHasher(false)
enc, err := hasher.encodeBranch(b)
if err != nil {
return nil, nil, err
Expand All @@ -255,7 +255,7 @@ func (l *leaf) encodeAndHash() ([]byte, []byte, error) {
if !l.isDirty() && l.encoding != nil && l.hash != nil {
return l.encoding, l.hash, nil
}
hasher := NewHasher(false)
hasher := newHasher(false)
enc, err := hasher.encodeLeaf(l)

if err != nil {
Expand Down
Loading

0 comments on commit a51e5ab

Please sign in to comment.