Skip to content

Commit

Permalink
proxy middleware: reuse echo request context (#2537)
Browse files Browse the repository at this point in the history
  • Loading branch information
x1h0 authored Nov 5, 2023
1 parent 69a0de8 commit c7d6d43
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
4 changes: 4 additions & 0 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
c.Set("_error", nil)
}

// This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
// that Balancer may have replaced with c.SetRequest.
req = c.Request()

// Proxy
switch {
case c.IsWebSocket():
Expand Down
60 changes: 60 additions & 0 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) {
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
assert.Nil(t, rrb.Next(nil))
}

type testContextKey string

type customBalancer struct {
target *ProxyTarget
}

func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
return false
}

func (b *customBalancer) RemoveTarget(name string) bool {
return false
}

func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
c.SetRequest(c.Request().WithContext(ctx))
return b.target
}

func TestModifyResponseUseContext(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}),
)
defer server.Close()

targetURL, _ := url.Parse(server.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: &customBalancer{
target: &ProxyTarget{
Name: "tst",
URL: targetURL,
},
},
RetryCount: 1,
ModifyResponse: func(res *http.Response) error {
val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
if valStr, ok := val.(string); ok {
res.Header.Set("FROM_BALANCER", valStr)
}
return nil
},
},
))

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}

0 comments on commit c7d6d43

Please sign in to comment.