Skip to content

Commit

Permalink
add tests for listener
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 5, 2024
1 parent 2837917 commit 60cef4d
Show file tree
Hide file tree
Showing 5 changed files with 553 additions and 96 deletions.
4 changes: 2 additions & 2 deletions p2p/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ var _ transport.Transport = &TcpTransport{}
var _ transport.DialUpdater = &TcpTransport{}

// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
// created. It represents an entire TCP stack (though it might not necessarily be).
// created.
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
Expand Down Expand Up @@ -269,7 +269,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect)
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
}
if err != nil {
return nil, err
Expand Down
48 changes: 27 additions & 21 deletions p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,24 @@ type peekAble interface {

var _ peekAble = (*bufio.Reader)(nil)

// TODO: We can unexport this type and rely completely on the multiaddr passed in to
// DemultiplexedListen.
type DemultiplexedConnType int

const (
Unknown DemultiplexedConnType = iota
MultistreamSelect
HTTP
TLS
DemultiplexedConnType_Unknown DemultiplexedConnType = iota
DemultiplexedConnType_MultistreamSelect
DemultiplexedConnType_HTTP
DemultiplexedConnType_TLS
)

func (t DemultiplexedConnType) String() string {
switch t {
case MultistreamSelect:
case DemultiplexedConnType_MultistreamSelect:
return "MultistreamSelect"
case HTTP:
case DemultiplexedConnType_HTTP:
return "HTTP"
case TLS:
case DemultiplexedConnType_TLS:
return "TLS"
default:
return fmt.Sprintf("Unknown(%d)", int(t))
Expand All @@ -49,7 +51,7 @@ func (t DemultiplexedConnType) IsKnown() bool {
return t >= 1 || t <= 3
}

func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
func getDemultiplexedConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
Expand All @@ -67,20 +69,24 @@ func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
}

if IsMultistreamSelect(s) {
return MultistreamSelect, sc, nil
return DemultiplexedConnType_MultistreamSelect, sc, nil
}
if IsTLS(s) {
return TLS, sc, nil
return DemultiplexedConnType_TLS, sc, nil
}
if IsHTTP(s) {
return HTTP, sc, nil
return DemultiplexedConnType_HTTP, sc, nil
}
return Unknown, sc, nil
return DemultiplexedConnType_Unknown, sc, nil
}

// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged.
// If an error occurs it only return the error.
// ReadSampleFromConn reads a sample and returns a reader which still includes the sample, so it can be kept undamaged.
// If an error occurs it only returns the error.
func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
// TODO: Should we remove this? This is only implemented by bufio.Reader.
// This made sense for magiselect: https://github.com/libp2p/go-libp2p/pull/2737 as it deals with a wrapped
// ReadWriteCloser from multistream which does use a buffered reader underneath.
// For our present purpose, we have a net.Conn and no net.Conn implementation offers peeking.
if peekAble, ok := c.(peekAble); ok {
b, err := peekAble.Peek(len(Sample{}))
switch {
Expand All @@ -92,6 +98,7 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {

return Sample(b), mac, nil
case errors.Is(err, bufio.ErrBufferFull):
// We can only peek < len(Sample{}) data.
// fallback to sampledConn
default:
return Sample{}, nil, err
Expand All @@ -118,13 +125,12 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
if err != nil {
return Sample{}, nil, err
}

return sc.s, sc, nil
}

// Try out best to mimic a TCPConn's functions
// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection
// If this is an issue here we can revisit the options.
// tcpConnInterface is the interface for TCPConn's functions
// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection.
// TODO: allow SyscallConn? Disallowing it breaks metrics tracking in TCP Transport.
type tcpConnInterface interface {
net.Conn

Expand Down Expand Up @@ -180,12 +186,12 @@ func (sc *sampledConn) Read(b []byte) (int, error) {
return sc.tcpConnInterface.Read(b)
}

// forward optimizations
// TODO: Do we need these?

func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(sc.tcpConnInterface, r)
}

// forward optimizations
func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) {
if int(sc.readFromSample) != len(sc.s) {
b := sc.s[sc.readFromSample:]
Expand All @@ -212,7 +218,7 @@ type Matcher interface {
Match(s Sample) bool
}

// Sample might evolve over time.
// Sample is the byte sequence we use to demultiplex.
type Sample [3]byte

// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
Expand Down
Loading

0 comments on commit 60cef4d

Please sign in to comment.