Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JWT: Add validate() method #284

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,22 +448,6 @@ def _check_provided_claims(self):
"Invalid '%s' value. Expected '%s' got '%s'" % (
name, value, claims[name]))

def _deserialize_token_with_keys(self, jwt, keys):
for k in keys:
try:
self.token.deserialize(jwt, k)
self.deserializelog.append("Success")
break
except Exception as e: # pylint: disable=broad-except
keyid = k.get('kid')
if keyid is None:
keyid = k.thumbprint()
self.deserializelog.append('Key [%s] failed: [%s]' % (
keyid, repr(e)))
continue
if "Success" not in self.deserializelog:
raise JWTMissingKey('No working key found in key set')

def norm_typ(self, val):
lc = val.lower()
if '/' in lc:
Expand Down Expand Up @@ -503,6 +487,63 @@ def make_encrypted_token(self, key):
t.add_recipient(key)
self.token = t

def validate(self, key):
"""Validate a JWT token that was deserialized w/o providing a key

:param key: A (:class:`jwcrypto.jwk.JWK`) verification or
decryption key, or a (:class:`jwcrypto.jwk.JWKSet`) that
contains a key indexed by the 'kid' header.
"""
self.deserializelog = []
if self.token is None:
raise ValueError("Token empty")

if isinstance(key, JWK):
try:
if isinstance(self.token, JWS):
self.token.verify(key)
elif isinstance(self.token, JWE):
self.token.decrypt(key)
else:
raise ValueError("Token format unrecognized")
self.deserializelog.append("Success")
except Exception as e: # pylint: disable=broad-except
self.deserializelog.append(
'Validation failed: [{}]'.format(repr(e)))
raise
elif isinstance(key, JWKSet):
keys = key
if 'kid' in self.token.jose_header:
kid_keys = key.get_keys(self.token.jose_header['kid'])
if not kid_keys:
raise JWTMissingKey('Key ID {} not in key set'.format(
self.token.jose_header['kid']))
keys = kid_keys

for k in keys:
try:
if isinstance(self.token, JWS):
self.token.verify(k)
elif isinstance(self.token, JWE):
self.token.decrypt(k)
else:
raise ValueError("Token format unrecognized")
self.deserializelog.append("Success")
break
except Exception as e: # pylint: disable=broad-except
keyid = k.get('kid', k.thumbprint())
self.deserializelog.append('Key [{}] failed: [{}]'.format(
keyid, repr(e)))
continue
if "Success" not in self.deserializelog:
raise JWTMissingKey('No working key found in key set')
else:
raise ValueError("Unrecognized key type")

self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
self._check_provided_claims()

def deserialize(self, jwt, key=None):
"""Deserialize a JWT token.

Expand All @@ -526,31 +567,12 @@ def deserialize(self, jwt, key=None):
if self._algs:
self.token.allowed_algs = self._algs

self.deserializelog = []
self.deserializelog = None
# now deserialize and also decrypt/verify (or raise) if we
# have a key
if key is None:
self.token.deserialize(jwt, None)
elif isinstance(key, JWK):
self.token.deserialize(jwt, key)
self.deserializelog.append("Success")
elif isinstance(key, JWKSet):
self.token.deserialize(jwt, None)
if 'kid' in self.token.jose_header:
kid_keys = key.get_keys(self.token.jose_header['kid'])
if not kid_keys:
raise JWTMissingKey('Key ID %s not in key set'
% self.token.jose_header['kid'])
self._deserialize_token_with_keys(jwt, kid_keys)
else:
self._deserialize_token_with_keys(jwt, key)
else:
raise ValueError("Unrecognized Key Type")

if key is not None:
self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
self._check_provided_claims()
self.token.deserialize(jwt, None)
if key:
self.validate(key)

def serialize(self, compact=True):
"""Serializes the object into a JWS token.
Expand Down
10 changes: 2 additions & 8 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,10 +2041,7 @@ def test_jwt_equality(self):

c = jwt.JWT.from_jose_token(a.serialize())
self.assertNotEqual(a, c)
# FIXME: replace once JWT.validate(key) is made available
c.token.verify(key)
c.header = c.token.jose_header
c.claims = c.token.payload.decode('utf-8')
c.validate(key)
self.assertEqual(a, c)

ea = jwt.JWT(header={"alg": "A256KW", "enc": "A256CBC-HS512"},
Expand All @@ -2063,10 +2060,7 @@ def test_jwt_equality(self):

ect = jwt.JWT.from_jose_token(ea.serialize())
self.assertNotEqual(ea, ect)
# FIXME: replace once JWT.validate(key) is made available
ect.token.decrypt(key)
ect.header = ect.token.jose_header
ect.claims = ect.token.payload.decode('utf-8')
ect.validate(key)
self.assertEqual(ea, ect)

def test_jwt_representations(self):
Expand Down