From 8e27201b2a84615d600a671184d95f715f7f7884 Mon Sep 17 00:00:00 2001 From: shawn1m Date: Sun, 8 Mar 2020 16:09:12 +0800 Subject: [PATCH] Using url.Parse to improve address parsing --- core/outbound/clients/resolver/address.go | 119 ++++++------------ .../outbound/clients/resolver/address_test.go | 54 +++----- .../clients/resolver/base_resolver.go | 2 +- .../clients/resolver/tcptls_resolver.go | 2 +- 4 files changed, 60 insertions(+), 117 deletions(-) diff --git a/core/outbound/clients/resolver/address.go b/core/outbound/clients/resolver/address.go index d91c508..655b75d 100644 --- a/core/outbound/clients/resolver/address.go +++ b/core/outbound/clients/resolver/address.go @@ -42,76 +42,61 @@ func ToNetwork(protocol string) string { } } -// ExtractSocksAddress parse socks5 address, -// support two formats: socks5://127.0.0.1:1080 and 127.0.0.1:1080 -func ExtractSocksAddress(rawAddress string) (string, error) { +// support two formats: scheme://127.0.0.1:1080 or 127.0.0.1:1080 +func extractUrl(rawAddress string, protocol string) (host string, port string, err error) { + + if !strings.Contains(rawAddress, "://") { + rawAddress = protocol + "://" + rawAddress + } + uri, err := url.Parse(rawAddress) if err != nil { - // socks5 address format is 127.0.0.1:1080 - _, _, err = net.SplitHostPort(rawAddress) - isJustIP := isJustIP(rawAddress) - if err != nil && !isJustIP { - log.Warnf("socks5 address %s is invalid", rawAddress) - return "", errors.New("socks5 address is invalid") - } - if isJustIP { - rawAddress = rawAddress + ":" + getDefaultPort("socks5") - } - return rawAddress, nil + log.Warnf("url %s is invalid", rawAddress) + return "", "", errors.New("url is invalid") } - // socks5://127.0.0.1:1080 - if len(uri.Scheme) == 0 || uri.Scheme != "socks5" { - return "", errors.New("socks5 address is invalid") + host = uri.Hostname() + + if len(uri.Scheme) == 0 || uri.Scheme != protocol { + return "", "", errors.New("url is invalid") } - port := uri.Port() + + port = uri.Port() if len(port) == 0 { - port = "1080" + port = getDefaultPort(protocol) } - address := net.JoinHostPort(uri.Hostname(), port) - return address, nil + return } -// ExtractTLSDNSAddress parse tcp-tls format: dns.google:853@8.8.8.8 -func ExtractTLSDNSAddress(rawAddress string) (host string, port string, ip string, err error) { +func ExtractFullUrl(rawAddress string, protocol string) (string, error) { + host, port, err := extractUrl(rawAddress, protocol) + return net.JoinHostPort(host, port), err +} + +func extractTLSDNSAddress(rawAddress string, protocol string) (host string, port string, err error) { + rawAddress = protocol + "://" + rawAddress s := strings.Split(rawAddress, "@") - host, port, err = net.SplitHostPort(s[0]) - isJustHost := len(rawAddress) > 0 - if err != nil && !isJustHost { - log.Warnf("dns server address %s is invalid", rawAddress) - return "", "", "", errors.New("dns up server address is invalid") - } - if err != nil && isJustHost { - host = s[0] - if isJustIP(host) { - host = generateLiteralIPv6AddressIfNecessary(host) - } - port = getDefaultPort("tcp-tls") - } - ip = s[1] - if isJustIP(ip) { - ip = generateLiteralIPv6AddressIfNecessary(ip) - } else { - log.Warnf("dns server address %s is invalid", rawAddress) - return "", "", "", errors.New("dns up server address is invalid") + host, port, err = extractUrl(s[0], protocol) + + if err != nil { + return "", "", nil } - return host, port, ip, nil -} -// extractNormalDNSAddress parse normal format: 8.8.8.8:53 -func extractNormalDNSAddress(rawAddress string, protocol string) (host string, port string, err error) { - host, port, err = net.SplitHostPort(rawAddress) - isJustIP := isJustIP(rawAddress) - if err != nil && !isJustIP { + if len(s) == 2 && isJustIP(s[1]) { + host = generateLiteralIPv6AddressIfNecessary(s[1]) + } else { log.Warnf("dns server address %s is invalid", rawAddress) return "", "", errors.New("dns up server address is invalid") } - if isJustIP { - host = generateLiteralIPv6AddressIfNecessary(rawAddress) - port = getDefaultPort(protocol) - } return host, port, nil +} +func ExtractTLSDNSHostName(rawAddress string) (host string, err error) { + rawAddress = "tcp-tls" + "://" + rawAddress + s := strings.Split(rawAddress, "@") + + host, _, err = extractUrl(s[0], "tcp-tls") + return host, err } func isJustIP(rawAddress string) bool { @@ -128,37 +113,13 @@ func generateLiteralIPv6AddressIfNecessary(rawAddress string) string { return rawAddress } -// extractHTTPSAddress parse https format: https://dns.google/dns-query -func extractHTTPSAddress(rawAddress string) (host string, port string, err error) { - uri, err := url.Parse(rawAddress) - if err != nil { - return "", "", err - } - host = uri.Hostname() - port = uri.Port() - if len(port) == 0 { - port = getDefaultPort("https") - } - return host, port, nil - -} - // ExtractDNSAddress parse all format, return literal IPv6 address func ExtractDNSAddress(rawAddress string, protocol string) (host string, port string, err error) { switch protocol { - case "https": - host, port, err = extractHTTPSAddress(rawAddress) case "tcp-tls": - _host, _port, _ip, _err := ExtractTLSDNSAddress(rawAddress) - if len(_ip) > 0 { - host = _ip - } else { - host = _host - } - port = _port - err = _err + host, port, err = extractTLSDNSAddress(rawAddress, protocol) default: - host, port, err = extractNormalDNSAddress(rawAddress, protocol) + host, port, err = extractUrl(rawAddress, protocol) } return host, port, err } diff --git a/core/outbound/clients/resolver/address_test.go b/core/outbound/clients/resolver/address_test.go index 206a423..a5fd89f 100644 --- a/core/outbound/clients/resolver/address_test.go +++ b/core/outbound/clients/resolver/address_test.go @@ -25,6 +25,7 @@ func TestExtractDNSAddress(t *testing.T) { {ipv4Address, "udp", ipv4Address, "53", nil}, {ipv6Address, "udp", literalIpa6Address, "53", nil}, {"https://dns.google/dns-query", "https", "dns.google", "443", nil}, + {"dns.google/dns-query", "https", "dns.google", "443", nil}, {"https://dns.google:888/dns-query", "https", "dns.google", "888", nil}, } for _, tt := range tests { @@ -37,46 +38,27 @@ func TestExtractDNSAddress(t *testing.T) { } } -func TestExtractSocksAddress(t *testing.T) { +func TestExtractFullUrl(t *testing.T) { var tests = []struct { - in string - out string + url string + protocol string + out string }{ - {"socks5://" + ipv4Address + ":80", ipv4Address + ":80"}, - {"socks5://" + ipv6Address + ":80", ipv6Address + ":80"}, - {"socks5://" + ipv6Address, ipv6Address + ":1080"}, - {"" + ipv4Address + ":80", ipv4Address + ":80"}, - {"" + ipv6Address + ":80", ipv6Address + ":80"}, - {"" + ipv6Address, ipv6Address + ":1080"}, + {"socks5://" + ipv4Address + ":80", "socks5", ipv4Address + ":80"}, + {ipv4Address + ":80", "socks5", ipv4Address + ":80"}, + {ipv6Address + ":80", "socks5", ipv6Address + ":80"}, + {ipv6Address, "socks5", ipv6Address + ":1080"}, + {ipv6Address, "https", ipv6Address + ":443"}, + {"tcp-tls://" + ipv6Address, "tcp-tls", ipv6Address + ":853"}, + {"" + ipv4Address + ":80", "socks5", ipv4Address + ":80"}, + {"" + ipv6Address + ":80", "socks5", ipv6Address + ":80"}, + {"" + ipv6Address, "socks5", ipv6Address + ":1080"}, + {"abc.com", "socks5", "abc.com:1080"}, } for _, tt := range tests { - t.Run(tt.in, func(t *testing.T) { - addr, err := ExtractSocksAddress(tt.in) - testEqual(t, addr, tt.out) - testErr(t, err) - }) - } -} - -func TestExtractTLSDNSAddress(t *testing.T) { - - var tests = []struct { - in string - host string - port string - ip string - err error - }{ - {"dns.google:853@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil}, - {"dns.google@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil}, - {"dns.google:853@" + ipv4Address, "dns.google", "853", ipv4Address, nil}, - } - for _, tt := range tests { - t.Run(tt.in, func(t *testing.T) { - host, port, ip, err := ExtractTLSDNSAddress(tt.in) - testEqual(t, host, tt.host) - testEqual(t, port, tt.port) - testEqual(t, ip, tt.ip) + t.Run(tt.url, func(t *testing.T) { + url, err := ExtractFullUrl(tt.url, tt.protocol) + testEqual(t, url, tt.out) testErr(t, err) }) } diff --git a/core/outbound/clients/resolver/base_resolver.go b/core/outbound/clients/resolver/base_resolver.go index 06134c4..dd74989 100644 --- a/core/outbound/clients/resolver/base_resolver.go +++ b/core/outbound/clients/resolver/base_resolver.go @@ -94,7 +94,7 @@ func (r *BaseResolver) CreateBaseConn() (net.Conn, error) { dialer := net.Dialer{Timeout: r.getDialTimeout()} dialerFunc := dialer.Dial if r.dnsUpstream.SOCKS5Address != "" { - socksAddress, err := ExtractSocksAddress(r.dnsUpstream.SOCKS5Address) + socksAddress, err := ExtractFullUrl(r.dnsUpstream.SOCKS5Address, "socks5") if err != nil { return nil, err } diff --git a/core/outbound/clients/resolver/tcptls_resolver.go b/core/outbound/clients/resolver/tcptls_resolver.go index 8e1e781..f4645b2 100644 --- a/core/outbound/clients/resolver/tcptls_resolver.go +++ b/core/outbound/clients/resolver/tcptls_resolver.go @@ -33,7 +33,7 @@ func (r *TCPTLSResolver) createTlsConn() (conn net.Conn, err error) { if err != nil { return nil, err } - host, _, _, err := ExtractTLSDNSAddress(r.dnsUpstream.Address) + host, err := ExtractTLSDNSHostName(r.dnsUpstream.Address) if err != nil { return nil, err }