Skip to content

Commit

Permalink
Abstracts the RoundTripper interface and provides a default implement (
Browse files Browse the repository at this point in the history
…#1602)

* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601)

* test: Add custom transport test case (#1601)

* Make default RoundTripper implmention none public

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>

---------

Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
  • Loading branch information
timandy and erikdubbelboer authored Aug 10, 2023
1 parent e181af1 commit 54fdc7a
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 129 deletions.
242 changes: 129 additions & 113 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool

// TransportFunc wraps every request/response.
type TransportFunc func(*Request, *Response) error
// RoundTripper wraps every request/response.
type RoundTripper interface {
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
}

// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
type ConnPoolStrategyType int
Expand Down Expand Up @@ -791,7 +793,7 @@ type HostClient struct {
RetryIf RetryIfFunc

// Transport defines a transport-like mechanism that wraps every request/response.
Transport TransportFunc
Transport RoundTripper

// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
Expand Down Expand Up @@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
}
}
if c.Transport != nil {
err := c.Transport(req, resp)
return err == nil, err
}

var deadline time.Time
if req.timeout > 0 {
deadline = time.Now().Add(req.timeout)
}

cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
}
conn := cc.c

resp.parseNetConn(conn)

writeDeadline := deadline
if c.WriteTimeout > 0 {
tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
writeDeadline = tmpWriteDeadline
}
}

if err = conn.SetWriteDeadline(writeDeadline); err != nil {
c.closeConn(cc)
return true, err
}

resetConnection := false
if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}

bw := c.acquireWriter(conn)
err = req.Write(bw)

if resetConnection {
req.Header.ResetConnectionClose()
}

if err == nil {
err = bw.Flush()
}
c.releaseWriter(bw)

// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}

isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
c.closeConn(cc)
return true, err
}

readDeadline := deadline
if c.ReadTimeout > 0 {
tmpReadDeadline := time.Now().Add(c.ReadTimeout)
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
readDeadline = tmpReadDeadline
}
}

if err = conn.SetReadDeadline(readDeadline); err != nil {
c.closeConn(cc)
return true, err
}

if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if c.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}

br := c.acquireReader(conn)
err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
c.releaseReader(br)
if err != nil {
c.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
retry := err != ErrBodyTooLarge
return retry, err
}

closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return nil
})
return false, nil
}
return c.transport().RoundTrip(c, req, resp)
}

if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
func (c *HostClient) transport() RoundTripper {
if c.Transport == nil {
return DefaultTransport
}
return false, nil
return c.Transport
}

var (
Expand Down Expand Up @@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
}

var errPipelineConnStopped = errors.New("pipeline connection has been stopped")

var DefaultTransport RoundTripper = &transport{}

type transport struct{}

func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody

var deadline time.Time
if req.timeout > 0 {
deadline = time.Now().Add(req.timeout)
}

cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
}
conn := cc.c

resp.parseNetConn(conn)

writeDeadline := deadline
if hc.WriteTimeout > 0 {
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
writeDeadline = tmpWriteDeadline
}
}

if err = conn.SetWriteDeadline(writeDeadline); err != nil {
hc.closeConn(cc)
return true, err
}

resetConnection := false
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}

bw := hc.acquireWriter(conn)
err = req.Write(bw)

if resetConnection {
req.Header.ResetConnectionClose()
}

if err == nil {
err = bw.Flush()
}
hc.releaseWriter(bw)

// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}

isConnRST := isConnectionReset(err)
if err != nil && !isConnRST {
hc.closeConn(cc)
return true, err
}

readDeadline := deadline
if hc.ReadTimeout > 0 {
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
readDeadline = tmpReadDeadline
}
}

if err = conn.SetReadDeadline(readDeadline); err != nil {
hc.closeConn(cc)
return true, err
}

if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if hc.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}

br := hc.acquireReader(conn)
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
hc.releaseReader(br)
if err != nil {
hc.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
needRetry := err != ErrBodyTooLarge
return needRetry, err
}

closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
}
return nil
})
return false, nil
}

if closeConn {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
}
return false, nil
}
Loading

0 comments on commit 54fdc7a

Please sign in to comment.