Skip to content

Add "algorithm mismatch" error to improve jws #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions jose/backends/cryptography_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cryptography.x509 import load_pem_x509_certificate

from ..constants import ALGORITHMS
from ..exceptions import JWEError, JWKError
from ..exceptions import JWEError, JWKError, JWKAlgMismatchError
from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64
from .base import Key

Expand Down Expand Up @@ -52,7 +52,7 @@ class CryptographyECKey(Key):

def __init__(self, key, algorithm, cryptography_backend=default_backend):
if algorithm not in ALGORITHMS.EC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm)

self.hash_alg = {
ALGORITHMS.ES256: self.SHA256,
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "EC":
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))

if not all(k in jwk_dict for k in ["x", "y", "crv"]):
raise JWKError("Mandatory parameters are missing")
Expand Down Expand Up @@ -226,7 +226,7 @@ class CryptographyRSAKey(Key):

def __init__(self, key, algorithm, cryptography_backend=default_backend):
if algorithm not in ALGORITHMS.RSA:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm)

self.hash_alg = {
ALGORITHMS.RS256: self.SHA256,
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "RSA":
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))

e = base64_to_long(jwk_dict.get("e", 256))
n = base64_to_long(jwk_dict.get("n"))
Expand Down Expand Up @@ -441,9 +441,9 @@ class CryptographyAESKey(Key):

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.AES:
raise JWKError("%s is not a valid AES algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a valid AES algorithm" % algorithm)
if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
raise JWKError("%s is not a supported algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a supported algorithm" % algorithm)

self._algorithm = algorithm
self._mode = self.MODES.get(self._algorithm)
Expand Down Expand Up @@ -538,7 +538,7 @@ class CryptographyHMACKey(Key):

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.HMAC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self._algorithm = algorithm
self._hash_alg = self.ALG_MAP.get(algorithm)

Expand Down Expand Up @@ -569,7 +569,7 @@ def __init__(self, key, algorithm):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "oct":
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))

k = jwk_dict.get("k")
k = k.encode("utf-8")
Expand Down
6 changes: 3 additions & 3 deletions jose/backends/ecdsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError
from jose.exceptions import JWKError, JWKAlgMismatchError
from jose.utils import base64_to_long, long_to_base64


Expand Down Expand Up @@ -35,7 +35,7 @@ class ECDSAECKey(Key):

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.EC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm)

self.hash_alg = {
ALGORITHMS.ES256: self.SHA256,
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, key, algorithm):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "EC":
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))

if not all(k in jwk_dict for k in ["x", "y", "crv"]):
raise JWKError("Mandatory parameters are missing")
Expand Down
6 changes: 3 additions & 3 deletions jose/backends/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError
from jose.exceptions import JWKError, JWKAlgMismatchError
from jose.utils import base64url_decode, base64url_encode


Expand All @@ -22,7 +22,7 @@ class HMACKey(Key):

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.HMAC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self._algorithm = algorithm
self._hash_alg = self.HASHES.get(algorithm)

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, key, algorithm):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "oct":
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))

k = jwk_dict.get("k")
k = k.encode("utf-8")
Expand Down
6 changes: 3 additions & 3 deletions jose/backends/rsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWEError, JWKError
from jose.exceptions import JWEError, JWKError, JWKAlgMismatchError
from jose.utils import base64_to_long, long_to_base64

ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported
Expand Down Expand Up @@ -124,7 +124,7 @@ class RSAKey(Key):

def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.RSA:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm)

if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5:
raise JWKError("alg: %s is not supported by the RSA backend" % algorithm)
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(self, key, algorithm):

def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "RSA":
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))

e = base64_to_long(jwk_dict.get("e"))
n = base64_to_long(jwk_dict.get("n"))
Expand Down
5 changes: 5 additions & 0 deletions jose/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class JWKError(JOSEError):
pass


class JWKAlgMismatchError(JWKError):
'''JWK Key type doesn't support the given algorithm.'''
pass


class JWEError(JOSEError):
"""Base error for all JWE errors"""

Expand Down
7 changes: 5 additions & 2 deletions jose/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jose import jwk
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError, JWSSignatureError
from jose.exceptions import JWSError, JWSSignatureError, JWKAlgMismatchError
from jose.utils import base64url_decode, base64url_encode


Expand Down Expand Up @@ -205,7 +205,10 @@ def _load(jwt):
def _sig_matches_keys(keys, signing_input, signature, alg):
for key in keys:
if not isinstance(key, Key):
key = jwk.construct(key, alg)
try:
key = jwk.construct(key, alg)
except JWKAlgMismatchError:
continue
try:
if key.verify(signing_input, signature):
return True
Expand Down
12 changes: 11 additions & 1 deletion tests/test_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jose import jwk, jws
from jose.backends import RSAKey
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError
from jose.exceptions import JWSError, JWKAlgMismatchError

try:
from jose.backends.cryptography_backend import CryptographyRSAKey
Expand All @@ -25,6 +25,16 @@ def test_unicode_token(self):
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8"
jws.verify(token, "secret", ["HS256"])

def test_hetero_keys(self):
class BadKey(jwk.Key):
def __init__(self, key, algorithm):
if key != "xyzw":
raise JWKAlgMismatchError("%s is not a valid XYZW algorithm" % algorithm)

token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8"
jwk.register_key("XYZW", BadKey)
jws.verify(token, {"keys": [{"alg": "XYZW"}, "secret"]}, ["XYZW", "HS256"])

def test_multiple_keys(self):
old_jwk_verify = jwk.HMACKey.verify
try:
Expand Down