diff --git a/Makefile b/Makefile index 14625ddf649..09a2b1b9456 100644 --- a/Makefile +++ b/Makefile @@ -167,7 +167,7 @@ generate-mocks: install-mock-generators rm -rf ./fvm/environment/mock mockery --name '.*' --dir=fvm/environment --case=underscore --output="./fvm/environment/mock" --outpkg="mock" mockery --name '.*' --dir=ledger --case=underscore --output="./ledger/mock" --outpkg="mock" - mockery --name 'ViolationsConsumer' --dir=network/slashing --case=underscore --output="./network/mocknetwork" --outpkg="mocknetwork" + mockery --name 'ViolationsConsumer' --dir=network --case=underscore --output="./network/mocknetwork" --outpkg="mocknetwork" mockery --name '.*' --dir=network/p2p/ --case=underscore --output="./network/p2p/mock" --outpkg="mockp2p" mockery --name '.*' --dir=network/alsp --case=underscore --output="./network/alsp/mock" --outpkg="mockalsp" mockery --name 'Vertex' --dir="./module/forest" --case=underscore --output="./module/forest/mock" --outpkg="mock" diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index bf7a52047b4..390cad33d43 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -77,7 +77,6 @@ import ( "github.com/onflow/flow-go/network/p2p/translator" "github.com/onflow/flow-go/network/p2p/unicast/protocols" relaynet "github.com/onflow/flow-go/network/relay" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/topology" "github.com/onflow/flow-go/network/validator" "github.com/onflow/flow-go/state/protocol" @@ -1261,15 +1260,14 @@ func (builder *FlowAccessNodeBuilder) initMiddleware(nodeID flow.Identifier, ) network.Middleware { logger := builder.Logger.With().Bool("staked", false).Logger() mw := middleware.NewMiddleware(&middleware.Config{ - Logger: logger, - Libp2pNode: libp2pNode, - FlowId: nodeID, - BitSwapMetrics: builder.Metrics.Bitswap, - RootBlockID: builder.SporkID, - UnicastMessageTimeout: middleware.DefaultUnicastTimeout, - IdTranslator: builder.IDTranslator, - Codec: builder.CodecFactory(), - SlashingViolationsConsumer: slashing.NewSlashingViolationsConsumer(logger, networkMetrics), + Logger: logger, + Libp2pNode: libp2pNode, + FlowId: nodeID, + BitSwapMetrics: builder.Metrics.Bitswap, + RootBlockID: builder.SporkID, + UnicastMessageTimeout: middleware.DefaultUnicastTimeout, + IdTranslator: builder.IDTranslator, + Codec: builder.CodecFactory(), }, middleware.WithMessageValidators(validators...), // use default identifier provider ) diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index be518249714..6f825421278 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -62,7 +62,6 @@ import ( "github.com/onflow/flow-go/network/p2p/translator" "github.com/onflow/flow-go/network/p2p/unicast/protocols" "github.com/onflow/flow-go/network/p2p/utils" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/validator" stateprotocol "github.com/onflow/flow-go/state/protocol" badgerState "github.com/onflow/flow-go/state/protocol/badger" @@ -905,17 +904,15 @@ func (builder *ObserverServiceBuilder) initMiddleware(nodeID flow.Identifier, libp2pNode p2p.LibP2PNode, validators ...network.MessageValidator, ) network.Middleware { - slashingViolationsConsumer := slashing.NewSlashingViolationsConsumer(builder.Logger, builder.Metrics.Network) mw := middleware.NewMiddleware(&middleware.Config{ - Logger: builder.Logger, - Libp2pNode: libp2pNode, - FlowId: nodeID, - BitSwapMetrics: builder.Metrics.Bitswap, - RootBlockID: builder.SporkID, - UnicastMessageTimeout: middleware.DefaultUnicastTimeout, - IdTranslator: builder.IDTranslator, - Codec: builder.CodecFactory(), - SlashingViolationsConsumer: slashingViolationsConsumer, + Logger: builder.Logger, + Libp2pNode: libp2pNode, + FlowId: nodeID, + BitSwapMetrics: builder.Metrics.Bitswap, + RootBlockID: builder.SporkID, + UnicastMessageTimeout: middleware.DefaultUnicastTimeout, + IdTranslator: builder.IDTranslator, + Codec: builder.CodecFactory(), }, middleware.WithMessageValidators(validators...), // use default identifier provider ) diff --git a/cmd/scaffold.go b/cmd/scaffold.go index 32114adb25e..cc332aac604 100644 --- a/cmd/scaffold.go +++ b/cmd/scaffold.go @@ -60,7 +60,6 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast/protocols" "github.com/onflow/flow-go/network/p2p/unicast/ratelimit" "github.com/onflow/flow-go/network/p2p/utils/ratelimiter" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/topology" "github.com/onflow/flow-go/state/protocol" badgerState "github.com/onflow/flow-go/state/protocol/badger" @@ -437,17 +436,15 @@ func (fnb *FlowNodeBuilder) InitFlowNetworkWithConduitFactory( mwOpts = append(mwOpts, middleware.WithPeerManagerFilters(peerManagerFilters)) } - slashingViolationsConsumer := slashing.NewSlashingViolationsConsumer(fnb.Logger, fnb.Metrics.Network) mw := middleware.NewMiddleware(&middleware.Config{ - Logger: fnb.Logger, - Libp2pNode: fnb.LibP2PNode, - FlowId: fnb.Me.NodeID(), - BitSwapMetrics: fnb.Metrics.Bitswap, - RootBlockID: fnb.SporkID, - UnicastMessageTimeout: fnb.BaseConfig.FlowConfig.NetworkConfig.UnicastMessageTimeout, - IdTranslator: fnb.IDTranslator, - Codec: fnb.CodecFactory(), - SlashingViolationsConsumer: slashingViolationsConsumer, + Logger: fnb.Logger, + Libp2pNode: fnb.LibP2PNode, + FlowId: fnb.Me.NodeID(), + BitSwapMetrics: fnb.Metrics.Bitswap, + RootBlockID: fnb.SporkID, + UnicastMessageTimeout: fnb.FlowConfig.NetworkConfig.UnicastMessageTimeout, + IdTranslator: fnb.IDTranslator, + Codec: fnb.CodecFactory(), }, mwOpts...) diff --git a/follower/follower_builder.go b/follower/follower_builder.go index 36486907d1c..b8f3eaa5bc4 100644 --- a/follower/follower_builder.go +++ b/follower/follower_builder.go @@ -56,7 +56,6 @@ import ( "github.com/onflow/flow-go/network/p2p/translator" "github.com/onflow/flow-go/network/p2p/unicast/protocols" "github.com/onflow/flow-go/network/p2p/utils" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/validator" "github.com/onflow/flow-go/state/protocol" badgerState "github.com/onflow/flow-go/state/protocol/badger" @@ -748,15 +747,14 @@ func (builder *FollowerServiceBuilder) initMiddleware(nodeID flow.Identifier, validators ...network.MessageValidator, ) network.Middleware { mw := middleware.NewMiddleware(&middleware.Config{ - Logger: builder.Logger, - Libp2pNode: libp2pNode, - FlowId: nodeID, - BitSwapMetrics: builder.Metrics.Bitswap, - RootBlockID: builder.SporkID, - UnicastMessageTimeout: middleware.DefaultUnicastTimeout, - IdTranslator: builder.IDTranslator, - Codec: builder.CodecFactory(), - SlashingViolationsConsumer: slashing.NewSlashingViolationsConsumer(builder.Logger, builder.Metrics.Network), + Logger: builder.Logger, + Libp2pNode: libp2pNode, + FlowId: nodeID, + BitSwapMetrics: builder.Metrics.Bitswap, + RootBlockID: builder.SporkID, + UnicastMessageTimeout: middleware.DefaultUnicastTimeout, + IdTranslator: builder.IDTranslator, + Codec: builder.CodecFactory(), }, middleware.WithMessageValidators(validators...), ) diff --git a/module/metrics.go b/module/metrics.go index 338f87c1ecc..2b889b98c44 100644 --- a/module/metrics.go +++ b/module/metrics.go @@ -42,6 +42,10 @@ type NetworkSecurityMetrics interface { // OnRateLimitedPeer tracks the number of rate limited unicast messages seen on the network. OnRateLimitedPeer(pid peer.ID, role, msgType, topic, reason string) + + // OnViolationReportSkipped tracks the number of slashing violations consumer violations that were not + // reported for misbehavior when the identity of the sender not known. + OnViolationReportSkipped() } // GossipSubRouterMetrics encapsulates the metrics collectors for GossipSubRouter module of the networking layer. @@ -182,6 +186,8 @@ type NetworkInboundQueueMetrics interface { type NetworkCoreMetrics interface { NetworkInboundQueueMetrics AlspMetrics + NetworkSecurityMetrics + // OutboundMessageSent collects metrics related to a message sent by the node. OutboundMessageSent(sizeBytes int, topic string, protocol string, messageType string) // InboundMessageReceived collects metrics related to a message received by the node. @@ -223,7 +229,6 @@ type AlspMetrics interface { // NetworkMetrics is the blanket abstraction that encapsulates the metrics collectors for the networking layer. type NetworkMetrics interface { LibP2PMetrics - NetworkSecurityMetrics NetworkCoreMetrics } diff --git a/module/metrics/namespaces.go b/module/metrics/namespaces.go index 31995538992..6fd0f2db82f 100644 --- a/module/metrics/namespaces.go +++ b/module/metrics/namespaces.go @@ -28,6 +28,7 @@ const ( subsystemAuth = "authorization" subsystemRateLimiting = "ratelimit" subsystemAlsp = "alsp" + subsystemSecurity = "security" ) // Storage subsystems represent the various components of the storage layer. diff --git a/module/metrics/network.go b/module/metrics/network.go index f064ca10f6e..311dbba9f15 100644 --- a/module/metrics/network.go +++ b/module/metrics/network.go @@ -45,9 +45,10 @@ type NetworkCollector struct { dnsLookupRequestDroppedCount prometheus.Counter routingTableSize prometheus.Gauge - // authorization, rate limiting metrics + // security metrics unAuthorizedMessagesCount *prometheus.CounterVec rateLimitedUnicastMessagesCount *prometheus.CounterVec + violationReportSkippedCount prometheus.Counter prefix string } @@ -245,6 +246,15 @@ func NewNetworkCollector(logger zerolog.Logger, opts ...NetworkCollectorOpt) *Ne }, []string{LabelNodeRole, LabelMessage, LabelChannel, LabelRateLimitReason}, ) + nc.violationReportSkippedCount = promauto.NewCounter( + prometheus.CounterOpts{ + Namespace: namespaceNetwork, + Subsystem: subsystemSecurity, + Name: nc.prefix + "slashing_violation_reports_skipped_count", + Help: "number of slashing violations consumer violations that were not reported for misbehavior because the identity of the sender not known", + }, + ) + return nc } @@ -358,3 +368,9 @@ func (nc *NetworkCollector) OnRateLimitedPeer(peerID peer.ID, role, msgType, top Msg("unicast peer rate limited") nc.rateLimitedUnicastMessagesCount.WithLabelValues(role, msgType, topic, reason).Inc() } + +// OnViolationReportSkipped tracks the number of slashing violations consumer violations that were not +// reported for misbehavior when the identity of the sender not known. +func (nc *NetworkCollector) OnViolationReportSkipped() { + nc.violationReportSkippedCount.Inc() +} diff --git a/module/metrics/noop.go b/module/metrics/noop.go index 2dd33d133cc..226af7f1f26 100644 --- a/module/metrics/noop.go +++ b/module/metrics/noop.go @@ -303,3 +303,4 @@ func (nc *NoopCollector) AsyncProcessingStarted(string) func (nc *NoopCollector) AsyncProcessingFinished(string, time.Duration) {} func (nc *NoopCollector) OnMisbehaviorReported(string, string) {} +func (nc *NoopCollector) OnViolationReportSkipped() {} diff --git a/module/mock/network_core_metrics.go b/module/mock/network_core_metrics.go index 63c849fbf27..d78c3355449 100644 --- a/module/mock/network_core_metrics.go +++ b/module/mock/network_core_metrics.go @@ -5,6 +5,8 @@ package mock import ( mock "github.com/stretchr/testify/mock" + peer "github.com/libp2p/go-libp2p/core/peer" + time "time" ) @@ -48,6 +50,21 @@ func (_m *NetworkCoreMetrics) OnMisbehaviorReported(channel string, misbehaviorT _m.Called(channel, misbehaviorType) } +// OnRateLimitedPeer provides a mock function with given fields: pid, role, msgType, topic, reason +func (_m *NetworkCoreMetrics) OnRateLimitedPeer(pid peer.ID, role string, msgType string, topic string, reason string) { + _m.Called(pid, role, msgType, topic, reason) +} + +// OnUnauthorizedMessage provides a mock function with given fields: role, msgType, topic, offense +func (_m *NetworkCoreMetrics) OnUnauthorizedMessage(role string, msgType string, topic string, offense string) { + _m.Called(role, msgType, topic, offense) +} + +// OnViolationReportSkipped provides a mock function with given fields: +func (_m *NetworkCoreMetrics) OnViolationReportSkipped() { + _m.Called() +} + // OutboundMessageSent provides a mock function with given fields: sizeBytes, topic, protocol, messageType func (_m *NetworkCoreMetrics) OutboundMessageSent(sizeBytes int, topic string, protocol string, messageType string) { _m.Called(sizeBytes, topic, protocol, messageType) diff --git a/module/mock/network_metrics.go b/module/mock/network_metrics.go index 851565d5724..2909f7d677f 100644 --- a/module/mock/network_metrics.go +++ b/module/mock/network_metrics.go @@ -300,6 +300,11 @@ func (_m *NetworkMetrics) OnUnauthorizedMessage(role string, msgType string, top _m.Called(role, msgType, topic, offense) } +// OnViolationReportSkipped provides a mock function with given fields: +func (_m *NetworkMetrics) OnViolationReportSkipped() { + _m.Called() +} + // OutboundConnections provides a mock function with given fields: connectionCount func (_m *NetworkMetrics) OutboundConnections(connectionCount uint) { _m.Called(connectionCount) diff --git a/module/mock/network_security_metrics.go b/module/mock/network_security_metrics.go index 51d045c2a12..a48a693c0ab 100644 --- a/module/mock/network_security_metrics.go +++ b/module/mock/network_security_metrics.go @@ -23,6 +23,11 @@ func (_m *NetworkSecurityMetrics) OnUnauthorizedMessage(role string, msgType str _m.Called(role, msgType, topic, offense) } +// OnViolationReportSkipped provides a mock function with given fields: +func (_m *NetworkSecurityMetrics) OnViolationReportSkipped() { + _m.Called() +} + type mockConstructorTestingTNewNetworkSecurityMetrics interface { mock.TestingT Cleanup(func()) diff --git a/network/alsp/manager/manager_test.go b/network/alsp/manager/manager_test.go index 03f012bb206..9b067e1ede8 100644 --- a/network/alsp/manager/manager_test.go +++ b/network/alsp/manager/manager_test.go @@ -54,7 +54,7 @@ func TestNetworkPassesReportedMisbehavior(t *testing.T) { misbehaviorReportManger.On("Ready").Return(readyDoneChan).Once() misbehaviorReportManger.On("Done").Return(readyDoneChan).Once() ids, nodes, _ := testutils.LibP2PNodeForMiddlewareFixture(t, 1) - mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t)) + mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t), mocknetwork.NewViolationsConsumer(t)) networkCfg := testutils.NetworkConfigFixture(t, *ids[0], ids, mws[0]) net, err := p2p.NewNetwork(networkCfg, p2p.WithAlspManager(misbehaviorReportManger)) @@ -111,7 +111,7 @@ func TestHandleReportedMisbehavior_Cache_Integration(t *testing.T) { }), } ids, nodes, _ := testutils.LibP2PNodeForMiddlewareFixture(t, 1) - mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t)) + mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t), mocknetwork.NewViolationsConsumer(t)) networkCfg := testutils.NetworkConfigFixture(t, *ids[0], ids, mws[0], p2p.WithAlspConfig(cfg)) net, err := p2p.NewNetwork(networkCfg) require.NoError(t, err) @@ -206,7 +206,7 @@ func TestHandleReportedMisbehavior_And_DisallowListing_Integration(t *testing.T) ids, nodes, _ := testutils.LibP2PNodeForMiddlewareFixture(t, 3, p2ptest.WithPeerManagerEnabled(p2ptest.PeerManagerConfigFixture(), nil)) - mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t)) + mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t), mocknetwork.NewViolationsConsumer(t)) networkCfg := testutils.NetworkConfigFixture(t, *ids[0], ids, mws[0], p2p.WithAlspConfig(cfg)) victimNetwork, err := p2p.NewNetwork(networkCfg) require.NoError(t, err) @@ -279,7 +279,7 @@ func TestMisbehaviorReportMetrics(t *testing.T) { cfg.AlspMetrics = alspMetrics ids, nodes, _ := testutils.LibP2PNodeForMiddlewareFixture(t, 1) - mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t)) + mws, _ := testutils.MiddlewareFixtures(t, ids, nodes, testutils.MiddlewareConfigFixture(t), mocknetwork.NewViolationsConsumer(t)) networkCfg := testutils.NetworkConfigFixture(t, *ids[0], ids, mws[0], p2p.WithAlspConfig(cfg)) net, err := p2p.NewNetwork(networkCfg) require.NoError(t, err) diff --git a/network/alsp/misbehavior.go b/network/alsp/misbehavior.go index 326b113cd8b..af4921cd06a 100644 --- a/network/alsp/misbehavior.go +++ b/network/alsp/misbehavior.go @@ -24,6 +24,25 @@ const ( // the message is not valid according to the engine's validation logic. The decision to consider a message invalid // is up to the engine. InvalidMessage network.Misbehavior = "misbehavior-invalid-message" + + // UnExpectedValidationError is a misbehavior that is reported when a validation error is encountered during message validation before the message + // is processed by an engine. + UnExpectedValidationError network.Misbehavior = "unexpected-validation-error" + + // UnknownMsgType is a misbehavior that is reported when a message of unknown type is received from a peer. + UnknownMsgType network.Misbehavior = "unknown-message-type" + + // SenderEjected is a misbehavior that is reported when a message is received from an ejected peer. + SenderEjected network.Misbehavior = "sender-ejected" + + // UnauthorizedUnicastOnChannel is a misbehavior that is reported when a message not authorized to be sent via unicast is received via unicast. + UnauthorizedUnicastOnChannel network.Misbehavior = "unauthorized-unicast-on-channel" + + // UnAuthorizedSender is a misbehavior that is reported when a message is sent by an unauthorized role. + UnAuthorizedSender network.Misbehavior = "unauthorized-sender" + + // UnauthorizedPublishOnChannel is a misbehavior that is reported when a message not authorized to be sent via pubsub is received via pubsub. + UnauthorizedPublishOnChannel network.Misbehavior = "unauthorized-pubsub-on-channel" ) func AllMisbehaviorTypes() []network.Misbehavior { @@ -33,5 +52,11 @@ func AllMisbehaviorTypes() []network.Misbehavior { RedundantMessage, UnsolicitedMessage, InvalidMessage, + UnExpectedValidationError, + UnknownMsgType, + SenderEjected, + UnauthorizedUnicastOnChannel, + UnauthorizedPublishOnChannel, + UnAuthorizedSender, } } diff --git a/network/internal/testutils/testUtil.go b/network/internal/testutils/testUtil.go index 2457c7c6af7..95707ee9e3c 100644 --- a/network/internal/testutils/testUtil.go +++ b/network/internal/testutils/testUtil.go @@ -28,7 +28,6 @@ import ( netcache "github.com/onflow/flow-go/network/cache" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/codec/cbor" - "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/netconf" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/p2p/conduit" @@ -167,12 +166,11 @@ func LibP2PNodeForMiddlewareFixture(t *testing.T, n int, opts ...p2ptest.NodeFix // - a middleware config. func MiddlewareConfigFixture(t *testing.T) *middleware.Config { return &middleware.Config{ - Logger: unittest.Logger(), - BitSwapMetrics: metrics.NewNoopCollector(), - RootBlockID: sporkID, - UnicastMessageTimeout: middleware.DefaultUnicastTimeout, - Codec: unittest.NetworkCodec(), - SlashingViolationsConsumer: mocknetwork.NewViolationsConsumer(t), + Logger: unittest.Logger(), + BitSwapMetrics: metrics.NewNoopCollector(), + RootBlockID: sporkID, + UnicastMessageTimeout: middleware.DefaultUnicastTimeout, + Codec: unittest.NetworkCodec(), } } @@ -187,7 +185,7 @@ func MiddlewareConfigFixture(t *testing.T) *middleware.Config { // Returns: // - a list of middlewares - one for each identity. // - a list of UpdatableIDProvider - one for each identity. -func MiddlewareFixtures(t *testing.T, identities flow.IdentityList, libP2PNodes []p2p.LibP2PNode, cfg *middleware.Config, opts ...middleware.OptionFn) ([]network.Middleware, []*unittest.UpdatableIDProvider) { +func MiddlewareFixtures(t *testing.T, identities flow.IdentityList, libP2PNodes []p2p.LibP2PNode, cfg *middleware.Config, consumer network.ViolationsConsumer, opts ...middleware.OptionFn) ([]network.Middleware, []*unittest.UpdatableIDProvider) { require.Equal(t, len(identities), len(libP2PNodes)) mws := make([]network.Middleware, len(identities)) @@ -199,8 +197,8 @@ func MiddlewareFixtures(t *testing.T, identities flow.IdentityList, libP2PNodes cfg.FlowId = identities[i].NodeID idProviders[i] = unittest.NewUpdatableIDProvider(identities) cfg.IdTranslator = translator.NewIdentityProviderIDTranslator(idProviders[i]) - mws[i] = middleware.NewMiddleware(cfg, opts...) + mws[i].SetSlashingViolationsConsumer(consumer) } return mws, idProviders } diff --git a/network/middleware.go b/network/middleware.go index c2eeef98905..d8e14ee82c1 100644 --- a/network/middleware.go +++ b/network/middleware.go @@ -22,6 +22,9 @@ type Middleware interface { // SetOverlay sets the overlay used by the middleware. This must be called before the middleware can be Started. SetOverlay(Overlay) + // SetSlashingViolationsConsumer sets the slashing violations consumer. + SetSlashingViolationsConsumer(ViolationsConsumer) + // SendDirect sends msg on a 1-1 direct connection to the target ID. It models a guaranteed delivery asynchronous // direct one-to-one connection on the underlying network. No intermediate node on the overlay is utilized // as the router. diff --git a/network/mocknetwork/adapter.go b/network/mocknetwork/adapter.go index 364ec1027ce..2700f6eb0cc 100644 --- a/network/mocknetwork/adapter.go +++ b/network/mocknetwork/adapter.go @@ -58,9 +58,9 @@ func (_m *Adapter) PublishOnChannel(_a0 channels.Channel, _a1 interface{}, _a2 . return r0 } -// ReportMisbehaviorOnChannel provides a mock function with given fields: _a0, _a1 -func (_m *Adapter) ReportMisbehaviorOnChannel(_a0 channels.Channel, _a1 network.MisbehaviorReport) { - _m.Called(_a0, _a1) +// ReportMisbehaviorOnChannel provides a mock function with given fields: channel, report +func (_m *Adapter) ReportMisbehaviorOnChannel(channel channels.Channel, report network.MisbehaviorReport) { + _m.Called(channel, report) } // UnRegisterChannel provides a mock function with given fields: channel diff --git a/network/mocknetwork/middleware.go b/network/mocknetwork/middleware.go index 64167ce9ed8..18cdaed21b0 100644 --- a/network/mocknetwork/middleware.go +++ b/network/mocknetwork/middleware.go @@ -160,6 +160,11 @@ func (_m *Middleware) SetOverlay(_a0 network.Overlay) { _m.Called(_a0) } +// SetSlashingViolationsConsumer provides a mock function with given fields: _a0 +func (_m *Middleware) SetSlashingViolationsConsumer(_a0 network.ViolationsConsumer) { + _m.Called(_a0) +} + // Start provides a mock function with given fields: _a0 func (_m *Middleware) Start(_a0 irrecoverable.SignalerContext) { _m.Called(_a0) diff --git a/network/mocknetwork/misbehavior_report_consumer.go b/network/mocknetwork/misbehavior_report_consumer.go new file mode 100644 index 00000000000..8731a6ae8fe --- /dev/null +++ b/network/mocknetwork/misbehavior_report_consumer.go @@ -0,0 +1,35 @@ +// Code generated by mockery v2.21.4. DO NOT EDIT. + +package mocknetwork + +import ( + channels "github.com/onflow/flow-go/network/channels" + mock "github.com/stretchr/testify/mock" + + network "github.com/onflow/flow-go/network" +) + +// MisbehaviorReportConsumer is an autogenerated mock type for the MisbehaviorReportConsumer type +type MisbehaviorReportConsumer struct { + mock.Mock +} + +// ReportMisbehaviorOnChannel provides a mock function with given fields: channel, report +func (_m *MisbehaviorReportConsumer) ReportMisbehaviorOnChannel(channel channels.Channel, report network.MisbehaviorReport) { + _m.Called(channel, report) +} + +type mockConstructorTestingTNewMisbehaviorReportConsumer interface { + mock.TestingT + Cleanup(func()) +} + +// NewMisbehaviorReportConsumer creates a new instance of MisbehaviorReportConsumer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMisbehaviorReportConsumer(t mockConstructorTestingTNewMisbehaviorReportConsumer) *MisbehaviorReportConsumer { + mock := &MisbehaviorReportConsumer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/network/mocknetwork/violations_consumer.go b/network/mocknetwork/violations_consumer.go index 9c6f252b095..2af1bf2b80f 100644 --- a/network/mocknetwork/violations_consumer.go +++ b/network/mocknetwork/violations_consumer.go @@ -3,7 +3,7 @@ package mocknetwork import ( - slashing "github.com/onflow/flow-go/network/slashing" + network "github.com/onflow/flow-go/network" mock "github.com/stretchr/testify/mock" ) @@ -13,32 +13,37 @@ type ViolationsConsumer struct { } // OnInvalidMsgError provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnInvalidMsgError(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnInvalidMsgError(violation *network.Violation) { _m.Called(violation) } // OnSenderEjectedError provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnSenderEjectedError(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnSenderEjectedError(violation *network.Violation) { _m.Called(violation) } // OnUnAuthorizedSenderError provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnUnAuthorizedSenderError(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnUnAuthorizedSenderError(violation *network.Violation) { + _m.Called(violation) +} + +// OnUnauthorizedPublishOnChannel provides a mock function with given fields: violation +func (_m *ViolationsConsumer) OnUnauthorizedPublishOnChannel(violation *network.Violation) { _m.Called(violation) } // OnUnauthorizedUnicastOnChannel provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnUnauthorizedUnicastOnChannel(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnUnauthorizedUnicastOnChannel(violation *network.Violation) { _m.Called(violation) } // OnUnexpectedError provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnUnexpectedError(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnUnexpectedError(violation *network.Violation) { _m.Called(violation) } // OnUnknownMsgTypeError provides a mock function with given fields: violation -func (_m *ViolationsConsumer) OnUnknownMsgTypeError(violation *slashing.Violation) { +func (_m *ViolationsConsumer) OnUnknownMsgTypeError(violation *network.Violation) { _m.Called(violation) } diff --git a/network/network.go b/network/network.go index 703c5e627c8..4f77892b666 100644 --- a/network/network.go +++ b/network/network.go @@ -47,6 +47,7 @@ type Network interface { // Adapter is meant to be utilized by the Conduit interface to send messages to the Network layer to be // delivered to the remote targets. type Adapter interface { + MisbehaviorReportConsumer // UnicastOnChannel sends the message in a reliable way to the given recipient. UnicastOnChannel(channels.Channel, interface{}, flow.Identifier) error @@ -60,7 +61,10 @@ type Adapter interface { // UnRegisterChannel unregisters the engine for the specified channel. The engine will no longer be able to send or // receive messages from that channel. UnRegisterChannel(channel channels.Channel) error +} +// MisbehaviorReportConsumer set of funcs used to handle MisbehaviorReport disseminated from misbehavior reporters. +type MisbehaviorReportConsumer interface { // ReportMisbehaviorOnChannel reports the misbehavior of a node on sending a message to the current node that appears // valid based on the networking layer but is considered invalid by the current node based on the Flow protocol. // The misbehavior report is sent to the current node's networking layer on the given channel to be processed. @@ -69,5 +73,5 @@ type Adapter interface { // - report: The misbehavior report to be sent. // Returns: // none - ReportMisbehaviorOnChannel(channels.Channel, MisbehaviorReport) + ReportMisbehaviorOnChannel(channel channels.Channel, report MisbehaviorReport) } diff --git a/network/p2p/middleware/middleware.go b/network/p2p/middleware/middleware.go index 169737eabc5..c908e8d7f18 100644 --- a/network/p2p/middleware/middleware.go +++ b/network/p2p/middleware/middleware.go @@ -35,7 +35,6 @@ import ( "github.com/onflow/flow-go/network/p2p/unicast/protocols" "github.com/onflow/flow-go/network/p2p/unicast/ratelimit" "github.com/onflow/flow-go/network/p2p/utils" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/validator" flowpubsub "github.com/onflow/flow-go/network/validator/pubsub" _ "github.com/onflow/flow-go/utils/binstat" @@ -62,13 +61,6 @@ const ( // LargeMsgUnicastTimeout is the maximum time to wait for a unicast request to complete for large message size LargeMsgUnicastTimeout = 1000 * time.Second - - // DisallowListCacheSize is the maximum number of peers that can be disallow-listed at a time. The recommended - // size is 100 * number of staked nodes. Note that when the cache is full, there is no eviction policy and - // disallow-listing a new peer will fail. Hence, the cache size should be set to a value that is large enough - // to accommodate all the peers that can be disallow-listed at a time. Also, note that this cache is only taking - // the staked (authorized) peers. Hence, Sybil attacks are not possible. - DisallowListCacheSize = 100 * 1000 ) var ( @@ -104,7 +96,7 @@ type Middleware struct { idTranslator p2p.IDTranslator previousProtocolStatePeers []peer.AddrInfo codec network.Codec - slashingViolationsConsumer slashing.ViolationsConsumer + slashingViolationsConsumer network.ViolationsConsumer unicastRateLimiters *ratelimit.RateLimiters authorizedSenderValidator *validator.AuthorizedSenderValidator } @@ -140,15 +132,14 @@ func WithUnicastRateLimiters(rateLimiters *ratelimit.RateLimiters) OptionFn { // Config is the configuration for the middleware. type Config struct { - Logger zerolog.Logger - Libp2pNode p2p.LibP2PNode - FlowId flow.Identifier // This node's Flow ID - BitSwapMetrics module.BitswapMetrics - RootBlockID flow.Identifier - UnicastMessageTimeout time.Duration - IdTranslator p2p.IDTranslator - Codec network.Codec - SlashingViolationsConsumer slashing.ViolationsConsumer + Logger zerolog.Logger + Libp2pNode p2p.LibP2PNode + FlowId flow.Identifier // This node's Flow ID + BitSwapMetrics module.BitswapMetrics + RootBlockID flow.Identifier + UnicastMessageTimeout time.Duration + IdTranslator p2p.IDTranslator + Codec network.Codec } // Validate validates the configuration, and sets default values for any missing fields. @@ -172,16 +163,15 @@ func NewMiddleware(cfg *Config, opts ...OptionFn) *Middleware { // create the node entity and inject dependencies & config mw := &Middleware{ - log: cfg.Logger, - libP2PNode: cfg.Libp2pNode, - bitswapMetrics: cfg.BitSwapMetrics, - rootBlockID: cfg.RootBlockID, - validators: DefaultValidators(cfg.Logger, cfg.FlowId), - unicastMessageTimeout: cfg.UnicastMessageTimeout, - idTranslator: cfg.IdTranslator, - codec: cfg.Codec, - slashingViolationsConsumer: cfg.SlashingViolationsConsumer, - unicastRateLimiters: ratelimit.NoopRateLimiters(), + log: cfg.Logger, + libP2PNode: cfg.Libp2pNode, + bitswapMetrics: cfg.BitSwapMetrics, + rootBlockID: cfg.RootBlockID, + validators: DefaultValidators(cfg.Logger, cfg.FlowId), + unicastMessageTimeout: cfg.UnicastMessageTimeout, + idTranslator: cfg.IdTranslator, + codec: cfg.Codec, + unicastRateLimiters: ratelimit.NoopRateLimiters(), } for _, opt := range opts { @@ -304,6 +294,11 @@ func (m *Middleware) SetOverlay(ov network.Overlay) { m.ov = ov } +// SetSlashingViolationsConsumer sets the slashing violations consumer. +func (m *Middleware) SetSlashingViolationsConsumer(consumer network.ViolationsConsumer) { + m.slashingViolationsConsumer = consumer +} + // authorizedPeers is a peer manager callback used by the underlying libp2p node that updates who can connect to this node (as // well as who this node can connect to). // and who is not allowed to connect to this node. This function is called by the peer manager and connection gater components @@ -518,7 +513,7 @@ func (m *Middleware) handleIncomingStream(s libp2pnetwork.Stream) { // ignore messages if node does not have subscription to topic if !m.libP2PNode.HasSubscription(topic) { - violation := &slashing.Violation{ + violation := &network.Violation{ Identity: nil, PeerID: remotePeer.String(), Channel: channel, Protocol: message.ProtocolTypeUnicast, } @@ -651,7 +646,7 @@ func (m *Middleware) processUnicastStreamMessage(remotePeer peer.ID, msg *messag // we can remove this check maxSize, err := UnicastMaxMsgSizeByCode(msg.Payload) if err != nil { - m.slashingViolationsConsumer.OnUnknownMsgTypeError(&slashing.Violation{ + m.slashingViolationsConsumer.OnUnknownMsgTypeError(&network.Violation{ Identity: nil, PeerID: remotePeer.String(), MsgType: "", Channel: channel, Protocol: message.ProtocolTypeUnicast, Err: err, }) return @@ -705,14 +700,14 @@ func (m *Middleware) processAuthenticatedMessage(msg *message.Message, peerID pe switch { case codec.IsErrUnknownMsgCode(err): // slash peer if message contains unknown message code byte - violation := &slashing.Violation{ + violation := &network.Violation{ PeerID: peerID.String(), OriginID: originId, Channel: channel, Protocol: protocol, Err: err, } m.slashingViolationsConsumer.OnUnknownMsgTypeError(violation) return case codec.IsErrMsgUnmarshal(err) || codec.IsErrInvalidEncoding(err): // slash if peer sent a message that could not be marshalled into the message type denoted by the message code byte - violation := &slashing.Violation{ + violation := &network.Violation{ PeerID: peerID.String(), OriginID: originId, Channel: channel, Protocol: protocol, Err: err, } m.slashingViolationsConsumer.OnInvalidMsgError(violation) @@ -722,7 +717,7 @@ func (m *Middleware) processAuthenticatedMessage(msg *message.Message, peerID pe // don't crash as a result of external inputs since that creates a DoS vector // collect slashing data because this could potentially lead to slashing err = fmt.Errorf("unexpected error during message validation: %w", err) - violation := &slashing.Violation{ + violation := &network.Violation{ PeerID: peerID.String(), OriginID: originId, Channel: channel, Protocol: protocol, Err: err, } m.slashingViolationsConsumer.OnUnexpectedError(violation) @@ -744,7 +739,6 @@ func (m *Middleware) processAuthenticatedMessage(msg *message.Message, peerID pe // processMessage processes a message and eventually passes it to the overlay func (m *Middleware) processMessage(scope *network.IncomingMessageScope) { - logger := m.log.With(). Str("channel", scope.Channel().String()). Str("type", scope.Protocol().String()). diff --git a/network/p2p/network.go b/network/p2p/network.go index 384fad3ab59..5b9d0b4df45 100644 --- a/network/p2p/network.go +++ b/network/p2p/network.go @@ -22,6 +22,7 @@ import ( "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/message" "github.com/onflow/flow-go/network/queue" + "github.com/onflow/flow-go/network/slashing" _ "github.com/onflow/flow-go/utils/binstat" "github.com/onflow/flow-go/utils/logging" ) @@ -53,6 +54,7 @@ type Network struct { registerEngineRequests chan *registerEngineRequest registerBlobServiceRequests chan *registerBlobServiceRequest misbehaviorReportManager network.MisbehaviorReportManager + slashingViolationsConsumer network.ViolationsConsumer } var _ network.Network = &Network{} @@ -171,6 +173,8 @@ func NewNetwork(param *NetworkConfig, opts ...NetworkOption) (*Network, error) { opt(n) } + n.slashingViolationsConsumer = slashing.NewSlashingViolationsConsumer(param.Logger, param.Metrics, n) + n.mw.SetSlashingViolationsConsumer(n.slashingViolationsConsumer) n.mw.SetOverlay(n) if err := n.conduitFactory.RegisterAdapter(n); err != nil { diff --git a/network/p2p/test/topic_validator_test.go b/network/p2p/test/topic_validator_test.go index b6f0dfe7ba5..5a7e402b141 100644 --- a/network/p2p/test/topic_validator_test.go +++ b/network/p2p/test/topic_validator_test.go @@ -10,14 +10,19 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/mock" + "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/messages" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" mockmodule "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/network" + "github.com/onflow/flow-go/network/alsp" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/internal/p2pfixtures" "github.com/onflow/flow-go/network/message" + "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/p2p" p2ptest "github.com/onflow/flow-go/network/p2p/test" "github.com/onflow/flow-go/network/p2p/translator" @@ -51,12 +56,12 @@ func TestTopicValidator_Unstaked(t *testing.T) { //NOTE: identity2 is not in the ids list simulating an un-staked node ids := flow.IdentityList{&identity1} - translator, err := translator.NewFixedTableIdentityTranslator(ids) + translatorFixture, err := translator.NewFixedTableIdentityTranslator(ids) require.NoError(t, err) // peer filter used by the topic validator to check if node is staked isStaked := func(pid peer.ID) error { - fid, err := translator.GetFlowID(pid) + fid, err := translatorFixture.GetFlowID(pid) if err != nil { return fmt.Errorf("could not translate the peer_id %s to a Flow identifier: %w", pid.String(), err) } @@ -272,8 +277,7 @@ func TestAuthorizedSenderValidator_Unauthorized(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) signalerCtx := irrecoverable.NewMockSignalerContext(t, ctx) idProvider := mockmodule.NewIdentityProvider(t) - // create a hooked logger - logger, hook := unittest.HookedLogger() + logger := unittest.Logger() sporkId := unittest.IdentifierFixture() @@ -292,12 +296,22 @@ func TestAuthorizedSenderValidator_Unauthorized(t *testing.T) { ids := flow.IdentityList{&identity1, &identity2, &identity3} - translator, err := translator.NewFixedTableIdentityTranslator(ids) + translatorFixture, err := translator.NewFixedTableIdentityTranslator(ids) require.NoError(t, err) - violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector()) + violation := &network.Violation{ + Identity: &identity3, + PeerID: an1.Host().ID().String(), + OriginID: identity3.NodeID, + MsgType: "*messages.BlockProposal", + Channel: channel, + Protocol: message.ProtocolTypePubSub, + Err: message.ErrUnauthorizedRole, + } + violationsConsumer := mocknetwork.NewViolationsConsumer(t) + violationsConsumer.On("OnUnAuthorizedSenderError", violation).Once().Return(nil) getIdentity := func(pid peer.ID) (*flow.Identity, bool) { - fid, err := translator.GetFlowID(pid) + fid, err := translatorFixture.GetFlowID(pid) if err != nil { return &flow.Identity{}, false } @@ -373,9 +387,6 @@ func TestAuthorizedSenderValidator_Unauthorized(t *testing.T) { p2pfixtures.SubMustNeverReceiveAnyMessage(t, timedCtx, sub2) unittest.RequireReturnsBefore(t, wg.Wait, 5*time.Second, "could not receive message on time") - - // ensure the correct error is contained in the logged error - require.Contains(t, hook.Logs(), message.ErrUnauthorizedRole.Error()) } // TestAuthorizedSenderValidator_Authorized tests that the authorized sender validator rejects messages being sent on the wrong channel @@ -401,12 +412,16 @@ func TestAuthorizedSenderValidator_InvalidMsg(t *testing.T) { topic := channels.TopicFromChannel(channel, sporkId) ids := flow.IdentityList{&identity1, &identity2} - translator, err := translator.NewFixedTableIdentityTranslator(ids) + translatorFixture, err := translator.NewFixedTableIdentityTranslator(ids) require.NoError(t, err) - violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector()) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity2.NodeID, alsp.UnAuthorizedSender) + require.NoError(t, err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(t) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channel, expectedMisbehaviorReport).Once() + violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector(), misbehaviorReportConsumer) getIdentity := func(pid peer.ID) (*flow.Identity, bool) { - fid, err := translator.GetFlowID(pid) + fid, err := translatorFixture.GetFlowID(pid) if err != nil { return &flow.Identity{}, false } @@ -474,12 +489,16 @@ func TestAuthorizedSenderValidator_Ejected(t *testing.T) { topic := channels.TopicFromChannel(channel, sporkId) ids := flow.IdentityList{&identity1, &identity2, &identity3} - translator, err := translator.NewFixedTableIdentityTranslator(ids) + translatorFixture, err := translator.NewFixedTableIdentityTranslator(ids) require.NoError(t, err) - violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector()) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity2.NodeID, alsp.SenderEjected) + require.NoError(t, err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(t) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channel, expectedMisbehaviorReport).Once() + violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector(), misbehaviorReportConsumer) getIdentity := func(pid peer.ID) (*flow.Identity, bool) { - fid, err := translator.GetFlowID(pid) + fid, err := translatorFixture.GetFlowID(pid) if err != nil { return &flow.Identity{}, false } @@ -568,13 +587,15 @@ func TestAuthorizedSenderValidator_ClusterChannel(t *testing.T) { topic := channels.TopicFromChannel(channel, sporkId) ids := flow.IdentityList{&identity1, &identity2, &identity3} - translator, err := translator.NewFixedTableIdentityTranslator(ids) + translatorFixture, err := translator.NewFixedTableIdentityTranslator(ids) require.NoError(t, err) logger := unittest.Logger() - violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector()) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(t) + defer misbehaviorReportConsumer.AssertNotCalled(t, "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) + violationsConsumer := slashing.NewSlashingViolationsConsumer(logger, metrics.NewNoopCollector(), misbehaviorReportConsumer) getIdentity := func(pid peer.ID) (*flow.Identity, bool) { - fid, err := translator.GetFlowID(pid) + fid, err := translatorFixture.GetFlowID(pid) if err != nil { return &flow.Identity{}, false } diff --git a/network/slashing/consumer.go b/network/slashing/consumer.go index aaac28fccc5..b7f09bb12b7 100644 --- a/network/slashing/consumer.go +++ b/network/slashing/consumer.go @@ -7,35 +7,34 @@ import ( "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" + "github.com/onflow/flow-go/network" + "github.com/onflow/flow-go/network/alsp" "github.com/onflow/flow-go/utils/logging" ) const ( - unknown = "unknown" - unExpectedValidationError = "unexpected_validation_error" - unAuthorizedSenderViolation = "unauthorized_sender" - unknownMsgTypeViolation = "unknown_message_type" - invalidMsgViolation = "invalid_message" - senderEjectedViolation = "sender_ejected" - unauthorizedUnicastOnChannel = "unauthorized_unicast_on_channel" + unknown = "unknown" ) // Consumer is a struct that logs a message for any slashable offenses. // This struct will be updated in the future when slashing is implemented. type Consumer struct { - log zerolog.Logger - metrics module.NetworkSecurityMetrics + log zerolog.Logger + metrics module.NetworkSecurityMetrics + misbehaviorReportConsumer network.MisbehaviorReportConsumer } // NewSlashingViolationsConsumer returns a new Consumer. -func NewSlashingViolationsConsumer(log zerolog.Logger, metrics module.NetworkSecurityMetrics) *Consumer { +func NewSlashingViolationsConsumer(log zerolog.Logger, metrics module.NetworkSecurityMetrics, misbehaviorReportConsumer network.MisbehaviorReportConsumer) *Consumer { return &Consumer{ - log: log.With().Str("module", "network_slashing_consumer").Logger(), - metrics: metrics, + log: log.With().Str("module", "network_slashing_consumer").Logger(), + metrics: metrics, + misbehaviorReportConsumer: misbehaviorReportConsumer, } } -func (c *Consumer) logOffense(networkOffense string, violation *Violation) { +// logOffense logs the slashing violation with details. +func (c *Consumer) logOffense(misbehavior network.Misbehavior, violation *network.Violation) { // if violation fails before the message is decoded the violation.MsgType will be unknown if len(violation.MsgType) == 0 { violation.MsgType = unknown @@ -51,7 +50,7 @@ func (c *Consumer) logOffense(networkOffense string, violation *Violation) { e := c.log.Error(). Str("peer_id", violation.PeerID). - Str("networking_offense", networkOffense). + Str("misbehavior", misbehavior.String()). Str("message_type", violation.MsgType). Str("channel", violation.Channel.String()). Str("protocol", violation.Protocol.String()). @@ -62,37 +61,78 @@ func (c *Consumer) logOffense(networkOffense string, violation *Violation) { e.Msg(fmt.Sprintf("potential slashable offense: %s", violation.Err)) // capture unauthorized message count metric - c.metrics.OnUnauthorizedMessage(role, violation.MsgType, violation.Channel.String(), networkOffense) + c.metrics.OnUnauthorizedMessage(role, violation.MsgType, violation.Channel.String(), misbehavior.String()) } -// OnUnAuthorizedSenderError logs an error for unauthorized sender error. -func (c *Consumer) OnUnAuthorizedSenderError(violation *Violation) { - c.logOffense(unAuthorizedSenderViolation, violation) +// reportMisbehavior reports the slashing violation to the alsp misbehavior report manager. When violation identity +// is nil this indicates the misbehavior occurred either on a public network and the identity of the sender is unknown +// we can skip reporting the misbehavior. +// Args: +// - misbehavior: the network misbehavior. +// - violation: the slashing violation. +// Any error encountered while creating the misbehavior report is considered irrecoverable and will result in a fatal log. +func (c *Consumer) reportMisbehavior(misbehavior network.Misbehavior, violation *network.Violation) { + if violation.Identity == nil { + c.log.Debug(). + Bool(logging.KeySuspicious, true). + Str("peerID", violation.PeerID). + Msg("violation identity unknown (or public) skipping misbehavior reporting") + c.metrics.OnViolationReportSkipped() + return + } + report, err := alsp.NewMisbehaviorReport(violation.Identity.NodeID, misbehavior) + if err != nil { + // failing to create the misbehavior report is unlikely. If an error is encountered while + // creating the misbehavior report it indicates a bug and processing can not proceed. + c.log.Fatal(). + Err(err). + Str("peerID", violation.PeerID). + Msg("failed to create misbehavior report") + + } + c.misbehaviorReportConsumer.ReportMisbehaviorOnChannel(violation.Channel, report) } -// OnUnknownMsgTypeError logs an error for unknown message type error. -func (c *Consumer) OnUnknownMsgTypeError(violation *Violation) { - c.logOffense(unknownMsgTypeViolation, violation) +// OnUnAuthorizedSenderError logs an error for unauthorized sender error and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnUnAuthorizedSenderError(violation *network.Violation) { + c.logOffense(alsp.UnAuthorizedSender, violation) + c.reportMisbehavior(alsp.UnAuthorizedSender, violation) +} + +// OnUnknownMsgTypeError logs an error for unknown message type error and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnUnknownMsgTypeError(violation *network.Violation) { + c.logOffense(alsp.UnknownMsgType, violation) + c.reportMisbehavior(alsp.UnknownMsgType, violation) } // OnInvalidMsgError logs an error for messages that contained payloads that could not -// be unmarshalled into the message type denoted by message code byte. -func (c *Consumer) OnInvalidMsgError(violation *Violation) { - c.logOffense(invalidMsgViolation, violation) +// be unmarshalled into the message type denoted by message code byte and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnInvalidMsgError(violation *network.Violation) { + c.logOffense(alsp.InvalidMessage, violation) + c.reportMisbehavior(alsp.InvalidMessage, violation) +} + +// OnSenderEjectedError logs an error for sender ejected error and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnSenderEjectedError(violation *network.Violation) { + c.logOffense(alsp.SenderEjected, violation) + c.reportMisbehavior(alsp.SenderEjected, violation) } -// OnSenderEjectedError logs an error for sender ejected error. -func (c *Consumer) OnSenderEjectedError(violation *Violation) { - c.logOffense(senderEjectedViolation, violation) +// OnUnauthorizedUnicastOnChannel logs an error for messages unauthorized to be sent via unicast and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnUnauthorizedUnicastOnChannel(violation *network.Violation) { + c.logOffense(alsp.UnauthorizedUnicastOnChannel, violation) + c.reportMisbehavior(alsp.UnauthorizedUnicastOnChannel, violation) } -// OnUnauthorizedUnicastOnChannel logs an error for messages unauthorized to be sent via unicast. -func (c *Consumer) OnUnauthorizedUnicastOnChannel(violation *Violation) { - c.logOffense(unauthorizedUnicastOnChannel, violation) +// OnUnauthorizedPublishOnChannel logs an error for messages unauthorized to be sent via pubsub. +func (c *Consumer) OnUnauthorizedPublishOnChannel(violation *network.Violation) { + c.logOffense(alsp.UnauthorizedPublishOnChannel, violation) + c.reportMisbehavior(alsp.UnauthorizedPublishOnChannel, violation) } // OnUnexpectedError logs an error for unexpected errors. This indicates message validation -// has failed for an unknown reason and could potentially be n slashable offense. -func (c *Consumer) OnUnexpectedError(violation *Violation) { - c.logOffense(unExpectedValidationError, violation) +// has failed for an unknown reason and could potentially be n slashable offense and reports a misbehavior to alsp misbehavior report manager. +func (c *Consumer) OnUnexpectedError(violation *network.Violation) { + c.logOffense(alsp.UnExpectedValidationError, violation) + c.reportMisbehavior(alsp.UnExpectedValidationError, violation) } diff --git a/network/stub/network.go b/network/stub/network.go index fc93cf9b588..9e91b386922 100644 --- a/network/stub/network.go +++ b/network/stub/network.go @@ -306,6 +306,6 @@ func (n *Network) StopConDev() { close(n.qCD) } -func (n *Network) ReportMisbehaviorOnChannel(channel channels.Channel, report network.MisbehaviorReport) { +func (n *Network) ReportMisbehaviorOnChannel(_ channels.Channel, _ network.MisbehaviorReport) { // no-op for stub network. } diff --git a/network/test/blob_service_test.go b/network/test/blob_service_test.go index 40c052111d7..c0979244ad8 100644 --- a/network/test/blob_service_test.go +++ b/network/test/blob_service_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/atomic" + "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/p2p/connection" "github.com/onflow/flow-go/network/p2p/dht" p2pconfig "github.com/onflow/flow-go/network/p2p/p2pbuilder/config" @@ -89,7 +90,7 @@ func (suite *BlobServiceTestSuite) SetupTest() { ConnectionPruning: true, ConnectorFactory: connection.DefaultLibp2pBackoffConnectorFactory(), }, nil)) - mws, _ := testutils.MiddlewareFixtures(suite.T(), ids, nodes, testutils.MiddlewareConfigFixture(suite.T())) + mws, _ := testutils.MiddlewareFixtures(suite.T(), ids, nodes, testutils.MiddlewareConfigFixture(suite.T()), mocknetwork.NewViolationsConsumer(suite.T())) suite.networks = testutils.NetworksFixture(suite.T(), ids, mws) testutils.StartNodesAndNetworks(signalerCtx, suite.T(), nodes, suite.networks, 100*time.Millisecond) diff --git a/network/test/echoengine_test.go b/network/test/echoengine_test.go index eb170cbf266..55732b64d17 100644 --- a/network/test/echoengine_test.go +++ b/network/test/echoengine_test.go @@ -19,6 +19,7 @@ import ( "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/internal/testutils" + "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/utils/unittest" ) @@ -54,7 +55,7 @@ func (suite *EchoEngineTestSuite) SetupTest() { // both nodes should be of the same role to get connected on epidemic dissemination var nodes []p2p.LibP2PNode suite.ids, nodes, _ = testutils.LibP2PNodeForMiddlewareFixture(suite.T(), count) - suite.mws, _ = testutils.MiddlewareFixtures(suite.T(), suite.ids, nodes, testutils.MiddlewareConfigFixture(suite.T())) + suite.mws, _ = testutils.MiddlewareFixtures(suite.T(), suite.ids, nodes, testutils.MiddlewareConfigFixture(suite.T()), mocknetwork.NewViolationsConsumer(suite.T())) suite.nets = testutils.NetworksFixture(suite.T(), suite.ids, suite.mws) testutils.StartNodesAndNetworks(signalerCtx, suite.T(), nodes, suite.nets, 100*time.Millisecond) } diff --git a/network/test/epochtransition_test.go b/network/test/epochtransition_test.go index e471b1d8f48..4e1eeabf717 100644 --- a/network/test/epochtransition_test.go +++ b/network/test/epochtransition_test.go @@ -24,6 +24,7 @@ import ( "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/internal/testutils" + "github.com/onflow/flow-go/network/mocknetwork" mockprotocol "github.com/onflow/flow-go/state/protocol/mock" "github.com/onflow/flow-go/utils/unittest" ) @@ -181,7 +182,7 @@ func (suite *MutableIdentityTableSuite) addNodes(count int) { // create the ids, middlewares and networks ids, nodes, _ := testutils.LibP2PNodeForMiddlewareFixture(suite.T(), count) - mws, _ := testutils.MiddlewareFixtures(suite.T(), ids, nodes, testutils.MiddlewareConfigFixture(suite.T())) + mws, _ := testutils.MiddlewareFixtures(suite.T(), ids, nodes, testutils.MiddlewareConfigFixture(suite.T()), mocknetwork.NewViolationsConsumer(suite.T())) nets := testutils.NetworksFixture(suite.T(), ids, mws) suite.cancels = append(suite.cancels, cancel) diff --git a/network/test/meshengine_test.go b/network/test/meshengine_test.go index 612d7679796..55a95994d45 100644 --- a/network/test/meshengine_test.go +++ b/network/test/meshengine_test.go @@ -11,8 +11,6 @@ import ( "testing" "time" - "github.com/onflow/flow-go/network/p2p" - "github.com/ipfs/go-log" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/rs/zerolog" @@ -20,9 +18,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/onflow/flow-go/network/p2p/middleware" - "github.com/onflow/flow-go/network/p2p/p2pnode" - "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" "github.com/onflow/flow-go/model/libp2p/message" @@ -31,6 +26,10 @@ import ( "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/internal/testutils" + "github.com/onflow/flow-go/network/mocknetwork" + "github.com/onflow/flow-go/network/p2p" + "github.com/onflow/flow-go/network/p2p/middleware" + "github.com/onflow/flow-go/network/p2p/p2pnode" "github.com/onflow/flow-go/utils/unittest" ) @@ -74,7 +73,7 @@ func (suite *MeshEngineTestSuite) SetupTest() { var nodes []p2p.LibP2PNode suite.ids, nodes, obs = testutils.LibP2PNodeForMiddlewareFixture(suite.T(), count) - suite.mws, _ = testutils.MiddlewareFixtures(suite.T(), suite.ids, nodes, testutils.MiddlewareConfigFixture(suite.T())) + suite.mws, _ = testutils.MiddlewareFixtures(suite.T(), suite.ids, nodes, testutils.MiddlewareConfigFixture(suite.T()), mocknetwork.NewViolationsConsumer(suite.T())) suite.nets = testutils.NetworksFixture(suite.T(), suite.ids, suite.mws) testutils.StartNodesAndNetworks(signalerCtx, suite.T(), nodes, suite.nets, 100*time.Millisecond) diff --git a/network/test/middleware_test.go b/network/test/middleware_test.go index 8c0c1adb4f0..1b42df088aa 100644 --- a/network/test/middleware_test.go +++ b/network/test/middleware_test.go @@ -81,8 +81,9 @@ type MiddlewareTestSuite struct { logger zerolog.Logger providers []*unittest.UpdatableIDProvider - mwCancel context.CancelFunc - mwCtx irrecoverable.SignalerContext + mwCancel context.CancelFunc + mwCtx irrecoverable.SignalerContext + slashingViolationsConsumer network.ViolationsConsumer } // TestMiddlewareTestSuit runs all the test methods in this test suit @@ -106,8 +107,9 @@ func (m *MiddlewareTestSuite) SetupTest() { log: m.logger, } + m.slashingViolationsConsumer = mocknetwork.NewViolationsConsumer(m.T()) m.ids, m.nodes, obs = testutils.LibP2PNodeForMiddlewareFixture(m.T(), m.size) - m.mws, m.providers = testutils.MiddlewareFixtures(m.T(), m.ids, m.nodes, testutils.MiddlewareConfigFixture(m.T())) + m.mws, m.providers = testutils.MiddlewareFixtures(m.T(), m.ids, m.nodes, testutils.MiddlewareConfigFixture(m.T()), m.slashingViolationsConsumer) for _, observableConnMgr := range obs { observableConnMgr.Subscribe(&ob) } @@ -158,7 +160,7 @@ func (m *MiddlewareTestSuite) TestUpdateNodeAddresses() { // create a new staked identity ids, libP2PNodes, _ := testutils.LibP2PNodeForMiddlewareFixture(m.T(), 1) - mws, providers := testutils.MiddlewareFixtures(m.T(), ids, libP2PNodes, testutils.MiddlewareConfigFixture(m.T())) + mws, providers := testutils.MiddlewareFixtures(m.T(), ids, libP2PNodes, testutils.MiddlewareConfigFixture(m.T()), m.slashingViolationsConsumer) require.Len(m.T(), ids, 1) require.Len(m.T(), providers, 1) require.Len(m.T(), mws, 1) @@ -256,6 +258,7 @@ func (m *MiddlewareTestSuite) TestUnicastRateLimit_Messages() { ids, libP2PNodes, testutils.MiddlewareConfigFixture(m.T()), + m.slashingViolationsConsumer, middleware.WithUnicastRateLimiters(rateLimiters), middleware.WithPeerManagerFilters([]p2p.PeerFilter{testutils.IsRateLimitedPeerFilter(messageRateLimiter)})) @@ -409,6 +412,7 @@ func (m *MiddlewareTestSuite) TestUnicastRateLimit_Bandwidth() { ids, libP2PNodes, testutils.MiddlewareConfigFixture(m.T()), + m.slashingViolationsConsumer, middleware.WithUnicastRateLimiters(rateLimiters), middleware.WithPeerManagerFilters([]p2p.PeerFilter{testutils.IsRateLimitedPeerFilter(bandwidthRateLimiter)})) require.Len(m.T(), ids, 1) diff --git a/network/test/unicast_authorization_test.go b/network/test/unicast_authorization_test.go index f4a4171944d..197c5f4a5a2 100644 --- a/network/test/unicast_authorization_test.go +++ b/network/test/unicast_authorization_test.go @@ -24,7 +24,6 @@ import ( "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/p2p/middleware" - "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/validator" "github.com/onflow/flow-go/utils/unittest" ) @@ -73,11 +72,10 @@ func (u *UnicastAuthorizationTestSuite) TearDownTest() { } // setupMiddlewaresAndProviders will setup 2 middlewares that will be used as a sender and receiver in each suite test. -func (u *UnicastAuthorizationTestSuite) setupMiddlewaresAndProviders(slashingViolationsConsumer slashing.ViolationsConsumer) { +func (u *UnicastAuthorizationTestSuite) setupMiddlewaresAndProviders(slashingViolationsConsumer network.ViolationsConsumer) { ids, libP2PNodes, _ := testutils.LibP2PNodeForMiddlewareFixture(u.T(), 2) cfg := testutils.MiddlewareConfigFixture(u.T()) - cfg.SlashingViolationsConsumer = slashingViolationsConsumer - mws, providers := testutils.MiddlewareFixtures(u.T(), ids, libP2PNodes, cfg) + mws, providers := testutils.MiddlewareFixtures(u.T(), ids, libP2PNodes, cfg, slashingViolationsConsumer) require.Len(u.T(), ids, 2) require.Len(u.T(), providers, 2) require.Len(u.T(), mws, 2) @@ -125,7 +123,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnstakedPeer() require.NoError(u.T(), err) var nilID *flow.Identity - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: nilID, // because the peer will be unverified this identity will be nil PeerID: expectedSenderPeerID.String(), MsgType: "", // message will not be decoded before OnSenderEjectedError is logged, we won't log message type @@ -136,7 +134,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnstakedPeer() slashingViolationsConsumer.On( "OnUnAuthorizedSenderError", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -187,8 +185,9 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_EjectedPeer() { expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) require.NoError(u.T(), err) - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: u.senderID, // we expect this method to be called with the ejected identity + OriginID: u.senderID.NodeID, PeerID: expectedSenderPeerID.String(), MsgType: "", // message will not be decoded before OnSenderEjectedError is logged, we won't log message type Channel: channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID @@ -198,7 +197,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_EjectedPeer() { slashingViolationsConsumer.On( "OnSenderEjectedError", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -246,8 +245,9 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedPee expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) require.NoError(u.T(), err) - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: u.senderID, + OriginID: u.senderID.NodeID, PeerID: expectedSenderPeerID.String(), MsgType: "*message.TestMessage", Channel: channels.ConsensusCommittee, @@ -258,7 +258,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedPee slashingViolationsConsumer.On( "OnUnAuthorizedSenderError", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -309,7 +309,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnknownMsgCode( invalidMessageCode := codec.MessageCode(byte('X')) var nilID *flow.Identity - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: nilID, PeerID: expectedSenderPeerID.String(), MsgType: "", @@ -321,7 +321,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnknownMsgCode( slashingViolationsConsumer.On( "OnUnknownMsgTypeError", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -378,8 +378,9 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_WrongMsgCode() modifiedMessageCode := codec.CodeDKGMessage - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: u.senderID, + OriginID: u.senderID.NodeID, PeerID: expectedSenderPeerID.String(), MsgType: "*messages.DKGMessage", Channel: channels.TestNetworkChannel, @@ -390,7 +391,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_WrongMsgCode() slashingViolationsConsumer.On( "OnUnAuthorizedSenderError", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -503,8 +504,9 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedUni expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) require.NoError(u.T(), err) - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: u.senderID, + OriginID: u.senderID.NodeID, PeerID: expectedSenderPeerID.String(), MsgType: "*messages.BlockProposal", Channel: channels.ConsensusCommittee, @@ -515,7 +517,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedUni slashingViolationsConsumer.On( "OnUnauthorizedUnicastOnChannel", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) @@ -566,7 +568,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasNoSu expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) require.NoError(u.T(), err) - expectedViolation := &slashing.Violation{ + expectedViolation := &network.Violation{ Identity: nil, PeerID: expectedSenderPeerID.String(), MsgType: "*message.TestMessage", @@ -578,7 +580,7 @@ func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasNoSu slashingViolationsConsumer.On( "OnUnauthorizedUnicastOnChannel", expectedViolation, - ).Once().Run(func(args mockery.Arguments) { + ).Return(nil).Return(nil).Once().Run(func(args mockery.Arguments) { close(u.waitCh) }) diff --git a/network/validator/authorized_sender_validator.go b/network/validator/authorized_sender_validator.go index 0af21b45e39..6841d69a9e6 100644 --- a/network/validator/authorized_sender_validator.go +++ b/network/validator/authorized_sender_validator.go @@ -8,11 +8,11 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/codec" "github.com/onflow/flow-go/network/message" "github.com/onflow/flow-go/network/p2p" - "github.com/onflow/flow-go/network/slashing" ) var ( @@ -25,12 +25,12 @@ type GetIdentityFunc func(peer.ID) (*flow.Identity, bool) // AuthorizedSenderValidator performs message authorization validation. type AuthorizedSenderValidator struct { log zerolog.Logger - slashingViolationsConsumer slashing.ViolationsConsumer + slashingViolationsConsumer network.ViolationsConsumer getIdentity GetIdentityFunc } // NewAuthorizedSenderValidator returns a new AuthorizedSenderValidator -func NewAuthorizedSenderValidator(log zerolog.Logger, slashingViolationsConsumer slashing.ViolationsConsumer, getIdentity GetIdentityFunc) *AuthorizedSenderValidator { +func NewAuthorizedSenderValidator(log zerolog.Logger, slashingViolationsConsumer network.ViolationsConsumer, getIdentity GetIdentityFunc) *AuthorizedSenderValidator { return &AuthorizedSenderValidator{ log: log.With().Str("component", "authorized_sender_validator").Logger(), slashingViolationsConsumer: slashingViolationsConsumer, @@ -61,14 +61,14 @@ func (av *AuthorizedSenderValidator) Validate(from peer.ID, payload []byte, chan // something terrible went wrong. identity, ok := av.getIdentity(from) if !ok { - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), Channel: channel, Protocol: protocol, Err: ErrIdentityUnverified} + violation := &network.Violation{PeerID: from.String(), Channel: channel, Protocol: protocol, Err: ErrIdentityUnverified} av.slashingViolationsConsumer.OnUnAuthorizedSenderError(violation) return "", ErrIdentityUnverified } msgCode, err := codec.MessageCodeFromPayload(payload) if err != nil { - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), Channel: channel, Protocol: protocol, Err: err} + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnUnknownMsgTypeError(violation) return "", err } @@ -77,28 +77,32 @@ func (av *AuthorizedSenderValidator) Validate(from peer.ID, payload []byte, chan switch { case err == nil: return msgType, nil - case message.IsUnknownMsgTypeErr(err): - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + case message.IsUnknownMsgTypeErr(err) || codec.IsErrUnknownMsgCode(err): + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnUnknownMsgTypeError(violation) return msgType, err case errors.Is(err, message.ErrUnauthorizedMessageOnChannel) || errors.Is(err, message.ErrUnauthorizedRole): - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnUnAuthorizedSenderError(violation) return msgType, err case errors.Is(err, ErrSenderEjected): - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnSenderEjectedError(violation) return msgType, err case errors.Is(err, message.ErrUnauthorizedUnicastOnChannel): - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnUnauthorizedUnicastOnChannel(violation) return msgType, err + case errors.Is(err, message.ErrUnauthorizedPublishOnChannel): + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + av.slashingViolationsConsumer.OnUnauthorizedPublishOnChannel(violation) + return msgType, err default: // this condition should never happen and indicates there's a bug // don't crash as a result of external inputs since that creates a DoS vector // collect slashing data because this could potentially lead to slashing err = fmt.Errorf("unexpected error during message validation: %w", err) - violation := &slashing.Violation{Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} + violation := &network.Violation{OriginID: identity.NodeID, Identity: identity, PeerID: from.String(), MsgType: msgType, Channel: channel, Protocol: protocol, Err: err} av.slashingViolationsConsumer.OnUnexpectedError(violation) return msgType, err } diff --git a/network/validator/authorized_sender_validator_test.go b/network/validator/authorized_sender_validator_test.go index 966ae5ba127..8a9cd138cbb 100644 --- a/network/validator/authorized_sender_validator_test.go +++ b/network/validator/authorized_sender_validator_test.go @@ -6,16 +6,20 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/rs/zerolog" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/onflow/flow-go/model/flow" + libp2pmessage "github.com/onflow/flow-go/model/libp2p/message" "github.com/onflow/flow-go/model/messages" "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/network" + "github.com/onflow/flow-go/network/alsp" "github.com/onflow/flow-go/network/channels" "github.com/onflow/flow-go/network/codec" "github.com/onflow/flow-go/network/message" + "github.com/onflow/flow-go/network/mocknetwork" "github.com/onflow/flow-go/network/p2p" "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/utils/unittest" @@ -43,7 +47,7 @@ type TestAuthorizedSenderValidatorSuite struct { unauthorizedUnicastOnChannel []TestCase authorizedUnicastOnChannel []TestCase log zerolog.Logger - slashingViolationsConsumer slashing.ViolationsConsumer + slashingViolationsConsumer network.ViolationsConsumer allMsgConfigs []message.MsgAuthConfig codec network.Codec } @@ -54,7 +58,6 @@ func (s *TestAuthorizedSenderValidatorSuite) SetupTest() { s.initializeInvalidMessageOnChannelTestCases() s.initializeUnicastOnChannelTestCases() s.log = unittest.Logger() - s.slashingViolationsConsumer = slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector()) s.codec = unittest.NetworkCodec() } @@ -64,37 +67,64 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedSen for _, c := range s.authorizedSenderTestCases { str := fmt.Sprintf("role (%s) should be authorized to send message type (%s) on channel (%s)", c.Identity.Role, c.MessageStr, c.Channel) s.Run(str, func() { - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) - + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) + validateUnicast := authorizedSenderValidator.Validate + validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel) pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - + switch { // ensure according to the message auth config, if a message is authorized to be sent via unicast it - // is accepted or rejected. - msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) - if c.Protocols.Contains(message.ProtocolTypeUnicast) { + // is accepted. + case c.Protocols.Contains(message.ProtocolTypeUnicast): + msgType, err := validateUnicast(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) + if c.Protocols.Contains(message.ProtocolTypeUnicast) { + require.NoError(s.T(), err) + require.Equal(s.T(), c.MessageStr, msgType) + } + // ensure according to the message auth config, if a message is authorized to be sent via pubsub it + // is accepted. + case c.Protocols.Contains(message.ProtocolTypePubSub): + payload, err := s.codec.Encode(c.Message) require.NoError(s.T(), err) - require.Equal(s.T(), c.MessageStr, msgType) - } else { - require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) - require.Equal(s.T(), c.MessageStr, msgType) - } - - payload, err := s.codec.Encode(c.Message) - require.NoError(s.T(), err) - m := &message.Message{ - ChannelID: c.Channel.String(), - Payload: payload, - } - validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel) - pubsubResult := validatePubsub(pid, m) - if !c.Protocols.Contains(message.ProtocolTypePubSub) { - require.Equal(s.T(), p2p.ValidationReject, pubsubResult) - } else { + m := &message.Message{ + ChannelID: c.Channel.String(), + Payload: payload, + } + pubsubResult := validatePubsub(pid, m) require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) + default: + s.T().Fatal("authconfig does not contain any protocols") } }) } + + s.Run("test messages should be allowed to be sent via both protocols unicast/pubsub on test channel", func() { + identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleCollection)) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + getIdentityFunc := s.getIdentity(identity) + pid, err := unittest.PeerIDFromFlowID(identity) + require.NoError(s.T(), err) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) + + msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeEcho.Uint8()}, channels.TestNetworkChannel, message.ProtocolTypeUnicast) + require.NoError(s.T(), err) + require.Equal(s.T(), "*message.TestMessage", msgType) + + payload, err := s.codec.Encode(&libp2pmessage.TestMessage{}) + require.NoError(s.T(), err) + m := &message.Message{ + ChannelID: channels.TestNetworkChannel.String(), + Payload: payload, + } + validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.TestNetworkChannel) + pubsubResult := validatePubsub(pid, m) + require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) + }) } // TestValidatorCallback_UnAuthorizedSender checks that AuthorizedSenderValidator.Validate return's p2p.ValidationReject @@ -105,8 +135,12 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedS s.Run(str, func() { pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) payload, err := s.codec.Encode(c.Message) require.NoError(s.T(), err) @@ -129,8 +163,10 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedUni s.Run(str, func() { pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) require.NoError(s.T(), err) @@ -147,8 +183,12 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedU s.Run(str, func() { pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedUnicastOnChannel) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) @@ -165,8 +205,12 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedM s.Run(str, func() { pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Twice() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) require.ErrorIs(s.T(), err, message.ErrUnauthorizedMessageOnChannel) @@ -195,10 +239,22 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ClusterPrefix pid, err := unittest.PeerIDFromFlowID(identity) require.NoError(s.T(), err) - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, getIdentityFunc) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnauthorizedUnicastOnChannel) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCluster(clusterID), expectedMisbehaviorReport).Once() + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCluster(clusterID), expectedMisbehaviorReport).Once() + + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) + + // validate collection sync cluster SyncRequest is not allowed to be sent on channel via unicast + msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCluster(clusterID), message.ProtocolTypeUnicast) + require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) + require.Equal(s.T(), "*messages.SyncRequest", msgType) // ensure ClusterBlockProposal not allowed to be sent on channel via unicast - msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeClusterBlockProposal.Uint8()}, channels.ConsensusCluster(clusterID), message.ProtocolTypeUnicast) + msgType, err = authorizedSenderValidator.Validate(pid, []byte{codec.CodeClusterBlockProposal.Uint8()}, channels.ConsensusCluster(clusterID), message.ProtocolTypeUnicast) require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) require.Equal(s.T(), "*messages.ClusterBlockProposal", msgType) @@ -213,11 +269,6 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ClusterPrefix pubsubResult := validateCollConsensusPubsub(pid, m) require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) - // validate collection sync cluster SyncRequest is not allowed to be sent on channel via unicast - msgType, err = authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCluster(clusterID), message.ProtocolTypeUnicast) - require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) - require.Equal(s.T(), "*messages.SyncRequest", msgType) - // ensure SyncRequest is allowed to be sent via pubsub by authorized sender payload, err = s.codec.Encode(&messages.SyncRequest{}) require.NoError(s.T(), err) @@ -239,7 +290,12 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ValidationFai pid, err := unittest.PeerIDFromFlowID(identity) require.NoError(s.T(), err) - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, getIdentityFunc) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.SenderEjected) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCommittee, expectedMisbehaviorReport).Twice() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast) require.ErrorIs(s.T(), err, ErrSenderEjected) @@ -263,7 +319,12 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ValidationFai pid, err := unittest.PeerIDFromFlowID(identity) require.NoError(s.T(), err) - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, getIdentityFunc) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnknownMsgType) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCommittee, expectedMisbehaviorReport).Twice() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.ConsensusCommittee) // unknown message types are rejected @@ -291,7 +352,11 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ValidationFai pid, err := unittest.PeerIDFromFlowID(identity) require.NoError(s.T(), err) - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, getIdentityFunc) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + // we cannot penalize a peer if identity is not known, in this case we don't expect any misbehavior reports to be reported + defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast) require.ErrorIs(s.T(), err, ErrIdentityUnverified) @@ -314,17 +379,21 @@ func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnauthorizedP for _, c := range s.authorizedUnicastOnChannel { str := fmt.Sprintf("message type (%s) is not authorized to be sent via libp2p publish", c.MessageStr) s.Run(str, func() { + // skip test message check + if c.MessageStr == "*message.TestMessage" { + return + } pid, err := unittest.PeerIDFromFlowID(c.Identity) require.NoError(s.T(), err) - - authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, s.slashingViolationsConsumer, c.GetIdentity) + expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedPublishOnChannel) + require.NoError(s.T(), err) + misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) + misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() + violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) + authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypePubSub) - if c.MessageStr == "*message.TestMessage" { - require.NoError(s.T(), err) - } else { - require.ErrorIs(s.T(), err, message.ErrUnauthorizedPublishOnChannel) - require.Equal(s.T(), c.MessageStr, msgType) - } + require.ErrorIs(s.T(), err, message.ErrUnauthorizedPublishOnChannel) + require.Equal(s.T(), c.MessageStr, msgType) }) } } diff --git a/network/slashing/violations_consumer.go b/network/violations_consumer.go similarity index 61% rename from network/slashing/violations_consumer.go rename to network/violations_consumer.go index cf1f8ea7d85..6c3de412c77 100644 --- a/network/slashing/violations_consumer.go +++ b/network/violations_consumer.go @@ -1,4 +1,4 @@ -package slashing +package network import ( "github.com/onflow/flow-go/model/flow" @@ -6,24 +6,30 @@ import ( "github.com/onflow/flow-go/network/message" ) +// ViolationsConsumer logs reported slashing violation errors and reports those violations as misbehavior's to the ALSP +// misbehavior report manager. Any errors encountered while reporting the misbehavior are considered irrecoverable and +// will result in a fatal level log. type ViolationsConsumer interface { - // OnUnAuthorizedSenderError logs an error for unauthorized sender error + // OnUnAuthorizedSenderError logs an error for unauthorized sender error. OnUnAuthorizedSenderError(violation *Violation) - // OnUnknownMsgTypeError logs an error for unknown message type error + // OnUnknownMsgTypeError logs an error for unknown message type error. OnUnknownMsgTypeError(violation *Violation) // OnInvalidMsgError logs an error for messages that contained payloads that could not // be unmarshalled into the message type denoted by message code byte. OnInvalidMsgError(violation *Violation) - // OnSenderEjectedError logs an error for sender ejected error + // OnSenderEjectedError logs an error for sender ejected error. OnSenderEjectedError(violation *Violation) - // OnUnauthorizedUnicastOnChannel logs an error for messages unauthorized to be sent via unicast + // OnUnauthorizedUnicastOnChannel logs an error for messages unauthorized to be sent via unicast. OnUnauthorizedUnicastOnChannel(violation *Violation) - // OnUnexpectedError logs an error for unknown errors + // OnUnauthorizedPublishOnChannel logs an error for messages unauthorized to be sent via pubsub. + OnUnauthorizedPublishOnChannel(violation *Violation) + + // OnUnexpectedError logs an error for unknown errors. OnUnexpectedError(violation *Violation) } diff --git a/utils/unittest/unittest.go b/utils/unittest/unittest.go index 459a4db0e16..0d9949ffc2d 100644 --- a/utils/unittest/unittest.go +++ b/utils/unittest/unittest.go @@ -21,6 +21,7 @@ import ( "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/util" "github.com/onflow/flow-go/network" + "github.com/onflow/flow-go/network/channels" cborcodec "github.com/onflow/flow-go/network/codec/cbor" "github.com/onflow/flow-go/network/slashing" "github.com/onflow/flow-go/network/topology" @@ -438,6 +439,18 @@ func GenerateRandomStringWithLen(commentLen uint) string { } // NetworkSlashingViolationsConsumer returns a slashing violations consumer for network middleware -func NetworkSlashingViolationsConsumer(logger zerolog.Logger, metrics module.NetworkSecurityMetrics) slashing.ViolationsConsumer { - return slashing.NewSlashingViolationsConsumer(logger, metrics) +func NetworkSlashingViolationsConsumer(logger zerolog.Logger, metrics module.NetworkSecurityMetrics, consumer network.MisbehaviorReportConsumer) network.ViolationsConsumer { + return slashing.NewSlashingViolationsConsumer(logger, metrics, consumer) +} + +type MisbehaviorReportConsumerFixture struct { + network.MisbehaviorReportManager +} + +func (c *MisbehaviorReportConsumerFixture) ReportMisbehaviorOnChannel(channel channels.Channel, report network.MisbehaviorReport) { + c.HandleMisbehaviorReport(channel, report) +} + +func NewMisbehaviorReportConsumerFixture(manager network.MisbehaviorReportManager) *MisbehaviorReportConsumerFixture { + return &MisbehaviorReportConsumerFixture{manager} }