diff --git a/network/p2p/dht_test.go b/network/p2p/dht_test.go index 55e53088d4d..e4b7525a921 100644 --- a/network/p2p/dht_test.go +++ b/network/p2p/dht_test.go @@ -148,7 +148,7 @@ func TestPubSubWithDHTDiscovery(t *testing.T) { require.NoError(t, err) for _, n := range nodes { - s, err := n.Subscribe(topic, codec, unittest.AllowAllPeerFilter()) + s, err := n.Subscribe(topic, codec, unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) go func(s *pubsub.Subscription, nodeID peer.ID) { diff --git a/network/p2p/libp2pNode.go b/network/p2p/libp2pNode.go index 729dcc03beb..8c65cfc274c 100644 --- a/network/p2p/libp2pNode.go +++ b/network/p2p/libp2pNode.go @@ -18,6 +18,8 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/rs/zerolog" + "github.com/onflow/flow-go/network/slashing" + flownet "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/p2p/unicast" @@ -174,7 +176,7 @@ func (n *Node) ListPeers(topic string) []peer.ID { // Subscribe subscribes the node to the given topic and returns the subscription // Currently only one subscriber is allowed per topic. // NOTE: A node will receive its own published messages. -func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter PeerFilter, validators ...validator.PubSubMessageValidator) (*pubsub.Subscription, error) { +func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter PeerFilter, slashingViolationsConsumer slashing.ViolationsConsumer, validators ...validator.PubSubMessageValidator) (*pubsub.Subscription, error) { n.Lock() defer n.Unlock() @@ -183,7 +185,7 @@ func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter P tp, found := n.topics[topic] var err error if !found { - topicValidator := flowpubsub.TopicValidator(n.logger, codec, peerFilter, validators...) + topicValidator := flowpubsub.TopicValidator(n.logger, codec, slashingViolationsConsumer, peerFilter, validators...) if err := n.pubSub.RegisterTopicValidator( topic.String(), topicValidator, pubsub.WithValidatorInline(true), ); err != nil { diff --git a/network/p2p/middleware.go b/network/p2p/middleware.go index 0df33b8db84..9de101ecd60 100644 --- a/network/p2p/middleware.go +++ b/network/p2p/middleware.go @@ -556,7 +556,7 @@ func (m *Middleware) Subscribe(channel channels.Channel) error { peerFilter = m.isProtocolParticipant() } - s, err := m.libP2PNode.Subscribe(topic, m.codec, peerFilter, validators...) + s, err := m.libP2PNode.Subscribe(topic, m.codec, peerFilter, m.slashingViolationsConsumer, validators...) if err != nil { return fmt.Errorf("could not subscribe to topic (%s): %w", topic, err) } diff --git a/network/p2p/sporking_test.go b/network/p2p/sporking_test.go index a789ed55346..b355491f3fe 100644 --- a/network/p2p/sporking_test.go +++ b/network/p2p/sporking_test.go @@ -163,9 +163,9 @@ func TestOneToKCrosstalkPrevention(t *testing.T) { topicBeforeSpork := channels.TopicFromChannel(channels.TestNetworkChannel, previousSporkId) // both nodes are initially on the same spork and subscribed to the same topic - _, err = node1.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = node1.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) - sub2, err := node2.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub2, err := node2.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) // add node 2 as a peer of node 1 @@ -189,7 +189,7 @@ func TestOneToKCrosstalkPrevention(t *testing.T) { // and keeping node2 subscribed to topic 'topicBeforeSpork' err = node1.UnSubscribe(topicBeforeSpork) require.NoError(t, err) - _, err = node1.Subscribe(topicAfterSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = node1.Subscribe(topicAfterSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) // assert that node 1 can no longer send a message to node 2 via PubSub diff --git a/network/p2p/subscription_filter_test.go b/network/p2p/subscription_filter_test.go index 6fa33b60589..12c75d10e66 100644 --- a/network/p2p/subscription_filter_test.go +++ b/network/p2p/subscription_filter_test.go @@ -41,13 +41,13 @@ func TestFilterSubscribe(t *testing.T) { badTopic := channels.TopicFromChannel(channels.SyncCommittee, sporkId) - sub1, err := node1.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub1, err := node1.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) - sub2, err := node2.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub2, err := node2.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) - unstakedSub, err := unstakedNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + unstakedSub, err := unstakedNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) require.Eventually(t, func() bool { @@ -112,7 +112,7 @@ func TestCanSubscribe(t *testing.T) { }() goodTopic := channels.TopicFromChannel(channels.ProvideCollections, sporkId) - _, err := collectionNode.Subscribe(goodTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err := collectionNode.Subscribe(goodTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) var badTopic channels.Topic @@ -126,11 +126,11 @@ func TestCanSubscribe(t *testing.T) { break } } - _, err = collectionNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = collectionNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.Error(t, err) clusterTopic := channels.TopicFromChannel(channels.SyncCluster(flow.Emulator), sporkId) - _, err = collectionNode.Subscribe(clusterTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = collectionNode.Subscribe(clusterTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger())) require.NoError(t, err) } diff --git a/network/p2p/topic_validator_test.go b/network/p2p/topic_validator_test.go index 806b8ab463f..b193e1c9203 100644 --- a/network/p2p/topic_validator_test.go +++ b/network/p2p/topic_validator_test.go @@ -64,12 +64,13 @@ func TestTopicValidator_Unstaked(t *testing.T) { // sn1 <-> sn2 require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host()))) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) // sn1 will subscribe with is staked callback that should force the TopicValidator to drop the message received from sn2 - sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), isStaked) + sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), isStaked, slashingViolationsConsumer) require.NoError(t, err) // sn2 will subscribe with an unauthenticated callback to allow it to send the unauthenticated message - _, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) // assert that the nodes are connected as expected @@ -122,10 +123,11 @@ func TestTopicValidator_PublicChannel(t *testing.T) { // sn1 <-> sn2 require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host()))) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) // sn1 & sn2 will subscribe with unauthenticated callback to allow it to send and receive unauthenticated messages - sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) - sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) // assert that the nodes are connected as expected @@ -197,12 +199,13 @@ func TestAuthorizedSenderValidator_Unauthorized(t *testing.T) { require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host()))) require.NoError(t, an1.AddPeer(context.TODO(), *host.InfoFromHost(sn1.Host()))) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) // sn1 and sn2 subscribe to the topic with the topic validator - sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) // assert that the nodes are connected as expected @@ -300,10 +303,11 @@ func TestAuthorizedSenderValidator_InvalidMsg(t *testing.T) { // sn1 <-> sn2 require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host()))) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) // sn1 subscribe to the topic with the topic validator, while sn2 will subscribe without the topic validator to allow sn2 to publish unauthorized messages - sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - _, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + _, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) // assert that the nodes are connected as expected @@ -373,12 +377,13 @@ func TestAuthorizedSenderValidator_Ejected(t *testing.T) { require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host()))) require.NoError(t, an1.AddPeer(context.TODO(), *host.InfoFromHost(sn1.Host()))) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) // sn1 subscribe to the topic with the topic validator, while sn2 will subscribe without the topic validator to allow sn2 to publish unauthorized messages - sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) - sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter()) + sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer) require.NoError(t, err) // assert that the nodes are connected as expected @@ -464,11 +469,12 @@ func TestAuthorizedSenderValidator_ClusterChannel(t *testing.T) { require.NoError(t, ln1.AddPeer(context.TODO(), *host.InfoFromHost(ln2.Host()))) require.NoError(t, ln3.AddPeer(context.TODO(), *host.InfoFromHost(ln1.Host()))) - sub1, err := ln1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger) + sub1, err := ln1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - sub2, err := ln2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub2, err := ln2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) - sub3, err := ln3.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator) + sub3, err := ln3.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator) require.NoError(t, err) // assert that the nodes are connected as expected diff --git a/network/validator/pubsub/topic_validator.go b/network/validator/pubsub/topic_validator.go index 5406698f2a8..2241d65e933 100644 --- a/network/validator/pubsub/topic_validator.go +++ b/network/validator/pubsub/topic_validator.go @@ -9,6 +9,10 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/rs/zerolog" + "github.com/onflow/flow-go/network/channels" + "github.com/onflow/flow-go/network/codec" + "github.com/onflow/flow-go/network/slashing" + "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/message" "github.com/onflow/flow-go/network/validator" @@ -63,7 +67,7 @@ type TopicValidatorData struct { // TopicValidator is the topic validator that is registered with libP2P whenever a flow libP2P node subscribes to a topic. // The TopicValidator will decode and perform validation on the raw pubsub message. -func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(peer.ID) error, validators ...validator.PubSubMessageValidator) pubsub.ValidatorEx { +func TopicValidator(log zerolog.Logger, c network.Codec, slashingViolationsConsumer slashing.ViolationsConsumer, peerFilter func(peer.ID) error, validators ...validator.PubSubMessageValidator) pubsub.ValidatorEx { log = log.With(). Str("component", "libp2p_node_topic_validator"). Logger() @@ -93,10 +97,24 @@ func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(pee } // Convert message payload to a known message type - decodedMsgPayload, err := codec.Decode(msg.Payload) - if err != nil { - log.Warn(). - Err(fmt.Errorf("could not decode message: %w", err)). + decodedMsgPayload, err := c.Decode(msg.Payload) + switch { + case err == nil: + break + case codec.IsErrUnknownMsgCode(err): + // slash peer if message contains unknown message code byte + slashingViolationsConsumer.OnUnknownMsgTypeError(violation(from, msg, err)) + return pubsub.ValidationReject + case codec.IsErrMsgUnmarshal(err): + // slash if peer sent a message that could not be marshalled into the message type denoted by the message code byte + slashingViolationsConsumer.OnInvalidMsgError(violation(from, msg, err)) + return pubsub.ValidationReject + default: + // unexpected error condition. this indicates there's a bug + // don't crash as a result of external inputs since that creates a DoS vector. + log. + Error(). + Err(fmt.Errorf("unexpected error while decoding message: %w", err)). Str("peer_id", from.String()). Hex("sender", msg.OriginID). Msg("rejecting message") @@ -122,3 +140,13 @@ func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(pee return result } } + +func violation(pid peer.ID, msg message.Message, err error) *slashing.Violation { + return &slashing.Violation{ + PeerID: pid.String(), + MsgType: msg.Type, + Channel: channels.Channel(msg.ChannelID), + IsUnicast: false, + Err: err, + } +} diff --git a/utils/unittest/unittest.go b/utils/unittest/unittest.go index a268cb9d934..694ce462e70 100644 --- a/utils/unittest/unittest.go +++ b/utils/unittest/unittest.go @@ -13,9 +13,12 @@ import ( "time" "github.com/dgraph-io/badger/v2" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/flow-go/network/slashing" + "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/util" @@ -427,3 +430,8 @@ func CrashTestWithExpectedStatus( outStr := string(outBytes) require.Contains(t, outStr, expectedErrorMsg) } + +// NetworkSlashingViolationsConsumer returns a slashing violations consumer for network middleware +func NetworkSlashingViolationsConsumer(logger zerolog.Logger) slashing.ViolationsConsumer { + return slashing.NewSlashingViolationsConsumer(logger) +}