diff --git a/middleware/jwt.go b/middleware/jwt.go index 6c8bcebb4..bce478743 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "net/http" "reflect" @@ -49,7 +50,8 @@ type ( // Optional. Default value "user". ContextKey string - // Claims are extendable claims data defining token content. + // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. + // Not used if custom ParseTokenFunc is set. // Optional. Default value jwt.MapClaims Claims jwt.Claims @@ -74,13 +76,20 @@ type ( // KeyFunc defines a user-defined function that supplies the public key for a token validation. // The function shall take care of verifying the signing algorithm and selecting the proper key. // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // Used by default ParseTokenFunc implementation. // // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. // This is one of the three options to provide a token validation key. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // Required if neither SigningKeys nor SigningKey is provided. + // Not used if custom ParseTokenFunc is set. // Default to an internal implementation verifying the signing algorithm and selecting the proper key. KeyFunc jwt.Keyfunc + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library + ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) } // JWTSuccessHandler defines a function which is executed for a valid token. @@ -140,7 +149,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil { + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { panic("echo: jwt middleware requires signing key") } if config.SigningMethod == "" { @@ -161,6 +170,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.KeyFunc == nil { config.KeyFunc = config.defaultKeyFunc } + if config.ParseTokenFunc == nil { + config.ParseTokenFunc = config.defaultParseToken + } // Initialize // Split sources @@ -214,16 +226,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return err } - token := new(jwt.Token) - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err == nil && token.Valid { + token, err := config.ParseTokenFunc(auth, c) + if err == nil { // Store user information from token into context. c.Set(config.ContextKey, token) if config.SuccessHandler != nil { @@ -246,6 +250,26 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } } +func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { + token := new(jwt.Token) + var err error + // Issue #647, #656 + if _, ok := config.Claims.(jwt.MapClaims); ok { + token, err = jwt.Parse(auth, config.KeyFunc) + } else { + t := reflect.ValueOf(config.Claims).Type().Elem() + claims := reflect.New(t).Interface().(jwt.Claims) + token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) + } + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil +} + // defaultKeyFunc returns a signing key of the given token. func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { // Check the signing method diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 1a0265917..9af4c83d8 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -2,6 +2,7 @@ package middleware import ( "errors" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -404,3 +405,194 @@ func TestJWTwithKID(t *testing.T) { } } } + +func TestJWTConfig_skipper(t *testing.T) { + e := echo.New() + + e.Use(JWTWithConfig(JWTConfig{ + Skipper: func(context echo.Context) bool { + return true // skip everything + }, + SigningKey: []byte("secret"), + })) + + isCalled := false + e.GET("/", func(c echo.Context) error { + isCalled = true + return c.String(http.StatusTeapot, "test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.True(t, isCalled) +} + +func TestJWTConfig_BeforeFunc(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + isCalled := false + e.Use(JWTWithConfig(JWTConfig{ + BeforeFunc: func(context echo.Context) { + isCalled = true + }, + SigningKey: []byte("secret"), + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.True(t, isCalled) +} + +func TestJWTConfig_extractorErrorHandling(t *testing.T) { + var testCases = []struct { + name string + given JWTConfig + expectStatusCode int + }{ + { + name: "ok, ErrorHandler is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandler: func(err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom_error") + }, + }, + expectStatusCode: http.StatusTeapot, + }, + { + name: "ok, ErrorHandlerWithContext is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, context echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "custom_error") + }, + }, + expectStatusCode: http.StatusTeapot, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(JWTWithConfig(tc.given)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatusCode, res.Code) + }) + } +} + +func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { + var testCases = []struct { + name string + given JWTConfig + expectErr string + }{ + { + name: "ok, ErrorHandler is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandler: func(err error) error { + return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) + }, + }, + expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", + }, + { + name: "ok, ErrorHandlerWithContext is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, context echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) + }, + }, + expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + //e.Debug = true + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + config := tc.given + parseTokenCalled := false + config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { + parseTokenCalled = true + return nil, errors.New("parsing failed") + } + e.Use(JWTWithConfig(config)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, tc.expectErr, res.Body.String()) + assert.True(t, parseTokenCalled) + }) + } +} + +func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go` + // with current JWT middleware + signingKey := []byte("secret") + + config := JWTConfig{ + ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != "HS256" { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return signingKey, nil + } + + // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` + token, err := jwt.Parse(auth, keyFunc) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + }, + } + + e.Use(JWTWithConfig(config)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) +}