Skip to content

Commit 500ec28

Browse files
committed
node: fix tests
1 parent fa0f9b0 commit 500ec28

File tree

3 files changed

+87
-38
lines changed

3 files changed

+87
-38
lines changed

node/jwt_handler.go

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,19 @@ package node
1919
import (
2020
"net/http"
2121

22-
"errors"
22+
"fmt"
23+
"github.com/ethereum/go-ethereum/log"
2324
"github.com/golang-jwt/jwt/v4"
2425
"strings"
2526
"time"
2627
)
2728

28-
// customClaim implements claims.Claim.
29-
type customClaim struct {
30-
// the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
31-
IssuedAt int64 `json:"iat,omitempty"`
32-
}
33-
34-
// Valid implements claims.Claim, and checks that the iat is present and valid.
35-
func (c customClaim) Valid() error {
36-
if time.Now().Unix()-5 < c.IssuedAt {
37-
return errors.New("token issuance (iat) is too old")
38-
}
39-
if time.Now().Unix()+5 > c.IssuedAt {
40-
return errors.New("token issuance (iat) is too far in the future")
41-
}
42-
return nil
43-
}
44-
4529
type jwtHandler struct {
4630
keyFunc func(token *jwt.Token) (interface{}, error)
4731
next http.Handler
4832
}
4933

50-
// MakeJWTValidator creates a validator for jwt tokens.
34+
// newJWTHandler creates a http.Handler with jwt authentication support.
5135
func newJWTHandler(secret []byte, next http.Handler) http.Handler {
5236
return &jwtHandler{
5337
keyFunc: func(token *jwt.Token) (interface{}, error) {
@@ -57,6 +41,30 @@ func newJWTHandler(secret []byte, next http.Handler) http.Handler {
5741
}
5842
}
5943

44+
// customClaim is basically a standard RegisteredClaim, but we override the
45+
// Valid method to be more lax in allowing some time skew.
46+
type customClaim jwt.RegisteredClaims
47+
48+
// Valid implements jwt.Claim. This method only validates the (optional) expiry-time.
49+
func (c customClaim) Valid() error {
50+
now := jwt.TimeFunc()
51+
rc := jwt.RegisteredClaims(c)
52+
if !rc.VerifyExpiresAt(now, false) { // optional
53+
return fmt.Errorf("token is expired")
54+
}
55+
if c.IssuedAt == nil {
56+
return fmt.Errorf("missing issued-at")
57+
}
58+
if time.Since(c.IssuedAt.Time) > 5*time.Second {
59+
return fmt.Errorf("stale token")
60+
}
61+
if time.Until(c.IssuedAt.Time) > 5*time.Second {
62+
return fmt.Errorf("future token")
63+
}
64+
return nil
65+
}
66+
67+
// ServeHTTP implements http.Handler
6068
func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
6169
var token string
6270
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
@@ -67,13 +75,13 @@ func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) {
6775
return
6876
}
6977
var claims customClaim
70-
t, err := jwt.ParseWithClaims(token, claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"}))
78+
t, err := jwt.ParseWithClaims(token, &claims, handler.keyFunc, jwt.WithValidMethods([]string{"HS256"}))
7179
if err != nil {
80+
log.Info("Token parsing failed", "err", err)
7281
http.Error(out, err.Error(), http.StatusForbidden)
7382
return
7483
}
7584
if !t.Valid {
76-
// This should not happen, but better safe than sorry if the implementation changes.
7785
http.Error(out, "invalid token", http.StatusForbidden)
7886
return
7987
}

node/rpcstack.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSe
373373
return newGzipHandler(handler)
374374
}
375375

376+
// NewWSHandlerStack returns a wrapped ws-related handler.
376377
func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler {
377378
if len(jwtSecret) != 0 {
378379
return newJWTHandler(jwtSecret, srv)

node/rpcstack_test.go

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -304,39 +304,79 @@ type tokenTest struct {
304304
expFail []string
305305
}
306306

307-
func TestJWT(t *testing.T) {
307+
type testClaim map[string]interface{}
308+
309+
func (testClaim) Valid() error {
310+
return nil
311+
}
308312

309-
makeToken := func() string {
310-
mySigningKey := []byte("secret")
311-
// Create the Claims
312-
claims := &jwt.RegisteredClaims{
313-
IssuedAt: jwt.NewNumericDate(time.Now()),
313+
func TestJWT(t *testing.T) {
314+
var secret = []byte("secret")
315+
issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string {
316+
if method == nil {
317+
method = jwt.SigningMethodHS256
314318
}
315-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
316-
ss, _ := token.SignedString(mySigningKey)
319+
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
317320
return ss
318321
}
319-
tests := []originTest{
322+
tests := []tokenTest{
320323
{
321-
//expFail: []string{"Bearer ", "Bearer: abc", "Baxonk hello there"},
322324
expOk: []string{
323-
fmt.Sprintf("Bearer %v", makeToken()),
325+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
326+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})),
327+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})),
328+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
329+
"iat": time.Now().Unix(),
330+
"exp": time.Now().Unix() + 2,
331+
})),
332+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
333+
"iat": time.Now().Unix(),
334+
"bar": "baz",
335+
})),
336+
},
337+
expFail: []string{
338+
// future
339+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})),
340+
// stale
341+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})),
342+
// wrong algo
343+
fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})),
344+
// expired
345+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})),
346+
// missing mandatory iat
347+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})),
348+
// wrong secret
349+
fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})),
350+
fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})),
351+
fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})),
352+
// Various malformed syntax
353+
fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
354+
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
355+
fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
356+
fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
357+
fmt.Sprintf("Bearer:\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
324358
},
325359
},
326360
}
327-
328361
for _, tc := range tests {
329362
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
330363
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")})
331-
url := fmt.Sprintf("ws://%v", srv.listenAddr())
364+
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
365+
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
332366
for i, token := range tc.expOk {
333-
if err := wsRequest(t, url, "Authorization", token); err != nil {
334-
t.Errorf("test %d, token '%v': expected ok, got %v", i, token, err)
367+
if err := wsRequest(t, wsUrl, "Authorization", token); err != nil {
368+
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
369+
}
370+
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 {
371+
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
335372
}
336373
}
337374
for i, token := range tc.expFail {
338-
if err := wsRequest(t, url, "Authorization", token); err == nil {
339-
t.Errorf("tc %d, token '%v': expected not to allow, got ok", i, token)
375+
if err := wsRequest(t, wsUrl, "Authorization", token); err == nil {
376+
t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token)
377+
}
378+
if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 {
379+
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
340380
}
341381
}
342382
srv.stop()

0 commit comments

Comments
 (0)