@@ -17,14 +17,13 @@ import (
1717 "github.com/ava-labs/avalanchego/ids"
1818 "github.com/ava-labs/avalanchego/network/p2p/mocks"
1919 "github.com/ava-labs/avalanchego/snow/engine/common"
20+ snowvalidators "github.com/ava-labs/avalanchego/snow/validators"
2021 "github.com/ava-labs/avalanchego/utils/logging"
2122 "github.com/ava-labs/avalanchego/utils/math"
2223 "github.com/ava-labs/avalanchego/utils/set"
2324 "github.com/ava-labs/avalanchego/version"
2425)
2526
26- var _ NodeSampler = (* testNodeSampler )(nil )
27-
2827func TestAppRequestResponse (t * testing.T ) {
2928 handlerID := uint64 (0x0 )
3029 request := []byte ("request" )
@@ -451,7 +450,7 @@ func TestPeersSample(t *testing.T) {
451450 sampleable .Union (tt .connected )
452451 sampleable .Difference (tt .disconnected )
453452
454- sampled := network .Sample (context .Background (), tt .limit )
453+ sampled := network .peers . Sample (context .Background (), tt .limit )
455454 require .Len (sampled , math .Min (tt .limit , len (sampleable )))
456455 require .Subset (sampleable , sampled )
457456 })
@@ -503,43 +502,93 @@ func TestAppRequestAnyNodeSelection(t *testing.T) {
503502}
504503
505504func TestNodeSamplerClientOption (t * testing.T ) {
506- require := require .New (t )
507-
508- nodeID := ids .GenerateTestNodeID ()
509- sent := make (chan struct {})
505+ nodeID0 := ids .GenerateTestNodeID ()
506+ nodeID1 := ids .GenerateTestNodeID ()
510507
511- sender := & common.SenderTest {
512- SendAppRequestF : func (_ context.Context , nodeIDs set.Set [ids.NodeID ], _ uint32 , _ []byte ) error {
513- require .Len (nodeIDs , 1 )
514- require .Contains (nodeIDs , nodeID )
508+ tests := []struct {
509+ name string
510+ peers []ids.NodeID
511+ option func (t * testing.T , n * Network ) ClientOption
512+ expected []ids.NodeID
513+ expectedErr error
514+ }{
515+ {
516+ name : "peers" ,
517+ peers : []ids.NodeID {nodeID0 },
518+ option : func (_ * testing.T , n * Network ) ClientOption {
519+ return WithPeerSampling (n )
520+ },
521+ expected : []ids.NodeID {nodeID0 },
522+ },
523+ {
524+ name : "validator connected" ,
525+ peers : []ids.NodeID {nodeID0 , nodeID1 },
526+ option : func (t * testing.T , n * Network ) ClientOption {
527+ state := & snowvalidators.TestState {
528+ GetCurrentHeightF : func (context.Context ) (uint64 , error ) {
529+ return 0 , nil
530+ },
531+ GetValidatorSetF : func (context.Context , uint64 , ids.ID ) (map [ids.NodeID ]* snowvalidators.GetValidatorOutput , error ) {
532+ return map [ids.NodeID ]* snowvalidators.GetValidatorOutput {
533+ nodeID1 : nil ,
534+ }, nil
535+ },
536+ }
515537
516- close (sent )
517- return nil
538+ validators := NewValidators (n , ids .Empty , state , 0 )
539+ return WithValidatorSampling (validators )
540+ },
541+ expected : []ids.NodeID {nodeID1 },
518542 },
519- }
520- network := NewNetwork (logging.NoLog {}, sender , prometheus .NewRegistry (), "" )
543+ {
544+ name : "validator disconnected" ,
545+ peers : []ids.NodeID {nodeID0 },
546+ option : func (t * testing.T , n * Network ) ClientOption {
547+ state := & snowvalidators.TestState {
548+ GetCurrentHeightF : func (context.Context ) (uint64 , error ) {
549+ return 0 , nil
550+ },
551+ GetValidatorSetF : func (context.Context , uint64 , ids.ID ) (map [ids.NodeID ]* snowvalidators.GetValidatorOutput , error ) {
552+ return map [ids.NodeID ]* snowvalidators.GetValidatorOutput {
553+ nodeID1 : nil ,
554+ }, nil
555+ },
556+ }
521557
522- nodeSampler := & testNodeSampler {
523- sampleF : func (context.Context , int ) []ids.NodeID {
524- return []ids.NodeID {nodeID }
558+ validators := NewValidators (n , ids .Empty , state , 0 )
559+ return WithValidatorSampling (validators )
560+ },
561+ expectedErr : ErrNoPeers ,
525562 },
526563 }
527564
528- client , err := network .RegisterAppProtocol (0x0 , nil , WithNodeSampler (nodeSampler ))
529- require .NoError (err )
565+ for _ , tt := range tests {
566+ t .Run (tt .name , func (t * testing.T ) {
567+ require := require .New (t )
530568
531- require .NoError (client .AppRequestAny (context .Background (), []byte ("request" ), nil ))
532- <- sent
533- }
569+ done := make (chan struct {})
570+ sender := & common.SenderTest {
571+ SendAppRequestF : func (_ context.Context , nodeIDs set.Set [ids.NodeID ], _ uint32 , _ []byte ) error {
572+ require .Equal (tt .expected , nodeIDs .List ())
573+ close (done )
574+ return nil
575+ },
576+ }
577+ network := NewNetwork (logging.NoLog {}, sender , prometheus .NewRegistry (), "" )
578+ ctx := context .Background ()
579+ for _ , peer := range tt .peers {
580+ require .NoError (network .Connected (ctx , peer , nil ))
581+ }
534582
535- type testNodeSampler struct {
536- sampleF func (ctx context.Context , limit int ) []ids.NodeID
537- }
583+ client , err := network .RegisterAppProtocol (0x0 , nil , tt .option (t , network ))
584+ require .NoError (err )
538585
539- func (t * testNodeSampler ) Sample (ctx context.Context , limit int ) []ids.NodeID {
540- if t .sampleF == nil {
541- return nil
542- }
586+ if err = client .AppRequestAny (ctx , []byte ("request" ), nil ); err != nil {
587+ close (done )
588+ }
543589
544- return t .sampleF (ctx , limit )
590+ require .ErrorIs (tt .expectedErr , err )
591+ <- done
592+ })
593+ }
545594}
0 commit comments