diff --git a/CHANGES b/CHANGES index 7420881..1b76522 100644 --- a/CHANGES +++ b/CHANGES @@ -1,6 +1,13 @@ CHANGES ======= +0.3.0 (2015-04-10) +------------------ +- Fixed critical JWT vulnerability (patch contributed by yuriikonovaliuk) + +Important: Only unencrypted tokens are vulnerable. This fix lead to backward +incompatible change to `verify` function signature. + 0.2.2 (2015-01-07) ------------------ - RFC compliance fixes (patch contributed by jaimeperez) diff --git a/CONTRIB b/CONTRIB index 4072377..13d730e 100644 --- a/CONTRIB +++ b/CONTRIB @@ -5,3 +5,4 @@ Demian Brecht (demianbrecht) Nick Murtagh (nmurtagh) Jakub Warmuz (kuba) Jaime PĂ©rez (jaimeperez) +Yurii Konovaliuk (yuriikonovaliuk) diff --git a/jose.py b/jose.py index 192fd40..fa2859c 100644 --- a/jose.py +++ b/jose.py @@ -144,7 +144,7 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', # promote the temp key to the header assert _TEMP_VER_KEY not in header - header[_TEMP_VER_KEY] = claims[_TEMP_VER_KEY] + header[_TEMP_VER_KEY] = claims[_TEMP_VER_KEY] plaintext = json_encode(claims) @@ -216,7 +216,7 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): else: plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) hash = hash_fn(_jwe_hash_str(plaintext, iv, adata, True), - encryption_key[-mod.digest_size:], mod=mod) + encryption_key[-mod.digest_size:], mod=mod) if not const_compare(auth_tag(hash), tag): raise Error('Mismatched authentication tags') @@ -265,12 +265,13 @@ def sign(claims, jwk, add_header=None, alg='HS256'): return JWS(header, payload, sig) -def verify(jws, jwk, validate_claims=True, expiry_seconds=None): +def verify(jws, jwk, alg, validate_claims=True, expiry_seconds=None): """ Verifies the given :class:`~jose.JWS` :param jws: The :class:`~jose.JWS` to be verified. :param jwk: A `dict` representing the JWK to use for verification. This parameter is algorithm-specific. + :param alg: The algorithm to verify the signature with. :param validate_claims: A `bool` indicating whether or not the `exp`, `iat` and `nbf` claims should be validated. Defaults to `True`. @@ -284,6 +285,9 @@ def verify(jws, jwk, validate_claims=True, expiry_seconds=None): """ header, payload, sig = map(b64decode_url, jws) header = json_decode(header) + if alg != header['alg']: + raise Error('Invalid algorithm') + (_, verify_fn), mod = JWA[header['alg']] if not verify_fn(_jws_hash_str(jws.header, jws.payload), diff --git a/setup.py b/setup.py index 1b9be93..68d8773 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def finalize_package_data(self): pkg_name = '-'.join((pyver.replace('.', ''), pkg_name)) setup(name=pkg_name, - version='0.2.2', + version='0.3.0', author='Demian Brecht', author_email='dbrecht@demonware.net', py_modules=['jose'], diff --git a/tests.py b/tests.py index 6b0e5fc..d502698 100644 --- a/tests.py +++ b/tests.py @@ -295,7 +295,7 @@ def test_jws_sym(self): for alg in algs: st = jose.serialize_compact(jose.sign(claims, jwk, alg=alg)) - jwt = jose.verify(jose.deserialize_compact(st), jwk) + jwt = jose.verify(jose.deserialize_compact(st), jwk, alg) self.assertEqual(jwt.claims, claims) @@ -305,17 +305,29 @@ def test_jws_asym(self): for alg in algs: st = jose.serialize_compact(jose.sign(claims, rsa_priv_key, alg=alg)) - jwt = jose.verify(jose.deserialize_compact(st), rsa_pub_key) + jwt = jose.verify(jose.deserialize_compact(st), rsa_pub_key, alg) self.assertEqual(jwt.claims, claims) def test_jws_signature_mismatch_error(self): + alg = 'HS256' jwk = {'k': 'password'} - jws = jose.sign(claims, jwk) + jws = jose.sign(claims, jwk, alg=alg) try: - jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk) + jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk, alg) except jose.Error as e: self.assertEqual(e.message, 'Mismatched signatures') + def test_jws_invalid_algorithm_error(self): + sign_alg = 'HS256' + verify_alg = 'RS256' + jwk = {'k': 'password'} + jws = jose.sign(claims, jwk, alg=sign_alg) + try: + jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk, + verify_alg) + except jose.Error as e: + self.assertEqual(e.message, 'Invalid algorithm') + class TestUtils(unittest.TestCase): def test_b64encode_url_utf8(self):