Skip to content

Commit a0cbeaa

Browse files
committed
feat: massively improve cors wildcard performance
1 parent d3608af commit a0cbeaa

File tree

5 files changed

+304
-131
lines changed

5 files changed

+304
-131
lines changed

router/pkg/cors/config.go

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
package cors
22

33
import (
4+
"maps"
45
"net/http"
56
"slices"
6-
"strings"
77
)
88

99
type cors struct {
10-
allowAllOrigins bool
11-
allowCredentials bool
12-
allowOriginFunc func(string) bool
13-
allowOrigins []string
14-
normalHeaders http.Header
15-
preflightHeaders http.Header
16-
wildcardOrigins [][]string
17-
handler http.Handler
10+
allowAllOrigins bool
11+
allowCredentials bool
12+
allowOriginFunc func(string) bool
13+
allowOrigins []string
14+
normalHeaders http.Header
15+
preflightHeaders http.Header
16+
wildcardOrigins []*WildcardPattern
17+
handler http.Handler
18+
disableMaxOriginLength bool
1819
}
1920

2021
var (
21-
maxRecursionDepth = 10 // Safeguard against deep recursion
22-
maxOriginLength = 1024 // Maximum length of an origin string
23-
DefaultSchemas = []string{
22+
maxRecursionDepth = 10 // Safeguard against deep recursion
23+
maxWildcardOriginLength = 4096 // Maximum length of an origin string for it to be eligible for wildcard matching
24+
DefaultSchemas = []string{
2425
"http://",
2526
"https://",
2627
}
@@ -57,7 +58,7 @@ func newCors(handler http.Handler, config Config) *cors {
5758
allowOrigins: normalize(config.AllowOrigins),
5859
normalHeaders: generateNormalHeaders(config),
5960
preflightHeaders: generatePreflightHeaders(config),
60-
wildcardOrigins: config.parseWildcardRules(),
61+
wildcardOrigins: config.parseNewWildcardRules(),
6162
handler: handler,
6263
}
6364
}
@@ -117,71 +118,25 @@ func (cors *cors) validateOrigin(origin string) bool {
117118
}
118119

119120
func (cors *cors) validateWildcardOrigin(origin string) bool {
120-
for _, w := range cors.wildcardOrigins {
121-
if matchOriginWithRule(origin, w, 0, map[string]bool{}) {
122-
return true
123-
}
124-
}
125-
return false
126-
}
127-
128-
// Recursive helper function with depth limit and memoization
129-
func matchOriginWithRule(origin string, rule []string, depth int, memo map[string]bool) bool {
130-
if depth > maxRecursionDepth {
131-
return false // Exceeded recursion depth
132-
}
133-
134-
// Memoization key
135-
key := origin + "|" + strings.Join(rule, "|")
136-
if val, exists := memo[key]; exists {
137-
return val
138-
}
139-
140-
if len(rule) == 0 {
141-
// Successfully matched if origin is also fully consumed
142-
return origin == ""
143-
}
144-
145-
part := rule[0]
146-
147-
if part == "*" {
148-
if len(origin) > maxOriginLength {
149-
return false
150-
}
151-
152-
// Try to match the remaining rule by advancing in origin
153-
for i := 0; i <= len(origin); i++ {
154-
if matchOriginWithRule(origin[i:], rule[1:], depth+1, memo) {
155-
memo[key] = true
156-
return true
157-
}
158-
}
159-
memo[key] = false
121+
// Origin is >4KB, avoid matching it for performance
122+
if len(origin) > maxWildcardOriginLength {
160123
return false
161124
}
162125

163-
// Check if the origin starts with the current part
164-
if strings.HasPrefix(origin, part) {
165-
// Recursively check the rest of the origin and rule
166-
result := matchOriginWithRule(origin[len(part):], rule[1:], depth+1, memo)
167-
memo[key] = result
168-
return result
126+
for _, w := range cors.wildcardOrigins {
127+
if w.Match(origin) {
128+
return true
129+
}
169130
}
170-
171-
memo[key] = false
172131
return false
173132
}
174133

175134
func (cors *cors) handlePreflight(w http.ResponseWriter) {
176135
header := w.Header()
177-
for key, value := range cors.preflightHeaders {
178-
header[key] = value
179-
}
136+
maps.Copy(header, cors.preflightHeaders)
180137
}
181138

182139
func (cors *cors) handleNormal(w http.ResponseWriter) {
183140
header := w.Header()
184-
for key, value := range cors.normalHeaders {
185-
header[key] = value
186-
}
141+
maps.Copy(header, cors.normalHeaders)
187142
}

router/pkg/cors/cors.go

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,35 +108,16 @@ func (c *Config) Validate() error {
108108
return nil
109109
}
110110

111-
func (c *Config) parseWildcardRules() [][]string {
112-
var wRules [][]string
111+
func (c *Config) parseNewWildcardRules() []*WildcardPattern {
112+
var wRules []*WildcardPattern
113113

114114
for _, o := range c.AllowOrigins {
115115
if !strings.Contains(o, "*") {
116116
continue
117117
}
118118

119-
// Split origin by wildcard (*)
120-
parts := strings.Split(o, "*")
121-
122-
// If there’s no wildcard, skip this origin
123-
if len(parts) == 1 {
124-
continue
125-
}
126-
127-
// Generate rules for origins with multiple wildcard segments
128-
var rule []string
129-
for i, part := range parts {
130-
if i > 0 {
131-
rule = append(rule, "*") // Add wildcard indicator between segments
132-
}
133-
if part != "" {
134-
rule = append(rule, part)
135-
}
136-
}
137-
138-
// Add parsed rule to wRules
139-
wRules = append(wRules, rule)
119+
wp := Compile(o)
120+
wRules = append(wRules, wp)
140121
}
141122

142123
return wRules

router/pkg/cors/cors_test.go

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func TestExtremeLengthOriginKillswitch(t *testing.T) {
229229

230230
shortSubdomain := strings.Repeat("a", 10)
231231
longSubdomain := strings.Repeat("a", 500)
232-
tooLongSubdomain := strings.Repeat("a", 2000)
232+
tooLongSubdomain := strings.Repeat("a", 4096)
233233

234234
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", shortSubdomain)))
235235
assert.True(t, cors.validateOrigin(fmt.Sprintf("https://%s.google.com", longSubdomain)))
@@ -543,29 +543,10 @@ func TestComplexWildcards(t *testing.T) {
543543
}
544544
for _, tc := range testCasesList {
545545
w := performRequest(router, "GET", tc.origin)
546-
assert.Equal(t, tc.expectedCode, w.Code)
546+
assert.Equalf(t, tc.expectedCode, w.Code, "expected %d for %s, got %d", tc.expectedCode, tc.origin, w.Code)
547547
}
548548
}
549549

550-
func TestMaxRecursionDepth(t *testing.T) {
551-
router := newTestRouter(Config{
552-
Enabled: true,
553-
AllowOrigins: []string{
554-
"https://*.example.*.*.com", // multiple sequential wildcards
555-
"https://*.*.*.*.com",
556-
},
557-
AllowMethods: []string{"GET"},
558-
})
559-
560-
maxRecursionDepth = 2
561-
w := performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
562-
assert.Equal(t, 403, w.Code)
563-
564-
maxRecursionDepth = 10
565-
w = performRequest(router, "GET", "https://subdomain.example.subdomain.example.com")
566-
assert.Equal(t, 200, w.Code)
567-
}
568-
569550
func TestDisabled(t *testing.T) {
570551
config := Config{
571552
Enabled: true,
@@ -585,35 +566,26 @@ func TestDisabled(t *testing.T) {
585566
assert.Equal(t, 200, w.Code)
586567
}
587568

588-
func BenchmarkCorsWithoutWildcards(b *testing.B) {
589-
b.ReportAllocs()
590-
b.ResetTimer()
591-
592-
b.Run("without wildcards", func(b *testing.B) {
569+
func BenchmarkCorsWithWildcards(b *testing.B) {
570+
b.Run("with wildcards", func(b *testing.B) {
593571
router := newTestRouter(Config{
594572
Enabled: true,
595573
AllowOrigins: []string{
596-
"https://*.wgexample.com",
597-
"https://wgexample.com",
598-
"https://*.wgexample.io:*",
599-
"https://*.wgexample.org",
600-
"https://*.d2grknavcceso7.amplifyapp.com",
601574
"https://*.example.*.*.com", // multiple sequential wildcards
602575
"https://*.*.*.*.com",
603576
},
604577
AllowMethods: []string{"GET"},
605578
})
606579

607-
w := performRequest(router, "GET", "https://wgexample.com")
608-
assert.Equal(b, 200, w.Code)
580+
b.ReportAllocs()
581+
b.ResetTimer()
582+
for i := 0; i < b.N; i++ {
583+
w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
584+
assert.Equal(b, 200, w.Code)
585+
}
609586
})
610-
}
611587

612-
func BenchmarkCorsWithWildcards(b *testing.B) {
613-
b.ReportAllocs()
614-
b.ResetTimer()
615-
616-
b.Run("with wildcards", func(b *testing.B) {
588+
b.Run("with massive wildcards", func(b *testing.B) {
617589
router := newTestRouter(Config{
618590
Enabled: true,
619591
AllowOrigins: []string{
@@ -623,7 +595,30 @@ func BenchmarkCorsWithWildcards(b *testing.B) {
623595
AllowMethods: []string{"GET"},
624596
})
625597

626-
w := performRequest(router, "GET", "https://subdomain.test.example.subdomain.example.co.whatgoeshere.woohoo.com")
627-
assert.Equal(b, 200, w.Code)
598+
longString := strings.Repeat("a", 50000)
599+
600+
b.ReportAllocs()
601+
b.ResetTimer()
602+
for i := 0; i < b.N; i++ {
603+
w := performRequest(router, "GET", fmt.Sprintf("https://%[1]s.%[1]s.%[1]s.%[1]s.com", longString))
604+
assert.Equal(b, 200, w.Code)
605+
}
606+
})
607+
608+
b.Run("without wildcards", func(b *testing.B) {
609+
router := newTestRouter(Config{
610+
Enabled: true,
611+
AllowOrigins: []string{
612+
"https://wgexample.com",
613+
},
614+
AllowMethods: []string{"GET"},
615+
})
616+
617+
b.ReportAllocs()
618+
b.ResetTimer()
619+
for i := 0; i < b.N; i++ {
620+
w := performRequest(router, "GET", "https://wgexample.com")
621+
assert.Equal(b, 200, w.Code)
622+
}
628623
})
629624
}

0 commit comments

Comments
 (0)