diff --git a/core/outbound/clients/resolver/base_resolver.go b/core/outbound/clients/resolver/base_resolver.go index 39a832b..ae33049 100644 --- a/core/outbound/clients/resolver/base_resolver.go +++ b/core/outbound/clients/resolver/base_resolver.go @@ -25,20 +25,26 @@ type BaseResolver struct { dnsUpstream *common.DNSUpstream } -func (r *BaseResolver) Exchange(*dns.Msg) (*dns.Msg, error) { - return nil, nil -} - -func (r *BaseResolver) ExchangeByBaseConn(q *dns.Msg) (*dns.Msg, error) { +func (r *BaseResolver) Exchange(q *dns.Msg) (*dns.Msg, error) { conn, err := r.CreateBaseConn() + defer conn.Close() if err != nil { return nil, err } + return r.exchangeByConnWithoutClose(q, conn) +} + +func (r *BaseResolver) exchangeByConnWithoutClose(q *dns.Msg, conn net.Conn) (msg *dns.Msg, err error) { + if conn == nil { + log.Fatal("Conn not initialized for exchangeByDNSClient") + return nil, err + } + r.setTimeout(conn) dc := &dns.Conn{Conn: conn, UDPSize: 65535} - defer dc.Close() err = dc.WriteMsg(q) if err != nil { + log.Warnf("%s Fail: Send question message failed", r.dnsUpstream.Name) return nil, err } return dc.ReadMsg() @@ -147,17 +153,18 @@ func (r *BaseResolver) createConnectionPool(connCreate func() (interface{}, erro return pool.NewChannelPool(poolConfig) } -func (r *BaseResolver) exchangeByDNSClient(q *dns.Msg, conn net.Conn) (msg *dns.Msg, err error) { - if conn == nil { - log.Fatal("Conn not initialized for exchangeByDNSClient") +func (r *BaseResolver) exchangeByPool(q *dns.Msg, poolConn pool.Pool) (msg *dns.Msg, err error) { + _conn, err := poolConn.Get() + if err != nil { return nil, err } - - dc := &dns.Conn{Conn: conn, UDPSize: 65535} - err = dc.WriteMsg(q) + conn := _conn.(net.Conn) + ret, err := r.exchangeByConnWithoutClose(q, conn) if err != nil { - log.Warnf("%s Fail: Send question message failed", r.dnsUpstream.Name) - return nil, err + poolConn.Close(conn) + } else { + r.setIdleTimeout(conn) + poolConn.Put(conn) } - return dc.ReadMsg() + return ret, err } diff --git a/core/outbound/clients/resolver/tcp_resolver.go b/core/outbound/clients/resolver/tcp_resolver.go index 15513f0..3122b49 100644 --- a/core/outbound/clients/resolver/tcp_resolver.go +++ b/core/outbound/clients/resolver/tcp_resolver.go @@ -14,24 +14,11 @@ type TCPResolver struct { } func (r *TCPResolver) Exchange(q *dns.Msg) (*dns.Msg, error) { - if !r.dnsUpstream.TCPPoolConfig.Enable { - return r.ExchangeByBaseConn(q) - } - - _conn, err := r.poolConn.Get() - if err != nil { - return nil, err - } - conn := _conn.(net.Conn) - r.setTimeout(conn) - ret, err := r.exchangeByDNSClient(q, conn) - if err != nil { - r.poolConn.Close(conn) + if r.dnsUpstream.TCPPoolConfig.Enable { + return r.BaseResolver.exchangeByPool(q, r.poolConn) } else { - r.setIdleTimeout(conn) - r.poolConn.Put(conn) + return r.BaseResolver.Exchange(q) } - return ret, err } func (r *TCPResolver) Init() error { diff --git a/core/outbound/clients/resolver/tcptls_resolver.go b/core/outbound/clients/resolver/tcptls_resolver.go index 62b0e26..ad2bca5 100644 --- a/core/outbound/clients/resolver/tcptls_resolver.go +++ b/core/outbound/clients/resolver/tcptls_resolver.go @@ -15,23 +15,16 @@ type TCPTLSResolver struct { } func (r *TCPTLSResolver) Exchange(q *dns.Msg) (*dns.Msg, error) { - if !r.dnsUpstream.TCPPoolConfig.Enable { - return r.ExchangeByBaseConn(q) - } - _conn, err := r.poolConn.Get() - if err != nil { - return nil, err - } - conn := _conn.(net.Conn) - r.setTimeout(conn) - ret, err := r.exchangeByDNSClient(q, conn) - if err != nil { - r.poolConn.Close(conn) + if r.dnsUpstream.TCPPoolConfig.Enable { + return r.BaseResolver.exchangeByPool(q, r.poolConn) } else { - r.setIdleTimeout(conn) - r.poolConn.Put(conn) + conn, err := r.createTlsConn() + if err != nil { + return nil, err + } + defer conn.Close() + return r.exchangeByConnWithoutClose(q, conn) } - return ret, err } func (r *TCPTLSResolver) createTlsConn() (conn net.Conn, err error) { diff --git a/core/outbound/clients/resolver/udp_resolver.go b/core/outbound/clients/resolver/udp_resolver.go index b3d092d..b9dfac7 100644 --- a/core/outbound/clients/resolver/udp_resolver.go +++ b/core/outbound/clients/resolver/udp_resolver.go @@ -9,14 +9,7 @@ type UDPResolver struct { } func (r *UDPResolver) Exchange(q *dns.Msg) (*dns.Msg, error) { - conn, err := r.CreateBaseConn() - if err != nil { - return nil, err - } - defer conn.Close() - r.setTimeout(conn) - ret, err := r.exchangeByDNSClient(q, conn) - return ret, err + return r.BaseResolver.Exchange(q) } func (r *UDPResolver) Init() error {