diff --git a/go.mod b/go.mod index 2618a1286e2..0147d868c1a 100644 --- a/go.mod +++ b/go.mod @@ -35,9 +35,9 @@ require ( github.com/u-root/u-root v7.0.0+incompatible go.etcd.io/bbolt v1.3.5 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 - golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20210309074719-68d13333faf2 + golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect golang.org/x/text v0.3.5 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 diff --git a/go.sum b/go.sum index 02c0a5399b7..ed7d6c27b63 100644 --- a/go.sum +++ b/go.sum @@ -516,8 +516,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -582,8 +582,8 @@ golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210309074719-68d13333faf2 h1:46ULzRKLh1CwgRq2dC5SlBzEqqNCi8rreOZnNrbqcIY= -golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d h1:SZxvLBoTP5yHO3Frd4z4vrF+DBX9vMVanchswa69toE= diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go index 559c9b46b2f..fb4672006a3 100644 --- a/internal/aghnet/addr.go +++ b/internal/aghnet/addr.go @@ -3,8 +3,10 @@ package aghnet import ( "fmt" "net" + "strings" "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "golang.org/x/net/idna" ) // ValidateHardwareAddress returns an error if hwa is not a valid EUI-48, @@ -21,3 +23,79 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) { return fmt.Errorf("bad len: %d", l) } } + +// maxDomainLabelLen is the maximum allowed length of a domain name label +// according to RFC 1035. +const maxDomainLabelLen = 63 + +// maxDomainNameLen is the maximum allowed length of a full domain name +// according to RFC 1035. +// +// See https://stackoverflow.com/a/32294443/1892060. +const maxDomainNameLen = 253 + +const invalidCharMsg = "invalid char %q at index %d in %q" + +// isValidHostFirstRune returns true if r is a valid first rune for a hostname +// label. +func isValidHostFirstRune(r rune) (ok bool) { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') +} + +// isValidHostRune returns true if r is a valid rune for a hostname label. +func isValidHostRune(r rune) (ok bool) { + return r == '-' || isValidHostFirstRune(r) +} + +// ValidateDomainNameLabel returns an error if label is not a valid label of +// a domain name. +func ValidateDomainNameLabel(label string) (err error) { + if len(label) > maxDomainLabelLen { + return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen) + } else if len(label) == 0 { + return agherr.Error("label is empty") + } + + if r := label[0]; !isValidHostFirstRune(rune(r)) { + return fmt.Errorf(invalidCharMsg, r, 0, label) + } + + for i, r := range label[1:] { + if !isValidHostRune(r) { + return fmt.Errorf(invalidCharMsg, r, i+1, label) + } + } + + return nil +} + +// ValidateDomainName validates the domain name in accordance to RFC 952, RFC +// 1035, and with RFC-1123's inclusion of digits at the start of the host. It +// doesn't validate against two or more hyphens to allow punycode and +// internationalized domains. +// +// TODO(a.garipov): After making sure that this works correctly, port this into +// module golibs. +func ValidateDomainName(name string) (err error) { + name, err = idna.ToASCII(name) + if err != nil { + return err + } + + l := len(name) + if l == 0 || l > maxDomainNameLen { + return fmt.Errorf("%q is too long, max: %d", name, maxDomainNameLen) + } + + labels := strings.Split(name, ".") + for i, l := range labels { + err = ValidateDomainNameLabel(l) + if err != nil { + return fmt.Errorf("invalid domain name label at index %d: %w", i, err) + } + } + + return nil +} diff --git a/internal/aghnet/addr_test.go b/internal/aghnet/addr_test.go index 0b3eb48d772..513760599e1 100644 --- a/internal/aghnet/addr_test.go +++ b/internal/aghnet/addr_test.go @@ -2,6 +2,7 @@ package aghnet import ( "net" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -50,6 +51,81 @@ func TestValidateHardwareAddress(t *testing.T) { assert.NoError(t, err) } else { require.Error(t, err) + + assert.Equal(t, tc.wantErrMsg, err.Error()) + } + }) + } +} + +func repeatStr(b *strings.Builder, s string, n int) { + for i := 0; i < n; i++ { + _, _ = b.WriteString(s) + } +} + +func TestValidateDomainName(t *testing.T) { + b := &strings.Builder{} + repeatStr(b, "a", 255) + longDomainName := b.String() + + b.Reset() + repeatStr(b, "a", 64) + longLabel := b.String() + + _, _ = b.WriteString(".com") + longLabelDomainName := b.String() + + testCases := []struct { + name string + in string + wantErrMsg string + }{{ + name: "success", + in: "example.com", + wantErrMsg: "", + }, { + name: "success_idna", + in: "пример.рф", + wantErrMsg: "", + }, { + name: "bad_symbol", + in: "!!!", + wantErrMsg: `invalid domain name label at index 0: ` + + `invalid char '!' at index 0 in "!!!"`, + }, { + name: "bad_length", + in: longDomainName, + wantErrMsg: `"` + longDomainName + `" is too long, max: 253`, + }, { + name: "bad_label_length", + in: longLabelDomainName, + wantErrMsg: `invalid domain name label at index 0: "` + longLabel + + `" is too long, max: 63`, + }, { + name: "bad_label_empty", + in: "example..com", + wantErrMsg: `invalid domain name label at index 1: label is empty`, + }, { + name: "bad_label_first_symbol", + in: "example.-aa.com", + wantErrMsg: `invalid domain name label at index 1:` + + ` invalid char '-' at index 0 in "-aa"`, + }, { + name: "bad_label_symbol", + in: "example.a!!!.com", + wantErrMsg: `invalid domain name label at index 1:` + + ` invalid char '!' at index 1 in "a!!!"`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateDomainName(tc.in) + if tc.wantErrMsg == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Equal(t, tc.wantErrMsg, err.Error()) } }) diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 21dcac53d80..495b97a98fb 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -6,33 +6,14 @@ import ( "path" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/lucas-clemente/quic-go" ) -// maxDomainLabelLen is the maximum allowed length of a domain name label -// according to RFC 1035. -const maxDomainLabelLen = 63 - -// validateDomainNameLabel returns an error if label is not a valid label of -// a domain name. -func validateDomainNameLabel(label string) (err error) { - if len(label) > maxDomainLabelLen { - return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen) - } - - for i, r := range label { - if (r < 'a' || r > 'z') && (r < '0' || r > '9') && r != '-' { - return fmt.Errorf("invalid char %q at index %d in %q", r, i, label) - } - } - - return nil -} - // ValidateClientID returns an error if clientID is not a valid client ID. func ValidateClientID(clientID string) (err error) { - err = validateDomainNameLabel(clientID) + err = aghnet.ValidateDomainNameLabel(clientID) if err != nil { return fmt.Errorf("invalid client id: %w", err) } diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index 463e841eb5f..0d07bcf750e 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -238,8 +238,8 @@ func TestProcessClientID_https(t *testing.T) { name: "invalid_client_id", path: "/dns-query/!!!", wantClientID: "", - wantErrMsg: `client id check: invalid client id: invalid char '!'` + - ` at index 0 in "!!!"`, + wantErrMsg: `client id check: invalid client id: invalid char '!' ` + + `at index 0 in "!!!"`, wantRes: resultCodeError, }} diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 5a6045b5518..3dfd1e37992 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -114,7 +114,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { if p.AutohostTLD == "" { autohostSuffix = defaultAutohostSuffix } else { - err = validateDomainNameLabel(p.AutohostTLD) + err = aghnet.ValidateDomainNameLabel(p.AutohostTLD) if err != nil { return nil, fmt.Errorf("autohost tld: %w", err) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 177e5a2a55c..91ecc158600 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -947,145 +947,6 @@ func publicKey(priv interface{}) interface{} { } } -func TestValidateUpstream(t *testing.T) { - testCases := []struct { - name string - upstream string - valid bool - wantDef bool - }{{ - name: "invalid", - upstream: "1.2.3.4.5", - valid: false, - }, { - name: "invalid", - upstream: "123.3.7m", - valid: false, - }, { - name: "invalid", - upstream: "htttps://google.com/dns-query", - valid: false, - }, { - name: "invalid", - upstream: "[/host.com]tls://dns.adguard.com", - valid: false, - }, { - name: "invalid", - upstream: "[host.ru]#", - valid: false, - }, { - name: "valid_default", - upstream: "1.1.1.1", - valid: true, - wantDef: true, - }, { - name: "valid_default", - upstream: "tls://1.1.1.1", - valid: true, - wantDef: true, - }, { - name: "valid_default", - upstream: "https://dns.adguard.com/dns-query", - valid: true, - wantDef: true, - }, { - name: "valid_default", - upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - valid: true, - wantDef: true, - }, { - name: "valid", - upstream: "[/host.com/]1.1.1.1", - valid: true, - wantDef: false, - }, { - name: "valid", - upstream: "[//]tls://1.1.1.1", - valid: true, - wantDef: false, - }, { - name: "valid", - upstream: "[/www.host.com/]#", - valid: true, - wantDef: false, - }, { - name: "valid", - upstream: "[/host.com/google.com/]8.8.8.8", - valid: true, - wantDef: false, - }, { - name: "valid", - upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - valid: true, - wantDef: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - defaultUpstream, err := validateUpstream(tc.upstream) - require.Equal(t, tc.valid, err == nil) - if tc.valid { - assert.Equal(t, tc.wantDef, defaultUpstream) - } - }) - } -} - -func TestValidateUpstreamsSet(t *testing.T) { - testCases := []struct { - name string - msg string - set []string - wantNil bool - }{{ - name: "empty", - msg: "empty upstreams array should be valid", - set: nil, - wantNil: true, - }, { - name: "comment", - msg: "comments should not be validated", - set: []string{"# comment"}, - wantNil: true, - }, { - name: "valid_no_default", - msg: "there is no default upstream", - set: []string{ - "[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - }, - wantNil: false, - }, { - name: "valid_with_default", - msg: "upstreams set is valid, but doesn't pass through validation cause: %s", - set: []string{ - "[/host.com/]1.1.1.1", - "[//]tls://1.1.1.1", - "[/www.host.com/]#", - "[/host.com/google.com/]8.8.8.8", - "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", - "8.8.8.8", - }, - wantNil: true, - }, { - name: "invalid", - msg: "there is an invalid upstream in set, but it pass through validation", - set: []string{"dhcp://fake.dns"}, - wantNil: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := ValidateUpstreams(tc.set) - - assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err) - }) - } -} - func TestIPStringFromAddr(t *testing.T) { t.Run("not_nil", func(t *testing.T) { addr := net.UDPAddr{ diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index b2f5476fa52..effd3b0a02d 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -8,10 +8,11 @@ import ( "strconv" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" "github.com/miekg/dns" ) @@ -302,7 +303,7 @@ type upstreamJSON struct { } // ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified -func ValidateUpstreams(upstreams []string) error { +func ValidateUpstreams(upstreams []string) (err error) { // No need to validate comments upstreams = filterOutComments(upstreams) @@ -311,7 +312,7 @@ func ValidateUpstreams(upstreams []string) error { return nil } - _, err := proxy.ParseUpstreamsConfig( + _, err = proxy.ParseUpstreamsConfig( upstreams, upstream.Options{ Bootstrap: []string{}, @@ -345,56 +346,61 @@ func ValidateUpstreams(upstreams []string) error { var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"} func validateUpstream(u string) (bool, error) { - // Check if user tries to specify upstream for domain - u, defaultUpstream, err := separateUpstream(u) + // Check if the user tries to specify upstream for domain. + u, useDefault, err := separateUpstream(u) if err != nil { - return defaultUpstream, err + return useDefault, err } // The special server address '#' means "use the default servers" - if u == "#" && !defaultUpstream { - return defaultUpstream, nil + if u == "#" && !useDefault { + return useDefault, nil } // Check if the upstream has a valid protocol prefix for _, proto := range protocols { if strings.HasPrefix(u, proto) { - return defaultUpstream, nil + return useDefault, nil } } // Return error if the upstream contains '://' without any valid protocol if strings.Contains(u, "://") { - return defaultUpstream, fmt.Errorf("wrong protocol") + return useDefault, fmt.Errorf("wrong protocol") } // Check if upstream is valid plain DNS - return defaultUpstream, checkPlainDNS(u) + return useDefault, checkPlainDNS(u) } -// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified -// error will be returned if upstream per domain specification is invalid -func separateUpstream(upstream string) (string, bool, error) { - defaultUpstream := true - if strings.HasPrefix(upstream, "[/") { - defaultUpstream = false - // split domains and upstream string - domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]") - if len(domainsAndUpstream) != 2 { - return "", defaultUpstream, fmt.Errorf("wrong dns upstream per domain specification: %s", upstream) +// separateUpstream returns the upstream without the specified domains. +// useDefault is true when a default upstream must be used. +func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) { + defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr) + + if !strings.HasPrefix(upstreamStr, "[/") { + return upstreamStr, true, nil + } + + parts := strings.Split(upstreamStr[2:], "/]") + if len(parts) != 2 { + return "", false, agherr.Error("duplicated separator") + } + + domains := parts[0] + upstream = parts[1] + for i, host := range strings.Split(domains, "/") { + if host == "" { + continue } - // split domains list and validate each one - for _, host := range strings.Split(domainsAndUpstream[0], "/") { - if host != "" { - if err := utils.IsValidHostname(host); err != nil { - return "", defaultUpstream, err - } - } + err = aghnet.ValidateDomainName(host) + if err != nil { + return "", false, fmt.Errorf("domain at index %d: %w", i, err) } - upstream = domainsAndUpstream[1] } - return upstream, defaultUpstream, nil + + return upstream, false, nil } // checkPlainDNS checks if host is plain DNS @@ -462,13 +468,13 @@ func checkDNS(input string, bootstrap []string) error { } // separate upstream from domains list - input, defaultUpstream, err := separateUpstream(input) + input, useDefault, err := separateUpstream(input) if err != nil { return fmt.Errorf("wrong upstream format: %w", err) } // No need to check this DNS server - if !defaultUpstream { + if !useDefault { return nil } diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index b55f759a297..6c3acc10a6a 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -213,3 +213,158 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { }) } } + +// TODO(a.garipov): Rewrite to check the actual error messages. +func TestValidateUpstream(t *testing.T) { + testCases := []struct { + name string + upstream string + valid bool + wantDef bool + }{{ + name: "invalid", + upstream: "1.2.3.4.5", + valid: false, + wantDef: false, + }, { + name: "invalid", + upstream: "123.3.7m", + valid: false, + wantDef: false, + }, { + name: "invalid", + upstream: "htttps://google.com/dns-query", + valid: false, + wantDef: false, + }, { + name: "invalid", + upstream: "[/host.com]tls://dns.adguard.com", + valid: false, + wantDef: false, + }, { + name: "invalid", + upstream: "[host.ru]#", + valid: false, + wantDef: false, + }, { + name: "valid_default", + upstream: "1.1.1.1", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "tls://1.1.1.1", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "https://dns.adguard.com/dns-query", + valid: true, + wantDef: true, + }, { + name: "valid_default", + upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + valid: true, + wantDef: true, + }, { + name: "valid", + upstream: "[/host.com/]1.1.1.1", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[//]tls://1.1.1.1", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/www.host.com/]#", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/host.com/google.com/]8.8.8.8", + valid: true, + wantDef: false, + }, { + name: "valid", + upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + valid: true, + wantDef: false, + }, { + name: "idna", + upstream: "[/пример.рф/]8.8.8.8", + valid: true, + wantDef: false, + }, { + name: "bad_domain", + upstream: "[/!/]8.8.8.8", + valid: false, + wantDef: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defaultUpstream, err := validateUpstream(tc.upstream) + require.Equal(t, tc.valid, err == nil) + if tc.valid { + assert.Equal(t, tc.wantDef, defaultUpstream) + } + }) + } +} + +func TestValidateUpstreamsSet(t *testing.T) { + testCases := []struct { + name string + msg string + set []string + wantNil bool + }{{ + name: "empty", + msg: "empty upstreams array should be valid", + set: nil, + wantNil: true, + }, { + name: "comment", + msg: "comments should not be validated", + set: []string{"# comment"}, + wantNil: true, + }, { + name: "valid_no_default", + msg: "there is no default upstream", + set: []string{ + "[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + }, + wantNil: false, + }, { + name: "valid_with_default", + msg: "upstreams set is valid, but doesn't pass through validation cause: %s", + set: []string{ + "[/host.com/]1.1.1.1", + "[//]tls://1.1.1.1", + "[/www.host.com/]#", + "[/host.com/google.com/]8.8.8.8", + "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + "8.8.8.8", + }, + wantNil: true, + }, { + name: "invalid", + msg: "there is an invalid upstream in set, but it pass through validation", + set: []string{"dhcp://fake.dns"}, + wantNil: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateUpstreams(tc.set) + + assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err) + }) + } +} diff --git a/internal/dnsforward/util.go b/internal/dnsforward/util.go index 4b57768ba6a..e1d0c4a878e 100644 --- a/internal/dnsforward/util.go +++ b/internal/dnsforward/util.go @@ -5,7 +5,7 @@ import ( "sort" "strings" - "github.com/AdguardTeam/golibs/utils" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" ) // IPFromAddr gets IP address from addr. @@ -58,9 +58,10 @@ func matchDomainWildcard(host, wildcard string) bool { // Return TRUE if client's SNI value matches DNS names from certificate func matchDNSName(dnsNames []string, sni string) bool { - if utils.IsValidHostname(sni) != nil { + if aghnet.ValidateDomainName(sni) != nil { return false } + if findSorted(dnsNames, sni) != -1 { return true } diff --git a/internal/home/clients.go b/internal/home/clients.go index ad119b002f1..825e28bd016 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -12,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -20,7 +21,6 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" ) const clientsUpdatePeriod = 10 * time.Minute @@ -751,7 +751,7 @@ func (clients *clientsContainer) addFromSystemARP() { host := ln[:open] ip := ln[open+2 : close] - if utils.IsValidHostname(host) != nil || net.ParseIP(ip) == nil { + if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil { continue } diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index beb50b2289c..895512f14a2 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -123,18 +123,20 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { } clientID := q.Get("client_id") - err = dnsforward.ValidateClientID(clientID) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - - err = json.NewEncoder(w).Encode(&jsonError{ - Message: err.Error(), - }) + if clientID != "" { + err = dnsforward.ValidateClientID(clientID) if err != nil { - log.Debug("writing 400 json response: %s", err) - } + w.WriteHeader(http.StatusBadRequest) - return + err = json.NewEncoder(w).Encode(&jsonError{ + Message: err.Error(), + }) + if err != nil { + log.Debug("writing 400 json response: %s", err) + } + + return + } } d := dnsSettings{