Skip to content

Commit

Permalink
Refactored the RemoteClient to support connection pooling & drop none…
Browse files Browse the repository at this point in the history
… answer (#206)

* Refactored the RemoteClient to support connection pooling

Part of RemoteClient were split into several Resolver, which will be shared across all RemoteClient and RemoteClientBundle, in the resolver the pool was implemented.

* Fix dispatcher_test.go

* Remove too verbose debugging log output

* Add option IdleTimeout

* Wait until the answer presents in response (Fixes #181)

* Revert accidental change & better code format

* Support PoolMaxCapacity config

* Fix timeout setting

* Add tests for Resolvers
  • Loading branch information
NyaMisty authored Feb 19, 2020
1 parent b3655b8 commit 14a8364
Show file tree
Hide file tree
Showing 19 changed files with 483 additions and 172 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ Configuration file is "config.json" by default:
"OnlyPrimaryDNS": false,
"IPv6UseAlternativeDNS": false,
"AlternativeDNSConcurrent": false,
"PoolIdleTimeout": 15,
"PoolMaxCapacity": 15,
"WhenPrimaryDNSAnswerNoneUse": "PrimaryDNS",
"IPNetworkFile": {
"Primary": "./ip_network_primary_sample",
Expand Down Expand Up @@ -206,6 +208,8 @@ IPv6). Overture will handle both TCP and UDP requests. Literal IPv6 addresses ar
+ OnlyPrimaryDNS: Disable dispatcher feature, use primary DNS only.
+ IPv6UseAlternativeDNS: Redirect IPv6 DNS queries to alternative DNS servers.
+ AlternativeDNSConcurrent: Query the PrimaryDNS and AlternativeDNS at the same time
+ PoolIdleTimeout: Specify idle timeout for connection in pool
+ PoolMaxCapacity: Specify max capacity for connection pool
+ WhenPrimaryDNSAnswerNoneUse: If the response of PrimaryDNS exists and there is no `ANSWER SECTION` in it, the final DNS should be defined. (There is no `AAAA` record for most domains right now)
+ File: Absolute path like `/path/to/file` is allowed. For Windows users, please use properly escaped path like
`C:\\path\\to\\file.txt` in the configuration.
Expand Down
2 changes: 2 additions & 0 deletions config.sample.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"OnlyPrimaryDNS": false,
"IPv6UseAlternativeDNS": false,
"AlternativeDNSConcurrent": false,
"PoolIdleTimeout": 15,
"PoolMaxCapacity": 15,
"WhenPrimaryDNSAnswerNoneUse": "PrimaryDNS",
"IPNetworkFile": {
"Primary": "./ip_network_primary_sample",
Expand Down
2 changes: 2 additions & 0 deletions config.test.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"OnlyPrimaryDNS": false,
"IPv6UseAlternativeDNS": false,
"AlternativeDNSConcurrent": false,
"PoolIdleTimeout": 15,
"PoolMaxCapacity": 15,
"WhenPrimaryDNSAnswerNoneUse": "PrimaryDNS",
"IPNetworkFile": {
"Primary": "./ip_network_primary_sample",
Expand Down
2 changes: 2 additions & 0 deletions core/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Config struct {
OnlyPrimaryDNS bool
IPv6UseAlternativeDNS bool
AlternativeDNSConcurrent bool
PoolIdleTimeout int
PoolMaxCapacity int
IPNetworkFile struct {
Primary string
Alternative string
Expand Down
3 changes: 3 additions & 0 deletions core/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ func InitServer(configFilePath string) {

RedirectIPv6Record: conf.IPv6UseAlternativeDNS,
AlternativeDNSConcurrent: conf.AlternativeDNSConcurrent,
PoolIdleTimeout: conf.PoolIdleTimeout,
PoolMaxCapacity: conf.PoolMaxCapacity,
MinimumTTL: conf.MinimumTTL,
DomainTTLMap: conf.DomainTTLMap,

Hosts: conf.Hosts,
Cache: conf.Cache,
}
dispatcher.Init()

s := inbound.NewServer(conf.BindAddress, conf.DebugHTTPAddress, dispatcher, conf.RejectQType)

Expand Down
169 changes: 7 additions & 162 deletions core/outbound/clients/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
package clients

import (
"bytes"
"crypto/tls"
"io/ioutil"
"net"
"net/http"
"time"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"net"

"github.com/shawn1m/overture/core/cache"
"github.com/shawn1m/overture/core/common"
"github.com/shawn1m/overture/core/outbound/clients/resolver"
)

type RemoteClient struct {
Expand All @@ -30,12 +24,13 @@ type RemoteClient struct {
dnsUpstream *common.DNSUpstream
ednsClientSubnetIP string
inboundIP string
dnsResolver resolver.Resolver

cache *cache.Cache
}

func NewClient(q *dns.Msg, u *common.DNSUpstream, ip string, cache *cache.Cache) *RemoteClient {
c := &RemoteClient{questionMessage: q.Copy(), dnsUpstream: u, inboundIP: ip, cache: cache}
func NewClient(q *dns.Msg, u *common.DNSUpstream, resolver resolver.Resolver, ip string, cache *cache.Cache) *RemoteClient {
c := &RemoteClient{questionMessage: q.Copy(), dnsUpstream: u, dnsResolver: resolver, inboundIP: ip, cache: cache}
c.getEDNSClientSubnetIP()

return c
Expand Down Expand Up @@ -77,25 +72,9 @@ func (c *RemoteClient) Exchange(isLog bool) *dns.Msg {
return c.responseMessage
}

var conn net.Conn = nil
var err error
if c.dnsUpstream.SOCKS5Address != "" {
if conn, err = c.createSocks5Conn(); err != nil {
return nil
}
}

var temp *dns.Msg
switch c.dnsUpstream.Protocol {
case "udp":
temp, err = c.ExchangeByUDP(conn)
case "tcp":
temp, err = c.ExchangeByTCP(conn)
case "tcp-tls":
temp, err = c.ExchangeByTLS(conn)
case "https":
temp, err = c.ExchangeByHTTPS(conn)
}
var err error
temp, err = c.dnsResolver.Exchange(c.questionMessage)

if err != nil {
log.Debugf("%s Fail: %s", c.dnsUpstream.Name, err)
Expand Down Expand Up @@ -127,137 +106,3 @@ func (c *RemoteClient) logAnswer(indicator string) {
log.Debugf("Answer from %s: %s", name, a.String())
}
}

func (c *RemoteClient) createSocks5Conn() (conn net.Conn, err error) {
socksAddress, err := ExtractSocksAddress(c.dnsUpstream.SOCKS5Address)
if err != nil {
return nil, err
}
network := ToNetwork(c.dnsUpstream.Protocol)
s, err := proxy.SOCKS5(network, socksAddress, nil, proxy.Direct)
if err != nil {
log.Warnf("Failed to connect to SOCKS5 proxy: %s", err)
return nil, err
}
host, port, err := ExtractDNSAddress(c.dnsUpstream.Address, c.dnsUpstream.Protocol)
if err != nil {
return nil, err
}
address := net.JoinHostPort(host, port)
conn, err = s.Dial(network, address)
if err != nil {
log.Warnf("Failed to connect to upstream via SOCKS5 proxy: %s", err)
return nil, err
}
return conn, err
}

func (c *RemoteClient) exchangeByDNSClient(conn net.Conn) (msg *dns.Msg, err error) {
if conn == nil {
network := ToNetwork(c.dnsUpstream.Protocol)
host, port, err := ExtractDNSAddress(c.dnsUpstream.Address, c.dnsUpstream.Protocol)
if err != nil {
return nil, err
}
address := net.JoinHostPort(host, port)
if conn, err = net.Dial(network, address); err != nil {
log.Warnf("Failed to connect to DNS upstream: %s", err)
return nil, err
}
}
c.setTimeout(conn)
dc := &dns.Conn{Conn: conn, UDPSize: 65535}
defer dc.Close()
err = dc.WriteMsg(c.questionMessage)
if err != nil {
log.Warnf("%s Fail: Send question message failed", c.dnsUpstream.Name)
return nil, err
}
return dc.ReadMsg()
}

// ExchangeByUDP send dns record by udp protocol
func (c *RemoteClient) ExchangeByUDP(conn net.Conn) (*dns.Msg, error) {
return c.exchangeByDNSClient(conn)
}

// ExchangeByTCP send dns record by tcp protocol
func (c *RemoteClient) ExchangeByTCP(conn net.Conn) (*dns.Msg, error) {
return c.exchangeByDNSClient(conn)
}

// ExchangeByTLS send dns record by tcp-tls protocol
func (c *RemoteClient) ExchangeByTLS(conn net.Conn) (msg *dns.Msg, err error) {
host, port, ip := ExtractTLSDNSAddress(c.dnsUpstream.Address)
var address string
if len(ip) > 0 {
address = net.JoinHostPort(ip, port)
} else {
address = net.JoinHostPort(host, port)
}

conf := &tls.Config{
InsecureSkipVerify: false,
ServerName: host,
}
if conn != nil {
// crate tls client use the existing connection
conn = tls.Client(conn, conf)
} else {
if conn, err = tls.Dial("tcp", address, conf); err != nil {
log.Warnf("Failed to connect to DNS-over-TLS upstream: %s", err)
return nil, err
}
}
c.setTimeout(conn)
return c.exchangeByDNSClient(conn)
}

// ExchangeByHTTPS send dns record by https protocol
func (c *RemoteClient) ExchangeByHTTPS(conn net.Conn) (*dns.Msg, error) {
if conn == nil {
host, port, err := ExtractHTTPSAddress(c.dnsUpstream.Address)
if err != nil {
return nil, err
}
address := net.JoinHostPort(host, port)
conn, err = net.Dial("tcp", address)
if err != nil {
log.Warnf("Fail connect to dns server %s", address)
return nil, err
}
}
c.setTimeout(conn)
client := http.Client{
Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return conn, nil
},
},
}
defer client.CloseIdleConnections()
request, err := c.questionMessage.Pack()
resp, err := client.Post(c.dnsUpstream.Address, "application/dns-message",
bytes.NewBuffer(request))
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
msg := new(dns.Msg)
err = msg.Unpack(data)
if err != nil {
return nil, err
}
return msg, nil
}

func (c *RemoteClient) setTimeout(conn net.Conn) {
dnsTimeout := time.Duration(c.dnsUpstream.Timeout) * time.Second / 3
conn.SetDeadline(time.Now().Add(dnsTimeout))
conn.SetReadDeadline(time.Now().Add(dnsTimeout))
conn.SetWriteDeadline(time.Now().Add(dnsTimeout))
}
18 changes: 12 additions & 6 deletions core/outbound/clients/remote_bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package clients

import (
"github.com/miekg/dns"
"github.com/shawn1m/overture/core/outbound/clients/resolver"
log "github.com/sirupsen/logrus"

"github.com/shawn1m/overture/core/cache"
"github.com/shawn1m/overture/core/common"
Expand All @@ -26,14 +28,15 @@ type RemoteClientBundle struct {

cache *cache.Cache
Name string
}

func NewClientBundle(q *dns.Msg, ul []*common.DNSUpstream, ip string, minimumTTL int, cache *cache.Cache, name string, domainTTLMap map[string]uint32) *RemoteClientBundle {
cb := &RemoteClientBundle{questionMessage: q.Copy(), dnsUpstreams: ul, inboundIP: ip, minimumTTL: minimumTTL, cache: cache, Name: name, domainTTLMap: domainTTLMap}
dnsResolvers []resolver.Resolver
}

for _, u := range ul {
func NewClientBundle(q *dns.Msg, ul []*common.DNSUpstream, resolvers []resolver.Resolver, ip string, minimumTTL int, cache *cache.Cache, name string, domainTTLMap map[string]uint32) *RemoteClientBundle {
cb := &RemoteClientBundle{questionMessage: q.Copy(), dnsUpstreams: ul, dnsResolvers: resolvers, inboundIP: ip, minimumTTL: minimumTTL, cache: cache, Name: name, domainTTLMap: domainTTLMap}

c := NewClient(cb.questionMessage, u, cb.inboundIP, cb.cache)
for i, u := range ul {
c := NewClient(cb.questionMessage, u, cb.dnsResolvers[i], cb.inboundIP, cb.cache)
cb.clients = append(cb.clients, c)
}

Expand All @@ -56,7 +59,10 @@ func (cb *RemoteClientBundle) Exchange(isCache bool, isLog bool) *dns.Msg {
c := <-ch
if c != nil {
ec = c
break
if ec.responseMessage != nil && ec.responseMessage.Answer != nil {
break
}
log.Debugf("DNSUpstream %s returned None answer, dropping it and wait the next one", ec.dnsUpstream.Address)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

// Package outbound implements multiple dns client and dispatcher for outbound connection.
package clients
package resolver

import (
"errors"
Expand Down Expand Up @@ -118,7 +118,13 @@ func ExtractDNSAddress(rawAddress string, protocol string) (host string, port st
case "https":
host, port, err = ExtractHTTPSAddress(rawAddress)
case "tcp-tls":
_, port, host = ExtractTLSDNSAddress(rawAddress)
_host, _port, _ip := ExtractTLSDNSAddress(rawAddress)
if len(_ip) > 0 {
host = _ip
} else {
host = _host
}
port = _port
default:
host, port, err = ExtractNormalDNSAddress(rawAddress, protocol)
}
Expand Down
Loading

0 comments on commit 14a8364

Please sign in to comment.