Skip to content

Commit

Permalink
Support generic net.PacketConn's for the Server (#1174)
Browse files Browse the repository at this point in the history
* Support generic net.PacketConn's for the Server

This commit adds support for listening on generic net.PacketConn's for
UDP DNS requests, previously *net.UDPConn was the only supported type.

In the event of a future v2 of this module, this should be streamlined.

* Eliminate wrapper functions around RunLocalXServerWithFinChan

* Eliminate RunLocalTCPServerWithTsig function

* Replace RunLocalTLSServer with a wrapper around RunLocalTCPServer

This reduces code duplication.

* Add net.PacketConn server tests

This provides coverage over nearly all of the newly added code (with
the unfortunate exception of (*response).RemoteAddr).

* Fix broken client_test.go tests

a433fbe was merged into master between this PR being opened and
being merged. This broke the CI tests in rather strange ways as the
code was being merged into master in a way that wasn't at all clear.
This commit fixes the two broken lines.
  • Loading branch information
tmthrgd authored Oct 24, 2020
1 parent a3ad444 commit 0e1c4e6
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 173 deletions.
2 changes: 1 addition & 1 deletion acceptfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

func TestAcceptNotify(t *testing.T) {
HandleFunc("example.org.", handleNotify)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down
24 changes: 12 additions & 12 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestDialUDP(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand All @@ -39,7 +39,7 @@ func TestClientSync(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestClientLocalAddress(t *testing.T) {
HandleFunc("miek.nl.", HelloServerEchoAddrPort)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestClientTLSSyncV4(t *testing.T) {
Certificates: []tls.Certificate{cert},
}

s, addrstr, err := RunLocalTLSServer(":0", &config)
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func TestClientSyncBadID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadID)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand All @@ -198,7 +198,7 @@ func TestClientSyncBadThenGoodID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadThenGoodID)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestClientSyncTCPBadID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadID)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalTCPServer(":0")
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand All @@ -250,7 +250,7 @@ func TestClientEDNS0(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestClientEDNS0Local(t *testing.T) {
HandleFunc("miek.nl.", handler)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestClientConn(t *testing.T) {
defer HandleRemove("miek.nl.")

// This uses TCP just to make it slightly different than TestClientSync
s, addrstr, err := RunLocalTCPServer(":0")
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down Expand Up @@ -594,7 +594,7 @@ func TestConcurrentExchanges(t *testing.T) {
HandleFunc("miek.nl.", handler)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
Expand Down Expand Up @@ -631,7 +631,7 @@ func TestExchangeWithConn(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")

s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
Expand Down
102 changes: 79 additions & 23 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ type response struct {
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp *net.UDPConn // i/o connection if UDP was used
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}

Expand Down Expand Up @@ -147,12 +148,24 @@ type Reader interface {
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
}

// defaultReader is an adapter for the Server struct that implements the Reader interface
// using the readTCP and readUDP func of the embedded Server.
// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
type PacketConnReader interface {
Reader

// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
// alter connection properties, for example the read-deadline.
ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
}

// defaultReader is an adapter for the Server struct that implements the Reader and
// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
// of the embedded Server.
type defaultReader struct {
*Server
}

var _ PacketConnReader = defaultReader{}

func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
}
Expand All @@ -161,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
return dr.readUDP(conn, timeout)
}

func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
return dr.readPacketConn(conn, timeout)
}

// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
// Implementations should never return a nil Reader.
// Readers should also implement the optional ReaderPacketConn interface.
// ReaderPacketConn is required to use a generic net.PacketConn.
type DecorateReader func(Reader) Reader

// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
Expand Down Expand Up @@ -325,24 +344,22 @@ func (srv *Server) ActivateAndServe() error {

srv.init()

pConn := srv.PacketConn
l := srv.Listener
if pConn != nil {
if srv.PacketConn != nil {
// Check PacketConn interface's type is valid and value
// is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
if e := setUDPSocketOptions(t); e != nil {
return e
}
srv.started = true
unlock()
return srv.serveUDP(t)
}
srv.started = true
unlock()
return srv.serveUDP(srv.PacketConn)
}
if l != nil {
if srv.Listener != nil {
srv.started = true
unlock()
return srv.serveTCP(l)
return srv.serveTCP(srv.Listener)
}
return &Error{err: "bad listeners"}
}
Expand Down Expand Up @@ -446,18 +463,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
}

// serveUDP starts a UDP listener for the server.
func (srv *Server) serveUDP(l *net.UDPConn) error {
func (srv *Server) serveUDP(l net.PacketConn) error {
defer l.Close()

if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}

reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}

lUDP, isUDP := l.(*net.UDPConn)
readerPC, canPacketConn := reader.(PacketConnReader)
if !isUDP && !canPacketConn {
return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
}

if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}

var wg sync.WaitGroup
defer func() {
wg.Wait()
Expand All @@ -467,7 +490,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
rtimeout := srv.getReadTimeout()
// deadline is not used here
for srv.isStarted() {
m, s, err := reader.ReadUDP(l, rtimeout)
var (
m []byte
sPC net.Addr
sUDP *SessionUDP
err error
)
if isUDP {
m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
} else {
m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
}
if err != nil {
if !srv.isStarted() {
return nil
Expand All @@ -484,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
continue
}
wg.Add(1)
go srv.serveUDPPacket(&wg, m, l, s)
go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
}

return nil
Expand Down Expand Up @@ -546,8 +579,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
}

// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand Down Expand Up @@ -659,6 +692,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
return m, s, nil
}

func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()

m := srv.udpPool.Get().([]byte)
n, addr, err := conn.ReadFrom(m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
return m, addr, nil
}

// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
Expand Down Expand Up @@ -692,7 +743,10 @@ func (w *response) Write(m []byte) (int, error) {

switch {
case w.udp != nil:
return WriteToSessionUDP(w.udp, m, w.udpSession)
if u, ok := w.udp.(*net.UDPConn); ok {
return WriteToSessionUDP(u, m, w.udpSession)
}
return w.udp.WriteTo(m, w.pcSession)
case w.tcp != nil:
if len(m) > MaxMsgSize {
return 0, &Error{err: "message too large"}
Expand Down Expand Up @@ -725,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
switch {
case w.udpSession != nil:
return w.udpSession.RemoteAddr()
case w.pcSession != nil:
return w.pcSession
case w.tcp != nil:
return w.tcp.RemoteAddr()
default:
panic("dns: internal error: udpSession and tcp both nil")
panic("dns: internal error: udpSession, pcSession and tcp are all nil")
}
}

Expand Down
Loading

0 comments on commit 0e1c4e6

Please sign in to comment.