Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drain response bodies #432

Merged
merged 1 commit into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
Expand Down Expand Up @@ -972,6 +973,10 @@ func retryForIMDS(sender Sender, req *http.Request, maxAttempts int) (resp *http
delay := time.Duration(0)

for attempt < maxAttempts {
if resp != nil && resp.Body != nil {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
resp, err = sender.Do(req)
// we want to retry if err is not nil or the status code is in the list of retry codes
if err == nil && !responseHasStatusCode(resp, retries...) {
Expand Down
31 changes: 16 additions & 15 deletions autorest/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,21 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
removeRequestBody(&rCopy)

resp, err := bacb.sender.Do(&rCopy)
if err == nil && resp.StatusCode == 401 {
defer resp.Body.Close()
if hasBearerChallenge(resp) {
bc, err := newBearerChallenge(resp)
if err != nil {
return r, err
}
DrainResponseBody(resp)
if resp.StatusCode == 401 && hasBearerChallenge(resp.Header) {
bc, err := newBearerChallenge(resp.Header)
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
if bacb.callback != nil {
ba, err := bacb.callback(bc.values[tenantID], bc.values["resource"])
if err != nil {
return r, err
}
return Prepare(r, ba.WithAuthorization())
}
return Prepare(r, ba.WithAuthorization())
}
}
}
Expand All @@ -194,8 +195,8 @@ func (bacb *BearerAuthorizerCallback) WithAuthorization() PrepareDecorator {
}

// returns true if the HTTP response contains a bearer challenge
func hasBearerChallenge(resp *http.Response) bool {
authHeader := resp.Header.Get(bearerChallengeHeader)
func hasBearerChallenge(header http.Header) bool {
authHeader := header.Get(bearerChallengeHeader)
if len(authHeader) == 0 || strings.Index(authHeader, bearer) < 0 {
return false
}
Expand All @@ -206,8 +207,8 @@ type bearerChallenge struct {
values map[string]string
}

func newBearerChallenge(resp *http.Response) (bc bearerChallenge, err error) {
challenge := strings.TrimSpace(resp.Header.Get(bearerChallengeHeader))
func newBearerChallenge(header http.Header) (bc bearerChallenge, err error) {
challenge := strings.TrimSpace(header.Get(bearerChallengeHeader))
trimmedChallenge := challenge[len(bearer)+1:]

// challenge is a set of key=value pairs that are comma delimited
Expand Down
12 changes: 11 additions & 1 deletion autorest/azure/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,17 @@ func (f Future) GetResult(sender autorest.Sender) (*http.Response, error) {
if err != nil {
return nil, err
}
return sender.Do(req)
resp, err := sender.Do(req)
if err == nil && resp.Body != nil {
// copy the body and close it so callers don't have to
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp, err
}
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
return resp, err
}

type pollingTracker interface {
Expand Down
3 changes: 3 additions & 0 deletions autorest/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ func DoRetryForAttempts(attempts int, backoff time.Duration) SendDecorator {
if err != nil {
return resp, err
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
Expand Down Expand Up @@ -288,6 +289,7 @@ func doRetryForStatusCodesImpl(s Sender, r *http.Request, count429 bool, attempt
if err != nil {
return
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
// we want to retry if err is not nil (e.g. transient network failure). note that for failed authentication
// resp and err will both have a value, so in this case we don't want to retry as it will never succeed.
Expand Down Expand Up @@ -347,6 +349,7 @@ func DoRetryForDuration(d time.Duration, backoff time.Duration) SendDecorator {
if err != nil {
return resp, err
}
DrainResponseBody(resp)
resp, err = s.Do(rr.Request())
if err == nil {
return resp, err
Expand Down
11 changes: 11 additions & 0 deletions autorest/utility.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -226,3 +227,13 @@ func IsTemporaryNetworkError(err error) bool {
}
return false
}

// DrainResponseBody reads the response body then closes it.
func DrainResponseBody(resp *http.Response) error {
if resp != nil && resp.Body != nil {
_, err := io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
return err
}
return nil
}
39 changes: 39 additions & 0 deletions autorest/utility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
Expand Down Expand Up @@ -469,3 +470,41 @@ func withErrorRespondDecorator(e *error) RespondDecorator {
})
}
}

type mockDrain struct {
read bool
closed bool
}

func (md *mockDrain) Read(b []byte) (int, error) {
md.read = true
b = append(b, 0xff)
return 1, io.EOF
}

func (md *mockDrain) Close() error {
md.closed = true
return nil
}

func TestDrainResponseBody(t *testing.T) {
err := DrainResponseBody(nil)
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
err = DrainResponseBody(&http.Response{})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
md := &mockDrain{}
err = DrainResponseBody(&http.Response{Body: md})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !md.closed {
t.Fatal("mockDrain wasn't closed")
}
if !md.read {
t.Fatal("mockDrain wasn't read")
}
}