Skip to content

Commit

Permalink
SimpleDNSClient: Use DNS servers in order, try until one works (#29)
Browse files Browse the repository at this point in the history
* initial implementation of in-order dns requests

* handle timeouts, add tests

* assert deadline is present
  • Loading branch information
fardog authored Jun 16, 2018
1 parent 3ea4d87 commit 0a72afc
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 34 deletions.
88 changes: 59 additions & 29 deletions dns_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secureoperator

import (
"context"
"errors"
"fmt"
"math/rand"
Expand All @@ -26,8 +27,14 @@ var ErrFailedParsingIP = errors.New("unable to parse IP from string")
// the port portion of the string was unable to be parsed
var ErrFailedParsingPort = errors.New("unable to parse port from string")

// ErrAllServersFailed is returned when we failed to reach all configured DNS
// servers
var ErrAllServersFailed = errors.New("unable to reach any of the configured servers")

// exchange is locally set to allow its mocking during testing
var exchange = dns.Exchange
var exchange = dns.ExchangeContext

const defaultDNSClientTimeout = 10 * time.Second

// ParseEndpoint parses a string into an Endpoint object, where the endpoint
// string is in the format of "ip:port". If a port is not present in the string,
Expand Down Expand Up @@ -113,15 +120,26 @@ func (d *dnsCache) Set(key string, rec dnsCacheRecord) {
d.records[key] = rec
}

type DNSClientOptions struct {
Timeout time.Duration
}

// NewSimpleDNSClient creates a SimpleDNSClient
func NewSimpleDNSClient(servers Endpoints) (*SimpleDNSClient, error) {
func NewSimpleDNSClient(servers Endpoints, opts *DNSClientOptions) (*SimpleDNSClient, error) {
if len(servers) < 1 {
return nil, fmt.Errorf("at least one endpoint server is required")
}
if opts == nil {
opts = &DNSClientOptions{}
}
if opts.Timeout == 0 {
opts.Timeout = defaultDNSClientTimeout
}

return &SimpleDNSClient{
servers: servers,
cache: newDNSCache(),
opts: opts,
}, nil
}

Expand All @@ -133,6 +151,7 @@ func NewSimpleDNSClient(servers Endpoints) (*SimpleDNSClient, error) {
type SimpleDNSClient struct {
servers Endpoints
cache *dnsCache
opts *DNSClientOptions
}

// LookupIP does a single lookup against the client's configured DNS servers,
Expand All @@ -146,41 +165,52 @@ func (c *SimpleDNSClient) LookupIP(host string) ([]net.IP, error) {
}

// we need to look it up
server := c.servers.Random()
msg := dns.Msg{}
msg.SetQuestion(dns.Fqdn(host), dns.TypeA)

log.Infof("simple dns lookup %v", host)
r, err := exchange(&msg, server.String())
if err != nil {
return []net.IP{}, err
}
for _, server := range c.servers {
msg := dns.Msg{}
msg.SetQuestion(dns.Fqdn(host), dns.TypeA)

ctx, cancel := context.WithTimeout(context.Background(), c.opts.Timeout)
defer cancel()

log.Infof("simple dns lookup %v", host)
r, err := exchange(ctx, &msg, server.String())
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// was a timeout error; continue to the next server
continue
}
if err != nil {
return nil, err
}

rec := dnsCacheRecord{
msg: r,
}
rec := dnsCacheRecord{
msg: r,
}

var shortestTTL uint32
var shortestTTL uint32

for _, ans := range r.Answer {
h := ans.Header()
for _, ans := range r.Answer {
h := ans.Header()

if t, ok := ans.(*dns.A); ok {
rec.ips = append(rec.ips, t.A)
if t, ok := ans.(*dns.A); ok {
rec.ips = append(rec.ips, t.A)

// if the TTL of this record is the shortest or first seen, use it
// as the cache record TTL
if shortestTTL == 0 || h.Ttl < shortestTTL {
shortestTTL = h.Ttl
// if the TTL of this record is the shortest or first seen, use it
// as the cache record TTL
if shortestTTL == 0 || h.Ttl < shortestTTL {
shortestTTL = h.Ttl
}
}
}
}

// set the expiry
rec.expires = time.Now().Add(time.Second * time.Duration(shortestTTL))
// set the expiry
rec.expires = time.Now().Add(time.Second * time.Duration(shortestTTL))

// cache the record
c.cache.Set(host, rec)

// cache the record
c.cache.Set(host, rec)
return rec.ips, nil
}

return rec.ips, nil
// we didn't reach any server; return a known error
return nil, ErrAllServersFailed
}
105 changes: 101 additions & 4 deletions dns_client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package secureoperator

import (
"context"
"errors"
"net"
"testing"
Expand Down Expand Up @@ -120,7 +121,7 @@ func TestSimpleDNSClient(t *testing.T) {
var callCount int

log.SetLevel(log.FatalLevel)
exchange = func(m *dns.Msg, a string) (*dns.Msg, error) {
exchange = func(ctx context.Context, m *dns.Msg, a string) (*dns.Msg, error) {
callCount++

if len(m.Question) != 1 {
Expand All @@ -147,7 +148,7 @@ func TestSimpleDNSClient(t *testing.T) {
// test first call, should hit resolver
client, err := NewSimpleDNSClient(Endpoints{
Endpoint{net.ParseIP("8.8.8.8"), 53},
})
}, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -196,13 +197,13 @@ func TestSimpleDNSClientError(t *testing.T) {
}()

log.SetLevel(log.FatalLevel)
exchange = func(m *dns.Msg, a string) (*dns.Msg, error) {
exchange = func(ctx context.Context, m *dns.Msg, a string) (*dns.Msg, error) {
return nil, errors.New("whoopsie daisy")
}

client, err := NewSimpleDNSClient(Endpoints{
Endpoint{net.ParseIP("8.8.8.8"), 53},
})
}, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
Expand All @@ -216,3 +217,99 @@ func TestSimpleDNSClientError(t *testing.T) {
t.Error("got unexpected error message")
}
}

func TestSimpleDNSClientTimeoutSingle(t *testing.T) {
exch := exchange
level := log.GetLevel()
defer func() {
exchange = exch
log.SetLevel(level)
}()

log.SetLevel(log.FatalLevel)

var callCount int
exchange = func(ctx context.Context, m *dns.Msg, a string) (*dns.Msg, error) {
callCount++
if _, ok := ctx.Deadline(); !ok {
t.Errorf("expected deadline")
}

if a == "8.8.8.8:53" {
return nil, &net.DNSError{IsTimeout: true}
}

if a != "8.8.4.4:53" {
t.Errorf("unexpected dns server in second call: %v", a)
}

r := dns.Msg{
Answer: []dns.RR{
&dns.A{
A: net.ParseIP("1.2.3.4"),
Hdr: dns.RR_Header{Ttl: 300},
},
},
}
r.SetReply(m)

return &r, nil
}

client, err := NewSimpleDNSClient(Endpoints{
Endpoint{net.ParseIP("8.8.8.8"), 53},
Endpoint{net.ParseIP("8.8.4.4"), 53},
}, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

_, err = client.LookupIP("who.wut")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if callCount != 2 {
t.Errorf("expected two calls to exchange, got %v", callCount)
}
}

func TestSimpleDNSClientTimeoutMultiple(t *testing.T) {
exch := exchange
level := log.GetLevel()
defer func() {
exchange = exch
log.SetLevel(level)
}()

log.SetLevel(log.FatalLevel)

var callCount int
exchange = func(ctx context.Context, m *dns.Msg, a string) (*dns.Msg, error) {
callCount++
if callCount == 1 && a != "8.8.8.8:53" {
t.Errorf("expected first server to be 8.8.8.8, was %v", a)
} else if callCount == 2 && a != "8.8.4.4:53" {
t.Errorf("expected second server to be 8.8.4.4, was %v", a)
}

return nil, &net.DNSError{IsTimeout: true}
}

client, err := NewSimpleDNSClient(Endpoints{
Endpoint{net.ParseIP("8.8.8.8"), 53},
Endpoint{net.ParseIP("8.8.4.4"), 53},
}, nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

_, err = client.LookupIP("who.wut")
if err != ErrAllServersFailed {
t.Fatalf("unexpected error: %v", err)
}

if callCount != 2 {
t.Errorf("expected two calls to exchange, got %v", callCount)
}
}
2 changes: 1 addition & 1 deletion provider_google.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func NewGDNSProvider(endpoint string, opts *GDNSOptions) (*GDNSProvider, error)
}

if len(opts.DNSServers) > 0 {
d, err := NewSimpleDNSClient(opts.DNSServers)
d, err := NewSimpleDNSClient(opts.DNSServers, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 0a72afc

Please sign in to comment.