diff --git a/config/config.go b/config/config.go index bbb3914e09..1111ea33b9 100644 --- a/config/config.go +++ b/config/config.go @@ -22,7 +22,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/relay" routed "github.com/libp2p/go-libp2p/p2p/host/routed" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" - holepunch "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" autonat "github.com/libp2p/go-libp2p-autonat" blankhost "github.com/libp2p/go-libp2p-blankhost" @@ -201,7 +201,6 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { EnableHolePunching: cfg.EnableHolePunching, HolePunchingOptions: cfg.HolePunchingOptions, }) - if err != nil { swrm.Close() return nil, err @@ -251,6 +250,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { // Note: h.AddrsFactory may be changed by AutoRelay, but non-relay version is // used by AutoNAT below. + var autorelay *relay.AutoRelay addrF := h.AddrsFactory if cfg.EnableAutoRelay { if !cfg.Relay { @@ -259,13 +259,12 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { } if len(cfg.StaticRelays) > 0 { - _ = relay.NewAutoRelay(ctx, h, nil, router, cfg.StaticRelays) + autorelay = relay.NewAutoRelay(h, nil, router, cfg.StaticRelays) } else { if router == nil { h.Close() return nil, fmt.Errorf("cannot enable autorelay; no routing for discovery") } - crouter, ok := router.(routing.ContentRouting) if !ok { h.Close() @@ -273,8 +272,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { } discovery := discovery.NewRoutingDiscovery(crouter) - - _ = relay.NewAutoRelay(ctx, h, discovery, router, cfg.StaticRelays) + autorelay = relay.NewAutoRelay(h, discovery, router, cfg.StaticRelays) } } @@ -341,10 +339,25 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) { // start the host background tasks h.Start() + var ho host.Host + ho = h if router != nil { - return routed.Wrap(h, router), nil + ho = routed.Wrap(h, router) + } + if autorelay != nil { + return &autoRelayHost{Host: ho, autoRelay: autorelay}, nil } - return h, nil + return ho, nil +} + +type autoRelayHost struct { + host.Host + autoRelay *relay.AutoRelay +} + +func (h *autoRelayHost) Close() error { + _ = h.autoRelay.Close() + return h.Host.Close() } // Option is a libp2p config option that can be given to the libp2p constructor diff --git a/p2p/host/relay/autorelay.go b/p2p/host/relay/autorelay.go index 11c1a3c9db..9bd9ffbb5c 100644 --- a/p2p/host/relay/autorelay.go +++ b/p2p/host/relay/autorelay.go @@ -49,6 +49,9 @@ type AutoRelay struct { static []peer.AddrInfo + refCount sync.WaitGroup + ctxCancel context.CancelFunc + disconnect chan struct{} mx sync.Mutex @@ -59,8 +62,10 @@ type AutoRelay struct { cachedAddrsExpiry time.Time } -func NewAutoRelay(ctx context.Context, bhost *basic.BasicHost, discover discovery.Discoverer, router routing.PeerRouting, static []peer.AddrInfo) *AutoRelay { +func NewAutoRelay(bhost *basic.BasicHost, discover discovery.Discoverer, router routing.PeerRouting, static []peer.AddrInfo) *AutoRelay { + ctx, cancel := context.WithCancel(context.Background()) ar := &AutoRelay{ + ctxCancel: cancel, host: bhost, discover: discover, router: router, @@ -72,6 +77,7 @@ func NewAutoRelay(ctx context.Context, bhost *basic.BasicHost, discover discover } bhost.AddrsFactory = ar.hostAddrs bhost.Network().Notify(ar) + ar.refCount.Add(1) go ar.background(ctx) return ar } @@ -81,6 +87,8 @@ func (ar *AutoRelay) hostAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { } func (ar *AutoRelay) background(ctx context.Context) { + defer ar.refCount.Done() + subReachability, _ := ar.host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged)) defer subReachability.Close() @@ -318,6 +326,12 @@ func (ar *AutoRelay) relayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { return raddrs } +func (ar *AutoRelay) Close() error { + ar.ctxCancel() + ar.refCount.Wait() + return nil +} + func shuffleRelays(pis []peer.AddrInfo) { for i := range pis { j := rand.Intn(i + 1)