Skip to content

Commit

Permalink
Merge pull request #6 from iden3/fix_panic_and_refactor
Browse files Browse the repository at this point in the history
Fix panic and refactor
  • Loading branch information
olomix authored Apr 21, 2022
2 parents 31dadad + 9b973f4 commit 1a61c1c
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 136 deletions.
128 changes: 94 additions & 34 deletions db/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import (

var debug = false

func newTestingMerkle(t *testing.T, sto merkletree.Storage, numLevels int) *merkletree.MerkleTree {
func newTestingMerkle(t *testing.T, sto merkletree.Storage,
numLevels int) *merkletree.MerkleTree {
mt, err := merkletree.NewMerkleTree(context.Background(), sto, numLevels)
require.NoError(t, err)
return mt
Expand Down Expand Up @@ -156,7 +157,9 @@ func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) {

ctx := context.Background()

node := merkletree.NewNodeLeaf(&merkletree.Hash{1, 2, 3}, &merkletree.Hash{4, 5, 6})
node := merkletree.NewNodeLeaf(
&merkletree.Hash{1, 2, 3},
&merkletree.Hash{4, 5, 6})
k, err := node.Key()
require.NoError(t, err)
err = sto1.Put(ctx, k[:], node)
Expand Down Expand Up @@ -390,15 +393,21 @@ func TestNewTree(t *testing.T, sto merkletree.Storage) {
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(ctx, big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"13578938674299138072471463694055224830892726234048532520316387704878000008795",
mt.Root().BigInt().String())

err = mt.Add(ctx, big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
mt.Root().BigInt().String())

err = mt.Add(ctx, big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"14204494359367183802864593755198662203838502594566452929175967972147978322084",
mt.Root().BigInt().String())

dbRoot, err := sto.GetRoot(ctx)
require.Nil(t, err)
Expand All @@ -408,11 +417,14 @@ func TestNewTree(t *testing.T, sto merkletree.Storage) {
assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)

assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
assert.True(t, merkletree.VerifyProof(
mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !merkletree.VerifyProof(
mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
}

func TestAddDifferentOrder(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestAddDifferentOrder(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
ctx := context.Background()

mt1 := newTestingMerkle(t, sto, 140)
Expand All @@ -434,7 +446,9 @@ func TestAddDifferentOrder(t *testing.T, sto merkletree.Storage, sto2 merkletree
}

assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
assert.Equal(t,
"3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f",
mt1.Root().Hex())
}

func TestAddRepeatedIndex(t *testing.T, sto merkletree.Storage) {
Expand Down Expand Up @@ -629,8 +643,11 @@ func TestVerifyProofCases(t *testing.T, sto merkletree.Storage) {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
assert.True(t,
merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
assert.Equal(t,
"0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", //nolint:lll
hex.EncodeToString(proof.Bytes()))

for i := 8; i < 32; i++ {
proof, _, err = mt.GenerateProof(ctx, big.NewInt(int64(i)), nil)
Expand All @@ -646,8 +663,11 @@ func TestVerifyProofCases(t *testing.T, sto merkletree.Storage) {
}
assert.Equal(t, proof.Existence, false)
// assert.True(t, proof.nodeAux == nil)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
assert.True(t,
merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
assert.Equal(t,
"0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", //nolint:lll
hex.EncodeToString(proof.Bytes()))

// Non-existence proof, diff. node aux
proof, _, err = mt.GenerateProof(ctx, big.NewInt(10), nil)
Expand All @@ -656,8 +676,11 @@ func TestVerifyProofCases(t *testing.T, sto merkletree.Storage) {
}
assert.Equal(t, proof.Existence, false)
assert.True(t, proof.NodeAux != nil)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
assert.True(t,
merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
assert.Equal(t,
"0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", //nolint:lll
hex.EncodeToString(proof.Bytes()))
}

func TestVerifyProofFalse(t *testing.T, sto merkletree.Storage) {
Expand Down Expand Up @@ -689,9 +712,20 @@ func TestVerifyProofFalse(t *testing.T, sto merkletree.Storage) {
// Now we change the proof from existence to non-existence, and add e's
// data as auxiliary node.
proof.Existence = false
proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))),
Value: merkletree.NewHashFromBigInt(big.NewInt(4))}
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
proof.NodeAux = &merkletree.NodeAux{
Key: hashFromInt(big.NewInt(int64(4))),
Value: hashFromInt(big.NewInt(4))}
assert.True(t,
!merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)),
big.NewInt(0)))
}

func hashFromInt(in *big.Int) *merkletree.Hash {
h, err := merkletree.NewHashFromBigInt(in)
if err != nil {
panic(err)
}
return h
}

func TestGraphViz(t *testing.T, sto merkletree.Storage) {
Expand Down Expand Up @@ -744,22 +778,30 @@ func TestDelete(t *testing.T, sto merkletree.Storage) {
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(ctx, big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"13578938674299138072471463694055224830892726234048532520316387704878000008795",
mt.Root().BigInt().String())

err = mt.Add(ctx, big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
mt.Root().BigInt().String())

err = mt.Add(ctx, big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"14204494359367183802864593755198662203838502594566452929175967972147978322084",
mt.Root().BigInt().String())

// mt.PrintGraphViz(nil)

err = mt.Delete(ctx, big.NewInt(33))
// mt.PrintGraphViz(nil)
assert.Nil(t, err)
assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"15550352095346187559699212771793131433118240951738528922418613687814377955591",
mt.Root().BigInt().String())

err = mt.Delete(ctx, big.NewInt(1234))
assert.Nil(t, err)
Expand All @@ -772,7 +814,8 @@ func TestDelete(t *testing.T, sto merkletree.Storage) {
assert.Equal(t, mt.Root(), dbRoot)
}

func TestDelete2(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestDelete2(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
ctx := context.Background()
mt := newTestingMerkle(t, sto, 140)
for i := 0; i < 8; i++ {
Expand Down Expand Up @@ -805,7 +848,8 @@ func TestDelete2(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage)
assert.Equal(t, mt2.Root(), mt.Root())
}

func TestDelete3(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestDelete3(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)

ctx := context.Background()
Expand All @@ -815,18 +859,23 @@ func TestDelete3(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage)
err = mt.Add(ctx, big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)

assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"19060075022714027595905950662613111880864833370144986660188929919683258088314",
mt.Root().BigInt().String())
err = mt.Delete(ctx, big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"849831128489032619062850458217693666094013083866167024127442191257793527951",
mt.Root().BigInt().String())

mt2 := newTestingMerkle(t, sto2, 140)
err = mt2.Add(ctx, big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}

func TestDelete4(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestDelete4(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)

ctx := context.Background()
Expand All @@ -839,10 +888,14 @@ func TestDelete4(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage)
err = mt.Add(ctx, big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)

assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"14109632483797541575275728657193822866549917334388996328141438956557066918117",
mt.Root().BigInt().String())
err = mt.Delete(ctx, big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"159935162486187606489815340465698714590556679404589449576549073038844694972",
mt.Root().BigInt().String())

mt2 := newTestingMerkle(t, sto2, 140)
err = mt2.Add(ctx, big.NewInt(2), big.NewInt(2))
Expand All @@ -852,7 +905,8 @@ func TestDelete4(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage)
assert.Equal(t, mt2.Root(), mt.Root())
}

func TestDelete5(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestDelete5(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
ctx := context.Background()
mt, err := merkletree.NewMerkleTree(ctx, sto, 10)
assert.Nil(t, err)
Expand All @@ -861,11 +915,15 @@ func TestDelete5(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage)
assert.Nil(t, err)
err = mt.Add(ctx, big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
mt.Root().BigInt().String())

err = mt.Delete(ctx, big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
assert.Equal(t,
"18869260084287237667925661423624848342947598951870765316380602291081195309822",
mt.Root().BigInt().String())

mt2 := newTestingMerkle(t, sto2, 140)
err = mt2.Add(ctx, big.NewInt(33), big.NewInt(44))
Expand Down Expand Up @@ -897,7 +955,8 @@ func TestDeleteNonExistingKeys(t *testing.T, sto merkletree.Storage) {
assert.Equal(t, merkletree.ErrKeyNotFound, err)
}

func TestDumpLeafsImportLeafs(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
func TestDumpLeafsImportLeafs(t *testing.T, sto merkletree.Storage,
sto2 merkletree.Storage) {
ctx := context.Background()
mt, err := merkletree.NewMerkleTree(ctx, sto, 140)
require.Nil(t, err)
Expand Down Expand Up @@ -967,7 +1026,8 @@ func TestAddAndGetCircomProof(t *testing.T, sto merkletree.Storage) {
assert.Equal(t, "55", cpp.NewKey.String())
assert.Equal(t, "66", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]",
fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings))
}

Expand Down
44 changes: 18 additions & 26 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package merkletree
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/big"
"strings"
Expand All @@ -27,8 +28,11 @@ func (h Hash) MarshalText() ([]byte, error) {
// UnmarshalText implements the unmarshaler for the Hash type
func (h *Hash) UnmarshalText(b []byte) error {
ha, err := NewHashFromString(string(b))
if err != nil {
return err
}
copy(h[:], ha[:])
return err
return nil
}

// String returns decimal representation in string format of the Hash
Expand Down Expand Up @@ -58,15 +62,6 @@ func (h *Hash) BigInt() *big.Int {
return new(big.Int).SetBytes(SwapEndianness(h[:]))
}

// Bytes returns the []byte representation of the *Hash, which always is 32
// bytes length.
func (h *Hash) Bytes() []byte {
bi := new(big.Int).SetBytes(h[:]).Bytes()
b := [32]byte{}
copy(b[:], SwapEndianness(bi[:]))
return b[:]
}

func (h *Hash) Equals(h2 *Hash) bool {
return bytes.Equal(h[:], h2[:])
}
Expand All @@ -87,22 +82,14 @@ func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
}

// NewHashFromBigInt returns a *Hash representation of the given *big.Int
func NewHashFromBigInt(b *big.Int) *Hash {
func NewHashFromBigInt(b *big.Int) (*Hash, error) {
if !cryptoUtils.CheckBigIntInField(b) {
return nil, errors.New(
"NewHashFromBigInt: Value not inside the Finite Field")
}
r := &Hash{}
copy(r[:], SwapEndianness(b.Bytes()))
return r
}

// NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
// in the process. This is the intended method to get a *Hash from a byte array
// that previously has ben generated by the Hash.Bytes() method.
func NewHashFromBytes(b []byte) (*Hash, error) {
if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
}
var h Hash
copy(h[:], SwapEndianness(b))
return &h, nil
return r, nil
}

// NewHashFromHex returns a *Hash representation of the given hex string
Expand All @@ -112,7 +99,12 @@ func NewHashFromHex(h string) (*Hash, error) {
if err != nil {
return nil, err
}
return NewHashFromBytes(SwapEndianness(b[:]))
var hash Hash
if len(b) != len(hash) {
return nil, errors.New("invalid hash length")
}
copy(hash[:], b)
return &hash, nil
}

// NewHashFromString returns a *Hash representation of the given decimal string
Expand All @@ -121,5 +113,5 @@ func NewHashFromString(s string) (*Hash, error) {
if !ok {
return nil, fmt.Errorf("Can not parse string to Hash")
}
return NewHashFromBigInt(bi), nil
return NewHashFromBigInt(bi)
}
Loading

0 comments on commit 1a61c1c

Please sign in to comment.