Skip to content

Commit 1f6cc36

Browse files
atsushi-ishibashivishr
authored andcommitted
Set subdomains to AllowOrigins with wildcard (#1301)
* Set subdomains to AllowOrigins with wildcard * Create IsSubDomain * Avoid panic when pattern length smaller than domain length * Change names, improve formula
1 parent 5434a53 commit 1f6cc36

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

middleware/cors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
102102
allowOrigin = o
103103
break
104104
}
105+
if matchSubdomain(origin, o) {
106+
allowOrigin = origin
107+
break
108+
}
105109
}
106110

107111
// Simple request

middleware/cors_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,20 @@ func TestCORS(t *testing.T) {
6666
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
6767
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
6868
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
69+
70+
// Preflight request with `AllowOrigins` which allow all subdomains with *
71+
req = httptest.NewRequest(http.MethodOptions, "/", nil)
72+
rec = httptest.NewRecorder()
73+
c = e.NewContext(req, rec)
74+
req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
75+
cors = CORSWithConfig(CORSConfig{
76+
AllowOrigins: []string{"http://*.example.com"},
77+
})
78+
h = cors(echo.NotFoundHandler)
79+
h(c)
80+
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
81+
82+
req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
83+
h(c)
84+
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
6985
}

middleware/util.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package middleware
2+
3+
import (
4+
"strings"
5+
)
6+
7+
func matchScheme(domain, pattern string) bool {
8+
didx := strings.Index(domain, ":")
9+
pidx := strings.Index(pattern, ":")
10+
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
11+
}
12+
13+
// matchSubdomain compares authority with wildcard
14+
func matchSubdomain(domain, pattern string) bool {
15+
if !matchScheme(domain, pattern) {
16+
return false
17+
}
18+
didx := strings.Index(domain, "://")
19+
pidx := strings.Index(pattern, "://")
20+
if didx == -1 || pidx == -1 {
21+
return false
22+
}
23+
domAuth := domain[didx+3:]
24+
// to avoid long loop by invalid long domain
25+
if len(domAuth) > 253 {
26+
return false
27+
}
28+
patAuth := pattern[pidx+3:]
29+
30+
domComp := strings.Split(domAuth, ".")
31+
patComp := strings.Split(patAuth, ".")
32+
for i := len(domComp)/2 - 1; i >= 0; i-- {
33+
opp := len(domComp) - 1 - i
34+
domComp[i], domComp[opp] = domComp[opp], domComp[i]
35+
}
36+
for i := len(patComp)/2 - 1; i >= 0; i-- {
37+
opp := len(patComp) - 1 - i
38+
patComp[i], patComp[opp] = patComp[opp], patComp[i]
39+
}
40+
41+
for i, v := range domComp {
42+
if len(patComp) <= i {
43+
return false
44+
}
45+
p := patComp[i]
46+
if p == "*" {
47+
return true
48+
}
49+
if p != v {
50+
return false
51+
}
52+
}
53+
return false
54+
}

middleware/util_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package middleware
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func Test_matchScheme(t *testing.T) {
10+
tests := []struct {
11+
domain, pattern string
12+
expected bool
13+
}{
14+
{
15+
domain: "http://example.com",
16+
pattern: "http://example.com",
17+
expected: true,
18+
},
19+
{
20+
domain: "https://example.com",
21+
pattern: "https://example.com",
22+
expected: true,
23+
},
24+
{
25+
domain: "http://example.com",
26+
pattern: "https://example.com",
27+
expected: false,
28+
},
29+
{
30+
domain: "https://example.com",
31+
pattern: "http://example.com",
32+
expected: false,
33+
},
34+
}
35+
36+
for _, v := range tests {
37+
assert.Equal(t, v.expected, matchScheme(v.domain, v.pattern))
38+
}
39+
}
40+
41+
func Test_matchSubdomain(t *testing.T) {
42+
tests := []struct {
43+
domain, pattern string
44+
expected bool
45+
}{
46+
{
47+
domain: "http://aaa.example.com",
48+
pattern: "http://*.example.com",
49+
expected: true,
50+
},
51+
{
52+
domain: "http://bbb.aaa.example.com",
53+
pattern: "http://*.example.com",
54+
expected: true,
55+
},
56+
{
57+
domain: "http://bbb.aaa.example.com",
58+
pattern: "http://*.aaa.example.com",
59+
expected: true,
60+
},
61+
{
62+
domain: "http://aaa.example.com:8080",
63+
pattern: "http://*.example.com:8080",
64+
expected: true,
65+
},
66+
67+
{
68+
domain: "http://fuga.hoge.com",
69+
pattern: "http://*.example.com",
70+
expected: false,
71+
},
72+
{
73+
domain: "http://ccc.bbb.example.com",
74+
pattern: "http://*.aaa.example.com",
75+
expected: false,
76+
},
77+
{
78+
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
79+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
80+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
81+
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
82+
pattern: "http://*.example.com",
83+
expected: false,
84+
},
85+
{
86+
domain: "http://ccc.bbb.example.com",
87+
pattern: "http://example.com",
88+
expected: false,
89+
},
90+
}
91+
92+
for _, v := range tests {
93+
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
94+
}
95+
}

0 commit comments

Comments
 (0)