diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 371a34e92c..f5dcd6d2b7 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -49,6 +49,8 @@ type opened*: uint64 closed*: uint64 + TcpTransportError* = object of transport.TransportError + proc setupTcpTransportTracker(): TcpTransportTracker {.gcsafe, raises: [Defect].} proc getTcpTransportTracker(): TcpTransportTracker {.gcsafe.} = @@ -157,6 +159,17 @@ method start*( warn "TCP transport already running" return + proc getPort(ma: MultiAddress): seq[byte] = + return ma[1].get().protoArgument().get() + + proc isNotZeroPort(port: seq[byte]): bool = + return port != @[0.byte, 0] + + let supported = addrs.filterIt(self.handles(it)) + let nonZeroPorts = supported.mapIt(getPort(it)).filterIt(isNotZeroPort(it)) + if deduplicate(nonZeroPorts).len < nonZeroPorts.len: + raise newException(TcpTransportError, "Duplicate ports detected") + await procCall Transport(self).start(addrs) trace "Starting TCP transport" inc getTcpTransportTracker().opened @@ -166,8 +179,7 @@ method start*( trace "Invalid address detected, skipping!", address = ma continue - if self.networkReachability == NetworkReachability.NotReachable: - self.flags.incl(ServerFlags.ReusePort) + self.flags.incl(ServerFlags.ReusePort) let server = createStreamServer( ma = ma, flags = self.flags, diff --git a/tests/testtcptransport.nim b/tests/testtcptransport.nim index d0666e1d30..b78ab0a354 100644 --- a/tests/testtcptransport.nim +++ b/tests/testtcptransport.nim @@ -125,6 +125,25 @@ suite "TCP transport": server.close() await server.join() + asyncTest "Starting with duplicate ports must fail": + # Starting with duplicate addresses must fail + let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/8080").tryGet(), + MultiAddress.init("/ip4/0.0.0.0/tcp/8080").tryGet()] + + let transport: TcpTransport = TcpTransport.new(upgrade = Upgrade()) + + expect TcpTransportError: + await transport.start(ma) + + asyncTest "Starting with duplicate but zero ports addresses must NOT fail": + let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet(), + MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()] + + let transport: TcpTransport = TcpTransport.new(upgrade = Upgrade()) + + await transport.start(ma) + await transport.stop() + proc transProvider(): Transport = TcpTransport.new(upgrade = Upgrade()) commonTransportTest(