Skip to content

Commit

Permalink
fix(dot/network): fix bugs in notifications protocol handlers; add me…
Browse files Browse the repository at this point in the history
…trics for inbound/outbound streams (#2010)
  • Loading branch information
noot authored Nov 18, 2021
1 parent 84ec792 commit 8c2993d
Show file tree
Hide file tree
Showing 17 changed files with 443 additions and 301 deletions.
2 changes: 1 addition & 1 deletion dot/network/block_announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
inboundHandshakeData: new(sync.Map),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, handshakeData{})
nodeA.notificationsProtocols[BlockAnnounceMsgType].inboundHandshakeData.Store(testPeerID, &handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
39 changes: 13 additions & 26 deletions dot/network/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
ma "github.com/multiformats/go-multiaddr"

"github.com/ChainSafe/gossamer/dot/peerset"
Expand All @@ -23,11 +22,9 @@ type ConnManager struct {
sync.Mutex
host *host
min, max int
connectHandler func(peer.ID)
disconnectHandler func(peer.ID)

// closeHandlerMap contains close handler corresponding to a protocol.
closeHandlerMap map[protocol.ID]func(peerID peer.ID)

// protectedPeers contains a list of peers that are protected from pruning
// when we reach the maximum numbers of peers.
protectedPeers *sync.Map // map[peer.ID]struct{}
Expand All @@ -47,7 +44,6 @@ func newConnManager(min, max int, peerSetCfg *peerset.ConfigSet) (*ConnManager,
return &ConnManager{
min: min,
max: max,
closeHandlerMap: make(map[protocol.ID]func(peerID peer.ID)),
protectedPeers: new(sync.Map),
persistentPeers: new(sync.Map),
peerSetHandler: psh,
Expand All @@ -68,19 +64,19 @@ func (cm *ConnManager) Notifee() network.Notifiee {
return nb
}

// TagPeer peer
// TagPeer is unimplemented
func (*ConnManager) TagPeer(peer.ID, string, int) {}

// UntagPeer peer
// UntagPeer is unimplemented
func (*ConnManager) UntagPeer(peer.ID, string) {}

// UpsertTag peer
// UpsertTag is unimplemented
func (*ConnManager) UpsertTag(peer.ID, string, func(int) int) {}

// GetTagInfo peer
// GetTagInfo is unimplemented
func (*ConnManager) GetTagInfo(peer.ID) *connmgr.TagInfo { return &connmgr.TagInfo{} }

// TrimOpenConns peer
// TrimOpenConns is unimplemented
func (*ConnManager) TrimOpenConns(context.Context) {}

// Protect peer will add the given peer to the protectedPeerMap which will
Expand All @@ -97,7 +93,7 @@ func (cm *ConnManager) Unprotect(id peer.ID, _ string) bool {
return wasDeleted
}

// Close peer
// Close is unimplemented
func (*ConnManager) Close() error { return nil }

// IsProtected returns whether the given peer is protected from pruning or not.
Expand Down Expand Up @@ -134,6 +130,7 @@ func (cm *ConnManager) unprotectedPeers(peers []peer.ID) []peer.ID {
func (cm *ConnManager) Connected(n network.Network, c network.Conn) {
logger.Tracef(
"Host %s connected to peer %s", n.LocalPeer(), c.RemotePeer())
cm.connectHandler(c.RemotePeer())

cm.Lock()
defer cm.Unlock()
Expand All @@ -143,7 +140,9 @@ func (cm *ConnManager) Connected(n network.Network, c network.Conn) {
return
}

// TODO: peer scoring doesn't seem to prevent us from going over the max.
// if over the max peer count, disconnect from (total_peers - maximum) peers
// (#2039)
for i := 0; i < over; i++ {
unprotPeers := cm.unprotectedPeers(n.Peers())
if len(unprotPeers) == 0 {
Expand All @@ -170,31 +169,19 @@ func (cm *ConnManager) Disconnected(_ network.Network, c network.Conn) {
logger.Tracef("Host %s disconnected from peer %s", c.LocalPeer(), c.RemotePeer())

cm.Unprotect(c.RemotePeer(), "")
if cm.disconnectHandler != nil {
cm.disconnectHandler(c.RemotePeer())
}
cm.disconnectHandler(c.RemotePeer())
}

// OpenedStream is called when a stream opened
// OpenedStream is called when a stream is opened
func (cm *ConnManager) OpenedStream(_ network.Network, s network.Stream) {
logger.Tracef("Stream opened with peer %s using protocol %s",
s.Conn().RemotePeer(), s.Protocol())
}

func (cm *ConnManager) registerCloseHandler(protocolID protocol.ID, cb func(id peer.ID)) {
cm.closeHandlerMap[protocolID] = cb
}

// ClosedStream is called when a stream closed
// ClosedStream is called when a stream is closed
func (cm *ConnManager) ClosedStream(_ network.Network, s network.Stream) {
logger.Tracef("Stream closed with peer %s using protocol %s",
s.Conn().RemotePeer(), s.Protocol())

cm.Lock()
defer cm.Unlock()
if closeCB, ok := cm.closeHandlerMap[s.Protocol()]; ok {
closeCB(s.Conn().RemotePeer())
}
}

func (cm *ConnManager) isPersistent(p peer.ID) bool {
Expand Down
13 changes: 13 additions & 0 deletions dot/network/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,21 @@ func (d *discovery) findPeers(ctx context.Context) {

logger.Tracef("found new peer %s via DHT", peer.ID)

// TODO: this isn't working on the devnet (#2026)
// can remove the code block below which directly connects
// once that's fixed
d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL)
d.handler.AddPeer(0, peer.ID)

// found a peer, try to connect if we need more peers
if len(d.h.Network().Peers()) >= d.maxPeers {
d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL)
return
}

if err = d.h.Connect(d.ctx, peer); err != nil {
logger.Tracef("failed to connect to discovered peer %s: %s", peer.ID, err)
}
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions dot/network/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2021 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package network

import (
"errors"
)

var (
errCannotValidateHandshake = errors.New("failed to validate handshake")
errMessageTypeNotValid = errors.New("message type is not valid")
errMessageIsNotHandshake = errors.New("failed to convert message to Handshake")
errMissingHandshakeMutex = errors.New("outboundHandshakeMutex does not exist")
errInvalidHandshakeForPeer = errors.New("peer previously sent invalid handshake")
errHandshakeTimeout = errors.New("handshake timeout reached")
)
2 changes: 1 addition & 1 deletion dot/network/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.inboundHandshakeData.Store(nodeB.host.id(), handshakeData{
info.inboundHandshakeData.Store(nodeB.host.id(), &handshakeData{
received: true,
validated: true,
})
Expand Down
76 changes: 76 additions & 0 deletions dot/network/inbound.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2021 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package network

import (
libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
)

func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) {
// we NEED to reset the stream if we ever return from this function, as if we return,
// the stream will never again be read by us, so we need to tell the remote side we're
// done with this stream, and they should also forget about it.
defer s.resetInboundStream(stream)
s.streamManager.logNewStream(stream)

peer := stream.Conn().RemotePeer()
msgBytes := s.bufPool.get()
defer s.bufPool.put(msgBytes)

for {
n, err := readStream(stream, msgBytes[:])
if err != nil {
logger.Tracef(
"failed to read from stream id %s of peer %s using protocol %s: %s",
stream.ID(), stream.Conn().RemotePeer(), stream.Protocol(), err)
return
}

s.streamManager.logMessageReceived(stream.ID())

// decode message based on message type
msg, err := decoder(msgBytes[:n], peer, isInbound(stream)) // stream should always be inbound if it passes through service.readStream
if err != nil {
logger.Tracef("failed to decode message from stream id %s using protocol %s: %s",
stream.ID(), stream.Protocol(), err)
continue
}

logger.Tracef(
"host %s received message from peer %s: %s",
s.host.id(), peer, msg)

if err = handler(stream, msg); err != nil {
logger.Tracef("failed to handle message %s from stream id %s: %s", msg, stream.ID(), err)
return
}

s.host.bwc.LogRecvMessage(int64(n))
}
}

func (s *Service) resetInboundStream(stream libp2pnetwork.Stream) {
protocolID := stream.Protocol()
peerID := stream.Conn().RemotePeer()

s.notificationsMu.Lock()
defer s.notificationsMu.Unlock()

for _, prtl := range s.notificationsProtocols {
if prtl.protocolID != protocolID {
continue
}

prtl.inboundHandshakeData.Delete(peerID)
break
}

logger.Debugf(
"cleaning up inbound handshake data for protocol=%s, peer=%s",
stream.Protocol(),
peerID,
)

_ = stream.Reset()
}
75 changes: 75 additions & 0 deletions dot/network/light.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,71 @@ import (
"github.com/ChainSafe/gossamer/dot/types"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"

libp2pnetwork "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
)

// handleLightStream handles streams with the <protocol-id>/light/2 protocol ID
func (s *Service) handleLightStream(stream libp2pnetwork.Stream) {
s.readStream(stream, s.decodeLightMessage, s.handleLightMsg)
}

func (s *Service) decodeLightMessage(in []byte, peer peer.ID, _ bool) (Message, error) {
s.lightRequestMu.RLock()
defer s.lightRequestMu.RUnlock()

// check if we are the requester
if _, ok := s.lightRequest[peer]; ok {
// if we are, decode the bytes as a LightResponse
return newLightResponseFromBytes(in)
}

// otherwise, decode bytes as LightRequest
return newLightRequestFromBytes(in)
}

func (s *Service) handleLightMsg(stream libp2pnetwork.Stream, msg Message) (err error) {
defer func() {
_ = stream.Close()
}()

lr, ok := msg.(*LightRequest)
if !ok {
return nil
}

resp := NewLightResponse()
switch {
case lr.RemoteCallRequest != nil:
resp.RemoteCallResponse, err = remoteCallResp(lr.RemoteCallRequest)
case lr.RemoteHeaderRequest != nil:
resp.RemoteHeaderResponse, err = remoteHeaderResp(lr.RemoteHeaderRequest)
case lr.RemoteChangesRequest != nil:
resp.RemoteChangesResponse, err = remoteChangeResp(lr.RemoteChangesRequest)
case lr.RemoteReadRequest != nil:
resp.RemoteReadResponse, err = remoteReadResp(lr.RemoteReadRequest)
case lr.RemoteReadChildRequest != nil:
resp.RemoteReadResponse, err = remoteReadChildResp(lr.RemoteReadChildRequest)
default:
logger.Warn("ignoring LightRequest without request data")
return nil
}

if err != nil {
return err
}

// TODO(arijit): Remove once we implement the internal APIs. Added to increase code coverage. (#1856)
logger.Debugf("LightResponse message: %s", resp)

err = s.host.writeToStream(stream, resp)
if err != nil {
logger.Warnf("failed to send LightResponse message to peer %s: %s", stream.Conn().RemotePeer(), err)
}
return err
}

// Pair is a pair of arbitrary bytes.
type Pair struct {
First []byte
Expand Down Expand Up @@ -46,6 +109,12 @@ func NewLightRequest() *LightRequest {
}
}

func newLightRequestFromBytes(in []byte) (msg *LightRequest, err error) {
msg = NewLightRequest()
err = msg.Decode(in)
return msg, err
}

func newRequest() *request {
return &request{
RemoteCallRequest: *newRemoteCallRequest(),
Expand Down Expand Up @@ -122,6 +191,12 @@ func NewLightResponse() *LightResponse {
}
}

func newLightResponseFromBytes(in []byte) (msg *LightResponse, err error) {
msg = NewLightResponse()
err = msg.Decode(in)
return msg, err
}

func newResponse() *response {
return &response{
RemoteCallResponse: *newRemoteCallResponse(),
Expand Down
Loading

0 comments on commit 8c2993d

Please sign in to comment.