Skip to content

Commit

Permalink
Pull request 259: 5874-fix-fallback
Browse files Browse the repository at this point in the history
Merge in GO/dnsproxy from 5874-fix-fallback to master

Updates AdguardTeam/AdGuardHome#5874.

Squashed commit of the following:

commit 6687dac
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jun 8 16:10:06 2023 +0300

    upstream: close plain connections

commit 4162e21
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jun 8 15:32:49 2023 +0300

    upstream: imp fallback test

commit 2614653
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jun 8 15:16:40 2023 +0300

    upstream: imp logging, test

commit 7fa6a9d
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Jun 8 14:53:51 2023 +0300

    upstream: fix fallback to tcp
  • Loading branch information
EugeneOne1 committed Jun 8, 2023
1 parent 20fd473 commit 066dc9a
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
12 changes: 7 additions & 5 deletions upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,24 +278,26 @@ func addPort(u *url.URL, port int) {
}
}

// Write to log DNS request information that we are going to send
func logBegin(upstreamAddress string, req *dns.Msg) {
// logBegin logs the start of DNS request resolution. It should be called right
// before dialing the connection to the upstream. n is the [network] that will
// be used to send the request.
func logBegin(upstreamAddress string, n network, req *dns.Msg) {
qtype := ""
target := ""
if len(req.Question) != 0 {
qtype = dns.Type(req.Question[0].Qtype).String()
target = req.Question[0].Name
}
log.Debug("%s: sending request %s %s", upstreamAddress, qtype, target)
log.Debug("%s: sending request over %s: %s %s", upstreamAddress, n, qtype, target)
}

// Write to log about the result of DNS request
func logFinish(upstreamAddress string, err error) {
func logFinish(upstreamAddress string, n network, err error) {
status := "ok"
if err != nil {
status = err.Error()
}
log.Debug("%s: response: %s", upstreamAddress, status)
log.Debug("%s: response received over %s: %s", upstreamAddress, n, status)
}

// DialerInitializer returns the handler that it creates. All the subsequent
Expand Down
9 changes: 7 additions & 2 deletions upstream/upstream_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
addr := p.Address()

logBegin(addr, req)
n := networkTCP
if isHTTP3(client) {
n = networkUDP
}

logBegin(addr, n, req)
resp, err = p.exchangeHTTPSClient(client, req)
logFinish(addr, err)
logFinish(addr, n, err)

return resp, err
}
Expand Down
4 changes: 2 additions & 2 deletions upstream/upstream_dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ func (p *dnsOverTLS) putBack(conn net.Conn) {
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
addr := p.Address()

logBegin(addr, m)
defer func() { logFinish(addr, err) }()
logBegin(addr, networkTCP, m)
defer func() { logFinish(addr, networkTCP, err) }()

dnsConn := dns.Conn{Conn: conn}

Expand Down
18 changes: 15 additions & 3 deletions upstream/upstream_plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,23 @@ func (p *plainDNS) dialExchange(
conn.UDPSize = dns.MinMsgSize
}

logBegin(addr, req)
defer func() { logFinish(addr, err) }()
logBegin(addr, network, req)
defer func() { logFinish(addr, network, err) }()

ctx := context.Background()
conn.Conn, err = dial(ctx, string(network), "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
}
defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

resp, _, err = client.ExchangeWithConn(req, conn)
if isExpectedConnErr(err) {
conn.Conn, err = dial(ctx, string(network), "")
if err != nil {
return nil, fmt.Errorf("dialing %s over %s again: %w", p.addr.Host, network, err)
}
defer func(c net.Conn) { err = errors.WithDeferred(err, c.Close()) }(conn.Conn)

resp, _, err = client.ExchangeWithConn(req, conn)
}
Expand Down Expand Up @@ -144,20 +146,30 @@ func (p *plainDNS) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {

resp, err = p.dialExchange(p.net, dial, req)
if p.net != networkUDP {
// The network is already TCP.
return resp, err
}

if resp == nil {
// There is likely an error with the upstream.
return resp, err
}

if errors.Is(err, errQuestion) {
// The upstream responds with malformed messages, so try TCP.
log.Debug("plain %s: %s, using tcp", addr, err)

return p.dialExchange(networkTCP, dial, req)
} else if resp.Truncated {
// Fallback to TCP on truncated responses.
log.Debug("plain %s: resp for %s is truncated, using tcp", &req.Question[0], addr)

return p.dialExchange(networkTCP, dial, req)
}

return p.dialExchange(networkTCP, dial, req)
// There is either no error or the error isn't related to the received
// message.
return resp, err
}

// Close implements the [Upstream] interface for *plainDNS.
Expand Down
24 changes: 21 additions & 3 deletions upstream/upstream_plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -62,9 +63,8 @@ func TestUpstream_plainDNS_badID(t *testing.T) {
assert.Nil(t, resp)
}

func TestUpstream_plainDNS_fallback(t *testing.T) {
func TestUpstream_plainDNS_fallbackToTCP(t *testing.T) {
req := createTestMessage()

goodResp := respondToTestMessage(req)

truncResp := goodResp.Copy()
Expand All @@ -79,26 +79,41 @@ func TestUpstream_plainDNS_fallback(t *testing.T) {
testCases := []struct {
udpResp *dns.Msg
name string
wantUDP int
wantTCP int
}{{
udpResp: goodResp,
name: "all_right",
wantUDP: 1,
wantTCP: 0,
}, {
udpResp: truncResp,
name: "truncated_response",
wantUDP: 1,
wantTCP: 1,
}, {
udpResp: badQNameResp,
name: "bad_qname",
wantUDP: 1,
wantTCP: 1,
}, {
udpResp: badQTypeResp,
name: "bad_qtype",
wantUDP: 1,
wantTCP: 1,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var udpReqNum, tcpReqNum atomic.Uint32
srv := startDNSServer(t, func(w dns.ResponseWriter, _ *dns.Msg) {
resp := goodResp
var resp *dns.Msg
if w.RemoteAddr().Network() == string(networkUDP) {
udpReqNum.Add(1)
resp = tc.udpResp
} else {
tcpReqNum.Add(1)
resp = goodResp
}

require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
Expand All @@ -116,6 +131,9 @@ func TestUpstream_plainDNS_fallback(t *testing.T) {
resp, err := u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)

assert.Equal(t, tc.wantUDP, int(udpReqNum.Load()))
assert.Equal(t, tc.wantTCP, int(tcpReqNum.Load()))
})
}
}
Expand Down

0 comments on commit 066dc9a

Please sign in to comment.