Skip to content

Commit 932976d

Browse files
lammelpwli0755
andauthored
Support real regex rules for rewrite and proxy middleware (#1767)
Support real regex rules for rewrite and proxy middleware (use non-greedy matching by default) Co-authored-by: pwli <lipw0755@gmail.com>
1 parent 7c8592a commit 932976d

File tree

5 files changed

+182
-55
lines changed

5 files changed

+182
-55
lines changed

middleware/middleware.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
3838
rulesRegex := map[*regexp.Regexp]string{}
3939
for k, v := range rewrite {
4040
k = regexp.QuoteMeta(k)
41-
k = strings.Replace(k, `\*`, "(.*)", -1)
41+
k = strings.Replace(k, `\*`, "(.*?)", -1)
4242
if strings.HasPrefix(k, `\^`) {
4343
k = strings.Replace(k, `\^`, "^", -1)
4444
}

middleware/proxy.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ type (
3636
// "/users/*/orders/*": "/user/$1/order/$2",
3737
Rewrite map[string]string
3838

39+
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
40+
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
41+
// Example:
42+
// "^/old/[0.9]+/": "/new",
43+
// "^/api/.+?/(.*)": "/v2/$1",
44+
RegexRewrite map[*regexp.Regexp]string
45+
3946
// Context key to store selected ProxyTarget into context.
4047
// Optional. Default value "target".
4148
ContextKey string
@@ -46,8 +53,6 @@ type (
4653

4754
// ModifyResponse defines function to modify response from ProxyTarget.
4855
ModifyResponse func(*http.Response) error
49-
50-
rewriteRegex map[*regexp.Regexp]string
5156
}
5257

5358
// ProxyTarget defines the upstream target.
@@ -206,7 +211,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
206211
panic("echo: proxy middleware requires balancer")
207212
}
208213

209-
config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
214+
if config.Rewrite != nil {
215+
if config.RegexRewrite == nil {
216+
config.RegexRewrite = make(map[*regexp.Regexp]string)
217+
}
218+
for k, v := range rewriteRulesRegex(config.Rewrite) {
219+
config.RegexRewrite[k] = v
220+
}
221+
}
210222

211223
return func(next echo.HandlerFunc) echo.HandlerFunc {
212224
return func(c echo.Context) (err error) {
@@ -220,7 +232,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
220232
c.Set(config.ContextKey, tgt)
221233

222234
// Set rewrite path and raw path
223-
rewritePath(config.rewriteRegex, req)
235+
rewritePath(config.RegexRewrite, req)
224236

225237
// Fix header
226238
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
@@ -251,5 +263,3 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
251263
}
252264
}
253265
}
254-
255-

middleware/proxy_test.go

Lines changed: 102 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"net/url"
11+
"regexp"
1112
"testing"
1213

1314
"github.com/labstack/echo/v4"
@@ -83,46 +84,6 @@ func TestProxy(t *testing.T) {
8384
body = rec.Body.String()
8485
assert.Equal(t, "target 2", body)
8586

86-
// Rewrite
87-
e = echo.New()
88-
e.Use(ProxyWithConfig(ProxyConfig{
89-
Balancer: rrb,
90-
Rewrite: map[string]string{
91-
"/old": "/new",
92-
"/api/*": "/$1",
93-
"/js/*": "/public/javascripts/$1",
94-
"/users/*/orders/*": "/user/$1/order/$2",
95-
},
96-
}))
97-
req.URL, _ = url.Parse("/api/users")
98-
rec = httptest.NewRecorder()
99-
e.ServeHTTP(rec, req)
100-
assert.Equal(t, "/users", req.URL.EscapedPath())
101-
assert.Equal(t, http.StatusOK, rec.Code)
102-
req.URL, _ = url.Parse( "/js/main.js")
103-
rec = httptest.NewRecorder()
104-
e.ServeHTTP(rec, req)
105-
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
106-
assert.Equal(t, http.StatusOK, rec.Code)
107-
req.URL, _ = url.Parse("/old")
108-
rec = httptest.NewRecorder()
109-
e.ServeHTTP(rec, req)
110-
assert.Equal(t, "/new", req.URL.EscapedPath())
111-
assert.Equal(t, http.StatusOK, rec.Code)
112-
req.URL, _ = url.Parse( "/users/jack/orders/1")
113-
rec = httptest.NewRecorder()
114-
e.ServeHTTP(rec, req)
115-
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
116-
assert.Equal(t, http.StatusOK, rec.Code)
117-
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
118-
rec = httptest.NewRecorder()
119-
e.ServeHTTP(rec, req)
120-
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
121-
assert.Equal(t, http.StatusOK, rec.Code)
122-
req.URL, _ = url.Parse("/api/new users")
123-
rec = httptest.NewRecorder()
124-
e.ServeHTTP(rec, req)
125-
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
12687
// ModifyResponse
12788
e = echo.New()
12889
e.Use(ProxyWithConfig(ProxyConfig{
@@ -196,3 +157,104 @@ func TestProxyRealIPHeader(t *testing.T) {
196157
assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
197158
}
198159
}
160+
161+
func TestProxyRewrite(t *testing.T) {
162+
// Setup
163+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
164+
defer upstream.Close()
165+
url, _ := url.Parse(upstream.URL)
166+
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
167+
req := httptest.NewRequest(http.MethodGet, "/", nil)
168+
rec := httptest.NewRecorder()
169+
170+
// Rewrite
171+
e := echo.New()
172+
e.Use(ProxyWithConfig(ProxyConfig{
173+
Balancer: rrb,
174+
Rewrite: map[string]string{
175+
"/old": "/new",
176+
"/api/*": "/$1",
177+
"/js/*": "/public/javascripts/$1",
178+
"/users/*/orders/*": "/user/$1/order/$2",
179+
},
180+
}))
181+
req.URL, _ = url.Parse("/api/users")
182+
rec = httptest.NewRecorder()
183+
e.ServeHTTP(rec, req)
184+
assert.Equal(t, "/users", req.URL.EscapedPath())
185+
assert.Equal(t, http.StatusOK, rec.Code)
186+
req.URL, _ = url.Parse("/js/main.js")
187+
rec = httptest.NewRecorder()
188+
e.ServeHTTP(rec, req)
189+
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
190+
assert.Equal(t, http.StatusOK, rec.Code)
191+
req.URL, _ = url.Parse("/old")
192+
rec = httptest.NewRecorder()
193+
e.ServeHTTP(rec, req)
194+
assert.Equal(t, "/new", req.URL.EscapedPath())
195+
assert.Equal(t, http.StatusOK, rec.Code)
196+
req.URL, _ = url.Parse("/users/jack/orders/1")
197+
rec = httptest.NewRecorder()
198+
e.ServeHTTP(rec, req)
199+
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
200+
assert.Equal(t, http.StatusOK, rec.Code)
201+
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
202+
rec = httptest.NewRecorder()
203+
e.ServeHTTP(rec, req)
204+
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
205+
assert.Equal(t, http.StatusOK, rec.Code)
206+
req.URL, _ = url.Parse("/api/new users")
207+
rec = httptest.NewRecorder()
208+
e.ServeHTTP(rec, req)
209+
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
210+
}
211+
212+
func TestProxyRewriteRegex(t *testing.T) {
213+
// Setup
214+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
215+
defer upstream.Close()
216+
url, _ := url.Parse(upstream.URL)
217+
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
218+
req := httptest.NewRequest(http.MethodGet, "/", nil)
219+
rec := httptest.NewRecorder()
220+
221+
// Rewrite
222+
e := echo.New()
223+
e.Use(ProxyWithConfig(ProxyConfig{
224+
Balancer: rrb,
225+
Rewrite: map[string]string{
226+
"^/a/*": "/v1/$1",
227+
"^/b/*/c/*": "/v2/$2/$1",
228+
"^/c/*/*": "/v3/$2",
229+
},
230+
RegexRewrite: map[*regexp.Regexp]string{
231+
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
232+
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
233+
},
234+
}))
235+
236+
testCases := []struct {
237+
requestPath string
238+
statusCode int
239+
expectPath string
240+
}{
241+
{"/unmatched", http.StatusOK, "/unmatched"},
242+
{"/a/test", http.StatusOK, "/v1/test"},
243+
{"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
244+
{"/c/ignore/test", http.StatusOK, "/v3/test"},
245+
{"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
246+
{"/x/ignore/test", http.StatusOK, "/v4/test"},
247+
{"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
248+
}
249+
250+
251+
for _, tc := range testCases {
252+
t.Run(tc.requestPath, func(t *testing.T) {
253+
req.URL, _ = url.Parse(tc.requestPath)
254+
rec = httptest.NewRecorder()
255+
e.ServeHTTP(rec, req)
256+
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
257+
assert.Equal(t, tc.statusCode, rec.Code)
258+
})
259+
}
260+
}

middleware/rewrite.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package middleware
22

33
import (
4-
"github.com/labstack/echo/v4"
54
"regexp"
5+
6+
"github.com/labstack/echo/v4"
67
)
78

89
type (
@@ -21,7 +22,12 @@ type (
2122
// Required.
2223
Rules map[string]string `yaml:"rules"`
2324

24-
rulesRegex map[*regexp.Regexp]string
25+
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
26+
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
27+
// Example:
28+
// "^/old/[0.9]+/": "/new",
29+
// "^/api/.+?/(.*)": "/v2/$1",
30+
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
2531
}
2632
)
2733

@@ -45,14 +51,20 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc {
4551
// See: `Rewrite()`.
4652
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
4753
// Defaults
48-
if config.Rules == nil {
49-
panic("echo: rewrite middleware requires url path rewrite rules")
54+
if config.Rules == nil && config.RegexRules == nil {
55+
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
5056
}
57+
5158
if config.Skipper == nil {
5259
config.Skipper = DefaultBodyDumpConfig.Skipper
5360
}
5461

55-
config.rulesRegex = rewriteRulesRegex(config.Rules)
62+
if config.RegexRules == nil {
63+
config.RegexRules = make(map[*regexp.Regexp]string)
64+
}
65+
for k, v := range rewriteRulesRegex(config.Rules) {
66+
config.RegexRules[k] = v
67+
}
5668

5769
return func(next echo.HandlerFunc) echo.HandlerFunc {
5870
return func(c echo.Context) (err error) {
@@ -62,7 +74,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
6274

6375
req := c.Request()
6476
// Set rewrite path and raw path
65-
rewritePath(config.rulesRegex, req)
77+
rewritePath(config.RegexRules, req)
6678
return next(c)
6779
}
6880
}

middleware/rewrite_test.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"net/http/httptest"
77
"net/url"
8+
"regexp"
89
"testing"
910

1011
"github.com/labstack/echo/v4"
@@ -55,8 +56,8 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
5556

5657
// Rewrite old url to new one
5758
e.Pre(Rewrite(map[string]string{
58-
"/old": "/new",
59-
},
59+
"/old": "/new",
60+
},
6061
))
6162

6263
// Route
@@ -129,3 +130,45 @@ func TestEchoRewriteWithCaret(t *testing.T) {
129130
e.ServeHTTP(rec, req)
130131
assert.Equal(t, "/v2/abc/test", req.URL.Path)
131132
}
133+
134+
// Verify regex used with rewrite
135+
func TestEchoRewriteWithRegexRules(t *testing.T) {
136+
e := echo.New()
137+
138+
e.Pre(RewriteWithConfig(RewriteConfig{
139+
Rules: map[string]string{
140+
"^/a/*": "/v1/$1",
141+
"^/b/*/c/*": "/v2/$2/$1",
142+
"^/c/*/*": "/v3/$2",
143+
},
144+
RegexRules: map[*regexp.Regexp]string{
145+
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
146+
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
147+
},
148+
}))
149+
150+
var rec *httptest.ResponseRecorder
151+
var req *http.Request
152+
153+
testCases := []struct {
154+
requestPath string
155+
expectPath string
156+
}{
157+
{"/unmatched", "/unmatched"},
158+
{"/a/test", "/v1/test"},
159+
{"/b/foo/c/bar/baz", "/v2/bar/baz/foo"},
160+
{"/c/ignore/test", "/v3/test"},
161+
{"/c/ignore1/test/this", "/v3/test/this"},
162+
{"/x/ignore/test", "/v4/test"},
163+
{"/y/foo/bar", "/v5/bar/foo"},
164+
}
165+
166+
for _, tc := range testCases {
167+
t.Run(tc.requestPath, func(t *testing.T) {
168+
req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
169+
rec = httptest.NewRecorder()
170+
e.ServeHTTP(rec, req)
171+
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
172+
})
173+
}
174+
}

0 commit comments

Comments
 (0)