Skip to content

Commit

Permalink
Using url.Parse to improve address parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
shawn1m committed Mar 8, 2020
1 parent 2f2d29f commit 8e27201
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 117 deletions.
119 changes: 40 additions & 79 deletions core/outbound/clients/resolver/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,76 +42,61 @@ func ToNetwork(protocol string) string {
}
}

// ExtractSocksAddress parse socks5 address,
// support two formats: socks5://127.0.0.1:1080 and 127.0.0.1:1080
func ExtractSocksAddress(rawAddress string) (string, error) {
// support two formats: scheme://127.0.0.1:1080 or 127.0.0.1:1080
func extractUrl(rawAddress string, protocol string) (host string, port string, err error) {

if !strings.Contains(rawAddress, "://") {
rawAddress = protocol + "://" + rawAddress
}

uri, err := url.Parse(rawAddress)
if err != nil {
// socks5 address format is 127.0.0.1:1080
_, _, err = net.SplitHostPort(rawAddress)
isJustIP := isJustIP(rawAddress)
if err != nil && !isJustIP {
log.Warnf("socks5 address %s is invalid", rawAddress)
return "", errors.New("socks5 address is invalid")
}
if isJustIP {
rawAddress = rawAddress + ":" + getDefaultPort("socks5")
}
return rawAddress, nil
log.Warnf("url %s is invalid", rawAddress)
return "", "", errors.New("url is invalid")
}
// socks5://127.0.0.1:1080
if len(uri.Scheme) == 0 || uri.Scheme != "socks5" {
return "", errors.New("socks5 address is invalid")
host = uri.Hostname()

if len(uri.Scheme) == 0 || uri.Scheme != protocol {
return "", "", errors.New("url is invalid")
}
port := uri.Port()

port = uri.Port()
if len(port) == 0 {
port = "1080"
port = getDefaultPort(protocol)
}
address := net.JoinHostPort(uri.Hostname(), port)
return address, nil
return
}

// ExtractTLSDNSAddress parse tcp-tls format: dns.google:853@8.8.8.8
func ExtractTLSDNSAddress(rawAddress string) (host string, port string, ip string, err error) {
func ExtractFullUrl(rawAddress string, protocol string) (string, error) {
host, port, err := extractUrl(rawAddress, protocol)
return net.JoinHostPort(host, port), err
}

func extractTLSDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
rawAddress = protocol + "://" + rawAddress
s := strings.Split(rawAddress, "@")
host, port, err = net.SplitHostPort(s[0])
isJustHost := len(rawAddress) > 0
if err != nil && !isJustHost {
log.Warnf("dns server address %s is invalid", rawAddress)
return "", "", "", errors.New("dns up server address is invalid")
}
if err != nil && isJustHost {
host = s[0]
if isJustIP(host) {
host = generateLiteralIPv6AddressIfNecessary(host)
}
port = getDefaultPort("tcp-tls")
}

ip = s[1]
if isJustIP(ip) {
ip = generateLiteralIPv6AddressIfNecessary(ip)
} else {
log.Warnf("dns server address %s is invalid", rawAddress)
return "", "", "", errors.New("dns up server address is invalid")
host, port, err = extractUrl(s[0], protocol)

if err != nil {
return "", "", nil
}
return host, port, ip, nil
}

// extractNormalDNSAddress parse normal format: 8.8.8.8:53
func extractNormalDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
host, port, err = net.SplitHostPort(rawAddress)
isJustIP := isJustIP(rawAddress)
if err != nil && !isJustIP {
if len(s) == 2 && isJustIP(s[1]) {
host = generateLiteralIPv6AddressIfNecessary(s[1])
} else {
log.Warnf("dns server address %s is invalid", rawAddress)
return "", "", errors.New("dns up server address is invalid")
}
if isJustIP {
host = generateLiteralIPv6AddressIfNecessary(rawAddress)
port = getDefaultPort(protocol)
}
return host, port, nil
}

func ExtractTLSDNSHostName(rawAddress string) (host string, err error) {
rawAddress = "tcp-tls" + "://" + rawAddress
s := strings.Split(rawAddress, "@")

host, _, err = extractUrl(s[0], "tcp-tls")
return host, err
}

func isJustIP(rawAddress string) bool {
Expand All @@ -128,37 +113,13 @@ func generateLiteralIPv6AddressIfNecessary(rawAddress string) string {
return rawAddress
}

// extractHTTPSAddress parse https format: https://dns.google/dns-query
func extractHTTPSAddress(rawAddress string) (host string, port string, err error) {
uri, err := url.Parse(rawAddress)
if err != nil {
return "", "", err
}
host = uri.Hostname()
port = uri.Port()
if len(port) == 0 {
port = getDefaultPort("https")
}
return host, port, nil

}

// ExtractDNSAddress parse all format, return literal IPv6 address
func ExtractDNSAddress(rawAddress string, protocol string) (host string, port string, err error) {
switch protocol {
case "https":
host, port, err = extractHTTPSAddress(rawAddress)
case "tcp-tls":
_host, _port, _ip, _err := ExtractTLSDNSAddress(rawAddress)
if len(_ip) > 0 {
host = _ip
} else {
host = _host
}
port = _port
err = _err
host, port, err = extractTLSDNSAddress(rawAddress, protocol)
default:
host, port, err = extractNormalDNSAddress(rawAddress, protocol)
host, port, err = extractUrl(rawAddress, protocol)
}
return host, port, err
}
54 changes: 18 additions & 36 deletions core/outbound/clients/resolver/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestExtractDNSAddress(t *testing.T) {
{ipv4Address, "udp", ipv4Address, "53", nil},
{ipv6Address, "udp", literalIpa6Address, "53", nil},
{"https://dns.google/dns-query", "https", "dns.google", "443", nil},
{"dns.google/dns-query", "https", "dns.google", "443", nil},
{"https://dns.google:888/dns-query", "https", "dns.google", "888", nil},
}
for _, tt := range tests {
Expand All @@ -37,46 +38,27 @@ func TestExtractDNSAddress(t *testing.T) {
}
}

func TestExtractSocksAddress(t *testing.T) {
func TestExtractFullUrl(t *testing.T) {
var tests = []struct {
in string
out string
url string
protocol string
out string
}{
{"socks5://" + ipv4Address + ":80", ipv4Address + ":80"},
{"socks5://" + ipv6Address + ":80", ipv6Address + ":80"},
{"socks5://" + ipv6Address, ipv6Address + ":1080"},
{"" + ipv4Address + ":80", ipv4Address + ":80"},
{"" + ipv6Address + ":80", ipv6Address + ":80"},
{"" + ipv6Address, ipv6Address + ":1080"},
{"socks5://" + ipv4Address + ":80", "socks5", ipv4Address + ":80"},
{ipv4Address + ":80", "socks5", ipv4Address + ":80"},
{ipv6Address + ":80", "socks5", ipv6Address + ":80"},
{ipv6Address, "socks5", ipv6Address + ":1080"},
{ipv6Address, "https", ipv6Address + ":443"},
{"tcp-tls://" + ipv6Address, "tcp-tls", ipv6Address + ":853"},
{"" + ipv4Address + ":80", "socks5", ipv4Address + ":80"},
{"" + ipv6Address + ":80", "socks5", ipv6Address + ":80"},
{"" + ipv6Address, "socks5", ipv6Address + ":1080"},
{"abc.com", "socks5", "abc.com:1080"},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
addr, err := ExtractSocksAddress(tt.in)
testEqual(t, addr, tt.out)
testErr(t, err)
})
}
}

func TestExtractTLSDNSAddress(t *testing.T) {

var tests = []struct {
in string
host string
port string
ip string
err error
}{
{"dns.google:853@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil},
{"dns.google@" + ipv6Address, "dns.google", "853", literalIpa6Address, nil},
{"dns.google:853@" + ipv4Address, "dns.google", "853", ipv4Address, nil},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
host, port, ip, err := ExtractTLSDNSAddress(tt.in)
testEqual(t, host, tt.host)
testEqual(t, port, tt.port)
testEqual(t, ip, tt.ip)
t.Run(tt.url, func(t *testing.T) {
url, err := ExtractFullUrl(tt.url, tt.protocol)
testEqual(t, url, tt.out)
testErr(t, err)
})
}
Expand Down
2 changes: 1 addition & 1 deletion core/outbound/clients/resolver/base_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (r *BaseResolver) CreateBaseConn() (net.Conn, error) {
dialer := net.Dialer{Timeout: r.getDialTimeout()}
dialerFunc := dialer.Dial
if r.dnsUpstream.SOCKS5Address != "" {
socksAddress, err := ExtractSocksAddress(r.dnsUpstream.SOCKS5Address)
socksAddress, err := ExtractFullUrl(r.dnsUpstream.SOCKS5Address, "socks5")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion core/outbound/clients/resolver/tcptls_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (r *TCPTLSResolver) createTlsConn() (conn net.Conn, err error) {
if err != nil {
return nil, err
}
host, _, _, err := ExtractTLSDNSAddress(r.dnsUpstream.Address)
host, err := ExtractTLSDNSHostName(r.dnsUpstream.Address)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 8e27201

Please sign in to comment.