diff --git a/jwcrypto/jwt.py b/jwcrypto/jwt.py index ff8c899..12beb04 100644 --- a/jwcrypto/jwt.py +++ b/jwcrypto/jwt.py @@ -448,22 +448,6 @@ def _check_provided_claims(self): "Invalid '%s' value. Expected '%s' got '%s'" % ( name, value, claims[name])) - def _deserialize_token_with_keys(self, jwt, keys): - for k in keys: - try: - self.token.deserialize(jwt, k) - self.deserializelog.append("Success") - break - except Exception as e: # pylint: disable=broad-except - keyid = k.get('kid') - if keyid is None: - keyid = k.thumbprint() - self.deserializelog.append('Key [%s] failed: [%s]' % ( - keyid, repr(e))) - continue - if "Success" not in self.deserializelog: - raise JWTMissingKey('No working key found in key set') - def norm_typ(self, val): lc = val.lower() if '/' in lc: @@ -503,6 +487,63 @@ def make_encrypted_token(self, key): t.add_recipient(key) self.token = t + def validate(self, key): + """Validate a JWT token that was deserialized w/o providing a key + + :param key: A (:class:`jwcrypto.jwk.JWK`) verification or + decryption key, or a (:class:`jwcrypto.jwk.JWKSet`) that + contains a key indexed by the 'kid' header. + """ + self.deserializelog = [] + if self.token is None: + raise ValueError("Token empty") + + if isinstance(key, JWK): + try: + if isinstance(self.token, JWS): + self.token.verify(key) + elif isinstance(self.token, JWE): + self.token.decrypt(key) + else: + raise ValueError("Token format unrecognized") + self.deserializelog.append("Success") + except Exception as e: # pylint: disable=broad-except + self.deserializelog.append( + 'Validation failed: [{}]'.format(repr(e))) + raise + elif isinstance(key, JWKSet): + keys = key + if 'kid' in self.token.jose_header: + kid_keys = key.get_keys(self.token.jose_header['kid']) + if not kid_keys: + raise JWTMissingKey('Key ID {} not in key set'.format( + self.token.jose_header['kid'])) + keys = kid_keys + + for k in keys: + try: + if isinstance(self.token, JWS): + self.token.verify(k) + elif isinstance(self.token, JWE): + self.token.decrypt(k) + else: + raise ValueError("Token format unrecognized") + self.deserializelog.append("Success") + break + except Exception as e: # pylint: disable=broad-except + keyid = k.get('kid', k.thumbprint()) + self.deserializelog.append('Key [{}] failed: [{}]'.format( + keyid, repr(e))) + continue + if "Success" not in self.deserializelog: + raise JWTMissingKey('No working key found in key set') + else: + raise ValueError("Unrecognized key type") + + self.header = self.token.jose_header + self.claims = self.token.payload.decode('utf-8') + self._check_provided_claims() + def deserialize(self, jwt, key=None): """Deserialize a JWT token. @@ -526,31 +567,12 @@ def deserialize(self, jwt, key=None): if self._algs: self.token.allowed_algs = self._algs - self.deserializelog = [] + self.deserializelog = None # now deserialize and also decrypt/verify (or raise) if we # have a key - if key is None: - self.token.deserialize(jwt, None) - elif isinstance(key, JWK): - self.token.deserialize(jwt, key) - self.deserializelog.append("Success") - elif isinstance(key, JWKSet): - self.token.deserialize(jwt, None) - if 'kid' in self.token.jose_header: - kid_keys = key.get_keys(self.token.jose_header['kid']) - if not kid_keys: - raise JWTMissingKey('Key ID %s not in key set' - % self.token.jose_header['kid']) - self._deserialize_token_with_keys(jwt, kid_keys) - else: - self._deserialize_token_with_keys(jwt, key) - else: - raise ValueError("Unrecognized Key Type") - - if key is not None: - self.header = self.token.jose_header - self.claims = self.token.payload.decode('utf-8') - self._check_provided_claims() + self.token.deserialize(jwt, None) + if key: + self.validate(key) def serialize(self, compact=True): """Serializes the object into a JWS token. diff --git a/jwcrypto/tests.py b/jwcrypto/tests.py index 5feffc8..51c6f10 100644 --- a/jwcrypto/tests.py +++ b/jwcrypto/tests.py @@ -2041,10 +2041,7 @@ def test_jwt_equality(self): c = jwt.JWT.from_jose_token(a.serialize()) self.assertNotEqual(a, c) - # FIXME: replace once JWT.validate(key) is made available - c.token.verify(key) - c.header = c.token.jose_header - c.claims = c.token.payload.decode('utf-8') + c.validate(key) self.assertEqual(a, c) ea = jwt.JWT(header={"alg": "A256KW", "enc": "A256CBC-HS512"}, @@ -2063,10 +2060,7 @@ def test_jwt_equality(self): ect = jwt.JWT.from_jose_token(ea.serialize()) self.assertNotEqual(ea, ect) - # FIXME: replace once JWT.validate(key) is made available - ect.token.decrypt(key) - ect.header = ect.token.jose_header - ect.claims = ect.token.payload.decode('utf-8') + ect.validate(key) self.assertEqual(ea, ect) def test_jwt_representations(self):