Skip to content

Commit 1a907bc

Browse files
authored
auth: clone the client request body before roundtripping (#597)
RoundTrippers may read and close the body, so be careful to clone before roundtripping during client oauth, as the request may be issued multiple times. Fixes #590
1 parent f01e7fa commit 1a907bc

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

auth/client.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
package auth
88

99
import (
10+
"bytes"
1011
"context"
1112
"errors"
13+
"io"
1214
"net/http"
1315
"sync"
1416

@@ -67,6 +69,28 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
6769
base := t.opts.Base
6870
t.mu.Unlock()
6971

72+
var (
73+
// If haveBody is set, the request has a nontrivial body, and we need avoid
74+
// reading (or closing) it multiple times. In that case, bodyBytes is its
75+
// content.
76+
haveBody bool
77+
bodyBytes []byte
78+
)
79+
if req.Body != nil && req.Body != http.NoBody {
80+
// if we're setting Body, we must mutate first.
81+
req = req.Clone(req.Context())
82+
haveBody = true
83+
var err error
84+
bodyBytes, err = io.ReadAll(req.Body)
85+
if err != nil {
86+
return nil, err
87+
}
88+
// Now that we've read the request body, http.RoundTripper requires that we
89+
// close it.
90+
req.Body.Close() // ignore error
91+
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
92+
}
93+
7094
resp, err := base.RoundTrip(req)
7195
if err != nil {
7296
return nil, err
@@ -97,7 +121,15 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
97121
}
98122
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts}
99123
}
100-
return t.opts.Base.RoundTrip(req.Clone(req.Context()))
124+
125+
// If we don't have a body, the request is reusable, though it will be cloned
126+
// by the base. However, if we've had to read the body, we must clone.
127+
if haveBody {
128+
req = req.Clone(req.Context())
129+
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
130+
}
131+
132+
return t.opts.Base.RoundTrip(req)
101133
}
102134

103135
func extractResourceMetadataURL(authHeaders []string) string {

auth/client_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,24 @@ import (
1010
"context"
1111
"errors"
1212
"fmt"
13+
"io"
1314
"net/http"
1415
"net/http/httptest"
16+
"strings"
1517
"testing"
1618

1719
"golang.org/x/oauth2"
1820
)
1921

22+
// A basicReader is an io.Reader to be used as a non-rereadable request body.
23+
//
24+
// net/http has special handling for strings.Reader that we want to avoid.
25+
type basicReader struct {
26+
r *strings.Reader
27+
}
28+
29+
func (r *basicReader) Read(p []byte) (n int, err error) { return r.r.Read(p) }
30+
2031
// TestHTTPTransport validates the OAuth HTTPTransport.
2132
func TestHTTPTransport(t *testing.T) {
2233
const testToken = "test-token-123"
@@ -27,6 +38,20 @@ func TestHTTPTransport(t *testing.T) {
2738

2839
// authServer simulates a resource that requires OAuth.
2940
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41+
if r.Method == http.MethodPost {
42+
// Ensure that the body was properly cloned, by reading it completely.
43+
// If the body is not cloned, reading it the second time may yield no
44+
// bytes.
45+
body, err := io.ReadAll(r.Body)
46+
if err != nil {
47+
http.Error(w, err.Error(), http.StatusInternalServerError)
48+
return
49+
}
50+
if len(body) == 0 {
51+
http.Error(w, "empty body", http.StatusBadRequest)
52+
return
53+
}
54+
}
3055
authHeader := r.Header.Get("Authorization")
3156
if authHeader == fmt.Sprintf("Bearer %s", testToken) {
3257
w.WriteHeader(http.StatusOK)
@@ -82,6 +107,31 @@ func TestHTTPTransport(t *testing.T) {
82107
}
83108
})
84109

110+
t.Run("request body is cloned", func(t *testing.T) {
111+
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
112+
if args.ResourceMetadataURL != "http://metadata.example.com" {
113+
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
114+
}
115+
return fakeTokenSource, nil
116+
}
117+
118+
transport, err := NewHTTPTransport(handler, nil)
119+
if err != nil {
120+
t.Fatalf("NewHTTPTransport() failed: %v", err)
121+
}
122+
client := &http.Client{Transport: transport}
123+
124+
resp, err := client.Post(authServer.URL, "application/json", &basicReader{strings.NewReader("{}")})
125+
if err != nil {
126+
t.Fatalf("client.Post() failed: %v", err)
127+
}
128+
defer resp.Body.Close()
129+
130+
if resp.StatusCode != http.StatusOK {
131+
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
132+
}
133+
})
134+
85135
t.Run("handler returns error", func(t *testing.T) {
86136
handlerErr := errors.New("user rejected auth")
87137
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {

0 commit comments

Comments
 (0)