Skip to content

Commit

Permalink
Merge pull request #3052 from onflow/khalil/1762-check-codec-sentinel…
Browse files Browse the repository at this point in the history
…s-topicval

Khalil/1762 check codec sentinels in topic validator
  • Loading branch information
kc1116 authored Aug 30, 2022
2 parents 24440c7 + 3db3060 commit ac1c5c7
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 33 deletions.
2 changes: 1 addition & 1 deletion network/p2p/dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestPubSubWithDHTDiscovery(t *testing.T) {
require.NoError(t, err)

for _, n := range nodes {
s, err := n.Subscribe(topic, codec, unittest.AllowAllPeerFilter())
s, err := n.Subscribe(topic, codec, unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

go func(s *pubsub.Subscription, nodeID peer.ID) {
Expand Down
6 changes: 4 additions & 2 deletions network/p2p/libp2pNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/rs/zerolog"

"github.com/onflow/flow-go/network/slashing"

flownet "github.com/onflow/flow-go/network"
"github.com/onflow/flow-go/network/channels"
"github.com/onflow/flow-go/network/p2p/unicast"
Expand Down Expand Up @@ -174,7 +176,7 @@ func (n *Node) ListPeers(topic string) []peer.ID {
// Subscribe subscribes the node to the given topic and returns the subscription
// Currently only one subscriber is allowed per topic.
// NOTE: A node will receive its own published messages.
func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter PeerFilter, validators ...validator.PubSubMessageValidator) (*pubsub.Subscription, error) {
func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter PeerFilter, slashingViolationsConsumer slashing.ViolationsConsumer, validators ...validator.PubSubMessageValidator) (*pubsub.Subscription, error) {
n.Lock()
defer n.Unlock()

Expand All @@ -183,7 +185,7 @@ func (n *Node) Subscribe(topic channels.Topic, codec flownet.Codec, peerFilter P
tp, found := n.topics[topic]
var err error
if !found {
topicValidator := flowpubsub.TopicValidator(n.logger, codec, peerFilter, validators...)
topicValidator := flowpubsub.TopicValidator(n.logger, codec, slashingViolationsConsumer, peerFilter, validators...)
if err := n.pubSub.RegisterTopicValidator(
topic.String(), topicValidator, pubsub.WithValidatorInline(true),
); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion network/p2p/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ func (m *Middleware) Subscribe(channel channels.Channel) error {
peerFilter = m.isProtocolParticipant()
}

s, err := m.libP2PNode.Subscribe(topic, m.codec, peerFilter, validators...)
s, err := m.libP2PNode.Subscribe(topic, m.codec, peerFilter, m.slashingViolationsConsumer, validators...)
if err != nil {
return fmt.Errorf("could not subscribe to topic (%s): %w", topic, err)
}
Expand Down
6 changes: 3 additions & 3 deletions network/p2p/sporking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ func TestOneToKCrosstalkPrevention(t *testing.T) {
topicBeforeSpork := channels.TopicFromChannel(channels.TestNetworkChannel, previousSporkId)

// both nodes are initially on the same spork and subscribed to the same topic
_, err = node1.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = node1.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)
sub2, err := node2.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub2, err := node2.Subscribe(topicBeforeSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

// add node 2 as a peer of node 1
Expand All @@ -189,7 +189,7 @@ func TestOneToKCrosstalkPrevention(t *testing.T) {
// and keeping node2 subscribed to topic 'topicBeforeSpork'
err = node1.UnSubscribe(topicBeforeSpork)
require.NoError(t, err)
_, err = node1.Subscribe(topicAfterSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = node1.Subscribe(topicAfterSpork, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

// assert that node 1 can no longer send a message to node 2 via PubSub
Expand Down
12 changes: 6 additions & 6 deletions network/p2p/subscription_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ func TestFilterSubscribe(t *testing.T) {

badTopic := channels.TopicFromChannel(channels.SyncCommittee, sporkId)

sub1, err := node1.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub1, err := node1.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

sub2, err := node2.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub2, err := node2.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

unstakedSub, err := unstakedNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
unstakedSub, err := unstakedNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

require.Eventually(t, func() bool {
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestCanSubscribe(t *testing.T) {
}()

goodTopic := channels.TopicFromChannel(channels.ProvideCollections, sporkId)
_, err := collectionNode.Subscribe(goodTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err := collectionNode.Subscribe(goodTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)

var badTopic channels.Topic
Expand All @@ -126,11 +126,11 @@ func TestCanSubscribe(t *testing.T) {
break
}
}
_, err = collectionNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = collectionNode.Subscribe(badTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.Error(t, err)

clusterTopic := channels.TopicFromChannel(channels.SyncCluster(flow.Emulator), sporkId)
_, err = collectionNode.Subscribe(clusterTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = collectionNode.Subscribe(clusterTopic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), unittest.NetworkSlashingViolationsConsumer(unittest.Logger()))
require.NoError(t, err)
}

Expand Down
36 changes: 21 additions & 15 deletions network/p2p/topic_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ func TestTopicValidator_Unstaked(t *testing.T) {
// sn1 <-> sn2
require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host())))

slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
// sn1 will subscribe with is staked callback that should force the TopicValidator to drop the message received from sn2
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), isStaked)
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), isStaked, slashingViolationsConsumer)
require.NoError(t, err)

// sn2 will subscribe with an unauthenticated callback to allow it to send the unauthenticated message
_, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down Expand Up @@ -122,10 +123,11 @@ func TestTopicValidator_PublicChannel(t *testing.T) {
// sn1 <-> sn2
require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host())))

slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
// sn1 & sn2 will subscribe with unauthenticated callback to allow it to send and receive unauthenticated messages
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down Expand Up @@ -197,12 +199,13 @@ func TestAuthorizedSenderValidator_Unauthorized(t *testing.T) {
require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host())))
require.NoError(t, an1.AddPeer(context.TODO(), *host.InfoFromHost(sn1.Host())))

slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
// sn1 and sn2 subscribe to the topic with the topic validator
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down Expand Up @@ -300,10 +303,11 @@ func TestAuthorizedSenderValidator_InvalidMsg(t *testing.T) {
// sn1 <-> sn2
require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host())))

slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
// sn1 subscribe to the topic with the topic validator, while sn2 will subscribe without the topic validator to allow sn2 to publish unauthorized messages
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
_, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
_, err = sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down Expand Up @@ -373,12 +377,13 @@ func TestAuthorizedSenderValidator_Ejected(t *testing.T) {
require.NoError(t, sn1.AddPeer(context.TODO(), *host.InfoFromHost(sn2.Host())))
require.NoError(t, an1.AddPeer(context.TODO(), *host.InfoFromHost(sn1.Host())))

slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
// sn1 subscribe to the topic with the topic validator, while sn2 will subscribe without the topic validator to allow sn2 to publish unauthorized messages
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub1, err := sn1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub2, err := sn2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)
sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter())
sub3, err := an1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down Expand Up @@ -464,11 +469,12 @@ func TestAuthorizedSenderValidator_ClusterChannel(t *testing.T) {
require.NoError(t, ln1.AddPeer(context.TODO(), *host.InfoFromHost(ln2.Host())))
require.NoError(t, ln3.AddPeer(context.TODO(), *host.InfoFromHost(ln1.Host())))

sub1, err := ln1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
slashingViolationsConsumer := unittest.NetworkSlashingViolationsConsumer(logger)
sub1, err := ln1.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
sub2, err := ln2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub2, err := ln2.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)
sub3, err := ln3.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), authorizedSenderValidator)
sub3, err := ln3.Subscribe(topic, unittest.NetworkCodec(), unittest.AllowAllPeerFilter(), slashingViolationsConsumer, authorizedSenderValidator)
require.NoError(t, err)

