Skip to content

Commit

Permalink
Fix tcptls and udp bug; Improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
shawn1m committed Feb 20, 2020
1 parent 4dc8303 commit adccb00
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 54 deletions.
37 changes: 22 additions & 15 deletions core/outbound/clients/resolver/base_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,26 @@ type BaseResolver struct {
dnsUpstream *common.DNSUpstream
}

func (r *BaseResolver) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, nil
}

func (r *BaseResolver) ExchangeByBaseConn(q *dns.Msg) (*dns.Msg, error) {
func (r *BaseResolver) Exchange(q *dns.Msg) (*dns.Msg, error) {
conn, err := r.CreateBaseConn()
defer conn.Close()
if err != nil {
return nil, err
}
return r.exchangeByConnWithoutClose(q, conn)
}

func (r *BaseResolver) exchangeByConnWithoutClose(q *dns.Msg, conn net.Conn) (msg *dns.Msg, err error) {
if conn == nil {
log.Fatal("Conn not initialized for exchangeByDNSClient")
return nil, err
}

r.setTimeout(conn)
dc := &dns.Conn{Conn: conn, UDPSize: 65535}
defer dc.Close()
err = dc.WriteMsg(q)
if err != nil {
log.Warnf("%s Fail: Send question message failed", r.dnsUpstream.Name)
return nil, err
}
return dc.ReadMsg()
Expand Down Expand Up @@ -147,17 +153,18 @@ func (r *BaseResolver) createConnectionPool(connCreate func() (interface{}, erro
return pool.NewChannelPool(poolConfig)
}

func (r *BaseResolver) exchangeByDNSClient(q *dns.Msg, conn net.Conn) (msg *dns.Msg, err error) {
if conn == nil {
log.Fatal("Conn not initialized for exchangeByDNSClient")
func (r *BaseResolver) exchangeByPool(q *dns.Msg, poolConn pool.Pool) (msg *dns.Msg, err error) {
_conn, err := poolConn.Get()
if err != nil {
return nil, err
}

dc := &dns.Conn{Conn: conn, UDPSize: 65535}
err = dc.WriteMsg(q)
conn := _conn.(net.Conn)
ret, err := r.exchangeByConnWithoutClose(q, conn)
if err != nil {
log.Warnf("%s Fail: Send question message failed", r.dnsUpstream.Name)
return nil, err
poolConn.Close(conn)
} else {
r.setIdleTimeout(conn)
poolConn.Put(conn)
}
return dc.ReadMsg()
return ret, err
}
19 changes: 3 additions & 16 deletions core/outbound/clients/resolver/tcp_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,11 @@ type TCPResolver struct {
}

func (r *TCPResolver) Exchange(q *dns.Msg) (*dns.Msg, error) {
if !r.dnsUpstream.TCPPoolConfig.Enable {
return r.ExchangeByBaseConn(q)
}

_conn, err := r.poolConn.Get()
if err != nil {
return nil, err
}
conn := _conn.(net.Conn)
r.setTimeout(conn)
ret, err := r.exchangeByDNSClient(q, conn)
if err != nil {
r.poolConn.Close(conn)
if r.dnsUpstream.TCPPoolConfig.Enable {
return r.BaseResolver.exchangeByPool(q, r.poolConn)
} else {
r.setIdleTimeout(conn)
r.poolConn.Put(conn)
return r.BaseResolver.Exchange(q)
}
return ret, err
}

func (r *TCPResolver) Init() error {
Expand Down
23 changes: 8 additions & 15 deletions core/outbound/clients/resolver/tcptls_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,16 @@ type TCPTLSResolver struct {
}

func (r *TCPTLSResolver) Exchange(q *dns.Msg) (*dns.Msg, error) {
if !r.dnsUpstream.TCPPoolConfig.Enable {
return r.ExchangeByBaseConn(q)
}
_conn, err := r.poolConn.Get()
if err != nil {
return nil, err
}
conn := _conn.(net.Conn)
r.setTimeout(conn)
ret, err := r.exchangeByDNSClient(q, conn)
if err != nil {
r.poolConn.Close(conn)
if r.dnsUpstream.TCPPoolConfig.Enable {
return r.BaseResolver.exchangeByPool(q, r.poolConn)
} else {
r.setIdleTimeout(conn)
r.poolConn.Put(conn)
conn, err := r.createTlsConn()
if err != nil {
return nil, err
}
defer conn.Close()
return r.exchangeByConnWithoutClose(q, conn)
}
return ret, err
}

func (r *TCPTLSResolver) createTlsConn() (conn net.Conn, err error) {
Expand Down
9 changes: 1 addition & 8 deletions core/outbound/clients/resolver/udp_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@ type UDPResolver struct {
}

func (r *UDPResolver) Exchange(q *dns.Msg) (*dns.Msg, error) {
conn, err := r.CreateBaseConn()
if err != nil {
return nil, err
}
defer conn.Close()
r.setTimeout(conn)
ret, err := r.exchangeByDNSClient(q, conn)
return ret, err
return r.BaseResolver.Exchange(q)
}

func (r *UDPResolver) Init() error {
Expand Down

0 comments on commit adccb00

Please sign in to comment.