Skip to content

Commit

Permalink
Make JWT require to know what to expect
Browse files Browse the repository at this point in the history
This is needed to address CVE-2022-3102.
Thanks to Tom tervoort from Secura for finding and reporting this issue.

Also test that "unepxected" token types are not validated

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed Sep 13, 2022
1 parent 5a13cfc commit 444acd1
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 10 deletions.
67 changes: 59 additions & 8 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jwcrypto.common import JWException, JWKeyNotFound
from jwcrypto.common import json_decode, json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwe import default_allowed_algs as jwe_algs
from jwcrypto.jws import JWS


Expand Down Expand Up @@ -153,7 +154,8 @@ class JWT:
"""

def __init__(self, header=None, claims=None, jwt=None, key=None,
algs=None, default_claims=None, check_claims=None):
algs=None, default_claims=None, check_claims=None,
expected_type=None):
"""Creates a JWT object.
:param header: A dict or a JSON string with the JWT Header data.
Expand All @@ -169,6 +171,12 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
:param check_claims: An optional dict of claims that must be
present in the token, if the value is not None the claim must
match exactly.
:param expected_type: An optional string that defines what kind
of token to expect when validating a deserialized token.
Supported values: "JWS" or "JWE"
If left to None the code will try to detect what the expected
type is based on other parameters like 'algs' and will default
to JWS if no hints are found. It has no effect on token creation.
Note: either the header,claims or jwt,key parameters should be
provided as a deserialization operation (which occurs if the jwt
Expand All @@ -190,6 +198,7 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
self._leeway = 60 # 1 minute clock skew allowed
self._validity = 600 # 10 minutes validity (up to 11 with leeway)
self.deserializelog = None
self._expected_type = expected_type

if header:
self.header = header
Expand Down Expand Up @@ -276,6 +285,33 @@ def validity(self):
def validity(self, v):
self._validity = int(v)

@property
def expected_type(self):
if self._expected_type is not None:
return self._expected_type

# If no expected type is set we default to accept only JWSs,
# however to improve backwards compatibility we try some
# heuristic to see if there has been strong indication of
# what the expected token type is.
if self._expected_type is None and self._algs:
if set(self._algs).issubset(jwe_algs + ['RSA1_5']):
self._expected_type = "JWE"
if self._expected_type is None and self._header:
if "enc" in json_decode(self._header):
self._expected_type = "JWE"
if self._expected_type is None:
self._expected_type = "JWS"

return self._expected_type

@expected_type.setter
def expected_type(self, v):
if v in ["JWS", "JWE"]:
self._expected_type = v
else:
raise ValueError("Invalid value, must be 'JWS' or 'JWE'")

def _add_optional_claim(self, name, claims):
if name in claims:
return
Expand Down Expand Up @@ -472,6 +508,7 @@ def make_signed_token(self, key):
t.allowed_algs = self._algs
t.add_signature(key, protected=self.header)
self.token = t
self._expected_type = "JWS"

def make_encrypted_token(self, key):
"""Encrypts the payload.
Expand All @@ -488,6 +525,7 @@ def make_encrypted_token(self, key):
t.allowed_algs = self._algs
t.add_recipient(key)
self.token = t
self._expected_type = "JWE"

def validate(self, key):
"""Validate a JWT token that was deserialized w/o providing a key
Expand All @@ -500,13 +538,23 @@ def validate(self, key):
if self.token is None:
raise ValueError("Token empty")

et = self.expected_type
validate_fn = None

if isinstance(self.token, JWS):
if et != "JWS":
raise TypeError("Expected {}, got JWS".format(et))
validate_fn = self.token.verify
elif isinstance(self.token, JWE):
if et != "JWE":
print("algs: {}".format(self._algs))
raise TypeError("Expected {}, got JWE".format(et))
validate_fn = self.token.decrypt
else:
raise ValueError("Token format unrecognized")

try:
if isinstance(self.token, JWS):
self.token.verify(key)
elif isinstance(self.token, JWE):
self.token.decrypt(key)
else:
raise ValueError("Token format unrecognized")
validate_fn(key)
self.deserializelog.append("Success")
except Exception as e: # pylint: disable=broad-except
if isinstance(self.token, JWS):
Expand All @@ -520,7 +568,10 @@ def validate(self, key):
raise

self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
payload = self.token.payload
if isinstance(payload, bytes):
payload = payload.decode('utf-8')
self.claims = payload
self._check_provided_claims()

def deserialize(self, jwt, key=None):
Expand Down
43 changes: 41 additions & 2 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jwcrypto.common import json_decode, json_encode

jwe_algs_and_rsa1_5 = jwe.default_allowed_algs + ['RSA1_5']
jws_algs_and_rsa1_5 = jws.default_allowed_algs + ['RSA1_5']

# RFC 7517 - A.1
PublicKeys = {"keys": [
Expand Down Expand Up @@ -1531,9 +1530,11 @@ def test_A2(self):
tinner = jwt.JWT(jwt=touter.claims, key=sigkey, check_claims=False)
self.assertEqual(A1_claims, json_decode(tinner.claims))

# Test Exception throwing when token is encrypted with
# algorithms not in the allowed set
with self.assertRaises(jwe.InvalidJWEData):
jwt.JWT(jwt=A2_token, key=E_A2_ex['key'],
algs=jws_algs_and_rsa1_5)
algs=['A192KW', 'A192CBC-HS384', 'RSA1_5'])

def test_decrypt_keyset(self):
key = jwk.JWK(kid='testkey', **E_A2_key)
Expand Down Expand Up @@ -1738,6 +1739,43 @@ def test_Issue_277(self):
jwt=sertok, check_claims={"aud": ["nomatch",
"failmatch"]})

def test_unexpected(self):
key = jwk.JWK(generate='oct', size=256)
claims = {"testclaim": "test"}
token = jwt.JWT(header={"alg": "HS256"}, claims=claims)
token.make_signed_token(key)
sertok = token.serialize()

token.validate(key)
token.expected_type = "JWS"
token.validate(key)
token.expected_type = "JWE"
with self.assertRaises(TypeError):
token.validate(key)

jwt.JWT(jwt=sertok, key=key)
jwt.JWT(jwt=sertok, key=key, expected_type='JWS')
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key, expected_type='JWE')

token = jwt.JWT(header={"alg": "A256KW", "enc": "A256GCM"},
claims=claims)
token.make_encrypted_token(key)
enctok = token.serialize()

token.validate(key)
token.expected_type = "JWE"
token.validate(key)
token.expected_type = "JWS"
with self.assertRaises(TypeError):
token.validate(key)

jwt.JWT(jwt=enctok, key=key, expected_type='JWE')
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key)
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key, expected_type='JWS')


class ConformanceTests(unittest.TestCase):

Expand Down Expand Up @@ -2107,6 +2145,7 @@ def test_jwt_equality(self):

ect = jwt.JWT.from_jose_token(ea.serialize())
self.assertNotEqual(ea, ect)
ect.expected_type = "JWE"
ect.validate(key)
self.assertEqual(ea, ect)

Expand Down

0 comments on commit 444acd1

Please sign in to comment.