// assert that the nodes are connected as expected
Expand Down
38 changes: 33 additions & 5 deletions network/validator/pubsub/topic_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/rs/zerolog"

"github.com/onflow/flow-go/network/channels"
"github.com/onflow/flow-go/network/codec"
"github.com/onflow/flow-go/network/slashing"

"github.com/onflow/flow-go/network"
"github.com/onflow/flow-go/network/message"
"github.com/onflow/flow-go/network/validator"
Expand Down Expand Up @@ -63,7 +67,7 @@ type TopicValidatorData struct {

// TopicValidator is the topic validator that is registered with libP2P whenever a flow libP2P node subscribes to a topic.
// The TopicValidator will decode and perform validation on the raw pubsub message.
func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(peer.ID) error, validators ...validator.PubSubMessageValidator) pubsub.ValidatorEx {
func TopicValidator(log zerolog.Logger, c network.Codec, slashingViolationsConsumer slashing.ViolationsConsumer, peerFilter func(peer.ID) error, validators ...validator.PubSubMessageValidator) pubsub.ValidatorEx {
log = log.With().
Str("component", "libp2p_node_topic_validator").
Logger()
Expand Down Expand Up @@ -93,10 +97,24 @@ func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(pee
}

// Convert message payload to a known message type
decodedMsgPayload, err := codec.Decode(msg.Payload)
if err != nil {
log.Warn().
Err(fmt.Errorf("could not decode message: %w", err)).
decodedMsgPayload, err := c.Decode(msg.Payload)
switch {
case err == nil:
break
case codec.IsErrUnknownMsgCode(err):
// slash peer if message contains unknown message code byte
slashingViolationsConsumer.OnUnknownMsgTypeError(violation(from, msg, err))
return pubsub.ValidationReject
case codec.IsErrMsgUnmarshal(err):
// slash if peer sent a message that could not be marshalled into the message type denoted by the message code byte
slashingViolationsConsumer.OnInvalidMsgError(violation(from, msg, err))
return pubsub.ValidationReject
default:
// unexpected error condition. this indicates there's a bug
// don't crash as a result of external inputs since that creates a DoS vector.
log.
Error().
Err(fmt.Errorf("unexpected error while decoding message: %w", err)).
Str("peer_id", from.String()).
Hex("sender", msg.OriginID).
Msg("rejecting message")
Expand All @@ -122,3 +140,13 @@ func TopicValidator(log zerolog.Logger, codec network.Codec, peerFilter func(pee
return result
}
}

func violation(pid peer.ID, msg message.Message, err error) *slashing.Violation {
return &slashing.Violation{
PeerID: pid.String(),
MsgType: msg.Type,
Channel: channels.Channel(msg.ChannelID),
IsUnicast: false,
Err: err,
}
}
8 changes: 8 additions & 0 deletions utils/unittest/unittest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (
"time"

"github.com/dgraph-io/badger/v2"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/onflow/flow-go/network/slashing"

"github.com/onflow/flow-go/model/flow"
"github.com/onflow/flow-go/module"
"github.com/onflow/flow-go/module/util"
Expand Down Expand Up @@ -427,3 +430,8 @@ func CrashTestWithExpectedStatus(
outStr := string(outBytes)
require.Contains(t, outStr, expectedErrorMsg)
}

// NetworkSlashingViolationsConsumer returns a slashing violations consumer for network middleware
func NetworkSlashingViolationsConsumer(logger zerolog.Logger) slashing.ViolationsConsumer {
return slashing.NewSlashingViolationsConsumer(logger)
}

0 comments on commit ac1c5c7

Please sign in to comment.