@@ -2,9 +2,7 @@ package middleware
2
2
3
3
import (
4
4
"crypto/subtle"
5
- "errors"
6
5
"net/http"
7
- "strings"
8
6
"time"
9
7
10
8
"github.com/labstack/echo/v4"
@@ -21,13 +19,15 @@ type (
21
19
TokenLength uint8 `yaml:"token_length"`
22
20
// Optional. Default value 32.
23
21
24
- // TokenLookup is a string in the form of "<source>:<key >" that is used
22
+ // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name >" that is used
25
23
// to extract token from the request.
26
24
// Optional. Default value "header:X-CSRF-Token".
27
25
// Possible values:
28
- // - "header:<name>"
29
- // - "form:<name>"
26
+ // - "header:<name>" or "header:<name>:<cut-prefix>"
30
27
// - "query:<name>"
28
+ // - "form:<name>"
29
+ // Multiple sources example:
30
+ // - "header:X-CSRF-Token,query:csrf"
31
31
TokenLookup string `yaml:"token_lookup"`
32
32
33
33
// Context key to store generated CSRF token into context.
@@ -62,12 +62,11 @@ type (
62
62
// Optional. Default value SameSiteDefaultMode.
63
63
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
64
64
}
65
-
66
- // csrfTokenExtractor defines a function that takes `echo.Context` and returns
67
- // either a token or an error.
68
- csrfTokenExtractor func (echo.Context ) (string , error )
69
65
)
70
66
67
+ // ErrCSRFInvalid is returned when CSRF check fails
68
+ var ErrCSRFInvalid = echo .NewHTTPError (http .StatusForbidden , "invalid csrf token" )
69
+
71
70
var (
72
71
// DefaultCSRFConfig is the default CSRF middleware config.
73
72
DefaultCSRFConfig = CSRFConfig {
@@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
114
113
config .CookieSecure = true
115
114
}
116
115
117
- // Initialize
118
- parts := strings .Split (config .TokenLookup , ":" )
119
- extractor := csrfTokenFromHeader (parts [1 ])
120
- switch parts [0 ] {
121
- case "form" :
122
- extractor = csrfTokenFromForm (parts [1 ])
123
- case "query" :
124
- extractor = csrfTokenFromQuery (parts [1 ])
116
+ extractors , err := createExtractors (config .TokenLookup , "" )
117
+ if err != nil {
118
+ panic (err )
125
119
}
126
120
127
121
return func (next echo.HandlerFunc ) echo.HandlerFunc {
@@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
130
124
return next (c )
131
125
}
132
126
133
- req := c .Request ()
134
- k , err := c .Cookie (config .CookieName )
135
127
token := ""
136
-
137
- // Generate token
138
- if err != nil {
139
- token = random .String (config .TokenLength )
128
+ if k , err := c .Cookie (config .CookieName ); err != nil {
129
+ token = random .String (config .TokenLength ) // Generate token
140
130
} else {
141
- // Reuse token
142
- token = k .Value
131
+ token = k .Value // Reuse token
143
132
}
144
133
145
- switch req .Method {
134
+ switch c . Request () .Method {
146
135
case http .MethodGet , http .MethodHead , http .MethodOptions , http .MethodTrace :
147
136
default :
148
137
// Validate token only for requests which are not defined as 'safe' by RFC7231
149
- clientToken , err := extractor (c )
150
- if err != nil {
151
- return echo .NewHTTPError (http .StatusBadRequest , err .Error ())
138
+ var lastExtractorErr error
139
+ var lastTokenErr error
140
+ outer:
141
+ for _ , extractor := range extractors {
142
+ clientTokens , err := extractor (c )
143
+ if err != nil {
144
+ lastExtractorErr = err
145
+ continue
146
+ }
147
+
148
+ for _ , clientToken := range clientTokens {
149
+ if validateCSRFToken (token , clientToken ) {
150
+ lastTokenErr = nil
151
+ lastExtractorErr = nil
152
+ break outer
153
+ }
154
+ lastTokenErr = ErrCSRFInvalid
155
+ }
152
156
}
153
- if ! validateCSRFToken (token , clientToken ) {
154
- return echo .NewHTTPError (http .StatusForbidden , "invalid csrf token" )
157
+ if lastTokenErr != nil {
158
+ return lastTokenErr
159
+ } else if lastExtractorErr != nil {
160
+ // ugly part to preserve backwards compatible errors. someone could rely on them
161
+ if lastExtractorErr == errQueryExtractorValueMissing {
162
+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in the query string" )
163
+ } else if lastExtractorErr == errFormExtractorValueMissing {
164
+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in the form parameter" )
165
+ } else if lastExtractorErr == errHeaderExtractorValueMissing {
166
+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , "missing csrf token in request header" )
167
+ } else {
168
+ lastExtractorErr = echo .NewHTTPError (http .StatusBadRequest , lastExtractorErr .Error ())
169
+ }
170
+ return lastExtractorErr
155
171
}
156
172
}
157
173
@@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
184
200
}
185
201
}
186
202
187
- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
188
- // provided request header.
189
- func csrfTokenFromHeader (header string ) csrfTokenExtractor {
190
- return func (c echo.Context ) (string , error ) {
191
- return c .Request ().Header .Get (header ), nil
192
- }
193
- }
194
-
195
- // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
196
- // provided form parameter.
197
- func csrfTokenFromForm (param string ) csrfTokenExtractor {
198
- return func (c echo.Context ) (string , error ) {
199
- token := c .FormValue (param )
200
- if token == "" {
201
- return "" , errors .New ("missing csrf token in the form parameter" )
202
- }
203
- return token , nil
204
- }
205
- }
206
-
207
- // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
208
- // provided query parameter.
209
- func csrfTokenFromQuery (param string ) csrfTokenExtractor {
210
- return func (c echo.Context ) (string , error ) {
211
- token := c .QueryParam (param )
212
- if token == "" {
213
- return "" , errors .New ("missing csrf token in the query string" )
214
- }
215
- return token , nil
216
- }
217
- }
218
-
219
203
func validateCSRFToken (token , clientToken string ) bool {
220
204
return subtle .ConstantTimeCompare ([]byte (token ), []byte (clientToken )) == 1
221
205
}
0 commit comments