Skip to content

Commit

Permalink
net/dtlslistener: use parallel handshakes for connections
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik committed Aug 15, 2022
1 parent d1c0c08 commit b6a4033
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 25 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.18
require (
github.com/dsnet/golib/memfile v1.0.0
github.com/pion/dtls/v2 v2.1.5
github.com/pion/udp v0.1.1
github.com/stretchr/testify v1.7.1
go.uber.org/atomic v1.9.0
golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d
Expand All @@ -17,7 +18,6 @@ require (
github.com/kr/pretty v0.1.0 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/transport v0.13.1 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b // indirect
Expand Down
14 changes: 7 additions & 7 deletions net/connUDP.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ func IsIPv6(addr net.IP) bool {
return false
}

var defaultUDPConnOptions = udpConnOptions{
errors: func(err error) {
var DefaultUDPConnConfig = UDPConnConfig{
Errors: func(err error) {
// don't log any error from fails for multicast requests
},
}

type udpConnOptions struct {
errors func(err error)
type UDPConnConfig struct {
Errors func(err error)
}

func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) {
Expand All @@ -160,9 +160,9 @@ func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) {

// NewUDPConn creates connection over net.UDPConn.
func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn {
cfg := defaultUDPConnOptions
cfg := DefaultUDPConnConfig
for _, o := range opts {
o.applyUDP(&cfg)
o.ApplyUDP(&cfg)
}

var pc packetConn
Expand All @@ -176,7 +176,7 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn {
network: network,
connection: c,
packetConn: pc,
errors: cfg.errors,
errors: cfg.Errors,
}
}

Expand Down
138 changes: 124 additions & 14 deletions net/dtlslistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,82 @@ package net

import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"

dtls "github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
"github.com/pion/udp"
"go.uber.org/atomic"
)

type GoPoolFunc = func(f func()) error

var DefaultDTLSListenerConfig = DTLSListenerConfig{
GoPool: func(f func()) error {
go f()
return nil
},
}

type DTLSListenerConfig struct {
GoPool GoPoolFunc
}

type acceptedConn struct {
conn net.Conn
err error
}

// DTLSListener is a DTLS listener that provides accept with context.
type DTLSListener struct {
listener net.Listener
closed atomic.Bool
listener net.Listener
config *dtls.Config
closed atomic.Bool
goPool GoPoolFunc
acceptedConnChan chan acceptedConn
wg sync.WaitGroup
done chan struct{}
}

func tlsPacketFilter(packet []byte) bool {
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return false
}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {
return false
}
return h.ContentType == protocol.ContentTypeHandshake
}

// NewDTLSListener creates dtls listener.
// Known networks are "udp", "udp4" (IPv4-only), "udp6" (IPv6-only).
func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (*DTLSListener, error) {
func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config, opts ...DTLSListenerOption) (*DTLSListener, error) {
a, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return nil, fmt.Errorf("cannot resolve address: %w", err)
}
cfg := DefaultDTLSListenerConfig
for _, o := range opts {
o.ApplyDTLS(&cfg)
}

if cfg.GoPool == nil {
return nil, fmt.Errorf("empty go pool")
}

var l DTLSListener
l := DTLSListener{
goPool: cfg.GoPool,
config: dtlsCfg,
acceptedConnChan: make(chan acceptedConn, 256),
done: make(chan struct{}),
}
connectContextMaker := dtlsCfg.ConnectContextMaker
if connectContextMaker == nil {
connectContextMaker = func() (context.Context, func()) {
Expand All @@ -39,14 +92,56 @@ func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (*DTLSLi
return ctx, cancel
}

listener, err := dtls.Listen(network, a, dtlsCfg)
lc := udp.ListenConfig{
AcceptFilter: tlsPacketFilter,
}
l.listener, err = lc.Listen(network, a)
if err != nil {
return nil, fmt.Errorf("cannot create new dtls listener: %w", err)
return nil, err
}
l.listener = listener
l.wg.Add(1)
go l.run()
return &l, nil
}

func (l *DTLSListener) send(conn net.Conn, err error) {
select {
case <-l.done:
case l.acceptedConnChan <- acceptedConn{
conn: conn,
err: err,
}:
}
}

func (l *DTLSListener) accept() error {
c, err := l.listener.Accept()
if err != nil {
l.send(nil, err)
return err
}
err = l.goPool(func() {
l.send(dtls.Server(c, l.config))
})
if err != nil {
_ = c.Close()
}
return err
}

func (l *DTLSListener) run() {
defer l.wg.Done()
for {
if l.closed.Load() {
return
}
err := l.accept()
if errors.Is(err, udp.ErrClosedListener) {
return
}
}
}

// AcceptWithContext waits with context for a generic Conn.
func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) {
select {
Expand All @@ -57,14 +152,27 @@ func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error)
if l.closed.Load() {
return nil, ErrListenerIsClosed
}
c, err := l.listener.Accept()
if err != nil {
return nil, err
}
if c == nil {
return nil, nil
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-l.done:
return nil, ErrListenerIsClosed
case d := <-l.acceptedConnChan:
err := d.err
if errors.Is(err, context.DeadlineExceeded) {
// we don't want to report error handshake deadline exceeded
continue
}
if errors.Is(err, udp.ErrClosedListener) {
return nil, ErrListenerIsClosed
}
if err != nil {
return nil, err
}
return d.conn, nil
}
}
return c, nil
}

// Accept waits for a generic Conn.
Expand All @@ -77,6 +185,8 @@ func (l *DTLSListener) Close() error {
if !l.closed.CAS(false, true) {
return nil
}
close(l.done)
defer l.wg.Wait()
return l.listener.Close()
}

Expand Down
27 changes: 24 additions & 3 deletions net/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import "net"

// A UDPOption sets options such as errors parameters, etc.
type UDPOption interface {
applyUDP(*udpConnOptions)
ApplyUDP(*UDPConnConfig)
}

type ErrorsOpt struct {
errors func(err error)
}

func (h ErrorsOpt) applyUDP(o *udpConnOptions) {
o.errors = h.errors
func (h ErrorsOpt) ApplyUDP(o *UDPConnConfig) {
o.Errors = h.errors
}

func WithErrors(v func(err error)) ErrorsOpt {
Expand Down Expand Up @@ -120,3 +120,24 @@ func (m MulticastInterfaceErrorOpt) applyMC(o *MulticastOptions) {
func WithMulticastInterfaceError(interfaceError InterfaceError) MulticastOption {
return &MulticastInterfaceErrorOpt{interfaceError: interfaceError}
}

// A DTLSListenerOption sets options such as gopool.
type DTLSListenerOption interface {
ApplyDTLS(*DTLSListenerConfig)
}

// GoPoolOpt gopool option.
type GoPoolOpt struct {
goPool GoPoolFunc
}

func (o GoPoolOpt) ApplyDTLS(cfg *DTLSListenerConfig) {
cfg.GoPool = o.goPool
}

// WithGoPool sets function for managing spawning go routines
// for handling incoming request's.
// Eg: https://github.com/panjf2000/ants.
func WithGoPool(goPool GoPoolFunc) GoPoolOpt {
return GoPoolOpt{goPool: goPool}
}

0 comments on commit b6a4033

Please sign in to comment.