From 52e512c4014bb7165715f211fd6f3ac4cf19c529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Wed, 28 Feb 2024 17:56:07 +0100 Subject: [PATCH] http_config: Add host (#549) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * http_config: Add host --------- Signed-off-by: Jan-Otto Kröpke Signed-off-by: Jan-Otto Kröpke Co-authored-by: Ben Kochie --- config/http_config.go | 38 ++++++++++++++++++++++++++++++++++++++ config/http_config_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/config/http_config.go b/config/http_config.go index f295e917..d8750bc2 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -309,6 +309,9 @@ type HTTPClientConfig struct { // The omitempty flag is not set, because it would be hidden from the // marshalled configuration when set to false. EnableHTTP2 bool `yaml:"enable_http2" json:"enable_http2"` + // Host optionally overrides the Host header to send. + // If empty, the host from the URL will be used. + Host string `yaml:"host,omitempty" json:"host,omitempty"` // Proxy configuration. ProxyConfig `yaml:",inline"` } @@ -427,6 +430,7 @@ type httpClientOptions struct { http2Enabled bool idleConnTimeout time.Duration userAgent string + host string } // HTTPClientOption defines an option that can be applied to the HTTP client. @@ -467,6 +471,13 @@ func WithUserAgent(ua string) HTTPClientOption { } } +// WithHost allows setting the host header. +func WithHost(host string) HTTPClientOption { + return func(opts *httpClientOptions) { + opts.host = host + } +} + // NewClient returns a http.Client using the specified http.RoundTripper. func newClient(rt http.RoundTripper) *http.Client { return &http.Client{Transport: rt} @@ -568,6 +579,10 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT rt = NewUserAgentRoundTripper(opts.userAgent, rt) } + if opts.host != "" { + rt = NewHostRoundTripper(opts.host, rt) + } + // Return a new configured RoundTripper. return rt, nil } @@ -1164,11 +1179,21 @@ type userAgentRoundTripper struct { rt http.RoundTripper } +type hostRoundTripper struct { + host string + rt http.RoundTripper +} + // NewUserAgentRoundTripper adds the user agent every request header. func NewUserAgentRoundTripper(userAgent string, rt http.RoundTripper) http.RoundTripper { return &userAgentRoundTripper{userAgent, rt} } +// NewHostRoundTripper sets the [http.Request.Host] of every request. +func NewHostRoundTripper(host string, rt http.RoundTripper) http.RoundTripper { + return &hostRoundTripper{host, rt} +} + func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req = cloneRequest(req) req.Header.Set("User-Agent", rt.userAgent) @@ -1181,6 +1206,19 @@ func (rt *userAgentRoundTripper) CloseIdleConnections() { } } +func (rt *hostRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + req.Host = rt.host + req.Header.Set("Host", rt.host) + return rt.rt.RoundTrip(req) +} + +func (rt *hostRoundTripper) CloseIdleConnections() { + if ci, ok := rt.rt.(closeIdler); ok { + ci.CloseIdleConnections() + } +} + func (c HTTPClientConfig) String() string { b, err := yaml.Marshal(c) if err != nil { diff --git a/config/http_config_test.go b/config/http_config_test.go index b0d3939f..cd13a188 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -35,7 +35,7 @@ import ( "testing" "time" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) const ( @@ -1671,6 +1671,32 @@ func TestOAuth2UserAgent(t *testing.T) { } } +func TestHost(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Host != "localhost.localdomain" { + t.Fatalf("Expected Host header in request to be 'localhost.localdomain', got '%s'", r.Host) + } + + w.Header().Add("Content-Type", "application/json") + })) + defer ts.Close() + + config := DefaultHTTPClientConfig + + rt, err := NewRoundTripperFromConfig(config, "test_host", WithHost("localhost.localdomain")) + if err != nil { + t.Fatal(err) + } + + client := http.Client{ + Transport: rt, + } + _, err = client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } +} + func TestOAuth2WithFile(t *testing.T) { var expectedAuth string ts := newTestOAuthServer(t, &expectedAuth)