Skip to content

Commit

Permalink
add fx option, move rcmgr and upgrader to sharedtcp
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 6, 2024
1 parent ed7f901 commit dfc0fd1
Show file tree
Hide file tree
Showing 18 changed files with 166 additions and 47 deletions.
9 changes: 9 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/prometheus/client_golang/prometheus"

Expand Down Expand Up @@ -142,6 +143,8 @@ type Config struct {
CustomUDPBlackHoleSuccessCounter bool
IPv6BlackHoleSuccessCounter *swarm.BlackHoleSuccessCounter
CustomIPv6BlackHoleSuccessCounter bool

ShareTCPListener bool
}

func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
Expand Down Expand Up @@ -286,6 +289,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr {
if !cfg.ShareTCPListener {
return nil
}
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr)
}),
fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
Expand Down
16 changes: 15 additions & 1 deletion libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"go.uber.org/goleak"

Expand All @@ -52,7 +53,7 @@ func TestTransportConstructor(t *testing.T) {
_ connmgr.ConnectionGater,
upgrader transport.Upgrader,
) transport.Transport {
tpt, err := tcp.NewTCPTransport(upgrader, nil)
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
require.NoError(t, err)
return tpt
}
Expand Down Expand Up @@ -655,3 +656,16 @@ func TestUseCorrectTransportForDialOut(t *testing.T) {
}
}
}

func TestSharedTCPAddr(t *testing.T) {
h, err := New(
ShareTCPListener(),
Transport(tcp.NewTCPTransport),
Transport(websocket.New),
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"),
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"),
)
require.NoError(t, err)
fmt.Println(h.Addrs())
h.Close()
}
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,10 @@ func IPv6BlackHoleSuccessCounter(f *swarm.BlackHoleSuccessCounter) Option {
return nil
}
}

func ShareTCPListener() Option {
return func(cfg *Config) error {
cfg.ShareTCPListener = true
return nil
}
}
2 changes: 1 addition & 1 deletion p2p/net/swarm/dial_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm {
upgrader := makeUpgrader(t, s)
var tcpOpts []tcp.Option
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
require.NoError(t, err)
if err := s.AddTransport(tcpTransport); err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion p2p/net/swarm/swarm_addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) {
s, err := swarm.NewSwarm("local", nil, eventbus.NewBus())
require.NoError(t, err)

tcpTr, err := tcp.NewTCPTransport(nil, nil)
tcpTr, err := tcp.NewTCPTransport(nil, nil, nil)
require.NoError(t, err)
require.NoError(t, s.AddTransport(tcpTr))
reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
Expand Down
8 changes: 4 additions & 4 deletions p2p/net/swarm/swarm_dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) {
ps.AddPrivKey(id, priv)
t.Cleanup(func() { ps.Close() })

tpt, err := websocket.New(nil, &network.NullResourceManager{})
tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver))
require.NoError(t, err)
Expand Down Expand Up @@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) {
require.NoError(t, err)
defer s.Close()

tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(tpt)
require.NoError(t, err)
Expand Down Expand Up @@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
})

// Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out.
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(tpt)
require.NoError(t, err)
Expand All @@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
err = s.AddTransport(wtTpt)
require.NoError(t, err)

wsTpt, err := websocket.New(nil, &network.NullResourceManager{})
wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(wsTpt)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion p2p/net/swarm/testing/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm {
if cfg.disableReuseport {
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
}
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
require.NoError(t, err)
if err := s.AddTransport(tcpTransport); err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion p2p/net/upgrader/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (l *listener) handleIncoming() {
connScope = sc.Scope()
}

if connScope != nil {
if connScope == nil {
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
Expand Down
2 changes: 1 addition & 1 deletion p2p/protocol/circuitv2/relay/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u
upgrader := swarmt.GenUpgrader(t, netw, nil)
upgraders = append(upgraders, upgrader)

tpt, err := tcp.NewTCPTransport(upgrader, nil)
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
25 changes: 24 additions & 1 deletion p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package transport_integration

import (
"context"
"encoding/binary"
"net/netip"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
return addr
}

func addrPort(addr ma.Multiaddr) netip.AddrPort {
a := netip.Addr{}
p := uint16(0)
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 {
a, _ = netip.AddrFromSlice(c.RawValue())
return false
}
if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP {
p = binary.BigEndian.Uint16(c.RawValue())
return true
}
return false
})
return netip.AddrPortFrom(a, p)
}

func TestInterceptPeerDial(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
Expand Down Expand Up @@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
})
}

Expand Down
32 changes: 32 additions & 0 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "TCP-Shared / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down
3 changes: 2 additions & 1 deletion p2p/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,15 @@ var _ transport.DialUpdater = &TcpTransport{}

// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
// created.
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) {
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
tr := &TcpTransport{
upgrader: upgrader,
connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option
rcmgr: rcmgr,
sharedTcp: sharedTCP,
}
for _, o := range opts {
if err := o(tr); err != nil {
Expand Down
20 changes: 10 additions & 10 deletions p2p/transport/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ func TestTcpTransport(t *testing.T) {

ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil)
tb, err := NewTCPTransport(ub, nil, nil)
require.NoError(t, err)

zero := "/ip4/127.0.0.1/tcp/0"
Expand All @@ -53,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) {

ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil, WithMetrics())
ta, err := NewTCPTransport(ua, nil, nil, WithMetrics())
require.NoError(t, err)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil, WithMetrics())
tb, err := NewTCPTransport(ub, nil, nil, WithMetrics())
require.NoError(t, err)

zero := "/ip4/127.0.0.1/tcp/0"
Expand All @@ -73,7 +73,7 @@ func TestResourceManager(t *testing.T) {

ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
Expand All @@ -82,7 +82,7 @@ func TestResourceManager(t *testing.T) {
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tb, err := NewTCPTransport(ub, rcmgr)
tb, err := NewTCPTransport(ub, rcmgr, nil)
require.NoError(t, err)

t.Run("success", func(t *testing.T) {
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestTcpTransportCantDialDNS(t *testing.T) {
require.NoError(t, err)

var u transport.Upgrader
tpt, err := NewTCPTransport(u, nil)
tpt, err := NewTCPTransport(u, nil, nil)
require.NoError(t, err)

if tpt.CanDial(dnsa) {
Expand All @@ -138,7 +138,7 @@ func TestTcpTransportCantListenUtp(t *testing.T) {
require.NoError(t, err)

var u transport.Upgrader
tpt, err := NewTCPTransport(u, nil)
tpt, err := NewTCPTransport(u, nil, nil)
require.NoError(t, err)

_, err = tpt.Listen(utpa)
Expand All @@ -155,15 +155,15 @@ func TestDialWithUpdates(t *testing.T) {

ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
defer ln.Close()

ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil)
tb, err := NewTCPTransport(ub, nil, nil)
require.NoError(t, err)

updCh := make(chan transport.DialUpdate, 1)
Expand Down
1 change: 0 additions & 1 deletion p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ func (m *multiplexedListener) run() error {
}
continue
}

connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
Expand Down
Loading

0 comments on commit dfc0fd1

Please sign in to comment.