diff --git a/peers/app_request_network.go b/peers/app_request_network.go index df4e20e0..1c0f57aa 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -7,6 +7,7 @@ package peers import ( "context" + "encoding/hex" "os" "sync" "time" @@ -230,30 +231,27 @@ func (n *appRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*Conn if err != nil { return nil, err } + // We make queries to node IDs, not unique validators as represented by a BLS pubkey, so we need this map to track // responses from nodes and populate the signatureMap with the corresponding validator signature // This maps node IDs to the index in the canonical validator set nodeValidatorIndexMap := make(map[ids.NodeID]int) + nodeIDs := set.NewSet[ids.NodeID](len(nodeValidatorIndexMap)) for i, vdr := range validatorSet { for _, node := range vdr.NodeIDs { nodeValidatorIndexMap[node] = i + nodeIDs.Add(node) } } // Manually connect to all peers in the validator set // If new peers are connected, AppRequests may fail while the handshake is in progress. // In that case, AppRequests to those nodes will be retried in the next iteration of the retry loop. - nodeIDs := set.NewSet[ids.NodeID](len(nodeValidatorIndexMap)) - for node := range nodeValidatorIndexMap { - nodeIDs.Add(node) - } connectedNodes := n.ConnectPeers(nodeIDs) - // Check if we've connected to a stake threshold of nodes - connectedWeight := uint64(0) - for node := range connectedNodes { - connectedWeight += validatorSet[nodeValidatorIndexMap[node]].Weight - } + // Calculate the total weight of connected validators. + connectedWeight := calculateConnectedWeight(validatorSet, nodeValidatorIndexMap, connectedNodes) + return &ConnectedCanonicalValidators{ ConnectedWeight: connectedWeight, TotalValidatorWeight: totalValidatorWeight, @@ -292,3 +290,24 @@ func (n *appRequestNetwork) setInfoAPICallLatencyMS(latency float64) { func (n *appRequestNetwork) setPChainAPICallLatencyMS(latency float64) { n.metrics.pChainAPICallLatencyMS.Observe(latency) } + +// Non-receiver util functions + +func calculateConnectedWeight( + validatorSet []*warp.Validator, + nodeValidatorIndexMap map[ids.NodeID]int, + connectedNodes set.Set[ids.NodeID], +) uint64 { + connectedBLSPubKeys := set.NewSet[string](len(validatorSet)) + connectedWeight := uint64(0) + for node := range connectedNodes { + vdr := validatorSet[nodeValidatorIndexMap[node]] + blsPubKey := hex.EncodeToString(vdr.PublicKeyBytes) + if connectedBLSPubKeys.Contains(blsPubKey) { + continue + } + connectedWeight += vdr.Weight + connectedBLSPubKeys.Add(blsPubKey) + } + return connectedWeight +} diff --git a/peers/app_request_network_test.go b/peers/app_request_network_test.go new file mode 100644 index 00000000..acd690d4 --- /dev/null +++ b/peers/app_request_network_test.go @@ -0,0 +1,66 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package peers + +import ( + "testing" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" + "github.com/stretchr/testify/require" +) + +func TestCalculateConnectedWeight(t *testing.T) { + vdr1 := makeValidator(t, 10, 1) + vdr2 := makeValidator(t, 20, 1) + vdr3 := makeValidator(t, 30, 2) + vdrs := []*warp.Validator{&vdr1, &vdr2, &vdr3} + nodeValidatorIndexMap := map[ids.NodeID]int{ + vdr1.NodeIDs[0]: 0, + vdr2.NodeIDs[0]: 1, + vdr3.NodeIDs[0]: 2, + vdr3.NodeIDs[1]: 2, + } + var connectedNodes set.Set[ids.NodeID] + connectedNodes.Add(vdr1.NodeIDs[0]) + connectedNodes.Add(vdr2.NodeIDs[0]) + + // vdr1 and vdr2 are connected, so their weight should be added + require.Equal(t, 2, connectedNodes.Len()) + connectedWeight := calculateConnectedWeight(vdrs, nodeValidatorIndexMap, connectedNodes) + require.Equal(t, uint64(30), connectedWeight) + + // Add one of the vdr3's nodeIDs to the connected nodes + // and confirm that it adds vdr3's weight + connectedNodes.Add(vdr3.NodeIDs[0]) + require.Equal(t, 3, connectedNodes.Len()) + connectedWeight2 := calculateConnectedWeight(vdrs, nodeValidatorIndexMap, connectedNodes) + require.Equal(t, uint64(60), connectedWeight2) + + // Add another of vdr3's nodeIDs to the connected nodes + // and confirm that it's weight isn't double counted + connectedNodes.Add(vdr3.NodeIDs[1]) + require.Equal(t, 4, connectedNodes.Len()) + connectedWeight3 := calculateConnectedWeight(vdrs, nodeValidatorIndexMap, connectedNodes) + require.Equal(t, uint64(60), connectedWeight3) +} + +func makeValidator(t *testing.T, weight uint64, numNodeIDs int) warp.Validator { + sk, err := bls.NewSecretKey() + require.NoError(t, err) + pk := bls.PublicFromSecretKey(sk) + + nodeIDs := make([]ids.NodeID, numNodeIDs) + for i := 0; i < numNodeIDs; i++ { + nodeIDs[i] = ids.GenerateTestNodeID() + } + return warp.Validator{ + PublicKey: pk, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pk), + Weight: weight, + NodeIDs: nodeIDs, + } +}