Skip to content

Commit ec71bd4

Browse files
authored
fix: impose limit on origin length when CORS is enabled with a wildcard (#2085)
1 parent 5f0a0b8 commit ec71bd4

File tree

5 files changed

+321
-122
lines changed

5 files changed

+321
-122
lines changed

router/pkg/cors/config.go

Lines changed: 16 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package cors
22

33
import (
4+
"maps"
45
"net/http"
5-
"strings"
6+
"slices"
67
)
78

89
type cors struct {
@@ -12,13 +13,13 @@ type cors struct {
1213
allowOrigins []string
1314
normalHeaders http.Header
1415
preflightHeaders http.Header
15-
wildcardOrigins [][]string
16+
wildcardOrigins []*WildcardPattern
1617
handler http.Handler
1718
}
1819

1920
var (
20-
maxRecursionDepth = 10 // Safeguard against deep recursion
21-
DefaultSchemas = []string{
21+
maxWildcardOriginLength = 4096 // Maximum length of an origin string for it to be eligible for wildcard matching
22+
DefaultSchemas = []string{
2223
"http://",
2324
"https://",
2425
}
@@ -55,7 +56,7 @@ func newCors(handler http.Handler, config Config) *cors {
5556
allowOrigins: normalize(config.AllowOrigins),
5657
normalHeaders: generateNormalHeaders(config),
5758
preflightHeaders: generatePreflightHeaders(config),
58-
wildcardOrigins: config.parseWildcardRules(),
59+
wildcardOrigins: config.parseNewWildcardRules(),
5960
handler: handler,
6061
}
6162
}
@@ -102,10 +103,8 @@ func (cors *cors) validateOrigin(origin string) bool {
102103
if cors.allowAllOrigins {
103104
return true
104105
}
105-
for _, value := range cors.allowOrigins {
106-
if value == origin {
107-
return true
108-
}
106+
if slices.Contains(cors.allowOrigins, origin) {
107+
return true
109108
}
110109
if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
111110
return true
@@ -117,67 +116,25 @@ func (cors *cors) validateOrigin(origin string) bool {
117116
}
118117

119118
func (cors *cors) validateWildcardOrigin(origin string) bool {
120-
for _, w := range cors.wildcardOrigins {
121-
if matchOriginWithRule(origin, w, 0, map[string]bool{}) {
122-
return true
123-
}
124-
}
125-
return false
126-
}
127-
128-
// Recursive helper function with depth limit and memoization
129-
func matchOriginWithRule(origin string, rule []string, depth int, memo map[string]bool) bool {
130-
if depth > maxRecursionDepth {
131-
return false // Exceeded recursion depth
132-
}
133-
134-
// Memoization key
135-
key := origin + "|" + strings.Join(rule, "|")
136-
if val, exists := memo[key]; exists {
137-
return val
138-
}
139-
140-
if len(rule) == 0 {
141-
// Successfully matched if origin is also fully consumed
142-
return origin == ""
143-
}
144-
145-
part := rule[0]
146-
147-
if part == "*" {
148-
// Try to match the remaining rule by advancing in origin
149-
for i := 0; i <= len(origin); i++ {
150-
if matchOriginWithRule(origin[i:], rule[1:], depth+1, memo) {
151-
memo[key] = true
152-
return true
153-
}
154-
}
155-
memo[key] = false
119+
// Origin is >4KB, avoid matching it for performance
120+
if len(origin) > maxWildcardOriginLength {
156121
return false
157122
}
158123

159-
// Check if the origin starts with the current part
160-
if strings.HasPrefix(origin, part) {
161-
// Recursively check the rest of the origin and rule
162-
result := matchOriginWithRule(origin[len(part):], rule[1:], depth+1, memo)
163-
memo[key] = result
164-
return result
124+
for _, w := range cors.wildcardOrigins {
125+
if w.Match(origin) {
126+
return true
127+
}
165128
}
166-
167-
memo[key] = false
168129
return false
169130
}
170131

171132
func (cors *cors) handlePreflight(w http.ResponseWriter) {
172133
header := w.Header()
173-
for key, value := range cors.preflightHeaders {
174-
header[key] = value
175-
}
134+
maps.Copy(header, cors.preflightHeaders)
176135
}
177136

178137
func (cors *cors) handleNormal(w http.ResponseWriter) {
179138
header := w.Header()
180-
for key, value := range cors.normalHeaders {
181-
header[key] = value
182-
}
139+
maps.Copy(header, cors.normalHeaders)
183140
}

router/pkg/cors/cors.go

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,35 +108,16 @@ func (c *Config) Validate() error {
108108
return nil
109109
}
110110

111-
func (c *Config) parseWildcardRules() [][]string {
112-
var wRules [][]string
111+
func (c *Config) parseNewWildcardRules() []*WildcardPattern {
112+
var wRules []*WildcardPattern
113113

114114
for _, o := range c.AllowOrigins {
115115
if !strings.Contains(o, "*") {
116116
continue
117117
}
118118

119-
// Split origin by wildcard (*)
120-
parts := strings.Split(o, "*")
121-
122-
// If there’s no wildcard, skip this origin
123-
if len(parts) == 1 {
124-
continue
125-
}
126-
127-
// Generate rules for origins with multiple wildcard segments
128-
var rule []string
129-
for i, part := range parts {
130-
if i > 0 {
131-
rule = append(rule, "*") // Add wildcard indicator between segments
132-
}
133-
if part != "" {
134-
rule = append(rule, part)
135-
}
136-
}
137-
138-
// Add parsed rule to wRules
139-
wRules = append(wRules, rule)
119+
wp := Compile(o)
120+
wRules = append(wRules, wp)
140121
}
141122

142123
return wRules

router/pkg/cors/cors_test.go

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cors
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"net/http/httptest"
78
"strings"
@@ -220,6 +221,29 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
220221
assert.Len(t, header, 2)
221222
}
222223

224+
func TestExtremeLengthOriginKillswitch(t *testing.T) {
225+
cors := newCors(nil, Config{
226+
Enabled: true,
227+
AllowOrigins: []string{"https://*.google.com"},
228+
})
229+
230+
shortSubdomain := strings.Repeat("a", 10)
231+
longSubdomain := strings.Repeat("a", 500)
232+
tooLongSubdomain := strings.Repeat("a", 4096)
233+
234+
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", shortSubdomain)))
235+
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", longSubdomain)))
236+
assert.False(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain)))
237+
238+
// Should not affect strict origins
239+
cors = newCors(nil, Config{
240+
Enabled: true,
241+
AllowOrigins: []string{fmt.Sprintf("https://%s.google.com", tooLongSubdomain)},
242+
})
243+
244+
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", tooLongSubdomain)))
245+
}
246+
223247
func TestValidateOrigin(t *testing.T) {
224248
cors := newCors(nil, Config{
225249
Enabled: true,
@@ -519,29 +543,10 @@ func TestComplexWildcards(t *testing.T) {
519543
}
520544
for _, tc := range testCasesList {
521545
w := performRequest(router, "GET", tc.origin)
522-
assert.Equal(t, tc.expectedCode, w.Code)
546+
assert.Equalf(t, tc.expectedCode, w.Code, "expected %d for %s, got %d", tc.expectedCode, tc.origin, w.Code)
523547
}
524548
}
525549

526-
func TestMaxRecursionDepth(t *testing.T) {
527-
router := newTestRouter(Config{
528-
Enabled: true,
529-
AllowOrigins: []string{
530-
"https://*.example.*.*.com", // multiple sequential wildcards
531-
"https://*.*.*.*.com",
532-
},
533-
AllowMethods: []string{"GET"},
534-
})
535-
536-
maxRecursionDepth = 2
537-
w := performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
538-
assert.Equal(t, 403, w.Code)
539-
540-
maxRecursionDepth = 10
541-
w = performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
542-
assert.Equal(t, 200, w.Code)
543-
}
544-
545550
func TestDisabled(t *testing.T) {
546551
config := Config{
547552
Enabled: true,
@@ -561,35 +566,26 @@ func TestDisabled(t *testing.T) {
561566
assert.Equal(t, 200, w.Code)
562567
}
563568

564-
func BenchmarkCorsWithoutWildcards(b *testing.B) {
565-
b.ReportAllocs()
566-
b.ResetTimer()
567-
568-
b.Run("without wildcards", func(b *testing.B) {
569+
func BenchmarkCorsWithWildcards(b *testing.B) {
570+
b.Run("with wildcards", func(b *testing.B) {
569571
router := newTestRouter(Config{
570572
Enabled: true,
571573
AllowOrigins: []string{
572-
"https://*.wgexample.com",
573-
"https://wgexample.com",
574-
"https://*.wgexample.io:*",
575-
"https://*.wgexample.org",
576-
"https://*.d2grknavcceso7.amplifyapp.com",
577574
"https://*.example.*.*.com", // multiple sequential wildcards
578575
"https://*.*.*.*.com",
579576
},
580577
AllowMethods: []string{"GET"},
581578
})
582579

583-
w := performRequest(router, "GET", "https://wgexample.com")
584-
assert.Equal(b, 200, w.Code)
580+
b.ReportAllocs()
581+
b.ResetTimer()
582+
for i := 0; i < b.N; i++ {
583+
w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
584+
assert.Equal(b, 200, w.Code)
585+
}
585586
})
586-
}
587-
588-
func BenchmarkCorsWithWildcards(b *testing.B) {
589-
b.ReportAllocs()
590-
b.ResetTimer()
591587

592-
b.Run("with wildcards", func(b *testing.B) {
588+
b.Run("with massive wildcards", func(b *testing.B) {
593589
router := newTestRouter(Config{
594590
Enabled: true,
595591
AllowOrigins: []string{
@@ -599,7 +595,30 @@ func BenchmarkCorsWithWildcards(b *testing.B) {
599595
AllowMethods: []string{"GET"},
600596
})
601597

602-
w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
603-
assert.Equal(b, 200, w.Code)
598+
longString := strings.Repeat("a", 50000)
599+
600+
b.ReportAllocs()
601+
b.ResetTimer()
602+
for i := 0; i < b.N; i++ {
603+
w := performRequest(router, "GET", fmt.Sprintf("https://%[1]s.%[1]s.%[1]s.%[1]s.com", longString))
604+
assert.Equal(b, 200, w.Code)
605+
}
606+
})
607+
608+
b.Run("without wildcards", func(b *testing.B) {
609+
router := newTestRouter(Config{
610+
Enabled: true,
611+
AllowOrigins: []string{
612+
"https://wgexample.com",
613+
},
614+
AllowMethods: []string{"GET"},
615+
})
616+
617+
b.ReportAllocs()
618+
b.ResetTimer()
619+
for i := 0; i < b.N; i++ {
620+
w := performRequest(router, "GET", "https://wgexample.com")
621+
assert.Equal(b, 200, w.Code)
622+
}
604623
})
605624
}

0 commit comments

Comments
 (0)