Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package jwt
import (
"context"
"crypto/rsa"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -181,6 +182,9 @@ type HertzJWTMiddleware struct {

// CookieSameSite allow use protocol.CookieSameSite cookie param
CookieSameSite protocol.CookieSameSite

// ParseOptions allow to modify jwt's parser methods
ParseOptions []jwt.ParserOption
}

var (
Expand Down Expand Up @@ -447,19 +451,27 @@ func (mw *HertzJWTMiddleware) middlewareImpl(ctx context.Context, c *app.Request
return
}

if claims["exp"] == nil {
switch v := claims["exp"].(type) {
case nil:
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, ctx, c))
return
}

if _, ok := claims["exp"].(float64); !ok {
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
return
}

if int64(claims["exp"].(float64)) < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
case float64:
if int64(v) < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
}
case json.Number:
n, err := v.Int64()
if err != nil {
mw.unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
return
}
if n < mw.TimeFunc().Unix() {
mw.unauthorized(ctx, c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, ctx, c))
return
}
default:
mw.Unauthorized(ctx, c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, ctx, c))
}

c.Set("JWT_PAYLOAD", claims)
Expand Down Expand Up @@ -728,7 +740,7 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont
}

if mw.KeyFunc != nil {
return jwt.Parse(token, mw.KeyFunc)
return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...)
}

return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
Expand All @@ -743,13 +755,13 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont
c.Set("JWT_TOKEN", token)

return mw.Key, nil
})
}, mw.ParseOptions...)
}

// ParseTokenString parse jwt token string
func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) {
if mw.KeyFunc != nil {
return jwt.Parse(token, mw.KeyFunc)
return jwt.Parse(token, mw.KeyFunc, mw.ParseOptions...)
}

return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
Expand All @@ -761,7 +773,7 @@ func (mw *HertzJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error)
}

return mw.Key, nil
})
}, mw.ParseOptions...)
}

func (mw *HertzJWTMiddleware) unauthorized(ctx context.Context, c *app.RequestContext, code int, message string) {
Expand Down
49 changes: 49 additions & 0 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package jwt
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -90,6 +91,28 @@ func makeTokenString(SigningAlgorithm, username string) string {
return tokenString
}

func makeTokenStringWithUserID(SigningAlgorithm string, userID int64) string {
if SigningAlgorithm == "" {
SigningAlgorithm = "HS256"
}

token := jwt.New(jwt.GetSigningMethod(SigningAlgorithm))
claims := token.Claims.(jwt.MapClaims)
claims["identity"] = userID
claims["exp"] = time.Now().Add(time.Hour).Unix()
claims["orig_iat"] = time.Now().Unix()
var tokenString string
if SigningAlgorithm == "RS256" {
keyData, _ := ioutil.ReadFile("testdata/jwtRS256.key")
signKey, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
tokenString, _ = token.SignedString(signKey)
} else {
tokenString, _ = token.SignedString(key)
}

return tokenString
}

func keyFunc(token *jwt.Token) (interface{}, error) {
cert, err := ioutil.ReadFile("testdata/jwtRS256.key.pub")
if err != nil {
Expand Down Expand Up @@ -533,6 +556,32 @@ func TestAuthorizator(t *testing.T) {
assert.DeepEqual(t, http.StatusOK, w.Code)
}

func TestParseTokenWithJsonNumber(t *testing.T) {
var userID int64 = 64
authMiddleware, _ := New(&HertzJWTMiddleware{
Realm: "test zone",
Key: key,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
IdentityHandler: func(ctx context.Context, c *app.RequestContext) interface{} {
claims := ExtractClaims(ctx, c)
testNum, err := claims["identity"].(json.Number).Int64()
assert.Nil(t, err)
assert.DeepEqual(t, userID, testNum)
return testNum
},
Unauthorized: func(ctx context.Context, c *app.RequestContext, code int, message string) {
c.String(code, message)
},
ParseOptions: []jwt.ParserOption{jwt.WithJSONNumber()},
})

handler := hertzHandler(authMiddleware)

w := ut.PerformRequest(handler, http.MethodGet, "/auth/hello", nil, ut.Header{Key: "Authorization", Value: "Bearer " + makeTokenStringWithUserID("HS256", userID)})
assert.DeepEqual(t, http.StatusOK, w.Code)
}

func TestClaimsDuringAuthorization(t *testing.T) {
// the middleware to test
authMiddleware, _ := New(&HertzJWTMiddleware{
Expand Down