diff --git a/go.mod b/go.mod index f541bec1..7346ef66 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/dsnet/golib/memfile v1.0.0 github.com/pion/dtls/v2 v2.1.5 + github.com/pion/udp v0.1.1 github.com/stretchr/testify v1.7.1 go.uber.org/atomic v1.9.0 golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d @@ -17,7 +18,6 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/transport v0.13.1 // indirect - github.com/pion/udp v0.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b // indirect diff --git a/net/connUDP.go b/net/connUDP.go index bef37da5..b04f0519 100644 --- a/net/connUDP.go +++ b/net/connUDP.go @@ -136,14 +136,14 @@ func IsIPv6(addr net.IP) bool { return false } -var defaultUDPConnOptions = udpConnOptions{ - errors: func(err error) { +var DefaultUDPConnConfig = UDPConnConfig{ + Errors: func(err error) { // don't log any error from fails for multicast requests }, } -type udpConnOptions struct { - errors func(err error) +type UDPConnConfig struct { + Errors func(err error) } func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) { @@ -160,9 +160,9 @@ func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) { // NewUDPConn creates connection over net.UDPConn. func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { - cfg := defaultUDPConnOptions + cfg := DefaultUDPConnConfig for _, o := range opts { - o.applyUDP(&cfg) + o.ApplyUDP(&cfg) } var pc packetConn @@ -176,7 +176,7 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { network: network, connection: c, packetConn: pc, - errors: cfg.errors, + errors: cfg.Errors, } } diff --git a/net/dtlslistener.go b/net/dtlslistener.go index e7c6c204..b50325c2 100644 --- a/net/dtlslistener.go +++ b/net/dtlslistener.go @@ -2,29 +2,82 @@ package net import ( "context" + "errors" "fmt" "net" + "sync" "time" dtls "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/udp" "go.uber.org/atomic" ) +type GoPoolFunc = func(f func()) error + +var DefaultDTLSListenerConfig = DTLSListenerConfig{ + GoPool: func(f func()) error { + go f() + return nil + }, +} + +type DTLSListenerConfig struct { + GoPool GoPoolFunc +} + +type acceptedConn struct { + conn net.Conn + err error +} + // DTLSListener is a DTLS listener that provides accept with context. type DTLSListener struct { - listener net.Listener - closed atomic.Bool + listener net.Listener + config *dtls.Config + closed atomic.Bool + goPool GoPoolFunc + acceptedConnChan chan acceptedConn + wg sync.WaitGroup + done chan struct{} +} + +func tlsPacketFilter(packet []byte) bool { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return false + } + return h.ContentType == protocol.ContentTypeHandshake } // NewDTLSListener creates dtls listener. // Known networks are "udp", "udp4" (IPv4-only), "udp6" (IPv6-only). -func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (*DTLSListener, error) { +func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config, opts ...DTLSListenerOption) (*DTLSListener, error) { a, err := net.ResolveUDPAddr(network, addr) if err != nil { return nil, fmt.Errorf("cannot resolve address: %w", err) } + cfg := DefaultDTLSListenerConfig + for _, o := range opts { + o.ApplyDTLS(&cfg) + } + + if cfg.GoPool == nil { + return nil, fmt.Errorf("empty go pool") + } - var l DTLSListener + l := DTLSListener{ + goPool: cfg.GoPool, + config: dtlsCfg, + acceptedConnChan: make(chan acceptedConn, 256), + done: make(chan struct{}), + } connectContextMaker := dtlsCfg.ConnectContextMaker if connectContextMaker == nil { connectContextMaker = func() (context.Context, func()) { @@ -39,14 +92,56 @@ func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (*DTLSLi return ctx, cancel } - listener, err := dtls.Listen(network, a, dtlsCfg) + lc := udp.ListenConfig{ + AcceptFilter: tlsPacketFilter, + } + l.listener, err = lc.Listen(network, a) if err != nil { - return nil, fmt.Errorf("cannot create new dtls listener: %w", err) + return nil, err } - l.listener = listener + l.wg.Add(1) + go l.run() return &l, nil } +func (l *DTLSListener) send(conn net.Conn, err error) { + select { + case <-l.done: + case l.acceptedConnChan <- acceptedConn{ + conn: conn, + err: err, + }: + } +} + +func (l *DTLSListener) accept() error { + c, err := l.listener.Accept() + if err != nil { + l.send(nil, err) + return err + } + err = l.goPool(func() { + l.send(dtls.Server(c, l.config)) + }) + if err != nil { + _ = c.Close() + } + return err +} + +func (l *DTLSListener) run() { + defer l.wg.Done() + for { + if l.closed.Load() { + return + } + err := l.accept() + if errors.Is(err, udp.ErrClosedListener) { + return + } + } +} + // AcceptWithContext waits with context for a generic Conn. func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) { select { @@ -57,14 +152,27 @@ func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) if l.closed.Load() { return nil, ErrListenerIsClosed } - c, err := l.listener.Accept() - if err != nil { - return nil, err - } - if c == nil { - return nil, nil + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.done: + return nil, ErrListenerIsClosed + case d := <-l.acceptedConnChan: + err := d.err + if errors.Is(err, context.DeadlineExceeded) { + // we don't want to report error handshake deadline exceeded + continue + } + if errors.Is(err, udp.ErrClosedListener) { + return nil, ErrListenerIsClosed + } + if err != nil { + return nil, err + } + return d.conn, nil + } } - return c, nil } // Accept waits for a generic Conn. @@ -77,6 +185,8 @@ func (l *DTLSListener) Close() error { if !l.closed.CAS(false, true) { return nil } + close(l.done) + defer l.wg.Wait() return l.listener.Close() } diff --git a/net/options.go b/net/options.go index 2cab9bce..c037ef9c 100644 --- a/net/options.go +++ b/net/options.go @@ -4,15 +4,15 @@ import "net" // A UDPOption sets options such as errors parameters, etc. type UDPOption interface { - applyUDP(*udpConnOptions) + ApplyUDP(*UDPConnConfig) } type ErrorsOpt struct { errors func(err error) } -func (h ErrorsOpt) applyUDP(o *udpConnOptions) { - o.errors = h.errors +func (h ErrorsOpt) ApplyUDP(o *UDPConnConfig) { + o.Errors = h.errors } func WithErrors(v func(err error)) ErrorsOpt { @@ -120,3 +120,24 @@ func (m MulticastInterfaceErrorOpt) applyMC(o *MulticastOptions) { func WithMulticastInterfaceError(interfaceError InterfaceError) MulticastOption { return &MulticastInterfaceErrorOpt{interfaceError: interfaceError} } + +// A DTLSListenerOption sets options such as gopool. +type DTLSListenerOption interface { + ApplyDTLS(*DTLSListenerConfig) +} + +// GoPoolOpt gopool option. +type GoPoolOpt struct { + goPool GoPoolFunc +} + +func (o GoPoolOpt) ApplyDTLS(cfg *DTLSListenerConfig) { + cfg.GoPool = o.goPool +} + +// WithGoPool sets function for managing spawning go routines +// for handling incoming request's. +// Eg: https://github.com/panjf2000/ants. +func WithGoPool(goPool GoPoolFunc) GoPoolOpt { + return GoPoolOpt{goPool: goPool} +}