From 9e4407103eaf2716f49182f826b459b25ea87fb3 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 7 May 2023 17:52:29 +0530 Subject: [PATCH] implement smart dialing --- core/network/network.go | 10 + p2p/net/swarm/clock.go | 49 +++ p2p/net/swarm/dial_ranker.go | 131 ++++++++ p2p/net/swarm/dial_ranker_test.go | 270 +++++++++++++++++ p2p/net/swarm/dial_worker.go | 233 +++++++++++---- p2p/net/swarm/dial_worker_test.go | 477 +++++++++++++++++++++++++++++- p2p/net/swarm/swarm.go | 20 ++ p2p/net/swarm/swarm_dial.go | 9 +- 8 files changed, 1117 insertions(+), 82 deletions(-) create mode 100644 p2p/net/swarm/clock.go create mode 100644 p2p/net/swarm/dial_ranker.go create mode 100644 p2p/net/swarm/dial_ranker_test.go diff --git a/core/network/network.go b/core/network/network.go index 0beaac0f71..215b5373b3 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -184,3 +184,13 @@ type Dialer interface { Notify(Notifiee) StopNotify(Notifiee) } + +// AddrDelay provides an address along with the delay after which the address +// should be dialed +type AddrDelay struct { + Addr ma.Multiaddr + Delay time.Duration +} + +// DialRanker provides a schedule of dialing the provided addresses +type DialRanker func([]ma.Multiaddr) []AddrDelay diff --git a/p2p/net/swarm/clock.go b/p2p/net/swarm/clock.go new file mode 100644 index 0000000000..6b63ac9c87 --- /dev/null +++ b/p2p/net/swarm/clock.go @@ -0,0 +1,49 @@ +package swarm + +import "time" + +// InstantTimer is a timer that triggers at some instant rather than some duration +type InstantTimer interface { + Reset(d time.Time) bool + Stop() bool + Ch() <-chan time.Time +} + +// Clock is a clock that can create timers that trigger at some +// instant rather than some duration +type Clock interface { + Now() time.Time + Since(t time.Time) time.Duration + InstantTimer(when time.Time) InstantTimer +} + +type RealTimer struct{ t *time.Timer } + +var _ InstantTimer = (*RealTimer)(nil) + +func (t RealTimer) Ch() <-chan time.Time { + return t.t.C +} + +func (t RealTimer) Reset(d time.Time) bool { + return t.t.Reset(time.Until(d)) +} + +func (t RealTimer) Stop() bool { + return t.t.Stop() +} + +type RealClock struct{} + +var _ Clock = RealClock{} + +func (RealClock) Now() time.Time { + return time.Now() +} +func (RealClock) Since(t time.Time) time.Duration { + return time.Since(t) +} +func (RealClock) InstantTimer(when time.Time) InstantTimer { + t := time.NewTimer(time.Until(when)) + return &RealTimer{t} +} diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go new file mode 100644 index 0000000000..321781fdf8 --- /dev/null +++ b/p2p/net/swarm/dial_ranker.go @@ -0,0 +1,131 @@ +package swarm + +import ( + "time" + + "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const ( + publicTCPDelay = 300 * time.Millisecond + privateTCPDelay = 30 * time.Millisecond + relayDelay = 500 * time.Millisecond +) + +func noDelayRanker(addrs []ma.Multiaddr) []network.AddrDelay { + res := make([]network.AddrDelay, len(addrs)) + for i, a := range addrs { + res[i] = network.AddrDelay{Addr: a, Delay: 0} + } + return res +} + +// defaultDialRanker is the default ranking logic. +// +// we consider private, public ip4, public ip6, relay addresses separately. +// +// In each group, if a quic address is present, we delay tcp addresses. +// +// private: 30 ms delay. +// public ip4: 300 ms delay. +// public ip6: 300 ms delay. +// +// If a quic-v1 address is present we don't dial quic or webtransport address on the same (ip,port) combination. +// If a tcp address is present we don't dial ws or wss address on the same (ip, port) combination. +// If direct addresses are present we delay all relay addresses by 500 millisecond +func defaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { + ip4 := make([]ma.Multiaddr, 0, len(addrs)) + ip6 := make([]ma.Multiaddr, 0, len(addrs)) + pvt := make([]ma.Multiaddr, 0, len(addrs)) + relay := make([]ma.Multiaddr, 0, len(addrs)) + + res := make([]network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + switch { + case !manet.IsPublicAddr(a): + pvt = append(pvt, a) + case isRelayAddr(a): + relay = append(relay, a) + case isProtocolAddr(a, ma.P_IP4): + ip4 = append(ip4, a) + case isProtocolAddr(a, ma.P_IP6): + ip6 = append(ip6, a) + default: + res = append(res, network.AddrDelay{Addr: a, Delay: 0}) + } + } + var roffset time.Duration = 0 + if len(ip4) > 0 || len(ip6) > 0 { + roffset = relayDelay + } + + res = append(res, getAddrDelay(pvt, privateTCPDelay, 0)...) + res = append(res, getAddrDelay(ip4, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(ip6, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(relay, publicTCPDelay, roffset)...) + return res +} + +func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, offset time.Duration) []network.AddrDelay { + var hasQuic, hasQuicV1 bool + quicV1Addr := make(map[string]struct{}) + tcpAddr := make(map[string]struct{}) + for _, a := range addrs { + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + case isProtocolAddr(a, ma.P_QUIC): + hasQuic = true + case isProtocolAddr(a, ma.P_QUIC_V1): + hasQuicV1 = true + quicV1Addr[addrPort(a, ma.P_UDP)] = struct{}{} + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + case isProtocolAddr(a, ma.P_TCP): + tcpAddr[addrPort(a, ma.P_TCP)] = struct{}{} + } + } + + res := make([]network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + delay := offset + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_QUIC): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + if _, ok := tcpAddr[addrPort(a, ma.P_TCP)]; ok { + continue + } + if hasQuic || hasQuicV1 { + delay += tcpDelay + } + case isProtocolAddr(a, ma.P_TCP): + if hasQuic || hasQuicV1 { + delay += tcpDelay + } + } + res = append(res, network.AddrDelay{Addr: a, Delay: delay}) + } + return res +} + +func addrPort(a ma.Multiaddr, p int) string { + c, _ := ma.SplitFirst(a) + port, _ := a.ValueForProtocol(p) + return c.Value() + ":" + port +} + +func isProtocolAddr(a ma.Multiaddr, p int) bool { + _, err := a.ValueForProtocol(p) + return err == nil +} diff --git a/p2p/net/swarm/dial_ranker_test.go b/p2p/net/swarm/dial_ranker_test.go new file mode 100644 index 0000000000..8edf8c2216 --- /dev/null +++ b/p2p/net/swarm/dial_ranker_test.go @@ -0,0 +1,270 @@ +package swarm + +import ( + "fmt" + "sort" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + ma "github.com/multiformats/go-multiaddr" +) + +func TestNoDelayRanker(t *testing.T) { + addrs := []ma.Multiaddr{ + ma.StringCast("/ip4/1.2.3.4/tcp/1"), + ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"), + } + addrDelays := noDelayRanker(addrs) + if len(addrs) != len(addrDelays) { + t.Errorf("addrDelay should have the same number of elements as addr") + } + + for _, a := range addrs { + for _, ad := range addrDelays { + if a.Equal(ad.Addr) { + if ad.Delay != 0 { + t.Errorf("expected 0 delay, got %s", ad.Delay) + } + } + } + } +} + +func TestDelayRankerTCPDelay(t *testing.T) { + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + ptcp := ma.StringCast("/ip4/192.168.0.100/tcp/1/") + + quic := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicv1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + + tcp6 := ma.StringCast("/ip6/1::1/tcp/1") + quicv16 := ma.StringCast("/ip6/1::2/udp/1/quic-v1") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "quic prioritised over tcp", + addrs: []ma.Multiaddr{quic, tcp}, + output: []network.AddrDelay{ + {Addr: quic, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "quic-v1 prioritised over tcp", + addrs: []ma.Multiaddr{quicv1, tcp}, + output: []network.AddrDelay{ + {Addr: quicv1, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "ip6 treated separately", + addrs: []ma.Multiaddr{quicv16, tcp6, quic}, + output: []network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + {Addr: quic, Delay: 0}, + {Addr: tcp6, Delay: publicTCPDelay}, + }, + }, + { + name: "private addrs treated separately", + addrs: []ma.Multiaddr{pquicv1, ptcp}, + output: []network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + {Addr: ptcp, Delay: privateTCPDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} + +func TestDelayRankerAddrDropped(t *testing.T) { + pquic := ma.StringCast("/ip4/192.168.0.100/udp/1/quic") + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + quicv1Addr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + wt := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/") + wt2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1/webtransport/") + + quic6 := ma.StringCast("/ip6/1::1/udp/1/quic") + quicv16 := ma.StringCast("/ip6/1::1/udp/1/quic-v1") + + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + ws := ma.StringCast("/ip4/1.2.3.5/tcp/1/ws") + ws2 := ma.StringCast("/ip4/1.2.3.4/tcp/1/ws") + wss := ma.StringCast("/ip4/1.2.3.5/tcp/1/wss") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "quic dropped when quic-v1 present", + addrs: []ma.Multiaddr{quicAddr, quicv1Addr, quicAddr2}, + output: []network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + }, + }, + { + name: "webtransport dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv1Addr, wt, wt2, quicAddr}, + output: []network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: wt2, Delay: 0}, + }, + }, + { + name: "ip6 quic dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv16, quic6}, + output: []network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + }, + }, + { + name: "web socket removed when tcp present", + addrs: []ma.Multiaddr{quicAddr, tcp, ws, wss, ws2}, + output: []network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + {Addr: ws2, Delay: publicTCPDelay}, + }, + }, + { + name: "private quic dropped when quiv1 present", + addrs: []ma.Multiaddr{pquic, pquicv1}, + output: []network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} + +func TestDelayRankerRelay(t *testing.T) { + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + + pid := test.RandPeerIDFatal(t) + r1 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p-circuit/p2p/%s", pid)) + r2 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/udp/1/quic/p2p-circuit/p2p/%s", pid)) + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "relay address delayed", + addrs: []ma.Multiaddr{quicAddr, quicAddr2, r1, r2}, + output: []network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + {Addr: r2, Delay: relayDelay}, + {Addr: r1, Delay: publicTCPDelay + relayDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Errorf("expected %+v got %+v", tc.output[i], res[i]) + } + } + }) + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index f805371cc6..cdc0013427 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -2,13 +2,14 @@ package swarm import ( "context" + "math" "sync" + "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // ///////////////////////////////////////////////////////////////////////////////// @@ -38,6 +39,7 @@ type addrDial struct { conn *Conn err error requests []int + dialed bool } type dialWorker struct { @@ -51,16 +53,15 @@ type dialWorker struct { connected bool // true when a connection has been successfully established - nextDial []ma.Multiaddr - - // ready when we have more addresses to dial (nextDial is not empty) - triggerDial <-chan struct{} - // for testing wg sync.WaitGroup + cl Clock } -func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { +func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dialWorker { + if cl == nil { + cl = RealClock{} + } return &dialWorker{ s: s, peer: p, @@ -68,6 +69,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { requests: make(map[int]*pendRequest), pending: make(map[ma.Multiaddr]*addrDial), resch: make(chan dialResult), + cl: cl, } } @@ -79,7 +81,25 @@ func (w *dialWorker) loop() { // used to signal readiness to dial and completion of the dial ready := make(chan struct{}) close(ready) - + dq := newDialQueue() + currDials := 0 + st := w.cl.Now() + timer := w.cl.InstantTimer(st.Add(math.MaxInt64)) + timerRunning := true + scheduleNext := func() { + if timerRunning && !timer.Stop() { + <-timer.Ch() + } + timerRunning = false + if dq.len() > 0 { + if currDials == 0 { + timer.Reset(st) + } else { + timer.Reset(st.Add(dq.top().Delay)) + } + timerRunning = true + } + } loop: for { select { @@ -102,7 +122,11 @@ loop: // at this point, len(addrs) > 0 or else it would be error from addrsForDial // ranke them to process in order - addrs = w.rankAddrs(addrs) + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + addrRanking := w.rankAddrs(addrs, simConnect) + addrDelay := make(map[ma.Multiaddr]time.Duration) // create the pending request object pr := &pendRequest{ @@ -110,8 +134,9 @@ loop: err: &DialError{Peer: w.peer}, addrs: make(map[ma.Multiaddr]struct{}), } - for _, a := range addrs { - pr.addrs[a] = struct{}{} + for _, adelay := range addrRanking { + pr.addrs[adelay.Addr] = struct{}{} + addrDelay[adelay.Addr] = adelay.Delay } // check if any of the addrs has been successfully dialed and accumulate @@ -154,9 +179,12 @@ loop: w.requests[w.reqno] = pr for _, ad := range tojoin { - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + if !ad.dialed { + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + dq.add(network.AddrDelay{Addr: ad.addr, Delay: addrDelay[ad.addr]}) + } } } ad.requests = append(ad.requests, w.reqno) @@ -165,33 +193,31 @@ loop: if len(todial) > 0 { for _, a := range todial { w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + dq.add(network.AddrDelay{Addr: a, Delay: addrDelay[a]}) } - - w.nextDial = append(w.nextDial, todial...) - w.nextDial = w.rankAddrs(w.nextDial) - - // trigger a new dial now to account for the new addrs we added - w.triggerDial = ready } + scheduleNext() - case <-w.triggerDial: - for _, addr := range w.nextDial { + case <-timer.Ch(): + for _, adelay := range dq.nextBatch() { // spawn the dial - ad := w.pending[addr] - err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + ad := w.pending[adelay.Addr] + ad.dialed = true + err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch) if err != nil { w.dispatchError(ad, err) + } else { + currDials++ } } - - w.nextDial = nil - w.triggerDial = nil + timerRunning = false + scheduleNext() case res := <-w.resch: if res.Conn != nil { w.connected = true } - + currDials-- ad := w.pending[res.Addr] if res.Conn != nil { @@ -228,8 +254,8 @@ loop: // for consistency with the old dialer behavior. w.s.backf.AddBackoff(w.peer, res.Addr) } - w.dispatchError(ad, res.Err) + scheduleNext() } } } @@ -275,39 +301,130 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { } } -// ranks addresses in descending order of preference for dialing, with the following rules: -// NonRelay > Relay -// NonWS > WS -// Private > Public -// UDP > TCP -func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - addrTier := func(a ma.Multiaddr) (tier int) { - if isRelayAddr(a) { - tier |= 0b1000 - } - if isExpensiveAddr(a) { - tier |= 0b0100 - } - if !manet.IsPrivateAddr(a) { - tier |= 0b0010 - } - if isFdConsumingAddr(a) { - tier |= 0b0001 - } - - return tier +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []network.AddrDelay { + if isSimConnect { + return noDelayRanker(addrs) } + return w.s.dialRanker(addrs) +} + +type dialQueue struct { + q []network.AddrDelay + pos map[ma.Multiaddr]int +} + +func newDialQueue() *dialQueue { + return &dialQueue{pos: make(map[ma.Multiaddr]int)} +} + +func (dq *dialQueue) add(adelay network.AddrDelay) { + dq.remove(adelay.Addr) + dq.q = append(dq.q, adelay) + dq.pos[adelay.Addr] = len(dq.q) - 1 + dq.heapify(len(dq.q) - 1) +} + +func (dq *dialQueue) swap(i, j int) { + dq.pos[dq.q[i].Addr] = j + dq.pos[dq.q[j].Addr] = i + dq.q[i], dq.q[j] = dq.q[j], dq.q[i] +} - tiers := make([][]ma.Multiaddr, 16) - for _, a := range addrs { - tier := addrTier(a) - tiers[tier] = append(tiers[tier], a) +func (dq *dialQueue) len() int { + return len(dq.q) +} + +func (dq *dialQueue) top() network.AddrDelay { + return dq.q[0] +} + +func (dq *dialQueue) pop() network.AddrDelay { + v := dq.q[0] + dq.remove(v.Addr) + return v +} + +func (dq *dialQueue) remove(a ma.Multiaddr) { + pos, ok := dq.pos[a] + if !ok { + return } + dq.swap(pos, len(dq.q)-1) + dq.q = dq.q[:len(dq.q)-1] + delete(dq.pos, a) + dq.heapify(pos) +} - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, tier := range tiers { - result = append(result, tier...) +func (dq *dialQueue) heapify(i int) { + if dq.len() == 0 { + return + } + for { + v := dq.q[i].Delay + l, r := 2*i+1, 2*i+2 + if l >= dq.len() && r >= dq.len() { + if i == 0 { + return + } + i = (i - 1) / 2 + continue + } + lv := dq.q[l].Delay + if v <= lv { + if r < dq.len() { + rv := dq.q[r].Delay + if v <= rv { + if i == 0 { + return + } + i = (i - 1) / 2 + continue + } else { + dq.swap(i, r) + i = r + continue + } + } else { + if i == 0 { + return + } + i = (i - 1) / 2 + continue + } + } else { + if r < dq.len() { + rv := dq.q[r].Delay + if lv <= rv { + dq.swap(i, l) + i = l + continue + } else { + dq.swap(i, r) + i = r + continue + } + } else { + dq.swap(i, l) + i = l + continue + } + } } +} - return result +func (dq *dialQueue) nextBatch() []network.AddrDelay { + if dq.len() == 0 { + return nil + } + res := make([]network.AddrDelay, 0) + top := dq.top() + for len(dq.q) > 0 { + v := dq.pop() + if v.Delay != top.Delay { + dq.add(v) + break + } + res = append(res, v) + } + return res } diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 2c441106b1..cb8fb1b833 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -5,15 +5,20 @@ import ( "crypto/rand" "errors" "fmt" + "math" + mrand "math/rand" + "sort" "sync" "testing" "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" + "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" @@ -24,9 +29,22 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/tcp" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) +type mockClock struct { + *test.MockClock +} + +func (m *mockClock) InstantTimer(when time.Time) InstantTimer { + return m.MockClock.InstantTimer(when) +} + +func newMockClock() *mockClock { + return &mockClock{test.NewMockClock()} +} + func newPeer(t *testing.T) (crypto.PrivKey, peer.ID) { priv, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) @@ -36,6 +54,19 @@ func newPeer(t *testing.T) (crypto.PrivKey, peer.ID) { } func makeSwarm(t *testing.T) *Swarm { + s := makeSwarmWithNoListenAddrs(t, WithDialTimeout(1*time.Second)) + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")); err != nil { + t.Fatal(err) + } + + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { + t.Fatal(err) + } + + return s +} + +func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { priv, id := newPeer(t) ps, err := pstoremem.NewPeerstore() @@ -44,11 +75,10 @@ func makeSwarm(t *testing.T) *Swarm { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - s, err := NewSwarm(id, ps, eventbus.NewBus(), WithDialTimeout(time.Second)) + s, err := NewSwarm(id, ps, eventbus.NewBus(), opts...) require.NoError(t, err) upgrader := makeUpgrader(t, s) - var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) @@ -56,10 +86,6 @@ func makeSwarm(t *testing.T) *Swarm { if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) } - if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")); err != nil { - t.Fatal(err) - } - reuse, err := quicreuse.NewConnManager([32]byte{}) if err != nil { t.Fatal(err) @@ -71,10 +97,6 @@ func makeSwarm(t *testing.T) *Swarm { if err := s.AddTransport(quicTransport); err != nil { t.Fatal(err) } - if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { - t.Fatal(err) - } - return s } @@ -88,6 +110,47 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { return u } +func acceptAndIgnoreTCP(t *testing.T, a ma.Multiaddr) manet.Listener { + t.Helper() + list, err := manet.Listen(a) + if err != nil { + t.Error(err) + } + go func() { + for { + _, err := list.Accept() + if err != nil { + break + } + } + }() + return list +} + +func makeTCPListener(t *testing.T, a ma.Multiaddr) (manet.Listener, chan struct{}) { + t.Helper() + list, err := manet.Listen(a) + if err != nil { + t.Error(err) + } + ch := make(chan struct{}) + go func() { + for { + c, err := list.Accept() + if err != nil { + break + } + <-ch + err = c.Close() + if err != nil { + t.Error(err) + } + + } + }() + return list, ch +} + func TestDialWorkerLoopBasic(t *testing.T) { s1 := makeSwarm(t) s2 := makeSwarm(t) @@ -99,7 +162,7 @@ func TestDialWorkerLoopBasic(t *testing.T) { reqch := make(chan dialRequest) resch := make(chan dialResponse) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() var conn *Conn @@ -144,7 +207,7 @@ func TestDialWorkerLoopConcurrent(t *testing.T) { s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() const dials = 100 @@ -187,7 +250,7 @@ func TestDialWorkerLoopFailure(t *testing.T) { reqch := make(chan dialRequest) resch := make(chan dialResponse) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() reqch <- dialRequest{ctx: context.Background(), resch: resch} @@ -211,7 +274,7 @@ func TestDialWorkerLoopConcurrentFailure(t *testing.T) { s1.Peerstore().AddAddrs(p2, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() const dials = 100 @@ -259,7 +322,7 @@ func TestDialWorkerLoopConcurrentMix(t *testing.T) { s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() const dials = 100 @@ -305,7 +368,7 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { s1.Peerstore().AddAddrs(p2, addrs, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() const dials = 100 @@ -342,3 +405,385 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { close(reqch) worker.wg.Wait() } + +// To support failafter. +// need to increment clock slowly +// need to track what's handled and what's new +// need to move slowly through the tests +// the bigger issue is dialing too slowly. +// this interface is nice, but the bigger problem is the dialqueue. We need to maintain a queue +// + +type schedTest struct { + addr ma.Multiaddr + delay time.Duration + success bool + failAfter time.Duration +} + +type dialState struct { + ch chan struct{} + at time.Time + addr ma.Multiaddr + delay time.Duration + success bool + failAfter time.Duration +} + +func TestDialWorkerLoopRanking3(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + ports := make(map[int]struct{}) + for i := 0; i < 10; i++ { + for { + p := 5000 + mrand.Intn(10000) + if _, ok := ports[p]; ok { + continue + } + ports[p] = struct{}{} + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", p))) + break + } + } + + makeRanker := func(tc []schedTest) network.DialRanker { + return func(addrs []ma.Multiaddr) []network.AddrDelay { + res := make([]network.AddrDelay, len(tc)) + for i := 0; i < len(tc); i++ { + res[i] = network.AddrDelay{Addr: tc[i].addr, Delay: tc[i].delay} + } + return res + } + } + + testcases := []struct { + name string + input []schedTest + maxTime time.Duration + }{ + { + name: "first success", + input: []schedTest{ + { + addr: addrs[1], + delay: 0, + success: true, + }, + { + addr: addrs[0], + delay: 100 * time.Millisecond, + success: false, + failAfter: 50 * time.Millisecond, + }, + }, + maxTime: 20 * time.Millisecond, + }, + { + name: "delayed dials", + input: []schedTest{ + { + addr: addrs[0], + delay: 0, + success: false, + failAfter: 5 * time.Millisecond, + }, + { + addr: addrs[1], + delay: 100 * time.Millisecond, + success: false, + failAfter: 105 * time.Millisecond, + }, + { + addr: addrs[2], + delay: 1 * time.Second, + success: false, + failAfter: 10 * time.Millisecond, + }, + { + addr: addrs[3], + delay: 2 * time.Second, + success: true, + }, + { + addr: addrs[4], + delay: 2*time.Second + 1*time.Millisecond, + success: false, + failAfter: 10 * time.Millisecond, + }, + }, + maxTime: 200 * time.Millisecond, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + s1 := makeSwarmWithNoListenAddrs(t) + defer s1.Close() + s2 := makeSwarmWithNoListenAddrs(t) + defer s2.Close() + + failDials := make(map[ma.Multiaddr]dialState) + allDials := make(map[ma.Multiaddr]dialState) + addrs := make([]ma.Multiaddr, 0) + for _, inp := range tc.input { + var failCh chan struct{} + if inp.success { + err := s2.AddListenAddr(inp.addr) + if err != nil { + t.Errorf("failed to listen on %s %s", inp.addr, err) + } + } else { + l, ch := makeTCPListener(t, inp.addr) + f := func() { + err := l.Close() + if err != nil { + t.Error(err) + } + } + failCh = ch + t.Cleanup(f) + } + addrs = append(addrs, inp.addr) + allDials[inp.addr] = dialState{ + ch: failCh, + addr: inp.addr, + delay: inp.delay, + success: inp.success, + failAfter: inp.failAfter, + } + } + s1.Peerstore().AddAddrs(s2.LocalPeer(), addrs, peerstore.PermanentAddrTTL) + s1.dialRanker = makeRanker(tc.input) + + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + cl := newMockClock() + st := cl.Now() + worker1 := newDialWorker(s1, s2.LocalPeer(), reqch, cl) + go worker1.loop() + defer worker1.wg.Wait() + reqch <- dialRequest{ctx: context.Background(), resch: resch} + loop: + for { + for a, p := range failDials { + if p.at.Before(cl.Now()) { + p.ch <- struct{}{} + delete(failDials, a) + } + } + trigger := len(failDials) == 0 + mi := time.Duration(math.MaxInt64) + for _, ds := range allDials { + if ds.delay < mi { + mi = ds.delay + } + } + for a, ds := range allDials { + if (trigger && mi == ds.delay) || cl.Now().After(st.Add(ds.delay)) { + delete(allDials, a) + if ds.success { + select { + case r := <-resch: + if r.conn == nil { + t.Error("expected a connection") + } + case <-time.After(1 * time.Second): + t.Error("expected a connection") + } + break loop + } else { + select { + case <-resch: + t.Error("didn't expect a connection") + case <-time.After(100 * time.Millisecond): + } + failDials[a] = dialState{ + ch: ds.ch, + at: cl.Now().Add(ds.failAfter), + addr: a, + delay: ds.delay, + } + } + } + } + cl.AdvanceBy(10 * time.Millisecond) + if len(failDials) == 0 && len(allDials) == 0 { + break + } + } + select { + case <-resch: + t.Error("didn't expect a connection") + case <-time.After(100 * time.Millisecond): + } + if cl.Now().Sub(st) > tc.maxTime { + t.Errorf("expected test to finish early: expected %d, took: %d", tc.maxTime, cl.Now().Sub(st)) + } + close(reqch) + }) + } + +} + +func TestDialQueuePriority(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + ports := make(map[int]struct{}) + for i := 0; i < 10; i++ { + for { + p := 1 + mrand.Intn(10000) + if _, ok := ports[p]; ok { + continue + } + ports[p] = struct{}{} + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", p))) + break + } + } + testcase := []struct { + name string + input []network.AddrDelay + output []ma.Multiaddr + }{ + { + name: "priority queue property", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 100}, + {Addr: addrs[1], Delay: 200}, + {Addr: addrs[2], Delay: 20}, + }, + output: []ma.Multiaddr{ + addrs[2], addrs[0], addrs[1], + }, + }, + { + name: "priority queue property 2", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 200}, + {Addr: addrs[1], Delay: 100}, + {Addr: addrs[2], Delay: 20}, + }, + output: []ma.Multiaddr{ + addrs[2], addrs[1], addrs[0], + }, + }, + { + name: "updates", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 200}, + {Addr: addrs[1], Delay: 100}, + {Addr: addrs[2], Delay: 20}, + {Addr: addrs[0], Delay: 0}, + {Addr: addrs[1], Delay: 100}, + }, + output: []ma.Multiaddr{ + addrs[0], addrs[2], addrs[1], + }, + }, + } + for _, tc := range testcase { + t.Run(tc.name, func(t *testing.T) { + q := newDialQueue() + for i := 0; i < len(tc.input); i++ { + q.add(tc.input[i]) + } + for i := 0; i < len(tc.output); i++ { + v := q.pop() + if !tc.output[i].Equal(v.Addr) { + t.Errorf("failed priority queue property: expected: %s got: %s", tc.output[i], v.Addr) + } + } + if q.len() != 0 { + t.Errorf("expected queue to be empty at end. got: %d", q.len()) + } + }) + } +} + +func TestDialQueueNextBatch(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i))) + } + testcase := []struct { + name string + input []network.AddrDelay + output [][]ma.Multiaddr + }{ + { + name: "next batch", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 100}, + {Addr: addrs[1], Delay: 100}, + {Addr: addrs[2], Delay: 20}, + {Addr: addrs[3], Delay: 20}, + }, + output: [][]ma.Multiaddr{ + {addrs[2], addrs[3]}, + {addrs[0], addrs[1]}, + }, + }, + { + name: "priority queue property 2", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 500}, + {Addr: addrs[1], Delay: 500}, + {Addr: addrs[2], Delay: 20}, + {Addr: addrs[3], Delay: 20}, + {Addr: addrs[4], Delay: 100}, + }, + + output: [][]ma.Multiaddr{ + {addrs[2], addrs[3]}, + {addrs[4]}, + {addrs[0], addrs[1]}, + }, + }, + { + name: "updates", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 200}, + {Addr: addrs[1], Delay: 100}, + {Addr: addrs[2], Delay: 20}, + {Addr: addrs[0], Delay: 0}, + {Addr: addrs[1], Delay: 200}, + {Addr: addrs[3], Delay: 200}, + }, + output: [][]ma.Multiaddr{ + {addrs[0]}, + {addrs[2]}, + {addrs[1], addrs[3]}, + }, + }, + { + name: "null input", + input: []network.AddrDelay{}, + output: [][]ma.Multiaddr{ + {}, + {}, + }, + }, + } + for _, tc := range testcase { + t.Run(tc.name, func(t *testing.T) { + q := newDialQueue() + for i := 0; i < len(tc.input); i++ { + q.add(tc.input[i]) + } + for _, batch := range tc.output { + b := q.nextBatch() + if len(batch) != len(b) { + t.Errorf("expected %d elements got %d", len(batch), len(b)) + } + sort.Slice(b, func(i, j int) bool { return b[i].Addr.String() < b[j].Addr.String() }) + sort.Slice(batch, func(i, j int) bool { return batch[i].String() < batch[j].String() }) + for i := 0; i < len(b); i++ { + if !b[i].Addr.Equal(batch[i]) { + log.Errorf("expected %s got %s", batch[i], b[i].Addr) + } + } + } + if q.len() != 0 { + t.Errorf("expected queue to be empty at end. got: %d", q.len()) + } + }) + } +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index cd19e726ed..3f4444f356 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -100,6 +100,23 @@ func WithResourceManager(m network.ResourceManager) Option { } } +// WithNoDialDelay configures swarm to dial all addresses for a peer without +// any delay +func WithNoDialDelay() Option { + return func(s *Swarm) error { + s.dialRanker = noDelayRanker + return nil + } +} + +// WithDialRanker configures swarm to use d as the DialRanker +func WithDialRanker(d network.DialRanker) Option { + return func(s *Swarm) error { + s.dialRanker = d + return nil + } +} + // Swarm is a connection muxer, allowing connections to other peers to // be opened and closed, while still using the same Chan for all // communication. The Chan sends/receives Messages, which note the @@ -163,6 +180,8 @@ type Swarm struct { bwc metrics.Reporter metricsTracer MetricsTracer + + dialRanker network.DialRanker } // NewSwarm constructs a Swarm. @@ -181,6 +200,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts dialTimeout: defaultDialTimeout, dialTimeoutLocal: defaultDialTimeoutLocal, maResolver: madns.DefaultResolver, + dialRanker: defaultDialRanker, } s.conns.m = make(map[peer.ID][]*Conn) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 5423a199b7..ab195b0072 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -295,7 +295,7 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { // dialWorkerLoop synchronizes and executes concurrent dials to a single peer func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { - w := newDialWorker(s, p, reqch) + w := newDialWorker(s, p, reqch, nil) w.loop() } @@ -542,13 +542,6 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isExpensiveAddr(addr ma.Multiaddr) bool { - _, wsErr := addr.ValueForProtocol(ma.P_WS) - _, wssErr := addr.ValueForProtocol(ma.P_WSS) - _, wtErr := addr.ValueForProtocol(ma.P_WEBTRANSPORT) - return wsErr == nil || wssErr == nil || wtErr == nil -} - func isRelayAddr(addr ma.Multiaddr) bool { _, err := addr.ValueForProtocol(ma.P_CIRCUIT) return err == nil