Skip to content

Commit 9a28fb8

Browse files
committed
cors allow regex pattern
enable cors to use regex pattern for allowed origins implementation is similar to another popular cors middleware: https://github.com/astaxie/beego/blob/master/plugins/cors/cors.go#L196-L201
1 parent 8dd25c3 commit 9a28fb8

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

middleware/cors.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"net/http"
5+
"regexp"
56
"strconv"
67
"strings"
78

@@ -76,6 +77,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
7677
config.AllowMethods = DefaultCORSConfig.AllowMethods
7778
}
7879

80+
allowOriginPatterns := []string{}
81+
for _, origin := range config.AllowOrigins {
82+
pattern := regexp.QuoteMeta(origin)
83+
pattern = strings.Replace(pattern, "\\*", ".*", -1)
84+
pattern = strings.Replace(pattern, "\\?", ".", -1)
85+
pattern = "^" + pattern + "$"
86+
allowOriginPatterns = append(allowOriginPatterns, pattern)
87+
}
88+
7989
allowMethods := strings.Join(config.AllowMethods, ",")
8090
allowHeaders := strings.Join(config.AllowHeaders, ",")
8191
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
@@ -108,6 +118,26 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
108118
}
109119
}
110120

121+
// Check allowed origin patterns
122+
for _, re := range allowOriginPatterns {
123+
if allowOrigin == "" {
124+
didx := strings.Index(origin, "://")
125+
if didx == -1 {
126+
continue
127+
}
128+
domAuth := origin[didx+3:]
129+
// to avoid regex cost by invalid long domain
130+
if len(domAuth) > 253 {
131+
break
132+
}
133+
134+
if match, _ := regexp.MatchString(re, origin); match {
135+
allowOrigin = origin
136+
break
137+
}
138+
}
139+
}
140+
111141
// Simple request
112142
if req.Method != http.MethodOptions {
113143
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)

middleware/cors_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,141 @@ func TestCORS(t *testing.T) {
8383
h(c)
8484
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
8585
}
86+
87+
func Test_allowOriginScheme(t *testing.T) {
88+
tests := []struct {
89+
domain, pattern string
90+
expected bool
91+
}{
92+
{
93+
domain: "http://example.com",
94+
pattern: "http://example.com",
95+
expected: true,
96+
},
97+
{
98+
domain: "https://example.com",
99+
pattern: "https://example.com",
100+
expected: true,
101+
},
102+
{
103+
domain: "http://example.com",
104+
pattern: "https://example.com",
105+
expected: false,
106+
},
107+
{
108+
domain: "https://example.com",
109+
pattern: "http://example.com",
110+
expected: false,
111+
},
112+
}
113+
114+
e := echo.New()
115+
for _, tt := range tests {
116+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
117+
rec := httptest.NewRecorder()
118+
c := e.NewContext(req, rec)
119+
req.Header.Set(echo.HeaderOrigin, tt.domain)
120+
cors := CORSWithConfig(CORSConfig{
121+
AllowOrigins: []string{tt.pattern},
122+
})
123+
h := cors(echo.NotFoundHandler)
124+
h(c)
125+
126+
if tt.expected {
127+
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
128+
} else {
129+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
130+
}
131+
}
132+
}
133+
134+
func Test_allowOriginSubdomain(t *testing.T) {
135+
tests := []struct {
136+
domain, pattern string
137+
expected bool
138+
}{
139+
{
140+
domain: "http://aaa.example.com",
141+
pattern: "http://*.example.com",
142+
expected: true,
143+
},
144+
{
145+
domain: "http://bbb.aaa.example.com",
146+
pattern: "http://*.example.com",
147+
expected: true,
148+
},
149+
{
150+
domain: "http://bbb.aaa.example.com",
151+
pattern: "http://*.aaa.example.com",
152+
expected: true,
153+
},
154+
{
155+
domain: "http://aaa.example.com:8080",
156+
pattern: "http://*.example.com:8080",
157+
expected: true,
158+
},
159+
160+
{
161+
domain: "http://fuga.hoge.com",
162+
pattern: "http://*.example.com",
163+
expected: false,
164+
},
165+
{
166+
domain: "http://ccc.bbb.example.com",
167+
pattern: "http://*.aaa.example.com",
168+
expected: false,
169+
},
170+
{
171+
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
172+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
173+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
174+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
175+
pattern: "http://*.example.com",
176+
expected: false,
177+
},
178+
{
179+
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
180+
pattern: "http://*.example.com",
181+
expected: false,
182+
},
183+
{
184+
domain: "http://ccc.bbb.example.com",
185+
pattern: "http://example.com",
186+
expected: false,
187+
},
188+
{
189+
domain: "https://prod-preview--aaa.bbb.com",
190+
pattern: "https://*--aaa.bbb.com",
191+
expected: true,
192+
},
193+
{
194+
domain: "http://ccc.bbb.example.com",
195+
pattern: "http://*.example.com",
196+
expected: true,
197+
},
198+
{
199+
domain: "http://ccc.bbb.example.com",
200+
pattern: "http://foo.[a-z]*.example.com",
201+
expected: false,
202+
},
203+
}
204+
205+
e := echo.New()
206+
for _, tt := range tests {
207+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
208+
rec := httptest.NewRecorder()
209+
c := e.NewContext(req, rec)
210+
req.Header.Set(echo.HeaderOrigin, tt.domain)
211+
cors := CORSWithConfig(CORSConfig{
212+
AllowOrigins: []string{tt.pattern},
213+
})
214+
h := cors(echo.NotFoundHandler)
215+
h(c)
216+
217+
if tt.expected {
218+
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
219+
} else {
220+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
221+
}
222+
}
223+
}

0 commit comments

Comments
 (0)