diff --git a/node/jwt_handler.go b/node/jwt_handler.go index ffb8c2dbc855..7d41ce38678f 100644 --- a/node/jwt_handler.go +++ b/node/jwt_handler.go @@ -19,35 +19,19 @@ package node import ( "net/http" - "errors" + "fmt" + "github.com/ethereum/go-ethereum/log" "github.com/golang-jwt/jwt/v4" "strings" "time" ) -// customClaim implements claims.Claim. -type customClaim struct { - // the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6 - IssuedAt int64 `json:"iat,omitempty"` -} - -// Valid implements claims.Claim, and checks that the iat is present and valid. -func (c customClaim) Valid() error { - if time.Now().Unix()-5 < c.IssuedAt { - return errors.New("token issuance (iat) is too old") - } - if time.Now().Unix()+5 > c.IssuedAt { - return errors.New("token issuance (iat) is too far in the future") - } - return nil -} - type jwtHandler struct { keyFunc func(token *jwt.Token) (interface{}, error) next http.Handler } -// MakeJWTValidator creates a validator for jwt tokens. +// newJWTHandler creates a http.Handler with jwt authentication support. func newJWTHandler(secret []byte, next http.Handler) http.Handler { return &jwtHandler{ keyFunc: func(token *jwt.Token) (interface{}, error) { @@ -57,6 +41,30 @@ func newJWTHandler(secret []byte, next http.Handler) http.Handler { } } +// customClaim is basically a standard RegisteredClaim, but we override the +// Valid method to be more lax in allowing some time skew. +type customClaim jwt.RegisteredClaims + +// Valid implements jwt.Claim. This method only validates the (optional) expiry-time. +func (c customClaim) Valid() error { + now := jwt.TimeFunc() + rc := jwt.RegisteredClaims(c) + if !rc.VerifyExpiresAt(now, false) { // optional + return fmt.Errorf("token is expired") + } + if c.IssuedAt == nil { + return fmt.Errorf("missing issued-at") + } + if time.Since(c.IssuedAt.Time) > 5*time.Second { + return fmt.Errorf("stale token") + } + if time.Until(c.IssuedAt.Time) > 5*time.Second { + return fmt.Errorf("future token") + } + return nil +} + +// ServeHTTP implements http.Handler func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) { var token string if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { @@ -67,13 +75,13 @@ func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) { return } var claims customClaim - t, err := jwt.ParseWithClaims(token, claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"})) + t, err := jwt.ParseWithClaims(token, &claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"})) if err != nil { + log.Info("Token parsing failed", "err", err) http.Error(out, err.Error(), http.StatusForbidden) return } if !t.Valid { - // This should not happen, but better safe than sorry if the implementation changes. http.Error(out, "invalid token", http.StatusForbidden) return } diff --git a/node/rpcstack.go b/node/rpcstack.go index 5673ee63cc8e..fd07814ab58a 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -373,6 +373,7 @@ func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSe return newGzipHandler(handler) } +// NewWSHandlerStack returns a wrapped ws-related handler. func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler { if len(jwtSecret) != 0 { return newJWTHandler(jwtSecret, srv) diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index e4a06f8e816b..6ee6af076c1e 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -304,39 +304,79 @@ type tokenTest struct { expFail []string } -func TestJWT(t *testing.T) { +type testClaim map[string]interface{} + +func (testClaim) Valid() error { + return nil +} - makeToken := func() string { - mySigningKey := []byte("secret") - // Create the Claims - claims := &jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now()), +func TestJWT(t *testing.T) { + var secret = []byte("secret") + issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string { + if method == nil { + method = jwt.SigningMethodHS256 } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, _ := token.SignedString(mySigningKey) + ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) return ss } - tests := []originTest{ + tests := []tokenTest{ { - //expFail: []string{"Bearer ", "Bearer: abc", "Baxonk hello there"}, expOk: []string{ - fmt.Sprintf("Bearer %v", makeToken()), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "exp": time.Now().Unix() + 2, + })), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "bar": "baz", + })), + }, + expFail: []string{ + // future + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})), + // stale + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})), + // wrong algo + fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})), + // expired + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})), + // missing mandatory iat + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})), + // wrong secret + fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})), + // Various malformed syntax + fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer:\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), }, }, } - for _, tc := range tests { srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) - url := fmt.Sprintf("ws://%v", srv.listenAddr()) + wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) + htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) for i, token := range tc.expOk { - if err := wsRequest(t, url, "Authorization", token); err != nil { - t.Errorf("test %d, token '%v': expected ok, got %v", i, token, err) + if err := wsRequest(t, wsUrl, "Authorization", token); err != nil { + t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) } } for i, token := range tc.expFail { - if err := wsRequest(t, url, "Authorization", token); err == nil { - t.Errorf("tc %d, token '%v': expected not to allow, got ok", i, token) + if err := wsRequest(t, wsUrl, "Authorization", token); err == nil { + t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 { + t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) } } srv.stop()