From dfc0fd10e34abf20a41ecf3ad980fd9a964af837 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 6 Oct 2024 23:09:32 +0530 Subject: [PATCH] add fx option, move rcmgr and upgrader to sharedtcp --- config/config.go | 9 ++++++ libp2p_test.go | 16 ++++++++++- options.go | 7 +++++ p2p/net/swarm/dial_worker_test.go | 2 +- p2p/net/swarm/swarm_addr_test.go | 2 +- p2p/net/swarm/swarm_dial_test.go | 8 +++--- p2p/net/swarm/testing/testing.go | 2 +- p2p/net/upgrader/listener.go | 2 +- p2p/protocol/circuitv2/relay/relay_test.go | 2 +- p2p/test/transport/gating_test.go | 25 +++++++++++++++- p2p/test/transport/transport_test.go | 32 +++++++++++++++++++++ p2p/transport/tcp/tcp.go | 3 +- p2p/transport/tcp/tcp_test.go | 20 ++++++------- p2p/transport/tcpreuse/listener.go | 1 - p2p/transport/websocket/conn.go | 33 ++++++++++++++++++++++ p2p/transport/websocket/listener.go | 18 ++++++------ p2p/transport/websocket/websocket.go | 3 +- p2p/transport/websocket/websocket_test.go | 28 +++++++++--------- 18 files changed, 166 insertions(+), 47 deletions(-) diff --git a/config/config.go b/config/config.go index d0a71664f7..abc0586a1b 100644 --- a/config/config.go +++ b/config/config.go @@ -36,6 +36,7 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/prometheus/client_golang/prometheus" @@ -142,6 +143,8 @@ type Config struct { CustomUDPBlackHoleSuccessCounter bool IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter CustomIPv6BlackHoleSuccessCounter bool + + ShareTCPListener bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -286,6 +289,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), + fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { + if !cfg.ShareTCPListener { + return nil + } + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) + }), fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { diff --git a/libp2p_test.go b/libp2p_test.go index a5803add4d..ff6097c1dc 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/websocket" webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "go.uber.org/goleak" @@ -52,7 +53,7 @@ func TestTransportConstructor(t *testing.T) { _ connmgr.ConnectionGater, upgrader transport.Upgrader, ) transport.Transport { - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) require.NoError(t, err) return tpt } @@ -655,3 +656,16 @@ func TestUseCorrectTransportForDialOut(t *testing.T) { } } } + +func TestSharedTCPAddr(t *testing.T) { + h, err := New( + ShareTCPListener(), + Transport(tcp.NewTCPTransport), + Transport(websocket.New), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"), + ) + require.NoError(t, err) + fmt.Println(h.Addrs()) + h.Close() +} diff --git a/options.go b/options.go index 1a8bc5dd55..e4083590a9 100644 --- a/options.go +++ b/options.go @@ -635,3 +635,10 @@ func IPv6BlackHoleSuccessCounter(f *swarm.BlackHoleSuccessCounter) Option { return nil } } + +func ShareTCPListener() Option { + return func(cfg *Config) error { + cfg.ShareTCPListener = true + return nil + } +} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index ed4f00ff58..d264fd1230 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { upgrader := makeUpgrader(t, s) var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index 435866e920..43e76716e5 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) { s, err := swarm.NewSwarm("local", nil, eventbus.NewBus()) require.NoError(t, err) - tcpTr, err := tcp.NewTCPTransport(nil, nil) + tcpTr, err := tcp.NewTCPTransport(nil, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(tcpTr)) reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index f4c33170a9..059cc41cca 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - tpt, err := websocket.New(nil, &network.NullResourceManager{}) + tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver)) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) { require.NoError(t, err) defer s.Close() - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { }) // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { err = s.AddTransport(wtTpt) require.NoError(t, err) - wsTpt, err := websocket.New(nil, &network.NullResourceManager{}) + wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(wsTpt) require.NoError(t, err) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 2bbe8b27a5..773314a1b8 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm { if cfg.disableReuseport { tcpOpts = append(tcpOpts, tcp.DisableReuseport()) } - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 0530bde292..65da2bec6c 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -91,7 +91,7 @@ func (l *listener) handleIncoming() { connScope = sc.Scope() } - if connScope != nil { + if connScope == nil { // gate the connection if applicable if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { log.Debugf("gater blocked incoming connection on local addr %s from %s", diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index e5d32b0c96..f6b63e32de 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u upgrader := swarmt.GenUpgrader(t, netw, nil) upgraders = append(upgraders, upgrader) - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..99ce67b521 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -2,6 +2,8 @@ package transport_integration import ( "context" + "encoding/binary" + "net/netip" "strings" "testing" "time" @@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { return addr } +func addrPort(addr ma.Multiaddr) netip.AddrPort { + a := netip.Addr{} + p := uint16(0) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + a, _ = netip.AddrFromSlice(c.RawValue()) + return false + } + if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP { + p = binary.BigEndian.Uint16(c.RawValue()) + return true + } + return false + }) + return netip.AddrPortFrom(a, p) +} + func TestInterceptPeerDial(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) { // remove the certhash component from WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() + } else if strings.Contains(tc.Name, "WebSocket-Shared") { + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) + }) } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr()) }) } diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..60f8ca0c06 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "TCP-Shared / TLS / Yamux", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, + { + Name: "WebSocket-Shared", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 5883e43f6a..1b145c2b45 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -149,7 +149,7 @@ var _ transport.DialUpdater = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners // created. -func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { +func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -157,6 +157,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, upgrader: upgrader, connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option rcmgr: rcmgr, + sharedTcp: sharedTCP, } for _, o := range opts { if err := o(tr); err != nil { diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 4c692fbf4c..1f939d92be 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -32,11 +32,11 @@ func TestTcpTransport(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -53,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil, WithMetrics()) + ta, err := NewTCPTransport(ua, nil, nil, WithMetrics()) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil, WithMetrics()) + tb, err := NewTCPTransport(ub, nil, nil, WithMetrics()) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -73,7 +73,7 @@ func TestResourceManager(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -82,7 +82,7 @@ func TestResourceManager(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tb, err := NewTCPTransport(ub, rcmgr) + tb, err := NewTCPTransport(ub, rcmgr, nil) require.NoError(t, err) t.Run("success", func(t *testing.T) { @@ -120,7 +120,7 @@ func TestTcpTransportCantDialDNS(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) if tpt.CanDial(dnsa) { @@ -138,7 +138,7 @@ func TestTcpTransportCantListenUtp(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) _, err = tpt.Listen(utpa) @@ -155,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -163,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) updCh := make(chan transport.DialUpdate, 1) diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index 4a5bfa119b..e0bfe8eef2 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -191,7 +191,6 @@ func (m *multiplexedListener) run() error { } continue } - connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) if err != nil { log.Debugw("resource manager blocked accept of new connection", "error", err) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 19d4e46ec5..df97189d90 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -8,6 +8,8 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ws "github.com/gorilla/websocket" ) @@ -23,21 +25,52 @@ type Conn struct { DefaultMessageType int reader io.Reader closeOnce sync.Once + laddr ma.Multiaddr + raddr ma.Multiaddr readLock, writeLock sync.Mutex } var _ net.Conn = (*Conn)(nil) +var _ manet.Conn = (*Conn)(nil) // NewConn creates a Conn given a regular gorilla/websocket Conn. +// +// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. func NewConn(raw *ws.Conn, secure bool) *Conn { + lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) + laddr, err := manet.FromNetAddr(lna) + if err != nil { + log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr()) + return nil + } + + rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + raddr, err := manet.FromNetAddr(rna) + if err != nil { + log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr()) + return nil + } + return &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, } } +// LocalMultiaddr implements manet.Conn. +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr implements manet.Conn. +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddr +} + func (c *Conn) Read(b []byte) (int, error) { c.readLock.Lock() defer c.readLock.Unlock() diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 40b290e212..dd399aa079 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -4,11 +4,12 @@ import ( "crypto/tls" "errors" "fmt" - "go.uber.org/zap" "net" "net/http" "sync" + "go.uber.org/zap" + logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/transport" @@ -129,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { // The upgrader writes a response for us. return } - + nc := NewConn(c, l.isWss) + if nc == nil { + c.Close() + w.WriteHeader(500) + return + } select { case l.incoming <- NewConn(c, l.isWss): case <-l.closed: @@ -144,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 68ac5e77a4..d181eb8dc2 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -101,7 +101,7 @@ type WebsocketTransport struct { var _ transport.Transport = (*WebsocketTransport)(nil) -func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { +func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -109,6 +109,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (* upgrader: u, rcmgr: rcmgr, tlsClientConf: &tls.Config{}, + sharedTcp: sharedTCP, } for _, opt := range opts { if err := opt(t); err != nil { diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..9ca03775a2 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID } id, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } @@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) { require.Contains(t, serverMA.String(), "tls") _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) { func TestWebsocketTransport(t *testing.T) { peerA, ua := newUpgrader(t) - ta, err := New(ua, nil) + ta, err := New(ua, nil, nil) if err != nil { t.Fatal(err) } _, ub := newUpgrader(t) - tb, err := New(ub, nil) + tb, err := New(ub, nil, nil) if err != nil { t.Fatal(err) } @@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSConfig(tlsConf)) } server, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) l, err := tpt.Listen(laddr) require.NoError(t, err) @@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) } _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) @@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) { func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss") _, err = tpt.Listen(addr) @@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { func TestWebsocketListenSecureAndInsecure(t *testing.T) { serverID, serverUpgrader := newUpgrader(t) - server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t))) + server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t))) require.NoError(t, err) lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) @@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("insecure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("secure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { @@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) { func TestWriteZero(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) if err != nil { t.Fatal(err) }