From e17a6ca9af632d4292decbc3496254f460a92518 Mon Sep 17 00:00:00 2001 From: Dan Laine Date: Wed, 31 May 2023 09:53:44 -0400 Subject: [PATCH] `x/sync` -- use for sending Range Proofs (#1537) Co-authored-by: dboehm-avalabs Co-authored-by: Ron Kuris --- proto/pb/sync/sync.pb.go | 6 +- proto/sync/sync.proto | 2 +- utils/hashing/hashing.go | 7 +- x/merkledb/codec.go | 109 --------------- x/merkledb/codec_test.go | 149 +++----------------- x/merkledb/proof.go | 126 +++++++++++++++++ x/merkledb/proof_test.go | 256 +++++++++++++++++++++++----------- x/sync/client.go | 11 +- x/sync/client_test.go | 16 ++- x/sync/network_server.go | 3 +- x/sync/network_server_test.go | 28 ++-- 11 files changed, 368 insertions(+), 345 deletions(-) diff --git a/proto/pb/sync/sync.pb.go b/proto/pb/sync/sync.pb.go index eed82c41a71c..438b34edbdac 100644 --- a/proto/pb/sync/sync.pb.go +++ b/proto/pb/sync/sync.pb.go @@ -390,7 +390,7 @@ type SerializedPath struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - NibbleLength uint32 `protobuf:"varint,1,opt,name=nibble_length,json=nibbleLength,proto3" json:"nibble_length,omitempty"` + NibbleLength uint64 `protobuf:"varint,1,opt,name=nibble_length,json=nibbleLength,proto3" json:"nibble_length,omitempty"` Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` } @@ -426,7 +426,7 @@ func (*SerializedPath) Descriptor() ([]byte, []int) { return file_sync_sync_proto_rawDescGZIP(), []int{5} } -func (x *SerializedPath) GetNibbleLength() uint32 { +func (x *SerializedPath) GetNibbleLength() uint64 { if x != nil { return x.NibbleLength } @@ -832,7 +832,7 @@ var file_sync_sync_proto_rawDesc = []byte{ 0x61, 0x79, 0x62, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x4b, 0x0a, 0x0e, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x64, 0x50, 0x61, 0x74, 0x68, 0x12, 0x23, 0x0a, 0x0d, 0x6e, 0x69, 0x62, 0x62, 0x6c, 0x65, 0x5f, 0x6c, 0x65, 0x6e, - 0x67, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x6e, 0x69, 0x62, 0x62, 0x6c, + 0x67, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0c, 0x6e, 0x69, 0x62, 0x62, 0x6c, 0x65, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x41, 0x0a, 0x0a, 0x4d, 0x61, 0x79, 0x62, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, diff --git a/proto/sync/sync.proto b/proto/sync/sync.proto index 6a2dfc031900..29369bc19658 100644 --- a/proto/sync/sync.proto +++ b/proto/sync/sync.proto @@ -47,7 +47,7 @@ message KeyChange { // SerializedPath is the serialized representation of a path. message SerializedPath { - uint32 nibble_length = 1; + uint64 nibble_length = 1; bytes value = 2; } diff --git a/utils/hashing/hashing.go b/utils/hashing/hashing.go index a74ef8fe062e..f2c79e235a64 100644 --- a/utils/hashing/hashing.go +++ b/utils/hashing/hashing.go @@ -5,6 +5,7 @@ package hashing import ( "crypto/sha256" + "errors" "fmt" "io" @@ -16,6 +17,8 @@ const ( AddrLen = ripemd160.Size ) +var ErrInvalidHashLen = errors.New("invalid hash length") + // Hash256 A 256 bit long hash value. type Hash256 = [HashLen]byte @@ -85,7 +88,7 @@ func Checksum(bytes []byte, length int) []byte { func ToHash256(bytes []byte) (Hash256, error) { hash := Hash256{} if bytesLen := len(bytes); bytesLen != HashLen { - return hash, fmt.Errorf("expected 32 bytes but got %d", bytesLen) + return hash, fmt.Errorf("%w: expected 32 bytes but got %d", ErrInvalidHashLen, bytesLen) } copy(hash[:], bytes) return hash, nil @@ -94,7 +97,7 @@ func ToHash256(bytes []byte) (Hash256, error) { func ToHash160(bytes []byte) (Hash160, error) { hash := Hash160{} if bytesLen := len(bytes); bytesLen != ripemd160.Size { - return hash, fmt.Errorf("expected 20 bytes but got %d", bytesLen) + return hash, fmt.Errorf("%w: expected 20 bytes but got %d", ErrInvalidHashLen, bytesLen) } copy(hash[:], bytes) return hash, nil diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 3bc0730fade0..f8c44e263031 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -74,7 +74,6 @@ type EncoderDecoder interface { type Encoder interface { EncodeProof(version uint16, p *Proof) ([]byte, error) EncodeChangeProof(version uint16, p *ChangeProof) ([]byte, error) - EncodeRangeProof(version uint16, p *RangeProof) ([]byte, error) encodeDBNode(version uint16, n *dbNode) ([]byte, error) encodeHashValues(version uint16, hv *hashValues) ([]byte, error) @@ -83,7 +82,6 @@ type Encoder interface { type Decoder interface { DecodeProof(bytes []byte, p *Proof) (uint16, error) DecodeChangeProof(bytes []byte, p *ChangeProof) (uint16, error) - DecodeRangeProof(bytes []byte, p *RangeProof) (uint16, error) decodeDBNode(bytes []byte, n *dbNode) (uint16, error) } @@ -161,37 +159,6 @@ func (c *codecImpl) EncodeChangeProof(version uint16, proof *ChangeProof) ([]byt return buf.Bytes(), nil } -func (c *codecImpl) EncodeRangeProof(version uint16, proof *RangeProof) ([]byte, error) { - if proof == nil { - return nil, errEncodeNil - } - - if version != codecVersion { - return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) - } - - buf := &bytes.Buffer{} - if err := c.encodeInt(buf, int(version)); err != nil { - return nil, err - } - if err := c.encodeProofPath(buf, proof.StartProof); err != nil { - return nil, err - } - if err := c.encodeProofPath(buf, proof.EndProof); err != nil { - return nil, err - } - if err := c.encodeInt(buf, len(proof.KeyValues)); err != nil { - return nil, err - } - for _, kv := range proof.KeyValues { - if err := c.encodeKeyValue(kv, buf); err != nil { - return nil, err - } - } - - return buf.Bytes(), nil -} - func (c *codecImpl) encodeDBNode(version uint16, n *dbNode) ([]byte, error) { if n == nil { return nil, errEncodeNil @@ -356,54 +323,6 @@ func (c *codecImpl) DecodeChangeProof(b []byte, proof *ChangeProof) (uint16, err return codecVersion, nil } -func (c *codecImpl) DecodeRangeProof(b []byte, proof *RangeProof) (uint16, error) { - if proof == nil { - return 0, errDecodeNil - } - if minRangeProofLen > len(b) { - return 0, io.ErrUnexpectedEOF - } - - var ( - src = bytes.NewReader(b) - err error - ) - gotCodecVersion, err := c.decodeInt(src) - if err != nil { - return 0, err - } - if codecVersion != gotCodecVersion { - return 0, fmt.Errorf("%w: %d", errInvalidCodecVersion, gotCodecVersion) - } - if proof.StartProof, err = c.decodeProofPath(src); err != nil { - return 0, err - } - if proof.EndProof, err = c.decodeProofPath(src); err != nil { - return 0, err - } - - numKeyValues, err := c.decodeInt(src) - if err != nil { - return 0, err - } - if numKeyValues < 0 { - return 0, errNegativeNumKeyValues - } - if numKeyValues > src.Len()/minKeyValueLen { - return 0, io.ErrUnexpectedEOF - } - proof.KeyValues = make([]KeyValue, numKeyValues) - for i := range proof.KeyValues { - if proof.KeyValues[i], err = c.decodeKeyValue(src); err != nil { - return 0, err - } - } - if src.Len() != 0 { - return 0, errExtraSpace - } - return codecVersion, nil -} - func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) (uint16, error) { if n == nil { return 0, errDecodeNil @@ -491,24 +410,6 @@ func (c *codecImpl) decodeKeyChange(src *bytes.Reader) (KeyChange, error) { return result, nil } -func (c *codecImpl) decodeKeyValue(src *bytes.Reader) (KeyValue, error) { - if minKeyValueLen > src.Len() { - return KeyValue{}, io.ErrUnexpectedEOF - } - - var ( - result KeyValue - err error - ) - if result.Key, err = c.decodeByteSlice(src); err != nil { - return result, err - } - if result.Value, err = c.decodeByteSlice(src); err != nil { - return result, err - } - return result, nil -} - func (c *codecImpl) encodeKeyChange(kv KeyChange, dst io.Writer) error { if err := c.encodeByteSlice(dst, kv.Key); err != nil { return err @@ -519,16 +420,6 @@ func (c *codecImpl) encodeKeyChange(kv KeyChange, dst io.Writer) error { return nil } -func (c *codecImpl) encodeKeyValue(kv KeyValue, dst io.Writer) error { - if err := c.encodeByteSlice(dst, kv.Key); err != nil { - return err - } - if err := c.encodeByteSlice(dst, kv.Value); err != nil { - return err - } - return nil -} - func (*codecImpl) encodeBool(dst io.Writer, value bool) error { bytesValue := falseBytes if value { diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 50790b87df4f..279db92a16b4 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -21,6 +21,8 @@ import ( func newRandomProofNode(r *rand.Rand) ProofNode { key := make([]byte, r.Intn(32)) // #nosec G404 _, _ = r.Read(key) // #nosec G404 + serializedKey := newPath(key).Serialize() + val := make([]byte, r.Intn(64)) // #nosec G404 _, _ = r.Read(val) // #nosec G404 @@ -32,22 +34,28 @@ func newRandomProofNode(r *rand.Rand) ProofNode { children[byte(j)] = childID } } - // use the hash instead when length is greater than the hash length - if len(val) >= HashLength { - val = hashing.ComputeHash256(val) - } else if len(val) == 0 { - // We do this because when we encode a value of []byte{} we will later - // decode it as nil. - // Doing this prevents inconsistency when comparing the encoded and - // decoded values. - // Calling nilEmptySlices doesn't set this because it is a private - // variable on the struct - val = nil + + hasValue := rand.Intn(2) == 1 // #nosec G404 + var valueOrHash Maybe[[]byte] + if hasValue { + // use the hash instead when length is greater than the hash length + if len(val) >= HashLength { + val = hashing.ComputeHash256(val) + } else if len(val) == 0 { + // We do this because when we encode a value of []byte{} we will later + // decode it as nil. + // Doing this prevents inconsistency when comparing the encoded and + // decoded values. + // Calling nilEmptySlices doesn't set this because it is a private + // variable on the struct + val = nil + } + valueOrHash = Some(val) } return ProofNode{ - KeyPath: newPath(key).Serialize(), - ValueOrHash: Some(val), + KeyPath: serializedKey, + ValueOrHash: valueOrHash, Children: children, } } @@ -290,29 +298,6 @@ func FuzzCodecChangeProofCanonical(f *testing.F) { ) } -func FuzzCodecRangeProofCanonical(f *testing.F) { - f.Fuzz( - func( - t *testing.T, - b []byte, - ) { - require := require.New(t) - - codec := Codec.(*codecImpl) - proof := &RangeProof{} - got, err := codec.DecodeRangeProof(b, proof) - if err != nil { - return - } - - // Encoding [proof] should be the same as [b]. - buf, err := codec.EncodeRangeProof(got, proof) - require.NoError(err) - require.Equal(b, buf) - }, - ) -} - func FuzzCodecDBNodeCanonical(f *testing.F) { f.Fuzz( func( @@ -431,66 +416,6 @@ func FuzzCodecChangeProofDeterministic(f *testing.F) { ) } -func FuzzCodecRangeProofDeterministic(f *testing.F) { - f.Fuzz( - func( - t *testing.T, - randSeed int, - numStartProofNodes uint, - numEndProofNodes uint, - numKeyValues uint, - ) { - r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404 - - var rootID ids.ID - _, _ = r.Read(rootID[:]) // #nosec G404 - - startProofNodes := make([]ProofNode, numStartProofNodes) - for i := range startProofNodes { - startProofNodes[i] = newRandomProofNode(r) - } - - endProofNodes := make([]ProofNode, numEndProofNodes) - for i := range endProofNodes { - endProofNodes[i] = newRandomProofNode(r) - } - - keyValues := make([]KeyValue, numKeyValues) - for i := range keyValues { - key := make([]byte, r.Intn(32)) // #nosec G404 - _, _ = r.Read(key) // #nosec G404 - val := make([]byte, r.Intn(32)) // #nosec G404 - _, _ = r.Read(val) // #nosec G404 - keyValues[i] = KeyValue{ - Key: key, - Value: val, - } - } - - proof := RangeProof{ - StartProof: startProofNodes, - EndProof: endProofNodes, - KeyValues: keyValues, - } - - proofBytes, err := Codec.EncodeRangeProof(Version, &proof) - require.NoError(t, err) - - var gotProof RangeProof - _, err = Codec.DecodeRangeProof(proofBytes, &gotProof) - require.NoError(t, err) - - nilEmptySlices(&proof) - nilEmptySlices(&gotProof) - require.Equal(t, proof, gotProof) - - proofBytes2, err := Codec.EncodeRangeProof(Version, &gotProof) - require.NoError(t, err) - require.Equal(t, proofBytes, proofBytes2) - }, - ) -} - func FuzzCodecDBNodeDeterministic(f *testing.F) { f.Fuzz( func( @@ -604,38 +529,6 @@ func TestCodec_DecodeChangeProof(t *testing.T) { require.ErrorIs(err, errNegativeNumKeyValues) } -func TestCodec_DecodeRangeProof(t *testing.T) { - require := require.New(t) - - _, err := Codec.DecodeRangeProof([]byte{1}, nil) - require.ErrorIs(err, errDecodeNil) - - var ( - parsedProof RangeProof - tooShortBytes = make([]byte, minRangeProofLen-1) - ) - _, err = Codec.DecodeRangeProof(tooShortBytes, &parsedProof) - require.ErrorIs(err, io.ErrUnexpectedEOF) - - proof := RangeProof{ - StartProof: nil, - EndProof: nil, - KeyValues: nil, - } - - proofBytes, err := Codec.EncodeRangeProof(Version, &proof) - require.NoError(err) - - // Remove key-values length (0) from end - proofBytes = proofBytes[:len(proofBytes)-minVarIntLen] - proofBytesBuf := bytes.NewBuffer(proofBytes) - // Put key-value length (-1) at end - require.NoError(Codec.(*codecImpl).encodeInt(proofBytesBuf, -1)) - - _, err = Codec.DecodeRangeProof(proofBytesBuf.Bytes(), &parsedProof) - require.ErrorIs(err, errNegativeNumKeyValues) -} - func TestCodec_DecodeDBNode(t *testing.T) { require := require.New(t) diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index b81d38421f9e..90c5fd4985f9 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -14,6 +14,8 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/trace" "github.com/ava-labs/avalanchego/utils/hashing" + + syncpb "github.com/ava-labs/avalanchego/proto/pb/sync" ) const verificationCacheSize = 2_000 @@ -34,6 +36,12 @@ var ( ErrProofNodeNotForKey = errors.New("the provided node has a key that is not a prefix of the specified key") ErrProofValueDoesntMatch = errors.New("the provided value does not match the proof node for the provided key's value") ErrProofNodeHasUnincludedValue = errors.New("the provided proof has a value for a key within the range that is not present in the provided key/values") + ErrInvalidMaybe = errors.New("maybe is nothing but has value") + ErrInvalidChildIndex = fmt.Errorf("child index must be less than %d", NodeBranchFactor) + ErrNilProofNode = errors.New("proof node is nil") + ErrNilValueOrHash = errors.New("proof node's valueOrHash field is nil") + ErrNilSerializedPath = errors.New("serialized path is nil") + ErrNilRangeProof = errors.New("range proof is nil") ) type ProofNode struct { @@ -45,6 +53,68 @@ type ProofNode struct { Children map[byte]ids.ID } +// Assumes [node.Key.KeyPath.NibbleLength] <= math.MaxUint64. +func (node *ProofNode) ToProto() *syncpb.ProofNode { + pbNode := &syncpb.ProofNode{ + Key: &syncpb.SerializedPath{ + NibbleLength: uint64(node.KeyPath.NibbleLength), + Value: node.KeyPath.Value, + }, + Children: make(map[uint32][]byte, len(node.Children)), + } + + for childIndex, childID := range node.Children { + childID := childID + pbNode.Children[uint32(childIndex)] = childID[:] + } + + if node.ValueOrHash.hasValue { + pbNode.ValueOrHash = &syncpb.MaybeBytes{ + Value: node.ValueOrHash.value, + } + } else { + pbNode.ValueOrHash = &syncpb.MaybeBytes{ + IsNothing: true, + } + } + + return pbNode +} + +func (node *ProofNode) UnmarshalProto(pbNode *syncpb.ProofNode) error { + switch { + case pbNode == nil: + return ErrNilProofNode + case pbNode.ValueOrHash == nil: + return ErrNilValueOrHash + case pbNode.ValueOrHash.IsNothing && len(pbNode.ValueOrHash.Value) != 0: + return ErrInvalidMaybe + case pbNode.Key == nil: + return ErrNilSerializedPath + } + + node.KeyPath.NibbleLength = int(pbNode.Key.NibbleLength) + node.KeyPath.Value = pbNode.Key.Value + + node.Children = make(map[byte]ids.ID, len(pbNode.Children)) + for childIndex, childIDBytes := range pbNode.Children { + if childIndex >= NodeBranchFactor { + return ErrInvalidChildIndex + } + childID, err := ids.ToID(childIDBytes) + if err != nil { + return err + } + node.Children[byte(childIndex)] = childID + } + + if !pbNode.ValueOrHash.IsNothing { + node.ValueOrHash = Some(pbNode.ValueOrHash.Value) + } + + return nil +} + // An inclusion/exclustion proof of a key. type Proof struct { // Nodes in the proof path from root --> target key @@ -249,6 +319,62 @@ func (proof *RangeProof) Verify( return nil } +func (proof *RangeProof) ToProto() *syncpb.RangeProof { + startProof := make([]*syncpb.ProofNode, len(proof.StartProof)) + for i, node := range proof.StartProof { + startProof[i] = node.ToProto() + } + + endProof := make([]*syncpb.ProofNode, len(proof.EndProof)) + for i, node := range proof.EndProof { + endProof[i] = node.ToProto() + } + + keyValues := make([]*syncpb.KeyValue, len(proof.KeyValues)) + for i, kv := range proof.KeyValues { + keyValues[i] = &syncpb.KeyValue{ + Key: kv.Key, + Value: kv.Value, + } + } + + return &syncpb.RangeProof{ + Start: startProof, + End: endProof, + KeyValues: keyValues, + } +} + +func (proof *RangeProof) UnmarshalProto(pbProof *syncpb.RangeProof) error { + if pbProof == nil { + return ErrNilRangeProof + } + + proof.StartProof = make([]ProofNode, len(pbProof.Start)) + for i, protoNode := range pbProof.Start { + if err := proof.StartProof[i].UnmarshalProto(protoNode); err != nil { + return err + } + } + + proof.EndProof = make([]ProofNode, len(pbProof.End)) + for i, protoNode := range pbProof.End { + if err := proof.EndProof[i].UnmarshalProto(protoNode); err != nil { + return err + } + } + + proof.KeyValues = make([]KeyValue, len(pbProof.KeyValues)) + for i, kv := range pbProof.KeyValues { + proof.KeyValues[i] = KeyValue{ + Key: kv.Key, + Value: kv.Value, + } + } + + return nil +} + // Verify that all non-intermediate nodes in [proof] which have keys // in [[start], [end]] have the value given for that key in [keysValues]. func verifyAllRangeProofKeyValuesPresent(proof []ProofNode, start, end path, keysValues map[path][]byte) error { diff --git a/x/merkledb/proof_test.go b/x/merkledb/proof_test.go index c65b4940cd82..61f1071d19f5 100644 --- a/x/merkledb/proof_test.go +++ b/x/merkledb/proof_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "io" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -14,6 +15,8 @@ import ( "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/hashing" + + syncpb "github.com/ava-labs/avalanchego/proto/pb/sync" ) func getBasicDB() (*merkleDB, error) { @@ -683,85 +686,6 @@ func Test_RangeProof_EmptyValues(t *testing.T) { )) } -func Test_RangeProof_Marshal_Nil(t *testing.T) { - db, err := getBasicDB() - require.NoError(t, err) - writeBasicBatch(t, db) - - val, err := db.Get([]byte{1}) - require.NoError(t, err) - require.Equal(t, []byte{1}, val) - - proof, err := db.GetRangeProof(context.Background(), []byte("key1"), []byte("key35"), 10) - require.NoError(t, err) - require.NotNil(t, proof) - - proofBytes, err := Codec.EncodeRangeProof(Version, proof) - require.NoError(t, err) - - parsedProof := &RangeProof{} - _, err = Codec.DecodeRangeProof(proofBytes, parsedProof) - require.NoError(t, err) - - verifyPath(t, proof.StartProof, parsedProof.StartProof) - verifyPath(t, proof.EndProof, parsedProof.EndProof) - - for index, kv := range proof.KeyValues { - require.True(t, bytes.Equal(kv.Key, parsedProof.KeyValues[index].Key)) - require.True(t, bytes.Equal(kv.Value, parsedProof.KeyValues[index].Value)) - } -} - -func Test_RangeProof_Marshal(t *testing.T) { - db, err := getBasicDB() - require.NoError(t, err) - - writeBasicBatch(t, db) - - val, err := db.Get([]byte{1}) - require.NoError(t, err) - require.Equal(t, []byte{1}, val) - - proof, err := db.GetRangeProof(context.Background(), nil, nil, 10) - require.NoError(t, err) - require.NotNil(t, proof) - - proofBytes, err := Codec.EncodeRangeProof(Version, proof) - require.NoError(t, err) - - parsedProof := &RangeProof{} - _, err = Codec.DecodeRangeProof(proofBytes, parsedProof) - require.NoError(t, err) - - verifyPath(t, proof.StartProof, parsedProof.StartProof) - verifyPath(t, proof.EndProof, parsedProof.EndProof) - - for index, state := range proof.KeyValues { - require.True(t, bytes.Equal(state.Key, parsedProof.KeyValues[index].Key)) - require.True(t, bytes.Equal(state.Value, parsedProof.KeyValues[index].Value)) - } -} - -func Test_RangeProof_Marshal_Errors(t *testing.T) { - db, err := getBasicDB() - require.NoError(t, err) - writeBasicBatch(t, db) - - proof, err := db.GetRangeProof(context.Background(), nil, nil, 10) - require.NoError(t, err) - require.NotNil(t, proof) - - proofBytes, err := Codec.EncodeRangeProof(Version, proof) - require.NoError(t, err) - - for i := 1; i < len(proofBytes); i++ { - broken := proofBytes[:i] - parsedProof := &RangeProof{} - _, err = Codec.DecodeRangeProof(broken, parsedProof) - require.ErrorIs(t, err, io.ErrUnexpectedEOF) - } -} - func Test_ChangeProof_Marshal(t *testing.T) { db, err := getBasicDB() require.NoError(t, err) @@ -1462,3 +1386,177 @@ func TestVerifyProofPath(t *testing.T) { }) } } + +func TestProofNodeUnmarshalProtoInvalidMaybe(t *testing.T) { + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + node := newRandomProofNode(rand) + protoNode := node.ToProto() + + // It's invalid to have a value and be nothing. + protoNode.ValueOrHash = &syncpb.MaybeBytes{ + Value: []byte{1, 2, 3}, + IsNothing: true, + } + + var unmarshaledNode ProofNode + err := unmarshaledNode.UnmarshalProto(protoNode) + require.ErrorIs(t, err, ErrInvalidMaybe) +} + +func TestProofNodeUnmarshalProtoInvalidChildBytes(t *testing.T) { + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + node := newRandomProofNode(rand) + protoNode := node.ToProto() + + protoNode.Children = map[uint32][]byte{ + 1: []byte("not 32 bytes"), + } + + var unmarshaledNode ProofNode + err := unmarshaledNode.UnmarshalProto(protoNode) + require.ErrorIs(t, err, hashing.ErrInvalidHashLen) +} + +func TestProofNodeUnmarshalProtoInvalidChildIndex(t *testing.T) { + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + node := newRandomProofNode(rand) + protoNode := node.ToProto() + + childID := ids.GenerateTestID() + protoNode.Children[NodeBranchFactor] = childID[:] + + var unmarshaledNode ProofNode + err := unmarshaledNode.UnmarshalProto(protoNode) + require.ErrorIs(t, err, ErrInvalidChildIndex) +} + +func TestProofNodeUnmarshalProtoMissingFields(t *testing.T) { + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + type test struct { + name string + nodeFunc func() *syncpb.ProofNode + expectedErr error + } + + tests := []test{ + { + name: "nil node", + nodeFunc: func() *syncpb.ProofNode { + return nil + }, + expectedErr: ErrNilProofNode, + }, + { + name: "nil ValueOrHash", + nodeFunc: func() *syncpb.ProofNode { + node := newRandomProofNode(rand) + protoNode := node.ToProto() + protoNode.ValueOrHash = nil + return protoNode + }, + expectedErr: ErrNilValueOrHash, + }, + { + name: "nil key", + nodeFunc: func() *syncpb.ProofNode { + node := newRandomProofNode(rand) + protoNode := node.ToProto() + protoNode.Key = nil + return protoNode + }, + expectedErr: ErrNilSerializedPath, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var node ProofNode + err := node.UnmarshalProto(tt.nodeFunc()) + require.ErrorIs(t, err, tt.expectedErr) + }) + } +} + +func TestProofNodeProtoMarshalUnmarshal(t *testing.T) { + require := require.New(t) + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + for i := 0; i < 1_000; i++ { + node := newRandomProofNode(rand) + + // Marshal and unmarshal it. + // Assert the unmarshaled one is the same as the original. + protoNode := node.ToProto() + var unmarshaledNode ProofNode + require.NoError(unmarshaledNode.UnmarshalProto(protoNode)) + require.Equal(node, unmarshaledNode) + + // Marshaling again should yield same result. + protoUnmarshaledNode := unmarshaledNode.ToProto() + require.Equal(protoNode, protoUnmarshaledNode) + } +} + +func TestRangeProofUnmarshalProtoNil(t *testing.T) { + var proof RangeProof + err := proof.UnmarshalProto(nil) + require.ErrorIs(t, err, ErrNilRangeProof) +} + +func TestRangeProofProtoMarshalUnmarshal(t *testing.T) { + require := require.New(t) + rand := rand.New(rand.NewSource(1337)) // #nosec G404 + + for i := 0; i < 500; i++ { + // Make a random range proof. + startProofLen := rand.Intn(32) + startProof := make([]ProofNode, startProofLen) + for i := 0; i < startProofLen; i++ { + startProof[i] = newRandomProofNode(rand) + } + + endProofLen := rand.Intn(32) + endProof := make([]ProofNode, endProofLen) + for i := 0; i < endProofLen; i++ { + endProof[i] = newRandomProofNode(rand) + } + + numKeyValues := rand.Intn(128) + keyValues := make([]KeyValue, numKeyValues) + for i := 0; i < numKeyValues; i++ { + keyLen := rand.Intn(32) + key := make([]byte, keyLen) + _, _ = rand.Read(key) + + valueLen := rand.Intn(32) + value := make([]byte, valueLen) + _, _ = rand.Read(value) + + keyValues[i] = KeyValue{ + Key: key, + Value: value, + } + } + + proof := RangeProof{ + StartProof: startProof, + EndProof: endProof, + KeyValues: keyValues, + } + + // Marshal and unmarshal it. + // Assert the unmarshaled one is the same as the original. + var unmarshaledProof RangeProof + protoProof := proof.ToProto() + require.NoError(unmarshaledProof.UnmarshalProto(protoProof)) + require.Equal(proof, unmarshaledProof) + + // Marshaling again should yield same result. + protoUnmarshaledProof := unmarshaledProof.ToProto() + require.Equal(protoProof, protoUnmarshaledProof) + } +} diff --git a/x/sync/client.go b/x/sync/client.go index 3a0bb6d447a8..0d4b2fdfdbf4 100644 --- a/x/sync/client.go +++ b/x/sync/client.go @@ -127,8 +127,13 @@ func (c *client) GetRangeProof(ctx context.Context, req *syncpb.RangeProofReques return nil, fmt.Errorf("%w: (%d) > %d)", errTooManyBytes, len(responseBytes), req.BytesLimit) } - rangeProof := &merkledb.RangeProof{} - if _, err := merkledb.Codec.DecodeRangeProof(responseBytes, rangeProof); err != nil { + var rangeProofProto syncpb.RangeProof + if err := proto.Unmarshal(responseBytes, &rangeProofProto); err != nil { + return nil, err + } + + var rangeProof merkledb.RangeProof + if err := rangeProof.UnmarshalProto(&rangeProofProto); err != nil { return nil, err } @@ -150,7 +155,7 @@ func (c *client) GetRangeProof(ctx context.Context, req *syncpb.RangeProofReques ); err != nil { return nil, fmt.Errorf("%s due to %w", errInvalidRangeProof, err) } - return rangeProof, nil + return &rangeProof, nil } reqBytes, err := proto.Marshal(&syncpb.Request{ diff --git a/x/sync/client_test.go b/x/sync/client_test.go index c85da41635d9..a10005f0cea9 100644 --- a/x/sync/client_test.go +++ b/x/sync/client_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/engine/common" @@ -87,17 +89,19 @@ func sendRangeRequest( ).DoAndReturn( func(_ context.Context, _ ids.NodeID, requestID uint32, responseBytes []byte) error { // deserialize the response so we can modify it if needed. - response := &merkledb.RangeProof{} - _, err := merkledb.Codec.DecodeRangeProof(responseBytes, response) - require.NoError(err) + var responseProto syncpb.RangeProof + require.NoError(proto.Unmarshal(responseBytes, &responseProto)) + + var response merkledb.RangeProof + require.NoError(response.UnmarshalProto(&responseProto)) // modify if needed if modifyResponse != nil { - modifyResponse(response) + modifyResponse(&response) } // reserialize the response and pass it to the client to complete the handling. - responseBytes, err = merkledb.Codec.EncodeRangeProof(merkledb.Version, response) + responseBytes, err := proto.Marshal(response.ToProto()) require.NoError(err) require.NoError(networkClient.AppResponse(context.Background(), serverNodeID, requestID, responseBytes)) return nil @@ -270,7 +274,7 @@ func TestGetRangeProof(t *testing.T) { if test.expectedResponseLen > 0 { require.Len(proof.KeyValues, test.expectedResponseLen) } - bytes, err := merkledb.Codec.EncodeRangeProof(merkledb.Version, proof) + bytes, err := proto.Marshal(proof.ToProto()) require.NoError(err) require.Less(len(bytes), int(test.request.BytesLimit)) }) diff --git a/x/sync/network_server.go b/x/sync/network_server.go index 5244bba625ee..0f6679bb0a88 100644 --- a/x/sync/network_server.go +++ b/x/sync/network_server.go @@ -267,10 +267,11 @@ func (s *NetworkServer) HandleRangeProofRequest( return err } - proofBytes, err := merkledb.Codec.EncodeRangeProof(merkledb.Version, rangeProof) + proofBytes, err := proto.Marshal(rangeProof.ToProto()) if err != nil { return err } + if len(proofBytes) < bytesLimit { return s.appSender.SendAppResponse(ctx, nodeID, requestID, proofBytes) } diff --git a/x/sync/network_server_test.go b/x/sync/network_server_test.go index b76bcd70f7fa..00ab062ca689 100644 --- a/x/sync/network_server_test.go +++ b/x/sync/network_server_test.go @@ -10,6 +10,8 @@ import ( "github.com/golang/mock/gomock" + "google.golang.org/protobuf/proto" + "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" @@ -95,7 +97,7 @@ func Test_Server_GetRangeProof(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() sender := common.NewMockSender(ctrl) - var proofResult *merkledb.RangeProof + var proof *merkledb.RangeProof sender.EXPECT().SendAppResponse( gomock.Any(), // ctx gomock.Any(), // nodeID @@ -105,31 +107,32 @@ func Test_Server_GetRangeProof(t *testing.T) { func(_ context.Context, _ ids.NodeID, requestID uint32, responseBytes []byte) error { // grab a copy of the proof so we can inspect it later if !test.proofNil { - var err error - proofResult = &merkledb.RangeProof{} - _, err = merkledb.Codec.DecodeRangeProof(responseBytes, proofResult) - require.NoError(err) + var proofProto syncpb.RangeProof + require.NoError(proto.Unmarshal(responseBytes, &proofProto)) + + var p merkledb.RangeProof + require.NoError(p.UnmarshalProto(&proofProto)) + proof = &p } return nil }, ).AnyTimes() handler := NewNetworkServer(sender, smallTrieDB, logging.NoLog{}) err := handler.HandleRangeProofRequest(context.Background(), test.nodeID, 0, test.request) + require.ErrorIs(err, test.expectedErr) if test.expectedErr != nil { - require.ErrorIs(err, test.expectedErr) return } - require.NoError(err) if test.proofNil { - require.Nil(proofResult) + require.Nil(proof) return } - require.NotNil(proofResult) + require.NotNil(proof) if test.expectedResponseLen > 0 { - require.LessOrEqual(len(proofResult.KeyValues), test.expectedResponseLen) + require.LessOrEqual(len(proof.KeyValues), test.expectedResponseLen) } - bytes, err := merkledb.Codec.EncodeRangeProof(merkledb.Version, proofResult) + bytes, err := proto.Marshal(proof.ToProto()) require.NoError(err) require.LessOrEqual(len(bytes), int(test.request.BytesLimit)) if test.expectedMaxResponseBytes > 0 { @@ -262,11 +265,10 @@ func Test_Server_GetChangeProof(t *testing.T) { ).AnyTimes() handler := NewNetworkServer(sender, trieDB, logging.NoLog{}) err := handler.HandleChangeProofRequest(context.Background(), test.nodeID, 0, test.request) + require.ErrorIs(err, test.expectedErr) if test.expectedErr != nil { - require.ErrorIs(err, test.expectedErr) return } - require.NoError(err) if test.proofNil { require.Nil(proofResult) return