Skip to content

Commit

Permalink
JWT: Add validate() method
Browse files Browse the repository at this point in the history
This allows callers to deserialize without a key and later validate the
parsed token without having to desearialize again from scratch.

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed May 11, 2022
1 parent bc1fd83 commit 2f30d07
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
100 changes: 61 additions & 39 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,22 +448,6 @@ def _check_provided_claims(self):
"Invalid '%s' value. Expected '%s' got '%s'" % (
name, value, claims[name]))

def _deserialize_token_with_keys(self, jwt, keys):
for k in keys:
try:
self.token.deserialize(jwt, k)
self.deserializelog.append("Success")
break
except Exception as e: # pylint: disable=broad-except
keyid = k.get('kid')
if keyid is None:
keyid = k.thumbprint()
self.deserializelog.append('Key [%s] failed: [%s]' % (
keyid, repr(e)))
continue
if "Success" not in self.deserializelog:
raise JWTMissingKey('No working key found in key set')

def norm_typ(self, val):
lc = val.lower()
if '/' in lc:
Expand Down Expand Up @@ -503,6 +487,63 @@ def make_encrypted_token(self, key):
t.add_recipient(key)
self.token = t

def validate(self, key):
"""Validate a JWT token that was deserialized w/o providing a key
:param key: A (:class:`jwcrypto.jwk.JWK`) verification or
decryption key, or a (:class:`jwcrypto.jwk.JWKSet`) that
contains a key indexed by the 'kid' header.
"""
self.deserializelog = []
if self.token is None:
raise ValueError("Token empty")

if isinstance(key, JWK):
try:
if isinstance(self.token, JWS):
self.token.verify(key)
elif isinstance(self.token, JWE):
self.token.decrypt(key)
else:
raise ValueError("Token format unrecognized")
self.deserializelog.append("Success")
except Exception as e: # pylint: disable=broad-except
self.deserializelog.append(
'Validation failed: [{}]'.format(repr(e)))
raise
elif isinstance(key, JWKSet):
keys = key
if 'kid' in self.token.jose_header:
kid_keys = key.get_keys(self.token.jose_header['kid'])
if not kid_keys:
raise JWTMissingKey('Key ID {} not in key set'.format(
self.token.jose_header['kid']))
keys = kid_keys

for k in keys:
try:
if isinstance(self.token, JWS):
self.token.verify(k)
elif isinstance(self.token, JWE):
self.token.decrypt(k)
else:
raise ValueError("Token format unrecognized")
self.deserializelog.append("Success")
break
except Exception as e: # pylint: disable=broad-except
keyid = k.get('kid', k.thumbprint())
self.deserializelog.append('Key [{}] failed: [{}]'.format(
keyid, repr(e)))
continue
if "Success" not in self.deserializelog:
raise JWTMissingKey('No working key found in key set')
else:
raise ValueError("Unrecognized key type")

self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
self._check_provided_claims()

def deserialize(self, jwt, key=None):
"""Deserialize a JWT token.
Expand All @@ -526,31 +567,12 @@ def deserialize(self, jwt, key=None):
if self._algs:
self.token.allowed_algs = self._algs

self.deserializelog = []
self.deserializelog = None
# now deserialize and also decrypt/verify (or raise) if we
# have a key
if key is None:
self.token.deserialize(jwt, None)
elif isinstance(key, JWK):
self.token.deserialize(jwt, key)
self.deserializelog.append("Success")
elif isinstance(key, JWKSet):
self.token.deserialize(jwt, None)
if 'kid' in self.token.jose_header:
kid_keys = key.get_keys(self.token.jose_header['kid'])
if not kid_keys:
raise JWTMissingKey('Key ID %s not in key set'
% self.token.jose_header['kid'])
self._deserialize_token_with_keys(jwt, kid_keys)
else:
self._deserialize_token_with_keys(jwt, key)
else:
raise ValueError("Unrecognized Key Type")

if key is not None:
self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
self._check_provided_claims()
self.token.deserialize(jwt, None)
if key:
self.validate(key)

def serialize(self, compact=True):
"""Serializes the object into a JWS token.
Expand Down
10 changes: 2 additions & 8 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,10 +2041,7 @@ def test_jwt_equality(self):

c = jwt.JWT.from_jose_token(a.serialize())
self.assertNotEqual(a, c)
# FIXME: replace once JWT.validate(key) is made available
c.token.verify(key)
c.header = c.token.jose_header
c.claims = c.token.payload.decode('utf-8')
c.validate(key)
self.assertEqual(a, c)

ea = jwt.JWT(header={"alg": "A256KW", "enc": "A256CBC-HS512"},
Expand All @@ -2063,10 +2060,7 @@ def test_jwt_equality(self):

ect = jwt.JWT.from_jose_token(ea.serialize())
self.assertNotEqual(ea, ect)
# FIXME: replace once JWT.validate(key) is made available
ect.token.decrypt(key)
ect.header = ect.token.jose_header
ect.claims = ect.token.payload.decode('utf-8')
ect.validate(key)
self.assertEqual(ea, ect)

def test_jwt_representations(self):
Expand Down

0 comments on commit 2f30d07

Please sign in to comment.