Skip to content

Commit 4a1ccdf

Browse files
authored
JWT, KeyAuth, CSRF multivalue extractors (#2060)
* CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
1 parent 9e9924d commit 4a1ccdf

File tree

10 files changed

+1562
-395
lines changed

10 files changed

+1562
-395
lines changed

echo.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ type (
111111
}
112112

113113
// MiddlewareFunc defines a function to process middleware.
114-
MiddlewareFunc func(HandlerFunc) HandlerFunc
114+
MiddlewareFunc func(next HandlerFunc) HandlerFunc
115115

116116
// HandlerFunc defines a function to serve HTTP requests.
117-
HandlerFunc func(Context) error
117+
HandlerFunc func(c Context) error
118118

119119
// HTTPErrorHandler is a centralized HTTP error handler.
120120
HTTPErrorHandler func(error, Context)

middleware/csrf.go

Lines changed: 47 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package middleware
22

33
import (
44
"crypto/subtle"
5-
"errors"
65
"net/http"
7-
"strings"
86
"time"
97

108
"github.com/labstack/echo/v4"
@@ -21,13 +19,15 @@ type (
2119
TokenLength uint8 `yaml:"token_length"`
2220
// Optional. Default value 32.
2321

24-
// TokenLookup is a string in the form of "<source>:<key>" that is used
22+
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
2523
// to extract token from the request.
2624
// Optional. Default value "header:X-CSRF-Token".
2725
// Possible values:
28-
// - "header:<name>"
29-
// - "form:<name>"
26+
// - "header:<name>" or "header:<name>:<cut-prefix>"
3027
// - "query:<name>"
28+
// - "form:<name>"
29+
// Multiple sources example:
30+
// - "header:X-CSRF-Token,query:csrf"
3131
TokenLookup string `yaml:"token_lookup"`
3232

3333
// Context key to store generated CSRF token into context.
@@ -62,12 +62,11 @@ type (
6262
// Optional. Default value SameSiteDefaultMode.
6363
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
6464
}
65-
66-
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
67-
// either a token or an error.
68-
csrfTokenExtractor func(echo.Context) (string, error)
6965
)
7066

67+
// ErrCSRFInvalid is returned when CSRF check fails
68+
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
69+
7170
var (
7271
// DefaultCSRFConfig is the default CSRF middleware config.
7372
DefaultCSRFConfig = CSRFConfig{
@@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
114113
config.CookieSecure = true
115114
}
116115

117-
// Initialize
118-
parts := strings.Split(config.TokenLookup, ":")
119-
extractor := csrfTokenFromHeader(parts[1])
120-
switch parts[0] {
121-
case "form":
122-
extractor = csrfTokenFromForm(parts[1])
123-
case "query":
124-
extractor = csrfTokenFromQuery(parts[1])
116+
extractors, err := createExtractors(config.TokenLookup, "")
117+
if err != nil {
118+
panic(err)
125119
}
126120

127121
return func(next echo.HandlerFunc) echo.HandlerFunc {
@@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
130124
return next(c)
131125
}
132126

133-
req := c.Request()
134-
k, err := c.Cookie(config.CookieName)
135127
token := ""
136-
137-
// Generate token
138-
if err != nil {
139-
token = random.String(config.TokenLength)
128+
if k, err := c.Cookie(config.CookieName); err != nil {
129+
token = random.String(config.TokenLength) // Generate token
140130
} else {
141-
// Reuse token
142-
token = k.Value
131+
token = k.Value // Reuse token
143132
}
144133

145-
switch req.Method {
134+
switch c.Request().Method {
146135
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
147136
default:
148137
// Validate token only for requests which are not defined as 'safe' by RFC7231
149-
clientToken, err := extractor(c)
150-
if err != nil {
151-
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
138+
var lastExtractorErr error
139+
var lastTokenErr error
140+
outer:
141+
for _, extractor := range extractors {
142+
clientTokens, err := extractor(c)
143+
if err != nil {
144+
lastExtractorErr = err
145+
continue
146+
}
147+
148+
for _, clientToken := range clientTokens {
149+
if validateCSRFToken(token, clientToken) {
150+
lastTokenErr = nil
151+
lastExtractorErr = nil
152+
break outer
153+
}
154+
lastTokenErr = ErrCSRFInvalid
155+
}
152156
}
153-
if !validateCSRFToken(token, clientToken) {
154-
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
157+
if lastTokenErr != nil {
158+
return lastTokenErr
159+
} else if lastExtractorErr != nil {
160+
// ugly part to preserve backwards compatible errors. someone could rely on them
161+
if lastExtractorErr == errQueryExtractorValueMissing {
162+
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
163+
} else if lastExtractorErr == errFormExtractorValueMissing {
164+
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
165+
} else if lastExtractorErr == errHeaderExtractorValueMissing {
166+
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
167+
} else {
168+
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
169+
}
170+
return lastExtractorErr
155171
}
156172
}
157173

@@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
184200
}
185201
}
186202

187-
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
188-
// provided request header.
189-
func csrfTokenFromHeader(header string) csrfTokenExtractor {
190-
return func(c echo.Context) (string, error) {
191-
return c.Request().Header.Get(header), nil
192-
}
193-
}
194-
195-
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
196-
// provided form parameter.
197-
func csrfTokenFromForm(param string) csrfTokenExtractor {
198-
return func(c echo.Context) (string, error) {
199-
token := c.FormValue(param)
200-
if token == "" {
201-
return "", errors.New("missing csrf token in the form parameter")
202-
}
203-
return token, nil
204-
}
205-
}
206-
207-
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
208-
// provided query parameter.
209-
func csrfTokenFromQuery(param string) csrfTokenExtractor {
210-
return func(c echo.Context) (string, error) {
211-
token := c.QueryParam(param)
212-
if token == "" {
213-
return "", errors.New("missing csrf token in the query string")
214-
}
215-
return token, nil
216-
}
217-
}
218-
219203
func validateCSRFToken(token, clientToken string) bool {
220204
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
221205
}

0 commit comments

Comments
 (0)