-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcors.go
270 lines (216 loc) · 7.66 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
package nano
// This cross-origin sharing standard is used to enable cross-site HTTP requests for:
// Invocations of the XMLHttpRequest or Fetch APIs in a cross-site manner, as discussed above.
// Web Fonts (for cross-domain font usage in @font-face within CSS), so that servers can deploy TrueType fonts that can only be cross-site loaded and used by web sites that are permitted to do so.
// WebGL textures.
// Images/video frames drawn to a canvas using drawImage().
import (
"net/http"
"strings"
)
// CORSConfig define nano cors middleware configuration.
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
}
// CORS struct.
type CORS struct {
allowedOrigins []string
allowedMethods []string
allowedHeaders []string
}
// parseRequestHeader is functions to split header string to array of headers.
func parseRequestHeader(header string) []string {
// request does not provide field Access-Control-Request-Header.
if header == "" {
return []string{}
}
// only requested one header.
if !strings.Contains(header, ",") {
return []string{header}
}
result := make([]string, 0)
for _, part := range strings.Split(header, ",") {
result = append(result, strings.Trim(part, " "))
}
return result
}
// SetAllowedOrigins is functions to fill/replace all allowed origins.
func (cors *CORS) SetAllowedOrigins(origins []string) {
cors.allowedOrigins = origins
}
// SetAllowedMethods is functions to fill/replace all allowed methods.
func (cors *CORS) SetAllowedMethods(methods []string) {
cors.allowedMethods = methods
}
// SetAllowedHeaders is functions to fill/replace all allowed headers.
func (cors *CORS) SetAllowedHeaders(headers []string) {
cors.allowedHeaders = headers
}
// AddAllowedHeader is functions to append method to allowed list.
func (cors *CORS) AddAllowedHeader(header string) {
cors.allowedHeaders = append(cors.allowedHeaders, header)
}
// AddAllowedMethod is functions to append method to allowed list.
func (cors *CORS) AddAllowedMethod(method string) {
cors.allowedMethods = append(cors.allowedMethods, method)
}
// AddAllowedOrigin is functions to append method to allowed list.
func (cors *CORS) AddAllowedOrigin(origin string) {
cors.allowedOrigins = append(cors.allowedOrigins, origin)
}
// isAllowAllOrigin returns true when there is * wildcrad in the origin list.
func (cors *CORS) isAllowAllOrigin() bool {
for _, origin := range cors.allowedOrigins {
if origin == "*" {
return true
}
}
return false
}
// isOriginAllowed returns true when origin found in allowed origin list.
func (cors *CORS) isOriginAllowed(requestOrigin string) bool {
for _, origin := range cors.allowedOrigins {
if origin == requestOrigin || origin == "*" {
return true
}
}
return false
}
// isMethodAllowed returns true when method found in allowed method list.
func (cors *CORS) isMethodAllowed(requestMethod string) bool {
for _, method := range cors.allowedMethods {
if method == requestMethod {
return true
}
}
return false
}
// mergeMethods is functions to stringify the allowed method list.
func (cors *CORS) mergeMethods() string {
// when there is found * wildcard in the list, so just return it.
for _, method := range cors.allowedMethods {
if method == "*" {
return method
}
}
return strings.Join(cors.allowedMethods, ", ")
}
// isAllHeaderAllowed returns true when there is * wildcrad in the allowed header list.
func (cors *CORS) isAllHeaderAllowed() bool {
for _, header := range cors.allowedHeaders {
if header == "*" {
return true
}
}
return false
}
// areHeadersAllowed is functions to check are all requested headers are allowed
func (cors *CORS) areHeadersAllowed(requestedHeaders []string) bool {
// alway return true if there is no control header.
if cors.isAllHeaderAllowed() {
return true
}
for _, requestedHeader := range requestedHeaders {
allowed := false
for _, allowedHeader := range cors.allowedHeaders {
if allowedHeader == requestedHeader {
allowed = true
}
}
if !allowed {
return false
}
}
return true
}
// handlePrefilghtRequest is functions to handle cross-origin preflight request.
func (cors *CORS) handlePrefilghtRequest(c *Context) {
if c.Origin == "" {
return
}
if !cors.isOriginAllowed(c.Origin) {
return
}
requestedMethod := c.GetRequestHeader(HeaderAccessControlRequestMethod)
if !cors.isMethodAllowed(requestedMethod) {
return
}
requestedHeader := c.GetRequestHeader(HeaderAccessControlRequestHeader)
requestedHeaders := parseRequestHeader(requestedHeader)
if len(requestedHeaders) > 0 {
if !cors.areHeadersAllowed(requestedHeaders) {
return
}
}
// vary must be set.
c.SetHeader(HeaderVary, "Origin, Access-Control-Request-Methods, Access-Control-Request-Header")
if cors.isAllowAllOrigin() {
c.SetHeader(HeaderAccessControlAllowOrigin, "*")
} else {
c.SetHeader(HeaderAccessControlAllowOrigin, c.Origin)
}
c.SetHeader(HeaderAccessControlAllowMethods, cors.mergeMethods())
if len(requestedHeader) > 0 {
c.SetHeader(HeaderAccessControlAllowHeader, requestedHeader)
}
}
// handleSimpleRequest is functions to handle simple cross origin request
func (cors *CORS) handleSimpleRequest(c *Context) {
if c.Origin == "" {
return
}
if !cors.isOriginAllowed(c.Origin) {
return
}
// vary must be set.
c.SetHeader(HeaderVary, HeaderOrigin)
if cors.isAllowAllOrigin() {
c.SetHeader(HeaderAccessControlAllowOrigin, "*")
} else {
c.SetHeader(HeaderAccessControlAllowOrigin, c.Origin)
}
}
// Handle corss-origin request
// The Cross-Origin Resource Sharing standard works by adding new HTTP headers that allow servers
// to describe the set of origins that are permitted to read that information using a web browser.
// Additionally, for HTTP request methods that can cause side-effects on server's data
// (in particular, for HTTP methods other than GET, or for POST usage with certain MIME types),
// the specification mandates that browsers "preflight" the request,
// soliciting supported methods from the server with an HTTP OPTIONS request method,
// and then, upon "approval" from the server, sending the actual request with the actual HTTP request method.
// Servers can also notify clients whether "credentials" (including Cookies and HTTP Authentication data) should be sent with requests.
func (cors *CORS) Handle(c *Context) {
// preflighted requests first send an HTTP request by the OPTIONS method to the resource on the other domain,
// in order to determine whether the actual request is safe to send.
// Cross-site requests are preflighted like this since they may have implications to user data.
if c.Method == http.MethodOptions && c.GetRequestHeader(HeaderAccessControlRequestMethod) != "" {
cors.handlePrefilghtRequest(c)
return
}
// Some requests don’t trigger a CORS preflight. Those are called “simple requests”,
// though the Fetch spec (which defines CORS) doesn’t use that term.
// A request that doesn’t trigger a CORS preflight—a so-called “simple request”
cors.handleSimpleRequest(c)
c.Next()
}
// CORSWithConfig returns cors middleware.
func CORSWithConfig(config CORSConfig) HandlerFunc {
cors := new(CORS)
// create default value for all configuration field.
// default value is allowed for all origin, methods, and headers.
if len(config.AllowedMethods) == 0 {
config.AllowedMethods = []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodGet}
}
if len(config.AllowedOrigins) == 0 {
config.AllowedOrigins = []string{"*"}
}
if len(config.AllowedHeaders) == 0 {
config.AllowedHeaders = []string{"*"}
}
cors.SetAllowedMethods(config.AllowedMethods)
cors.SetAllowedOrigins(config.AllowedOrigins)
cors.SetAllowedHeaders(config.AllowedHeaders)
return cors.Handle
}