Skip to content

Commit

Permalink
crypto/merkle: reuse hasher, reduce allocations
Browse files Browse the repository at this point in the history
adaptation/backport of celestiaorg#351
  • Loading branch information
shyba committed Jan 31, 2023
1 parent e6f0c0f commit 71f9333
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 36 deletions.
40 changes: 27 additions & 13 deletions crypto/merkle/hash.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@
package merkle

import (
"hash"

"github.com/tendermint/tendermint/crypto/tmhash"
)

// TODO: make these have a large predefined capacity
var (
leafPrefix = []byte{0}
innerPrefix = []byte{1}
const (
leafPrefix = 0
innerPrefix = 1
)

// returns tmhash(<empty>)
func emptyHash() []byte {
return tmhash.Sum([]byte{})
type merkleHasher struct {
state hash.Hash
}

func newMerkleHasher() merkleHasher {
return merkleHasher{state: tmhash.New()}
}

func (mh *merkleHasher) emptyHash() []byte {
mh.state.Reset()
return mh.state.Sum(nil)
}

// returns tmhash(0x00 || leaf)
func leafHash(leaf []byte) []byte {
return tmhash.Sum(append(leafPrefix, leaf...))
func (mh *merkleHasher) leafHash(leaf []byte) []byte {
mh.state.Reset()
mh.state.Write([]byte{leafPrefix})
mh.state.Write(leaf)
return mh.state.Sum(nil)
}

// returns tmhash(0x01 || left || right)
func innerHash(left []byte, right []byte) []byte {
return tmhash.Sum(append(innerPrefix, append(left, right...)...))
func (mh *merkleHasher) innerHash(left []byte, right []byte) []byte {
mh.state.Reset()
mh.state.Write([]byte{innerPrefix})
mh.state.Write(left)
mh.state.Write(right)
return mh.state.Sum(nil)
}
29 changes: 19 additions & 10 deletions crypto/merkle/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) {
// Verify that the Proof proves the root hash.
// Check sp.Index/sp.Total manually if needed
func (sp *Proof) Verify(rootHash []byte, leaf []byte) error {
leafHash := leafHash(leaf)
hasher := newMerkleHasher()
leafHash := hasher.leafHash(leaf)
if sp.Total < 0 {
return errors.New("proof total must be positive")
}
Expand Down Expand Up @@ -149,6 +150,10 @@ func ProofFromProto(pb *tmcrypto.Proof) (*Proof, error) {
// If the length of the innerHashes slice isn't exactly correct, the result is nil.
// Recursive impl.
func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) []byte {
return computeHashFromAuntsWithHasher(newMerkleHasher(), index, total, leafHash, innerHashes)
}

func computeHashFromAuntsWithHasher(hasher merkleHasher, index, total int64, leafHash []byte, innerHashes [][]byte) []byte {
if index >= total || index < 0 || total <= 0 {
return nil
}
Expand All @@ -166,17 +171,17 @@ func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]b
}
numLeft := getSplitPoint(total)
if index < numLeft {
leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
leftHash := computeHashFromAuntsWithHasher(hasher, index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
if leftHash == nil {
return nil
}
return innerHash(leftHash, innerHashes[len(innerHashes)-1])
return hasher.innerHash(leftHash, innerHashes[len(innerHashes)-1])
}
rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
rightHash := computeHashFromAuntsWithHasher(hasher, index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
if rightHash == nil {
return nil
}
return innerHash(innerHashes[len(innerHashes)-1], rightHash)
return hasher.innerHash(innerHashes[len(innerHashes)-1], rightHash)
}
}

Expand Down Expand Up @@ -214,18 +219,22 @@ func (spn *ProofNode) FlattenAunts() [][]byte {
// trails[0].Hash is the leaf hash for items[0].
// trails[i].Parent.Parent....Parent == root for all i.
func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) {
return recursiveHasherTrails(newMerkleHasher(), items)
}

func recursiveHasherTrails(hasher merkleHasher, items [][]byte) (trails []*ProofNode, root *ProofNode) {
// Recursive impl.
switch len(items) {
case 0:
return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil}
return []*ProofNode{}, &ProofNode{hasher.emptyHash(), nil, nil, nil}
case 1:
trail := &ProofNode{leafHash(items[0]), nil, nil, nil}
trail := &ProofNode{hasher.leafHash(items[0]), nil, nil, nil}
return []*ProofNode{trail}, trail
default:
k := getSplitPoint(int64(len(items)))
lefts, leftRoot := trailsFromByteSlices(items[:k])
rights, rightRoot := trailsFromByteSlices(items[k:])
rootHash := innerHash(leftRoot.Hash, rightRoot.Hash)
lefts, leftRoot := recursiveHasherTrails(hasher, items[:k])
rights, rightRoot := recursiveHasherTrails(hasher, items[k:])
rootHash := hasher.innerHash(leftRoot.Hash, rightRoot.Hash)
root := &ProofNode{rootHash, nil, nil, nil}
leftRoot.Parent = root
leftRoot.Right = rightRoot
Expand Down
7 changes: 3 additions & 4 deletions crypto/merkle/proof_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,14 @@ func (op ValueOp) Run(args [][]byte) ([][]byte, error) {
return nil, fmt.Errorf("expected 1 arg, got %v", len(args))
}
value := args[0]
hasher := tmhash.New()
hasher.Write(value)
vhash := hasher.Sum(nil)
vhash := tmhash.Sum(value)

bz := new(bytes.Buffer)
// Wrap <op.Key, vhash> to hash the KVPair.
encodeByteSlice(bz, op.key) //nolint: errcheck // does not error
encodeByteSlice(bz, vhash) //nolint: errcheck // does not error
kvhash := leafHash(bz.Bytes())
hasher := newMerkleHasher()
kvhash := hasher.leafHash(bz.Bytes())

if !bytes.Equal(kvhash, op.Proof.LeafHash) {
return nil, fmt.Errorf("leaf hash mismatch: want %X got %X", op.Proof.LeafHash, kvhash)
Expand Down
3 changes: 2 additions & 1 deletion crypto/merkle/rfc6962_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestRFC6962Hasher(t *testing.T) {
emptyLeafHash := leafHashTrail.Hash
_, emptyHashTrail := trailsFromByteSlices([][]byte{})
emptyTreeHash := emptyHashTrail.Hash
hasher := newMerkleHasher()
for _, tc := range []struct {
desc string
got []byte
Expand Down Expand Up @@ -60,7 +61,7 @@ func TestRFC6962Hasher(t *testing.T) {
{
desc: "RFC6962 Node",
want: "aa217fe888e47007fa15edab33c2b492a722cb106c64667fc2b044444de66bbb"[:tmhash.Size*2],
got: innerHash([]byte("N123"), []byte("N456")),
got: hasher.innerHash([]byte("N123"), []byte("N456")),
},
} {
tc := tc
Expand Down
21 changes: 13 additions & 8 deletions crypto/merkle/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@ import (
// HashFromByteSlices computes a Merkle tree where the leaves are the byte slice,
// in the provided order. It follows RFC-6962.
func HashFromByteSlices(items [][]byte) []byte {
return hashRecursive(newMerkleHasher(), items)
}

func hashRecursive(hasher merkleHasher, items [][]byte) []byte {
switch len(items) {
case 0:
return emptyHash()
return hasher.emptyHash()
case 1:
return leafHash(items[0])
return hasher.leafHash(items[0])
default:
k := getSplitPoint(int64(len(items)))
left := HashFromByteSlices(items[:k])
right := HashFromByteSlices(items[k:])
return innerHash(left, right)
left := hashRecursive(hasher, items[:k])
right := hashRecursive(hasher, items[k:])
return hasher.innerHash(left, right)
}
}

Expand Down Expand Up @@ -61,24 +65,25 @@ func HashFromByteSlices(items [][]byte) []byte {
// implementation for so little benefit.
func HashFromByteSlicesIterative(input [][]byte) []byte {
items := make([][]byte, len(input))
hasher := newMerkleHasher()

for i, leaf := range input {
items[i] = leafHash(leaf)
items[i] = hasher.leafHash(leaf)
}

size := len(items)
for {
switch size {
case 0:
return emptyHash()
return hasher.emptyHash()
case 1:
return items[0]
default:
rp := 0 // read position
wp := 0 // write position
for rp < size {
if rp+1 < size {
items[wp] = innerHash(items[rp], items[rp+1])
items[wp] = hasher.innerHash(items[rp], items[rp+1])
rp += 2
} else {
items[wp] = items[rp]
Expand Down

0 comments on commit 71f9333

Please sign in to comment.