diff --git a/middleware/proxy.go b/middleware/proxy.go index e4f98d9ed..16b00d645 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { c.Set("_error", nil) } + // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request + // that Balancer may have replaced with c.SetRequest. + req = c.Request() + // Proxy switch { case c.IsWebSocket(): diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 415d68e77..1c93ba031 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) { rrb := NewRoundRobinBalancer([]*ProxyTarget{}) assert.Nil(t, rrb.Next(nil)) } + +type testContextKey string + +type customBalancer struct { + target *ProxyTarget +} + +func (b *customBalancer) AddTarget(target *ProxyTarget) bool { + return false +} + +func (b *customBalancer) RemoveTarget(name string) bool { + return false +} + +func (b *customBalancer) Next(c echo.Context) *ProxyTarget { + ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") + c.SetRequest(c.Request().WithContext(ctx)) + return b.target +} + +func TestModifyResponseUseContext(t *testing.T) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }), + ) + defer server.Close() + + targetURL, _ := url.Parse(server.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: &customBalancer{ + target: &ProxyTarget{ + Name: "tst", + URL: targetURL, + }, + }, + RetryCount: 1, + ModifyResponse: func(res *http.Response) error { + val := res.Request.Context().Value(testContextKey("FROM_BALANCER")) + if valStr, ok := val.(string); ok { + res.Header.Set("FROM_BALANCER", valStr) + } + return nil + }, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) +}