From 9835b2bec15e709f5503a3af3838bff8f22e8e1b Mon Sep 17 00:00:00 2001 From: Marko Date: Tue, 1 Jun 2021 06:24:53 +0000 Subject: [PATCH] crypto/merkle: optimize merkle tree hashing (#6513) Upstream https://github.com/lazyledger/lazyledger-core/pull/351 to optimize merkle tree hashing ``` benchmark old ns/op new ns/op delta BenchmarkHashAlternatives/recursive-8 22914 21949 -4.21% BenchmarkHashAlternatives/iterative-8 21634 21939 +1.41% benchmark old allocs new allocs delta BenchmarkHashAlternatives/recursive-8 398 200 -49.75% BenchmarkHashAlternatives/iterative-8 399 301 -24.56% benchmark old bytes new bytes delta BenchmarkHashAlternatives/recursive-8 19088 6496 -65.97% BenchmarkHashAlternatives/iterative-8 21776 13984 -35.78% ``` cc @odeke-em @cuonglm --- crypto/merkle/hash.go | 18 ++++++++++++++++++ crypto/merkle/proof.go | 2 +- crypto/merkle/proof_key_path_test.go | 1 + crypto/merkle/proof_test.go | 2 +- crypto/merkle/tree.go | 18 ++++++++++++------ 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/crypto/merkle/hash.go b/crypto/merkle/hash.go index d45130fe58c..7ac391b64f8 100644 --- a/crypto/merkle/hash.go +++ b/crypto/merkle/hash.go @@ -1,6 +1,8 @@ package merkle import ( + "hash" + "github.com/tendermint/tendermint/crypto/tmhash" ) @@ -20,7 +22,23 @@ func leafHash(leaf []byte) []byte { return tmhash.Sum(append(leafPrefix, leaf...)) } +// returns tmhash(0x00 || leaf) +func leafHashOpt(s hash.Hash, leaf []byte) []byte { + s.Reset() + s.Write(leafPrefix) + s.Write(leaf) + return s.Sum(nil) +} + // returns tmhash(0x01 || left || right) func innerHash(left []byte, right []byte) []byte { return tmhash.Sum(append(innerPrefix, append(left, right...)...)) } + +func innerHashOpt(s hash.Hash, left []byte, right []byte) []byte { + s.Reset() + s.Write(innerPrefix) + s.Write(left) + s.Write(right) + return s.Sum(nil) +} diff --git a/crypto/merkle/proof.go b/crypto/merkle/proof.go index ab43f30e76d..2994e804871 100644 --- a/crypto/merkle/proof.go +++ b/crypto/merkle/proof.go @@ -50,13 +50,13 @@ func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) { // Verify that the Proof proves the root hash. // Check sp.Index/sp.Total manually if needed func (sp *Proof) Verify(rootHash []byte, leaf []byte) error { - leafHash := leafHash(leaf) if sp.Total < 0 { return errors.New("proof total must be positive") } if sp.Index < 0 { return errors.New("proof index cannot be negative") } + leafHash := leafHash(leaf) if !bytes.Equal(sp.LeafHash, leafHash) { return fmt.Errorf("invalid leaf hash: wanted %X got %X", leafHash, sp.LeafHash) } diff --git a/crypto/merkle/proof_key_path_test.go b/crypto/merkle/proof_key_path_test.go index 22e3e21ca3d..0cc947643f5 100644 --- a/crypto/merkle/proof_key_path_test.go +++ b/crypto/merkle/proof_key_path_test.go @@ -35,6 +35,7 @@ func TestKeyPath(t *testing.T) { res, err := KeyPathToKeys(path.String()) require.Nil(t, err) + require.Equal(t, len(keys), len(res)) for i, key := range keys { require.Equal(t, key, res[i]) diff --git a/crypto/merkle/proof_test.go b/crypto/merkle/proof_test.go index 22ab900f06d..f0d2f868969 100644 --- a/crypto/merkle/proof_test.go +++ b/crypto/merkle/proof_test.go @@ -171,12 +171,12 @@ func TestProofValidateBasic(t *testing.T) { } } func TestVoteProtobuf(t *testing.T) { - _, proofs := ProofsFromByteSlices([][]byte{ []byte("apple"), []byte("watermelon"), []byte("kiwi"), }) + testCases := []struct { testName string v1 *Proof diff --git a/crypto/merkle/tree.go b/crypto/merkle/tree.go index 089c2f82ee0..896b67c5952 100644 --- a/crypto/merkle/tree.go +++ b/crypto/merkle/tree.go @@ -1,22 +1,28 @@ package merkle import ( + "crypto/sha256" + "hash" "math/bits" ) // HashFromByteSlices computes a Merkle tree where the leaves are the byte slice, // in the provided order. It follows RFC-6962. func HashFromByteSlices(items [][]byte) []byte { + return hashFromByteSlices(sha256.New(), items) +} + +func hashFromByteSlices(sha hash.Hash, items [][]byte) []byte { switch len(items) { case 0: return emptyHash() case 1: - return leafHash(items[0]) + return leafHashOpt(sha, items[0]) default: k := getSplitPoint(int64(len(items))) - left := HashFromByteSlices(items[:k]) - right := HashFromByteSlices(items[k:]) - return innerHash(left, right) + left := hashFromByteSlices(sha, items[:k]) + right := hashFromByteSlices(sha, items[k:]) + return innerHashOpt(sha, left, right) } } @@ -61,7 +67,7 @@ func HashFromByteSlices(items [][]byte) []byte { // implementation for so little benefit. func HashFromByteSlicesIterative(input [][]byte) []byte { items := make([][]byte, len(input)) - + sha := sha256.New() for i, leaf := range input { items[i] = leafHash(leaf) } @@ -78,7 +84,7 @@ func HashFromByteSlicesIterative(input [][]byte) []byte { wp := 0 // write position for rp < size { if rp+1 < size { - items[wp] = innerHash(items[rp], items[rp+1]) + items[wp] = innerHashOpt(sha, items[rp], items[rp+1]) rp += 2 } else { items[wp] = items[rp]