Skip to content

Commit

Permalink
feat: export token expiration error
Browse files Browse the repository at this point in the history
try to match style

follow naming convention

tests

revert import change

match format
  • Loading branch information
crenshaw-dev committed Jun 23, 2022
1 parent 26c5037 commit fb92e10
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
21 changes: 20 additions & 1 deletion oidc/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ const (
issuerGoogleAccountsNoScheme = "accounts.google.com"
)

// TokenExpiredError indicates that Verify failed because the token was expired. This
// error does NOT indicate that the token is not also invalid for other reasons. Other
// checks might have failed if the expiration check had not failed.
type TokenExpiredError struct {
// Expiry is the time when the token expired.
Expiry time.Time
}

func (e *TokenExpiredError) Error() string {
return fmt.Sprintf("oidc: token is expired (Token Expiry: %v)", e.Expiry)
}

// Is returns true if the error was due to an expired token. It does not check that
// the expiry time is identical.
func (e *TokenExpiredError) Is(target error) bool {
_, ok := target.(*TokenExpiredError)
return ok
}

// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
// of JSON web tokens. This is expected to be backed by a remote key set through
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
Expand Down Expand Up @@ -260,7 +279,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
nowTime := now()

if t.Expiry.Before(nowTime) {
return nil, fmt.Errorf("oidc: token is expired (Token Expiry: %v)", t.Expiry)
return nil, &TokenExpiredError{Expiry: t.Expiry}
}

// If nbf claim is provided in token, ensure that it is indeed in the past.
Expand Down
31 changes: 20 additions & 11 deletions oidc/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"context"
"crypto"
"errors"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -12,6 +13,9 @@ import (
"time"
)

// anyError is a fake error to match any error returned in a test.
var anyError = errors.New("")

func TestVerify(t *testing.T) {
tests := []verificationTest{
{
Expand All @@ -32,7 +36,7 @@ func TestVerify(t *testing.T) {
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
wantErr: true,
wantErr: anyError,
},
{
name: "skip issuer check",
Expand All @@ -54,7 +58,7 @@ func TestVerify(t *testing.T) {
},
signKey: newRSAKey(t),
verificationKey: newRSAKey(t),
wantErr: true,
wantErr: anyError,
},
{
name: "google accounts without scheme",
Expand All @@ -73,7 +77,7 @@ func TestVerify(t *testing.T) {
SkipClientIDCheck: true,
},
signKey: newRSAKey(t),
wantErr: true,
wantErr: &TokenExpiredError{},
},
{
name: "unexpired token",
Expand Down Expand Up @@ -101,7 +105,7 @@ func TestVerify(t *testing.T) {
SkipClientIDCheck: true,
},
signKey: newRSAKey(t),
wantErr: true,
wantErr: anyError,
},
{
name: "nbf in past",
Expand Down Expand Up @@ -146,7 +150,7 @@ func TestVerifyAudience(t *testing.T) {
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
wantErr: true,
wantErr: anyError,
},
{
name: "multiple audiences, one matches",
Expand Down Expand Up @@ -182,7 +186,7 @@ func TestVerifySigningAlg(t *testing.T) {
SkipExpiryCheck: true,
},
signKey: newECDSAKey(t),
wantErr: true,
wantErr: anyError,
},
{
name: "ecdsa signing",
Expand Down Expand Up @@ -213,7 +217,7 @@ func TestVerifySigningAlg(t *testing.T) {
SkipExpiryCheck: true,
},
signKey: newECDSAKey(t),
wantErr: true,
wantErr: anyError,
},
}
for _, test := range tests {
Expand Down Expand Up @@ -531,7 +535,7 @@ type verificationTest struct {
verificationKey *signingKey

config Config
wantErr bool
wantErr error
}

func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) {
Expand All @@ -557,10 +561,15 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) {

func (v verificationTest) run(t *testing.T) {
_, err := v.runGetToken(t)
if err != nil && !v.wantErr {
t.Errorf("%v", err)
if err != nil {
if v.wantErr == nil {
t.Errorf("%v", err)
}
if !errors.Is(v.wantErr, anyError) && !errors.Is(err, v.wantErr) {
t.Errorf("expected error %q but got %q", v.wantErr, err)
}
}
if err == nil && v.wantErr {
if err == nil && v.wantErr != nil {
t.Errorf("expected error")
}
}

0 comments on commit fb92e10

Please sign in to comment.