@@ -5,7 +5,11 @@ import (
5
5
"strings"
6
6
)
7
7
8
- func allowCORS (h http.Handler , allowOrigins []string ) http.Handler {
8
+ var (
9
+ defaultAllowHeaders = []string {"Content-Type" , "Accept" , "Authorization" , "Origin" }
10
+ )
11
+
12
+ func allowCORS (h http.Handler , allowOrigins []string , extraAllowHeaders []string ) http.Handler {
9
13
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
10
14
if origin := r .Header .Get ("Origin" ); origin != "" {
11
15
if len (allowOrigins ) > 0 {
@@ -20,16 +24,34 @@ func allowCORS(h http.Handler, allowOrigins []string) http.Handler {
20
24
}
21
25
22
26
if r .Method == "OPTIONS" && r .Header .Get ("Access-Control-Request-Method" ) != "" {
23
- preflightHandler (w , r )
27
+ preflightHandler (w , r , extraAllowHeaders )
24
28
return
25
29
}
26
30
}
27
31
h .ServeHTTP (w , r )
28
32
})
29
33
}
30
34
31
- func preflightHandler (w http.ResponseWriter , r * http.Request ) {
32
- headers := []string {"Content-Type" , "Accept" , "Authorization" , "Origin" }
35
+ func evaluateExtraAllowHeaders (allowHeaders []string ) []string {
36
+ m := map [string ]bool {}
37
+ for _ , h := range defaultAllowHeaders {
38
+ m [h ] = true
39
+ }
40
+
41
+ extraAllowHeaders := []string {}
42
+ for _ , h := range allowHeaders {
43
+ if m [h ] == false {
44
+ extraAllowHeaders = append (extraAllowHeaders , h )
45
+ }
46
+ }
47
+ return extraAllowHeaders
48
+ }
49
+
50
+ func preflightHandler (w http.ResponseWriter , r * http.Request , extraAllowHeaders []string ) {
51
+ headers := defaultAllowHeaders
52
+ if len (extraAllowHeaders ) > 0 {
53
+ headers = append (headers , extraAllowHeaders ... )
54
+ }
33
55
w .Header ().Set ("Access-Control-Allow-Headers" , strings .Join (headers , "," ))
34
56
methods := []string {"GET" , "HEAD" , "POST" , "PUT" , "DELETE" }
35
57
w .Header ().Set ("Access-Control-Allow-Methods" , strings .Join (methods , "," ))
0 commit comments