diff --git a/core/network/context.go b/core/network/context.go index 7fabfb53e0..75db775932 100644 --- a/core/network/context.go +++ b/core/network/context.go @@ -13,12 +13,12 @@ var DialPeerTimeout = 60 * time.Second type noDialCtxKey struct{} type dialPeerTimeoutCtxKey struct{} type forceDirectDialCtxKey struct{} -type useTransientCtxKey struct{} +type allowLimitedConnCtxKey struct{} type simConnectCtxKey struct{ isClient bool } var noDial = noDialCtxKey{} var forceDirectDial = forceDirectDialCtxKey{} -var useTransient = useTransientCtxKey{} +var allowLimitedConn = allowLimitedConnCtxKey{} var simConnectIsServer = simConnectCtxKey{} var simConnectIsClient = simConnectCtxKey{isClient: true} @@ -94,15 +94,35 @@ func WithDialPeerTimeout(ctx context.Context, timeout time.Duration) context.Con return context.WithValue(ctx, dialPeerTimeoutCtxKey{}, timeout) } +// WithAllowLimitedConn constructs a new context with an option that instructs +// the network that it is acceptable to use a limited connection when opening a +// new stream. +func WithAllowLimitedConn(ctx context.Context, reason string) context.Context { + return context.WithValue(ctx, allowLimitedConn, reason) +} + // WithUseTransient constructs a new context with an option that instructs the network // that it is acceptable to use a transient connection when opening a new stream. +// +// Deprecated: Use WithAllowLimitedConn instead. func WithUseTransient(ctx context.Context, reason string) context.Context { - return context.WithValue(ctx, useTransient, reason) + return context.WithValue(ctx, allowLimitedConn, reason) +} + +// GetAllowLimitedConn returns true if the allow limited conn option is set in the context. +func GetAllowLimitedConn(ctx context.Context) (usetransient bool, reason string) { + v := ctx.Value(allowLimitedConn) + if v != nil { + return true, v.(string) + } + return false, "" } // GetUseTransient returns true if the use transient option is set in the context. +// +// Deprecated: Use GetAllowLimitedConn instead. func GetUseTransient(ctx context.Context) (usetransient bool, reason string) { - v := ctx.Value(useTransient) + v := ctx.Value(allowLimitedConn) if v != nil { return true, v.(string) } diff --git a/core/network/errors.go b/core/network/errors.go index 03bb90c266..0f98cd5a28 100644 --- a/core/network/errors.go +++ b/core/network/errors.go @@ -22,7 +22,13 @@ var ErrNoConn = errors.New("no usable connection to peer") // ErrTransientConn is returned when attempting to open a stream to a peer with only a transient // connection, without specifying the UseTransient option. -var ErrTransientConn = errors.New("transient connection to peer") +// +// Deprecated: Use ErrLimitedConn instead. +var ErrTransientConn = ErrLimitedConn + +// ErrLimitedConn is returned when attempting to open a stream to a peer with only a conn +// connection, without specifying the AllowLimitedConn option. +var ErrLimitedConn = errors.New("limited connection to peer") // ErrResourceLimitExceeded is returned when attempting to perform an operation that would // exceed system resource limits. diff --git a/core/network/network.go b/core/network/network.go index 66b0a1cd34..22efbf235d 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -55,16 +55,23 @@ const ( // Connected means has an open, live connection to peer Connected + // Deprecated: CanConnect is deprecated and will be removed in a future release. + // // CanConnect means recently connected to peer, terminated gracefully CanConnect + // Deprecated: CannotConnect is deprecated and will be removed in a future release. + // // CannotConnect means recently attempted connecting but failed to connect. // (should signal "made effort, failed") CannotConnect + + // Limited means we have a transient connection to the peer, but aren't fully connected. + Limited ) func (c Connectedness) String() string { - str := [...]string{"NotConnected", "Connected", "CanConnect", "CannotConnect"} + str := [...]string{"NotConnected", "Connected", "CanConnect", "CannotConnect", "Limited"} if c < 0 || int(c) >= len(str) { return unrecognized } @@ -111,8 +118,10 @@ type Stats struct { Direction Direction // Opened is the timestamp when this connection was opened. Opened time.Time - // Transient indicates that this connection is transient and may be closed soon. - Transient bool + // Limited indicates that this connection is Limited. It maybe limited by + // bytes or time. In practice, this is a connection formed over a circuit v2 + // relay. + Limited bool // Extra stores additional metadata about this connection. Extra map[interface{}]interface{} } diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 367fca05f2..d80445e32e 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -724,8 +724,10 @@ func (h *BasicHost) Connect(ctx context.Context, pi peer.AddrInfo) error { h.Peerstore().AddAddrs(pi.ID, pi.Addrs, peerstore.TempAddrTTL) forceDirect, _ := network.GetForceDirectDial(ctx) + canUseLimitedConn, _ := network.GetAllowLimitedConn(ctx) if !forceDirect { - if h.Network().Connectedness(pi.ID) == network.Connected { + connectedness := h.Network().Connectedness(pi.ID) + if connectedness == network.Connected || (canUseLimitedConn && connectedness == network.Limited) { return nil } } diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 0fdded30ff..0cf6642f69 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -63,9 +63,10 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { } bh := &BlankHost{ - n: n, - cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer[protocol.ID](), + n: n, + cmgr: cfg.cmgr, + mux: mstream.NewMultistreamMuxer[protocol.ID](), + eventbus: cfg.eventBus, } if bh.eventbus == nil { bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) diff --git a/p2p/host/pstoremanager/pstoremanager.go b/p2p/host/pstoremanager/pstoremanager.go index 2a22b2caee..93cc2a98d9 100644 --- a/p2p/host/pstoremanager/pstoremanager.go +++ b/p2p/host/pstoremanager/pstoremanager.go @@ -103,15 +103,16 @@ func (m *PeerstoreManager) background(ctx context.Context, sub event.Subscriptio ev := e.(event.EvtPeerConnectednessChanged) p := ev.Peer switch ev.Connectedness { - case network.NotConnected: + case network.Connected, network.Limited: + // If we reconnect to the peer before we've cleared the information, + // keep it. This is an optimization to keep the disconnected map + // small. We still need to check that a peer is actually + // disconnected before removing it from the peer store. + delete(disconnected, p) + default: if _, ok := disconnected[p]; !ok { disconnected[p] = time.Now() } - case network.Connected: - // If we reconnect to the peer before we've cleared the information, keep it. - // This is an optimization to keep the disconnected map small. - // We still need to check that a peer is actually disconnected before removing it from the peer store. - delete(disconnected, p) } case <-ticker.C: now := time.Now() diff --git a/p2p/host/routed/routed.go b/p2p/host/routed/routed.go index eb8e58ee7f..8248e50f0e 100644 --- a/p2p/host/routed/routed.go +++ b/p2p/host/routed/routed.go @@ -48,8 +48,10 @@ func Wrap(h host.Host, r Routing) *RoutedHost { func (rh *RoutedHost) Connect(ctx context.Context, pi peer.AddrInfo) error { // first, check if we're already connected unless force direct dial. forceDirect, _ := network.GetForceDirectDial(ctx) + canUseLimitedConn, _ := network.GetAllowLimitedConn(ctx) if !forceDirect { - if rh.Network().Connectedness(pi.ID) == network.Connected { + connectedness := rh.Network().Connectedness(pi.ID) + if connectedness == network.Connected || (canUseLimitedConn && connectedness == network.Limited) { return nil } } diff --git a/p2p/net/swarm/connectedness_event_emitter.go b/p2p/net/swarm/connectedness_event_emitter.go new file mode 100644 index 0000000000..07db583fc9 --- /dev/null +++ b/p2p/net/swarm/connectedness_event_emitter.go @@ -0,0 +1,143 @@ +package swarm + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +// connectednessEventEmitter emits PeerConnectednessChanged events. +// We ensure that for any peer we connected to we always sent atleast 1 NotConnected Event after +// the peer disconnects. This is because peers can observe a connection before they are notified +// of the connection by a peer connectedness changed event. +type connectednessEventEmitter struct { + mx sync.RWMutex + // newConns is the channel that holds the peerIDs we recently connected to + newConns chan peer.ID + removeConnsMx sync.Mutex + // removeConns is a slice of peerIDs we have recently closed connections to + removeConns []peer.ID + // lastEvent is the last connectedness event sent for a particular peer. + lastEvent map[peer.ID]network.Connectedness + // connectedness is the function that gives the peers current connectedness state + connectedness func(peer.ID) network.Connectedness + // emitter is the PeerConnectednessChanged event emitter + emitter event.Emitter + wg sync.WaitGroup + removeConnNotif chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +func newConnectednessEventEmitter(connectedness func(peer.ID) network.Connectedness, emitter event.Emitter) *connectednessEventEmitter { + ctx, cancel := context.WithCancel(context.Background()) + c := &connectednessEventEmitter{ + newConns: make(chan peer.ID, 32), + lastEvent: make(map[peer.ID]network.Connectedness), + removeConnNotif: make(chan struct{}, 1), + connectedness: connectedness, + emitter: emitter, + ctx: ctx, + cancel: cancel, + } + c.wg.Add(1) + go c.runEmitter() + return c +} + +func (c *connectednessEventEmitter) AddConn(p peer.ID) { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return + } + + c.newConns <- p +} + +func (c *connectednessEventEmitter) RemoveConn(p peer.ID) { + c.mx.RLock() + defer c.mx.RUnlock() + if c.ctx.Err() != nil { + return + } + + c.removeConnsMx.Lock() + // This queue is roughly bounded by the total number of added connections we + // have. If consumers of connectedness events are slow, we apply + // backpressure to AddConn operations. + // + // We purposefully don't block/backpressure here to avoid deadlocks, since it's + // reasonable for a consumer of the event to want to remove a connection. + c.removeConns = append(c.removeConns, p) + + c.removeConnsMx.Unlock() + + select { + case c.removeConnNotif <- struct{}{}: + default: + } +} + +func (c *connectednessEventEmitter) Close() { + c.cancel() + c.wg.Wait() +} + +func (c *connectednessEventEmitter) runEmitter() { + defer c.wg.Done() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + case <-c.ctx.Done(): + c.mx.Lock() // Wait for all pending AddConn & RemoveConn operations to complete + defer c.mx.Unlock() + for { + select { + case p := <-c.newConns: + c.notifyPeer(p, true) + case <-c.removeConnNotif: + c.sendConnRemovedNotifications() + default: + return + } + } + } + } +} + +// notifyPeer sends the peer connectedness event using the emitter. +// Use forceNotConnectedEvent = true to send a NotConnected event even if +// no Connected event was sent for this peer. +// In case a peer is disconnected before we sent the Connected event, we still +// send the Disconnected event because a connection to the peer can be observed +// in such cases. +func (c *connectednessEventEmitter) notifyPeer(p peer.ID, forceNotConnectedEvent bool) { + oldState := c.lastEvent[p] + c.lastEvent[p] = c.connectedness(p) + if c.lastEvent[p] == network.NotConnected { + delete(c.lastEvent, p) + } + if (forceNotConnectedEvent && c.lastEvent[p] == network.NotConnected) || c.lastEvent[p] != oldState { + c.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: c.lastEvent[p], + }) + } +} + +func (c *connectednessEventEmitter) sendConnRemovedNotifications() { + c.removeConnsMx.Lock() + removeConns := c.removeConns + c.removeConns = nil + c.removeConnsMx.Unlock() + for _, p := range removeConns { + c.notifyPeer(p, false) + } +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 0140c3f596..7897277cc7 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -203,9 +203,10 @@ type Swarm struct { dialRanker network.DialRanker - udpBlackHoleConfig blackHoleConfig - ipv6BlackHoleConfig blackHoleConfig - bhd *blackHoleDetector + udpBlackHoleConfig blackHoleConfig + ipv6BlackHoleConfig blackHoleConfig + bhd *blackHoleDetector + connectednessEventEmitter *connectednessEventEmitter } // NewSwarm constructs a Swarm. @@ -238,6 +239,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) s.directConnNotifs.m = make(map[peer.ID][]chan struct{}) + s.connectednessEventEmitter = newConnectednessEventEmitter(s.Connectedness, emitter) for _, opt := range opts { if err := opt(s); err != nil { @@ -254,7 +256,6 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.backf.init(s.ctx) s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer) - return s, nil } @@ -271,8 +272,6 @@ func (s *Swarm) Done() <-chan struct{} { func (s *Swarm) close() { s.ctxCancel() - s.emitter.Close() - // Prevents new connections and/or listeners from being added to the swarm. s.listeners.Lock() listeners := s.listeners.m @@ -308,6 +307,8 @@ func (s *Swarm) close() { // Wait for everything to finish. s.refs.Wait() + s.connectednessEventEmitter.Close() + s.emitter.Close() // Now close out any transports (if necessary). Do this after closing // all connections/listeners. @@ -350,6 +351,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, } stat.Direction = dir stat.Opened = time.Now() + isLimited := stat.Limited // Wrap and register the connection. c := &Conn{ @@ -390,21 +392,24 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, } c.streams.m = make(map[*Stream]struct{}) - isFirstConnection := len(s.conns.m[p]) == 0 s.conns.m[p] = append(s.conns.m[p], c) - // Add two swarm refs: // * One will be decremented after the close notifications fire in Conn.doClose // * The other will be decremented when Conn.start exits. s.refs.Add(2) - // Take the notification lock before releasing the conns lock to block // Disconnect notifications until after the Connect notifications done. + // This lock also ensures that swarm.refs.Wait() exits after we have + // enqueued the peer connectedness changed notification. + // TODO: Fix this fragility by taking a swarm ref for dial worker loop c.notifyLk.Lock() s.conns.Unlock() - // Notify goroutines waiting for a direct connection - if !c.Stat().Transient { + s.connectednessEventEmitter.AddConn(p) + + if !isLimited { + // Notify goroutines waiting for a direct connection + // // Go routines interested in waiting for direct connection first acquire this lock // and then acquire s.conns.RLock. Do not acquire this lock before conns.Unlock to // prevent deadlock. @@ -415,16 +420,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, delete(s.directConnNotifs.m, p) s.directConnNotifs.Unlock() } - - // Emit event after releasing `s.conns` lock so that a consumer can still - // use swarm methods that need the `s.conns` lock. - if isFirstConnection { - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.Connected, - }) - } - s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) @@ -455,14 +450,14 @@ func (s *Swarm) StreamHandler() network.StreamHandler { // NewStream creates a new stream on any available connection to peer, dialing // if necessary. -// Use network.WithUseTransient to open a stream over a transient(relayed) +// Use network.WithAllowLimitedConn to open a stream over a limited(relayed) // connection. func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) { log.Debugf("[%s] opening stream to peer [%s]", s.local, p) // Algorithm: // 1. Find the best connection, otherwise, dial. - // 2. If the best connection is transient, wait for a direct conn via conn + // 2. If the best connection is limited, wait for a direct conn via conn // reversal or hole punching. // 3. Try opening a stream. // 4. If the underlying connection is, in fact, closed, close the outer @@ -491,8 +486,8 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error } } - useTransient, _ := network.GetUseTransient(ctx) - if !useTransient && c.Stat().Transient { + limitedAllowed, _ := network.GetAllowLimitedConn(ctx) + if !limitedAllowed && c.Stat().Limited { var err error c, err = s.waitForDirectConn(ctx, p) if err != nil { @@ -518,12 +513,12 @@ func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) if c == nil { s.directConnNotifs.Unlock() return nil, network.ErrNoConn - } else if !c.Stat().Transient { + } else if !c.Stat().Limited { s.directConnNotifs.Unlock() return c, nil } - // Wait for transient connection to upgrade to a direct connection either by + // Wait for limited connection to upgrade to a direct connection either by // connection reversal or hole punching. ch := make(chan struct{}) s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch) @@ -555,8 +550,8 @@ func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) if c == nil { return nil, network.ErrNoConn } - if c.Stat().Transient { - return nil, network.ErrTransientConn + if c.Stat().Limited { + return nil, network.ErrLimitedConn } return c, nil } @@ -577,11 +572,11 @@ func (s *Swarm) ConnsToPeer(p peer.ID) []network.Conn { } func isBetterConn(a, b *Conn) bool { - // If one is transient and not the other, prefer the non-transient connection. - aTransient := a.Stat().Transient - bTransient := b.Stat().Transient - if aTransient != bTransient { - return !aTransient + // If one is limited and not the other, prefer the unlimited connection. + aLimited := a.Stat().Limited + bLimited := b.Stat().Limited + if aLimited != bLimited { + return !aLimited } // If one is direct and not the other, prefer the direct connection. @@ -632,7 +627,7 @@ func (s *Swarm) bestConnToPeer(p peer.ID) *Conn { // bestAcceptableConnToPeer returns the best acceptable connection, considering the passed in ctx. // If network.WithForceDirectDial is used, it only returns a direct connections, ignoring -// any transient (relayed) connections to the peer. +// any limited (relayed) connections to the peer. func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn { conn := s.bestConnToPeer(p) @@ -652,8 +647,28 @@ func isDirectConn(c *Conn) bool { // To check if we have an open connection, use `s.Connectedness(p) == // network.Connected`. func (s *Swarm) Connectedness(p peer.ID) network.Connectedness { - if s.bestConnToPeer(p) != nil { - return network.Connected + s.conns.RLock() + defer s.conns.RUnlock() + + return s.connectednessUnlocked(p) +} + +// connectednessUnlocked returns the connectedness of a peer. +func (s *Swarm) connectednessUnlocked(p peer.ID) network.Connectedness { + var haveLimited bool + for _, c := range s.conns.m[p] { + if c.IsClosed() { + // These will be garbage collected soon + continue + } + if c.Stat().Limited { + haveLimited = true + } else { + return network.Connected + } + } + if haveLimited { + return network.Limited } return network.NotConnected } @@ -751,24 +766,7 @@ func (s *Swarm) removeConn(c *Conn) { p := c.RemotePeer() s.conns.Lock() - cs := s.conns.m[p] - - if len(cs) == 1 { - delete(s.conns.m, p) - s.conns.Unlock() - - // Emit event after releasing `s.conns` lock so that a consumer can still - // use swarm methods that need the `s.conns` lock. - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.NotConnected, - }) - return - } - - defer s.conns.Unlock() - for i, ci := range cs { if ci == c { // NOTE: We're intentionally preserving order. @@ -780,6 +778,10 @@ func (s *Swarm) removeConn(c *Conn) { break } } + if len(s.conns.m[p]) == 0 { + delete(s.conns.m, p) + } + s.conns.Unlock() } // String returns a string representation of Network. diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 8c3ce7c5ac..38e942cce8 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -73,6 +73,11 @@ func (c *Conn) doClose() { c.err = c.conn.Close() + // Send the connectedness event after closing the connection. + // This ensures that both remote connection close and local connection + // close events are sent after the underlying transport connection is closed. + c.swarm.connectednessEventEmitter.RemoveConn(c.RemotePeer()) + // This is just for cleaning up state. The connection has already been closed. // We *could* optimize this but it really isn't worth it. for s := range streams { @@ -85,10 +90,11 @@ func (c *Conn) doClose() { c.notifyLk.Lock() defer c.notifyLk.Unlock() + // Only notify for disconnection if we notified for connection c.swarm.notifyAll(func(f network.Notifiee) { f.Disconnected(c.swarm, c) }) - c.swarm.refs.Done() // taken in Swarm.addConn + c.swarm.refs.Done() }() } @@ -108,7 +114,6 @@ func (c *Conn) start() { go func() { defer c.swarm.refs.Done() defer c.Close() - for { ts, err := c.conn.AcceptStream() if err != nil { @@ -193,9 +198,9 @@ func (c *Conn) Stat() network.ConnStats { // NewStream returns a new Stream from this connection func (c *Conn) NewStream(ctx context.Context) (network.Stream, error) { - if c.Stat().Transient { - if useTransient, _ := network.GetUseTransient(ctx); !useTransient { - return nil, network.ErrTransientConn + if c.Stat().Limited { + if useLimited, _ := network.GetAllowLimitedConn(ctx); !useLimited { + return nil, network.ErrLimitedConn } } diff --git a/p2p/net/swarm/swarm_event_test.go b/p2p/net/swarm/swarm_event_test.go index 86d698d611..5010215fc2 100644 --- a/p2p/net/swarm/swarm_event_test.go +++ b/p2p/net/swarm/swarm_event_test.go @@ -2,6 +2,7 @@ package swarm_test import ( "context" + "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -66,6 +68,10 @@ func TestConnectednessEventsSingleConn(t *testing.T) { } func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + dialerEventBus := eventbus.NewBus() dialer := swarmt.GenSwarm(t, swarmt.OptDialOnly, swarmt.EventBus(dialerEventBus)) defer dialer.Close() @@ -85,10 +91,6 @@ func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { sub, err := dialerEventBus.Subscribe(new(event.EvtPeerConnectednessChanged)) require.NoError(t, err) - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - // A slow consumer go func() { for { @@ -113,3 +115,196 @@ func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { // The test should finish without deadlocking } + +func TestConnectednessEvents(t *testing.T) { + s1, sub1 := newSwarmWithSubscription(t) + const N = 100 + peers := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers[i] = swarmt.GenSwarm(t) + } + + // First check all connected events + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < N; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.Connected { + t.Errorf("invalid event received: expected: Connected, got: %s", evt) + return + } + } + }() + for i := 0; i < N; i++ { + s1.Peerstore().AddAddrs(peers[i].LocalPeer(), []ma.Multiaddr{peers[i].ListenAddresses()[0]}, time.Hour) + _, err := s1.DialPeer(context.Background(), peers[i].LocalPeer()) + require.NoError(t, err) + } + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all connectedness events to be completed") + } + + // Disconnect some peers + done = make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < N/2; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.NotConnected { + t.Errorf("invalid event received: expected: NotConnected, got: %s", evt) + return + } + } + }() + for i := 0; i < N/2; i++ { + err := s1.ClosePeer(peers[i].LocalPeer()) + require.NoError(t, err) + } + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all disconnected events to be completed") + } + + // Check for disconnected events on swarm close + done = make(chan struct{}) + go func() { + defer close(done) + for i := N / 2; i < N; i++ { + e := <-sub1.Out() + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.NotConnected { + t.Errorf("invalid event received: expected: NotConnected, got: %s", evt) + return + } + } + }() + s1.Close() + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("expected all disconnected events after swarm close to be completed") + } +} + +func TestConnectednessEventDeadlock(t *testing.T) { + s1, sub1 := newSwarmWithSubscription(t) + const N = 100 + peers := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers[i] = swarmt.GenSwarm(t) + } + + // First check all connected events + done := make(chan struct{}) + go func() { + defer close(done) + count := 0 + for count < N { + e := <-sub1.Out() + // sleep to simulate a slow consumer + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.Connected { + continue + } + count++ + s1.ClosePeer(evt.Peer) + } + }() + for i := 0; i < N; i++ { + s1.Peerstore().AddAddrs(peers[i].LocalPeer(), []ma.Multiaddr{peers[i].ListenAddresses()[0]}, time.Hour) + go func(i int) { + _, err := s1.DialPeer(context.Background(), peers[i].LocalPeer()) + assert.NoError(t, err) + }(i) + } + select { + case <-done: + case <-time.After(100 * time.Second): + t.Fatal("expected all connectedness events to be completed") + } +} + +func TestConnectednessEventDeadlockWithDial(t *testing.T) { + s1, sub1 := newSwarmWithSubscription(t) + const N = 200 + peers := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers[i] = swarmt.GenSwarm(t) + } + peers2 := make([]*Swarm, N) + for i := 0; i < N; i++ { + peers2[i] = swarmt.GenSwarm(t) + } + + // First check all connected events + done := make(chan struct{}) + var subWG sync.WaitGroup + subWG.Add(1) + go func() { + defer subWG.Done() + count := 0 + for { + var e interface{} + select { + case e = <-sub1.Out(): + case <-done: + return + } + // sleep to simulate a slow consumer + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Error("invalid event received", e) + return + } + if evt.Connectedness != network.Connected { + continue + } + if count < N { + time.Sleep(10 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + s1.Peerstore().AddAddrs(peers2[count].LocalPeer(), []ma.Multiaddr{peers2[count].ListenAddresses()[0]}, time.Hour) + s1.DialPeer(ctx, peers2[count].LocalPeer()) + count++ + cancel() + } + } + }() + var wg sync.WaitGroup + wg.Add(N) + for i := 0; i < N; i++ { + s1.Peerstore().AddAddrs(peers[i].LocalPeer(), []ma.Multiaddr{peers[i].ListenAddresses()[0]}, time.Hour) + go func(i int) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + s1.DialPeer(ctx, peers[i].LocalPeer()) + cancel() + wg.Done() + }(i) + } + wg.Wait() + s1.Close() + + close(done) + subWG.Wait() +} diff --git a/p2p/protocol/circuitv2/client/dial.go b/p2p/protocol/circuitv2/client/dial.go index ecf5d3a51a..271652506a 100644 --- a/p2p/protocol/circuitv2/client/dial.go +++ b/p2p/protocol/circuitv2/client/dial.go @@ -179,7 +179,7 @@ func (c *Client) connect(s network.Stream, dest peer.AddrInfo) (*Conn, error) { // relay connection and we mark the connection as transient. var stat network.ConnStats if limit := msg.GetLimit(); limit != nil { - stat.Transient = true + stat.Limited = true stat.Extra = make(map[interface{}]interface{}) stat.Extra[StatLimitDuration] = time.Duration(limit.GetDuration()) * time.Second stat.Extra[StatLimitData] = limit.GetData() diff --git a/p2p/protocol/circuitv2/client/handlers.go b/p2p/protocol/circuitv2/client/handlers.go index 6b5361b123..9c36de0e89 100644 --- a/p2p/protocol/circuitv2/client/handlers.go +++ b/p2p/protocol/circuitv2/client/handlers.go @@ -67,7 +67,7 @@ func (c *Client) handleStreamV2(s network.Stream) { // relay connection and we mark the connection as transient. var stat network.ConnStats if limit := msg.GetLimit(); limit != nil { - stat.Transient = true + stat.Limited = true stat.Extra = make(map[interface{}]interface{}) stat.Extra[StatLimitDuration] = time.Duration(limit.GetDuration()) * time.Second stat.Extra[StatLimitData] = limit.GetData() diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index a229fe4aac..e5d32b0c96 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/metrics" "github.com/libp2p/go-libp2p/core/network" @@ -23,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/transport/tcp" + "github.com/stretchr/testify/require" ma "github.com/multiformats/go-multiaddr" ) @@ -49,7 +51,8 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u } bwr := metrics.NewBandwidthCounter() - netw, err := swarm.NewSwarm(p, ps, eventbus.NewBus(), swarm.WithMetrics(bwr)) + bus := eventbus.NewBus() + netw, err := swarm.NewSwarm(p, ps, bus, swarm.WithMetrics(bwr)) if err != nil { t.Fatal(err) } @@ -70,7 +73,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u t.Fatal(err) } - h := bhost.NewBlankHost(netw) + h := bhost.NewBlankHost(netw, bhost.WithEventBus(bus)) hosts = append(hosts, h) } @@ -145,20 +148,41 @@ func TestBasicRelay(t *testing.T) { t.Fatal(err) } + sub, err := hosts[2].EventBus().Subscribe(new(event.EvtPeerConnectednessChanged)) + require.NoError(t, err) + err = hosts[2].Connect(ctx, peer.AddrInfo{ID: hosts[0].ID(), Addrs: []ma.Multiaddr{raddr}}) if err != nil { t.Fatal(err) } + for { + var e interface{} + select { + case e = <-sub.Out(): + case <-time.After(2 * time.Second): + t.Fatal("expected limited connectivity event") + } + evt, ok := e.(event.EvtPeerConnectednessChanged) + if !ok { + t.Fatalf("invalid event: %s", e) + } + if evt.Peer == hosts[0].ID() { + if evt.Connectedness != network.Limited { + t.Fatalf("expected limited connectivity %s", evt.Connectedness) + } + break + } + } conns := hosts[2].Network().ConnsToPeer(hosts[0].ID()) if len(conns) != 1 { t.Fatalf("expected 1 connection, but got %d", len(conns)) } - if !conns[0].Stat().Transient { + if !conns[0].Stat().Limited { t.Fatal("expected transient connection") } - s, err := hosts[2].NewStream(network.WithUseTransient(ctx, "test"), hosts[0].ID(), "test") + s, err := hosts[2].NewStream(network.WithAllowLimitedConn(ctx, "test"), hosts[0].ID(), "test") if err != nil { t.Fatal(err) } @@ -229,11 +253,11 @@ func TestRelayLimitTime(t *testing.T) { if len(conns) != 1 { t.Fatalf("expected 1 connection, but got %d", len(conns)) } - if !conns[0].Stat().Transient { + if !conns[0].Stat().Limited { t.Fatal("expected transient connection") } - s, err := hosts[2].NewStream(network.WithUseTransient(ctx, "test"), hosts[0].ID(), "test") + s, err := hosts[2].NewStream(network.WithAllowLimitedConn(ctx, "test"), hosts[0].ID(), "test") if err != nil { t.Fatal(err) } @@ -315,11 +339,11 @@ func TestRelayLimitData(t *testing.T) { if len(conns) != 1 { t.Fatalf("expected 1 connection, but got %d", len(conns)) } - if !conns[0].Stat().Transient { + if !conns[0].Stat().Limited { t.Fatal("expected transient connection") } - s, err := hosts[2].NewStream(network.WithUseTransient(ctx, "test"), hosts[0].ID(), "test") + s, err := hosts[2].NewStream(network.WithAllowLimitedConn(ctx, "test"), hosts[0].ID(), "test") if err != nil { t.Fatal(err) } diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 1f3d9263df..23593c7970 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -339,7 +339,7 @@ func TestFailuresOnResponder(t *testing.T) { defer h2.Close() defer relay.Close() - s, err := h2.NewStream(network.WithUseTransient(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol) + s, err := h2.NewStream(network.WithAllowLimitedConn(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol) require.NoError(t, err) go tc.initiator(s) diff --git a/p2p/protocol/holepunch/holepuncher.go b/p2p/protocol/holepunch/holepuncher.go index b651bd7822..479376ef09 100644 --- a/p2p/protocol/holepunch/holepuncher.go +++ b/p2p/protocol/holepunch/holepuncher.go @@ -174,7 +174,7 @@ func (hp *holePuncher) directConnect(rp peer.ID) error { // initiateHolePunch opens a new hole punching coordination stream, // exchanges the addresses and measures the RTT. func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, []ma.Multiaddr, time.Duration, error) { - hpCtx := network.WithUseTransient(hp.ctx, "hole-punch") + hpCtx := network.WithAllowLimitedConn(hp.ctx, "hole-punch") sCtx := network.WithNoDial(hpCtx, "hole-punch") str, err := hp.host.NewStream(sCtx, rp, Protocol) diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 4bb4a59243..7ae4feb935 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -408,7 +408,7 @@ func (ids *idService) IdentifyWait(c network.Conn) <-chan struct{} { func (ids *idService) identifyConn(c network.Conn) error { ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() - s, err := c.NewStream(network.WithUseTransient(ctx, "identify")) + s, err := c.NewStream(network.WithAllowLimitedConn(ctx, "identify")) if err != nil { log.Debugw("error opening identify stream", "peer", c.RemotePeer(), "error", err) return err @@ -752,7 +752,8 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo // Taking the lock ensures that we don't concurrently process a disconnect. ids.addrMu.Lock() ttl := peerstore.RecentlyConnectedAddrTTL - if ids.Host.Network().Connectedness(p) == network.Connected { + switch ids.Host.Network().Connectedness(p) { + case network.Limited, network.Connected: ttl = peerstore.ConnectedAddrTTL } @@ -980,13 +981,15 @@ func (nn *netNotifiee) Disconnected(_ network.Network, c network.Conn) { delete(ids.conns, c) ids.connsMu.Unlock() - if ids.Host.Network().Connectedness(c.RemotePeer()) != network.Connected { - // Last disconnect. - // Undo the setting of addresses to peer.ConnectedAddrTTL we did - ids.addrMu.Lock() - defer ids.addrMu.Unlock() - ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.RecentlyConnectedAddrTTL) + switch ids.Host.Network().Connectedness(c.RemotePeer()) { + case network.Connected, network.Limited: + return } + // Last disconnect. + // Undo the setting of addresses to peer.ConnectedAddrTTL we did + ids.addrMu.Lock() + defer ids.addrMu.Unlock() + ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.RecentlyConnectedAddrTTL) } func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} diff --git a/p2p/protocol/ping/ping.go b/p2p/protocol/ping/ping.go index 1c78229084..9a67715593 100644 --- a/p2p/protocol/ping/ping.go +++ b/p2p/protocol/ping/ping.go @@ -111,7 +111,7 @@ func pingError(err error) chan Result { // Ping pings the remote peer until the context is canceled, returning a stream // of RTTs or errors. func Ping(ctx context.Context, h host.Host, p peer.ID) <-chan Result { - s, err := h.NewStream(network.WithUseTransient(ctx, "ping"), p, ID) + s, err := h.NewStream(network.WithAllowLimitedConn(ctx, "ping"), p, ID) if err != nil { return pingError(err) } diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index e6cd7ea9d9..9cd442dbf0 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -77,7 +77,7 @@ func TestNoStreamOverTransientConnection(t *testing.T) { require.Error(t, err) - _, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol") + _, err = h1.NewStream(network.WithAllowLimitedConn(context.Background(), "test"), h2.ID(), "/testprotocol") require.NoError(t, err) } diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 8027cebe53..10298f5139 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -110,16 +110,16 @@ func TestNewStreamTransientConnection(t *testing.T) { h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL) - // WithUseTransient should succeed + // WithAllowLimitedConn should succeed ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - ctx = network.WithUseTransient(ctx, "test") + ctx = network.WithAllowLimitedConn(ctx, "test") s, err := h1.Network().NewStream(ctx, h2.ID()) require.NoError(t, err) require.NotNil(t, s) defer s.Close() - // Without WithUseTransient should fail with context deadline exceeded + // Without WithAllowLimitedConn should fail with context deadline exceeded ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() s, err = h1.Network().NewStream(ctx, h2.ID())