Skip to content

Commit

Permalink
Drain response bodies
Browse files Browse the repository at this point in the history
The retry helpers and a few other methods weren't reading and closing
response bodies leading to connection leaks.
  • Loading branch information
jhendrixMSFT committed Feb 5, 2020
1 parent 81b386e commit 0521821
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 16 deletions.
5 changes: 5 additions & 0 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
Expand Down Expand Up @@ -972,6 +973,10 @@ func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http
delay := time.Duration(0)

for attempt < maxAttempts {
if resp != nil && resp.Body != nil {
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
resp, err = sender.Do(req)
// we want to retry if err is not nil or the status code is in the list of retry codes
if err == nil && !responseHasStatusCode(resp, retries...) {
Expand Down
31 changes: 16 additions & 15 deletions autorest/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,21 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
removeRequestBody(&rCopy)

resp, err := bacb.sender.Do(&rCopy)
if err == nil && resp.StatusCode == 401 {
defer resp.Body.Close()
if hasBearerChallenge(resp) {
bc, err := newBearerChallenge(resp)
if err != nil {
return r, err
}
DrainResponseBody(resp)
if resp.StatusCode == 401 && hasBearerChallenge(resp.Header) {
bc, err := newBearerChallenge(resp.Header)
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
return Prepare(r, ba.WithAuthorization())
}
return Prepare(r, ba.WithAuthorization())
}
}
}
Expand All @@ -194,8 +195,8 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
}

// returns true if the HTTP response contains a bearer challenge
func hasBearerChallenge(resp *http.Response) bool {
authHeader := resp.Header.Get(bearerChallengeHeader)
func hasBearerChallenge(header http.Header) bool {
authHeader := header.Get(bearerChallengeHeader)
if len(authHeader) == 0 || strings.Index(authHeader, bearer) < 0 {
return false
}
Expand All @@ -206,8 +207,8 @@ type bearerChallenge struct {
values map[string]string
}

func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
challenge := strings.TrimSpace(resp.Header.Get(bearerChallengeHeader))
func newBearerChallenge(header http.Header) (bc bearerChallenge, err error) {
challenge := strings.TrimSpace(header.Get(bearerChallengeHeader))
trimmedChallenge := challenge[len(bearer)+1:]

// challenge is a set of key=value pairs that are comma delimited
Expand Down
12 changes: 11 additions & 1 deletion autorest/azure/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,17 @@ func (f Future) GetResult(sender autorest.Sender) (*http.Response, error) {
if err != nil {
return nil, err
}
return sender.Do(req)
resp, err := sender.Do(req)
if err == nil && resp.Body != nil {
// copy the body and close it so callers don't have to
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp, err
}
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
return resp, err
}

type pollingTracker interface {
Expand Down
3 changes: 3 additions & 0 deletions autorest/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
if err != nil {
return resp, err
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
Expand Down Expand Up @@ -288,6 +289,7 @@ func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempt
if err != nil {
return
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
Expand Down Expand Up @@ -347,6 +349,7 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
if err != nil {
return resp, err
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
Expand Down
11 changes: 11 additions & 0 deletions autorest/utility.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -226,3 +227,13 @@ func IsTemporaryNetworkError(err error) bool {
}
return false
}

// DrainResponseBody reads the response body then closes it.
func DrainResponseBody(resp *http.Response) error {
if resp != nil && resp.Body != nil {
_, err := io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
return err
}
return nil
}
39 changes: 39 additions & 0 deletions autorest/utility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
Expand Down Expand Up @@ -469,3 +470,41 @@ func withErrorRespondDecorator(e *error) RespondDecorator {
})
}
}

type mockDrain struct {
read bool
closed bool
}

func (md *mockDrain) Read(b []byte) (int, error) {
md.read = true
b = append(b, 0xff)
return 1, io.EOF
}

func (md *mockDrain) Close() error {
md.closed = true
return nil
}

func TestDrainResponseBody(t *testing.T) {
err := DrainResponseBody(nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
err = DrainResponseBody(&http.Response{})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
md := &mockDrain{}
err = DrainResponseBody(&http.Response{Body: md})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !md.closed {
t.Fatal("mockDrain wasn't closed")
}
if !md.read {
t.Fatal("mockDrain wasn't read")
}
}

0 comments on commit 0521821

Please sign in to comment.