Skip to content

Commit

Permalink
feat(auth/httptransport): add ability to customize transport (#10023)
Browse files Browse the repository at this point in the history
This was a known limitation of the current implementation that had been forgotten to be implemented. See removal of todo in related PR.

Updates: #9812
Updates: #9814
Related: googleapis/google-api-go-client#2541
  • Loading branch information
codyoss authored Apr 22, 2024
1 parent 50994e7 commit 72c7f6b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
9 changes: 8 additions & 1 deletion auth/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type Options struct {
// Headers are extra HTTP headers that will be appended to every outgoing
// request.
Headers http.Header
// BaseRoundTripper overrides the base transport used for serving requests.
// If specified ClientCertProvider is ignored.
BaseRoundTripper http.RoundTripper
// Endpoint overrides the default endpoint to be used for a service.
Endpoint string
// APIKey specifies an API key to be used as the basis for authentication.
Expand Down Expand Up @@ -181,7 +184,11 @@ func NewClient(opts *Options) (*http.Client, error) {
if err != nil {
return nil, err
}
trans, err := newTransport(defaultBaseTransport(clientCertProvider, dialTLSContext), opts)
baseRoundTripper := opts.BaseRoundTripper
if baseRoundTripper == nil {
baseRoundTripper = defaultBaseTransport(clientCertProvider, dialTLSContext)
}
trans, err := newTransport(baseRoundTripper, opts)
if err != nil {
return nil, err
}
Expand Down
36 changes: 36 additions & 0 deletions auth/httptransport/httptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,46 @@ func TestNewClient_APIKey(t *testing.T) {
}
}

func TestNewClient_BaseRoundTripper(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("Foo")
if want := "foo"; got != want {
t.Errorf("got %q, want %q", got, want)
}
got = r.Header.Get("Bar")
if want := "bar"; got != want {
t.Errorf("got %q, want %q", got, want)
}
}))
defer ts.Close()
client, err := NewClient(&Options{
BaseRoundTripper: &rt{key: "Bar", value: "bar"},
Headers: http.Header{"Foo": []string{"foo"}},
APIKey: "key",
})
if err != nil {
t.Fatalf("NewClient() = %v", err)
}
if _, err := client.Get(ts.URL); err != nil {
t.Fatalf("client.Get() = %v", err)
}
}

type staticTP string

func (tp staticTP) Token(context.Context) (*auth.Token, error) {
return &auth.Token{
Value: string(tp),
}, nil
}

type rt struct {
key string
value string
}

func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.Header.Add(r.key, r.value)
return http.DefaultTransport.RoundTrip(req2)
}

0 comments on commit 72c7f6b

Please sign in to comment.