From 71f933375ddaa2ec062f849cbb89b71972462e72 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Tue, 31 Jan 2023 01:25:28 -0300 Subject: [PATCH] crypto/merkle: reuse hasher, reduce allocations adaptation/backport of #351 --- crypto/merkle/hash.go | 40 +++++++++++++++++++++++------------ crypto/merkle/proof.go | 29 ++++++++++++++++--------- crypto/merkle/proof_value.go | 7 +++--- crypto/merkle/rfc6962_test.go | 3 ++- crypto/merkle/tree.go | 21 +++++++++++------- 5 files changed, 64 insertions(+), 36 deletions(-) diff --git a/crypto/merkle/hash.go b/crypto/merkle/hash.go index d45130fe58..5381a8fdaf 100644 --- a/crypto/merkle/hash.go +++ b/crypto/merkle/hash.go @@ -1,26 +1,40 @@ package merkle import ( + "hash" + "github.com/tendermint/tendermint/crypto/tmhash" ) -// TODO: make these have a large predefined capacity -var ( - leafPrefix = []byte{0} - innerPrefix = []byte{1} +const ( + leafPrefix = 0 + innerPrefix = 1 ) -// returns tmhash() -func emptyHash() []byte { - return tmhash.Sum([]byte{}) +type merkleHasher struct { + state hash.Hash +} + +func newMerkleHasher() merkleHasher { + return merkleHasher{state: tmhash.New()} +} + +func (mh *merkleHasher) emptyHash() []byte { + mh.state.Reset() + return mh.state.Sum(nil) } -// returns tmhash(0x00 || leaf) -func leafHash(leaf []byte) []byte { - return tmhash.Sum(append(leafPrefix, leaf...)) +func (mh *merkleHasher) leafHash(leaf []byte) []byte { + mh.state.Reset() + mh.state.Write([]byte{leafPrefix}) + mh.state.Write(leaf) + return mh.state.Sum(nil) } -// returns tmhash(0x01 || left || right) -func innerHash(left []byte, right []byte) []byte { - return tmhash.Sum(append(innerPrefix, append(left, right...)...)) +func (mh *merkleHasher) innerHash(left []byte, right []byte) []byte { + mh.state.Reset() + mh.state.Write([]byte{innerPrefix}) + mh.state.Write(left) + mh.state.Write(right) + return mh.state.Sum(nil) } diff --git a/crypto/merkle/proof.go b/crypto/merkle/proof.go index ab43f30e76..1df21ffbd9 100644 --- a/crypto/merkle/proof.go +++ b/crypto/merkle/proof.go @@ -50,7 +50,8 @@ func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) { // Verify that the Proof proves the root hash. // Check sp.Index/sp.Total manually if needed func (sp *Proof) Verify(rootHash []byte, leaf []byte) error { - leafHash := leafHash(leaf) + hasher := newMerkleHasher() + leafHash := hasher.leafHash(leaf) if sp.Total < 0 { return errors.New("proof total must be positive") } @@ -149,6 +150,10 @@ func ProofFromProto(pb *tmcrypto.Proof) (*Proof, error) { // If the length of the innerHashes slice isn't exactly correct, the result is nil. // Recursive impl. func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) []byte { + return computeHashFromAuntsWithHasher(newMerkleHasher(), index, total, leafHash, innerHashes) +} + +func computeHashFromAuntsWithHasher(hasher merkleHasher, index, total int64, leafHash []byte, innerHashes [][]byte) []byte { if index >= total || index < 0 || total <= 0 { return nil } @@ -166,17 +171,17 @@ func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]b } numLeft := getSplitPoint(total) if index < numLeft { - leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + leftHash := computeHashFromAuntsWithHasher(hasher, index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) if leftHash == nil { return nil } - return innerHash(leftHash, innerHashes[len(innerHashes)-1]) + return hasher.innerHash(leftHash, innerHashes[len(innerHashes)-1]) } - rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + rightHash := computeHashFromAuntsWithHasher(hasher, index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) if rightHash == nil { return nil } - return innerHash(innerHashes[len(innerHashes)-1], rightHash) + return hasher.innerHash(innerHashes[len(innerHashes)-1], rightHash) } } @@ -214,18 +219,22 @@ func (spn *ProofNode) FlattenAunts() [][]byte { // trails[0].Hash is the leaf hash for items[0]. // trails[i].Parent.Parent....Parent == root for all i. func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) { + return recursiveHasherTrails(newMerkleHasher(), items) +} + +func recursiveHasherTrails(hasher merkleHasher, items [][]byte) (trails []*ProofNode, root *ProofNode) { // Recursive impl. switch len(items) { case 0: - return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil} + return []*ProofNode{}, &ProofNode{hasher.emptyHash(), nil, nil, nil} case 1: - trail := &ProofNode{leafHash(items[0]), nil, nil, nil} + trail := &ProofNode{hasher.leafHash(items[0]), nil, nil, nil} return []*ProofNode{trail}, trail default: k := getSplitPoint(int64(len(items))) - lefts, leftRoot := trailsFromByteSlices(items[:k]) - rights, rightRoot := trailsFromByteSlices(items[k:]) - rootHash := innerHash(leftRoot.Hash, rightRoot.Hash) + lefts, leftRoot := recursiveHasherTrails(hasher, items[:k]) + rights, rightRoot := recursiveHasherTrails(hasher, items[k:]) + rootHash := hasher.innerHash(leftRoot.Hash, rightRoot.Hash) root := &ProofNode{rootHash, nil, nil, nil} leftRoot.Parent = root leftRoot.Right = rightRoot diff --git a/crypto/merkle/proof_value.go b/crypto/merkle/proof_value.go index 842dc82018..b2b819f1d2 100644 --- a/crypto/merkle/proof_value.go +++ b/crypto/merkle/proof_value.go @@ -79,15 +79,14 @@ func (op ValueOp) Run(args [][]byte) ([][]byte, error) { return nil, fmt.Errorf("expected 1 arg, got %v", len(args)) } value := args[0] - hasher := tmhash.New() - hasher.Write(value) - vhash := hasher.Sum(nil) + vhash := tmhash.Sum(value) bz := new(bytes.Buffer) // Wrap to hash the KVPair. encodeByteSlice(bz, op.key) //nolint: errcheck // does not error encodeByteSlice(bz, vhash) //nolint: errcheck // does not error - kvhash := leafHash(bz.Bytes()) + hasher := newMerkleHasher() + kvhash := hasher.leafHash(bz.Bytes()) if !bytes.Equal(kvhash, op.Proof.LeafHash) { return nil, fmt.Errorf("leaf hash mismatch: want %X got %X", op.Proof.LeafHash, kvhash) diff --git a/crypto/merkle/rfc6962_test.go b/crypto/merkle/rfc6962_test.go index c762cda56a..3d951f257f 100644 --- a/crypto/merkle/rfc6962_test.go +++ b/crypto/merkle/rfc6962_test.go @@ -30,6 +30,7 @@ func TestRFC6962Hasher(t *testing.T) { emptyLeafHash := leafHashTrail.Hash _, emptyHashTrail := trailsFromByteSlices([][]byte{}) emptyTreeHash := emptyHashTrail.Hash + hasher := newMerkleHasher() for _, tc := range []struct { desc string got []byte @@ -60,7 +61,7 @@ func TestRFC6962Hasher(t *testing.T) { { desc: "RFC6962 Node", want: "aa217fe888e47007fa15edab33c2b492a722cb106c64667fc2b044444de66bbb"[:tmhash.Size*2], - got: innerHash([]byte("N123"), []byte("N456")), + got: hasher.innerHash([]byte("N123"), []byte("N456")), }, } { tc := tc diff --git a/crypto/merkle/tree.go b/crypto/merkle/tree.go index 089c2f82ee..43d63e409e 100644 --- a/crypto/merkle/tree.go +++ b/crypto/merkle/tree.go @@ -7,16 +7,20 @@ import ( // HashFromByteSlices computes a Merkle tree where the leaves are the byte slice, // in the provided order. It follows RFC-6962. func HashFromByteSlices(items [][]byte) []byte { + return hashRecursive(newMerkleHasher(), items) +} + +func hashRecursive(hasher merkleHasher, items [][]byte) []byte { switch len(items) { case 0: - return emptyHash() + return hasher.emptyHash() case 1: - return leafHash(items[0]) + return hasher.leafHash(items[0]) default: k := getSplitPoint(int64(len(items))) - left := HashFromByteSlices(items[:k]) - right := HashFromByteSlices(items[k:]) - return innerHash(left, right) + left := hashRecursive(hasher, items[:k]) + right := hashRecursive(hasher, items[k:]) + return hasher.innerHash(left, right) } } @@ -61,16 +65,17 @@ func HashFromByteSlices(items [][]byte) []byte { // implementation for so little benefit. func HashFromByteSlicesIterative(input [][]byte) []byte { items := make([][]byte, len(input)) + hasher := newMerkleHasher() for i, leaf := range input { - items[i] = leafHash(leaf) + items[i] = hasher.leafHash(leaf) } size := len(items) for { switch size { case 0: - return emptyHash() + return hasher.emptyHash() case 1: return items[0] default: @@ -78,7 +83,7 @@ func HashFromByteSlicesIterative(input [][]byte) []byte { wp := 0 // write position for rp < size { if rp+1 < size { - items[wp] = innerHash(items[rp], items[rp+1]) + items[wp] = hasher.innerHash(items[rp], items[rp+1]) rp += 2 } else { items[wp] = items[rp]