diff --git a/request.go b/request.go index 95c26088d..ce5fb83a8 100644 --- a/request.go +++ b/request.go @@ -31,11 +31,20 @@ func (r ClientRequest) Do(ctx context.Context, model interface{}) error { return err } + // If the caller provided a response header hook then we'll call it + // once we have a response. + respHeaderHook := contextResponseHeaderHook(ctx) + // Add the context to the request. reqWithCxt := r.retryableRequest.WithContext(ctx) // Execute the request and check the response. resp, err := r.http.Do(reqWithCxt) + if resp != nil { + // We call the callback whenever there's any sort of response, + // even if it's returned in conjunction with an error. + respHeaderHook(resp.StatusCode, resp.Header) + } if err != nil { // If we got an error, and the context has been canceled, // the context's error is probably more useful. diff --git a/request_hooks.go b/request_hooks.go new file mode 100644 index 000000000..b1fe71893 --- /dev/null +++ b/request_hooks.go @@ -0,0 +1,63 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tfe + +import ( + "context" + "fmt" + "net/http" +) + +// ContextWithResponseHeaderHook returns a context that will, if passed to +// [ClientRequest.Do] or to any of the wrapper methods that call it, arrange +// for the given callback to be called with the headers from the raw HTTP +// response. +// +// This is intended for allowing callers to respond to out-of-band metadata +// such as cache-control-related headers, rate limiting headers, etc. Hooks +// must not modify the given [http.Header] or otherwise attempt to change how +// the response is handled by [ClientRequest.Do]. +// +// If the given context already has a response header hook then the returned +// context will call both the existing hook and the newly-provided one, with +// the newer being called first. +func ContextWithResponseHeaderHook(parentCtx context.Context, cb func(status int, header http.Header)) context.Context { + // If the given context already has a notification callback then we'll + // arrange to notify both the previous and the new one. This is not + // a super efficient way to achieve that but we expect it to be rare + // for there to be more than one or two hooks associated with a particular + // request, so it's not warranted to optimize this further. + existingI := parentCtx.Value(contextResponseHeaderHookKey) + finalCb := cb + if existingI != nil { + existing, ok := existingI.(func(int, http.Header)) + // This explicit check-and-panic is redundant but required by our linter. + if !ok { + panic(fmt.Sprintf("context has response header hook of invalid type %T", existingI)) + } + finalCb = func(status int, header http.Header) { + cb(status, header) + existing(status, header) + } + } + return context.WithValue(parentCtx, contextResponseHeaderHookKey, finalCb) +} + +func contextResponseHeaderHook(ctx context.Context) func(int, http.Header) { + cbI := ctx.Value(contextResponseHeaderHookKey) + if cbI == nil { + // Stub callback that does absolutely nothing, then. + return func(int, http.Header) {} + } + return cbI.(func(int, http.Header)) +} + +// contextResponseHeaderHookKey is the type of the internal key used to store +// the callback for [ContextWithResponseHeaderHook] inside a [context.Context] +// object. +type contextResponseHeaderHookKeyType struct{} + +// contextResponseHeaderHookKey is the internal key used to store the callback +// for [ContextWithResponseHeaderHook] inside a [context.Context] object. +var contextResponseHeaderHookKey contextResponseHeaderHookKeyType diff --git a/request_hooks_test.go b/request_hooks_test.go new file mode 100644 index 000000000..e5c573814 --- /dev/null +++ b/request_hooks_test.go @@ -0,0 +1,55 @@ +package tfe + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestContextWithResponseHeaderHook(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("x-thingy", "boop") + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + cfg := &Config{ + Address: server.URL, + BasePath: "/anything", + Token: "placeholder", + } + client, err := NewClient(cfg) + if err != nil { + t.Fatal(err) + } + + called := false + var gotStatus int + var gotHeader http.Header + ctx := ContextWithResponseHeaderHook(context.Background(), func(status int, header http.Header) { + called = true + gotStatus = status + gotHeader = header + }) + + req, err := client.NewRequest("GET", "boop", nil) + if err != nil { + t.Fatal(err) + } + + err = req.Do(ctx, nil) + if err != nil { + t.Fatal(err) + } + + if !called { + t.Fatal("hook was not called") + } + if got, want := gotStatus, http.StatusNoContent; got != want { + t.Fatalf("wrong response status: got %d, want %d", got, want) + } + if got, want := gotHeader.Get("x-thingy"), "boop"; got != want { + t.Fatalf("wrong value for x-thingy field: got %q, want %q", got, want) + } +}