Skip to content

net/http/httputil: add a ModifyRequest method to ReverseProxy #44535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/net/http/httputil/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,22 @@ type ReverseProxy struct {
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
// Deprecated: New code should use the new ModifyRequest method,
// which allows returning an error to interrupt the flow.
Director func(*http.Request)

// ModifyRequest must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// ModifyRequest must not access the provided Request
// after returning.
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
// If ModifyRequest is nil, the old Director method is called.
ModifyRequest func(*http.Request) error

// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
Expand Down Expand Up @@ -141,7 +155,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// Director policy.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
modifyRequest := func(req *http.Request) error {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
Expand All @@ -154,8 +168,9 @@ func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
return nil
}
return &ReverseProxy{Director: director}
return &ReverseProxy{ModifyRequest: modifyRequest}
}

func copyHeader(dst, src http.Header) {
Expand Down Expand Up @@ -195,6 +210,20 @@ func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request
return p.defaultErrorHandler
}

// modifyRequest conditionally runs the optional ModifyRequest hook or the old
// Director policy and reports whether the request should proceed.
func (p *ReverseProxy) modifyRequest(rw http.ResponseWriter, req *http.Request) bool {
if p.ModifyRequest == nil {
p.Director(req)
return true
}
if err := p.ModifyRequest(req); err != nil {
p.getErrorHandler()(rw, req, err)
return false
}
return true
}

// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
Expand Down Expand Up @@ -238,7 +267,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}

p.Director(outreq)
if !p.modifyRequest(rw, outreq) {
return
}
outreq.Close = false

reqUpType := upgradeType(outreq.Header)
Expand Down
77 changes: 63 additions & 14 deletions src/net/http/httputil/reverseproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ func TestXForwardedFor_Omit(t *testing.T) {
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()

oldDirector := proxyHandler.Director
proxyHandler.Director = func(r *http.Request) {
oldModifyRequest := proxyHandler.ModifyRequest
proxyHandler.ModifyRequest = func(r *http.Request) error {
r.Header["X-Forwarded-For"] = nil
oldDirector(r)
return oldModifyRequest(r)
}

getReq, _ := http.NewRequest("GET", frontend.URL, nil)
Expand Down Expand Up @@ -694,8 +694,8 @@ func TestReverseProxy_NilBody(t *testing.T) {
// Issue 33142: always allocate the request headers
func TestReverseProxy_AllocatedHeader(t *testing.T) {
proxyHandler := new(ReverseProxy)
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
proxyHandler.Director = func(*http.Request) {} // noop
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
proxyHandler.ModifyRequest = func(*http.Request) error { return nil } // noop
proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.Header == nil {
t.Error("Header == nil; want a non-nil Header")
Expand Down Expand Up @@ -898,8 +898,8 @@ func BenchmarkServeHTTP(b *testing.B) {
Body: io.NopCloser(strings.NewReader("")),
}
proxy := &ReverseProxy{
Director: func(*http.Request) {},
Transport: &staticTransport{res},
ModifyRequest: func(*http.Request) error { return nil },
Transport: &staticTransport{res},
}

w := httptest.NewRecorder()
Expand Down Expand Up @@ -957,20 +957,21 @@ func TestClonesRequestHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
req.RemoteAddr = "1.2.3.4:56789"
rp := &ReverseProxy{
Director: func(req *http.Request) {
req.Header.Set("From-Director", "1")
ModifyRequest: func(req *http.Request) error {
req.Header.Set("From-ModifyRequest", "1")
return nil
},
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if v := req.Header.Get("From-Director"); v != "1" {
t.Errorf("From-Directory value = %q; want 1", v)
if v := req.Header.Get("From-ModifyRequest"); v != "1" {
t.Errorf("From-ModifyRequest value = %q; want 1", v)
}
return nil, io.EOF
}),
}
rp.ServeHTTP(httptest.NewRecorder(), req)

if req.Header.Get("From-Director") == "1" {
t.Error("Director header mutation modified caller's request")
if req.Header.Get("From-ModifyRequest") == "1" {
t.Error("ModifyRequest header mutation modified caller's request")
}
if req.Header.Get("X-Forwarded-For") != "" {
t.Error("X-Forward-For header mutation modified caller's request")
Expand All @@ -991,7 +992,7 @@ func TestModifyResponseClosesBody(t *testing.T) {
logBuf := new(bytes.Buffer)
outErr := errors.New("ModifyResponse error")
rp := &ReverseProxy{
Director: func(req *http.Request) {},
ModifyRequest: func(req *http.Request) error { return nil },
Transport: &staticTransport{&http.Response{
StatusCode: 200,
Body: closeCheck,
Expand Down Expand Up @@ -1418,3 +1419,51 @@ func TestJoinURLPath(t *testing.T) {
}
}
}

func TestModifyRequestPrecedence(t *testing.T) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
req.RemoteAddr = "1.2.3.4:56789"
rp := &ReverseProxy{
Director: func(req *http.Request) {
req.Header.Set("From-Director", "1")
},
ModifyRequest: func(req *http.Request) error {
req.Header.Set("From-ModifyRequest", "1")
return nil
},
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if v := req.Header.Get("From-ModifyRequest"); v != "1" {
t.Errorf("From-ModifyRequest value = %q; want 1", v)
}
if v := req.Header.Get("From-Director"); v != "" {
t.Errorf("From-Director value = %q; want nothing", v)
}
return nil, io.EOF
}),
}
rp.ServeHTTP(httptest.NewRecorder(), req)
}

func TestModifyRequestDirectorFallback(t *testing.T) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
req.RemoteAddr = "1.2.3.4:56789"
rp := &ReverseProxy{
Director: func(req *http.Request) {
req.Header.Set("From-Director", "1")
},
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if v := req.Header.Get("From-Director"); v != "1" {
t.Errorf("From-Director value = %q; want 1", v)
}
if v := req.Header.Get("From-ModifyRequest"); v != "" {
t.Errorf("From-ModifyRequest value = %q; want nothing", v)
}
return nil, io.EOF
}),
}
rp.ServeHTTP(httptest.NewRecorder(), req)
}