Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fix for CVE-2022-3102 #299

Merged
merged 2 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ Classes
:members:
:show-inheritance:

Variables
---------

.. autodata:: jwcrypto.jwt.JWTClaimsRegistry

.. autodata:: jwcrypto.jwt.JWT_expect_type

Examples
--------

Expand Down Expand Up @@ -42,7 +49,7 @@ Now decrypt and verify::
>>> k = {"k": "Wal4ZHCBsml0Al_Y8faoNTKsXCkw8eefKXYFuwTBOpA", "kty": "oct"}
>>> key = jwk.JWK(**k)
>>> e = 'eyJhbGciOiJBMjU2S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIn0.ST5RmjqDLj696xo7YFTFuKUhcd3naCrm6yMjBM3cqWiFD6U8j2JIsbclsF7ryNg8Ktmt1kQJRKavV6DaTl1T840tP3sIs1qz.wSxVhZH5GyzbJnPBAUMdzQ.6uiVYwrRBzAm7Uge9rEUjExPWGbgerF177A7tMuQurJAqBhgk3_5vee5DRH84kHSapFOxcEuDdMBEQLI7V2E0F57-d01TFStHzwtgtSmeZRQ6JSIL5XlgJouwHfSxn9Z_TGl5xxq4TksORHED1vnRA.5jPyPWanJVqlOohApEbHmxi3JHp1MXbmvQe2_dVd8FI'
>>> ET = jwt.JWT(key=key, jwt=e)
>>> ET = jwt.JWT(key=key, jwt=e, expected_type="JWE")
>>> ST = jwt.JWT(key=key, jwt=ET.claims)
>>> ST.claims
'{"info":"I\'m a signed token"}'
78 changes: 70 additions & 8 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jwcrypto.common import JWException, JWKeyNotFound
from jwcrypto.common import json_decode, json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwe import default_allowed_algs as jwe_algs
from jwcrypto.jws import JWS


Expand All @@ -21,6 +22,17 @@
'nbf': 'Not Before',
'iat': 'Issued At',
'jti': 'JWT ID'}
"""Registry of RFC 7519 defined claims"""


# do not use this unless you know about CVE-2022-3102
JWT_expect_type = True
"""This module parameter can disable the use of the expectation
feature that has been introduced to fix CVE-2022-3102. This knob
has been added as a workaround for applications that can't be
immediately refactored to deal with the change in behavior but it
is considered deprecated and will be removed in a future release.
"""


class JWTExpired(JWException):
Expand Down Expand Up @@ -153,7 +165,8 @@ class JWT:
"""

def __init__(self, header=None, claims=None, jwt=None, key=None,
algs=None, default_claims=None, check_claims=None):
algs=None, default_claims=None, check_claims=None,
expected_type=None):
"""Creates a JWT object.

:param header: A dict or a JSON string with the JWT Header data.
Expand All @@ -169,6 +182,12 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
:param check_claims: An optional dict of claims that must be
present in the token, if the value is not None the claim must
match exactly.
:param expected_type: An optional string that defines what kind
of token to expect when validating a deserialized token.
Supported values: "JWS" or "JWE"
If left to None the code will try to detect what the expected
type is based on other parameters like 'algs' and will default
to JWS if no hints are found. It has no effect on token creation.

Note: either the header,claims or jwt,key parameters should be
provided as a deserialization operation (which occurs if the jwt
Expand All @@ -190,6 +209,7 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
self._leeway = 60 # 1 minute clock skew allowed
self._validity = 600 # 10 minutes validity (up to 11 with leeway)
self.deserializelog = None
self._expected_type = expected_type

if header:
self.header = header
Expand Down Expand Up @@ -276,6 +296,33 @@ def validity(self):
def validity(self, v):
self._validity = int(v)

@property
def expected_type(self):
if self._expected_type is not None:
return self._expected_type

# If no expected type is set we default to accept only JWSs,
# however to improve backwards compatibility we try some
# heuristic to see if there has been strong indication of
# what the expected token type is.
if self._expected_type is None and self._algs:
if set(self._algs).issubset(jwe_algs + ['RSA1_5']):
self._expected_type = "JWE"
if self._expected_type is None and self._header:
if "enc" in json_decode(self._header):
self._expected_type = "JWE"
if self._expected_type is None:
self._expected_type = "JWS"

return self._expected_type

@expected_type.setter
def expected_type(self, v):
if v in ["JWS", "JWE"]:
self._expected_type = v
else:
raise ValueError("Invalid value, must be 'JWS' or 'JWE'")

def _add_optional_claim(self, name, claims):
if name in claims:
return
Expand Down Expand Up @@ -472,6 +519,7 @@ def make_signed_token(self, key):
t.allowed_algs = self._algs
t.add_signature(key, protected=self.header)
self.token = t
self._expected_type = "JWS"

def make_encrypted_token(self, key):
"""Encrypts the payload.
Expand All @@ -488,6 +536,7 @@ def make_encrypted_token(self, key):
t.allowed_algs = self._algs
t.add_recipient(key)
self.token = t
self._expected_type = "JWE"

def validate(self, key):
"""Validate a JWT token that was deserialized w/o providing a key
Expand All @@ -500,13 +549,23 @@ def validate(self, key):
if self.token is None:
raise ValueError("Token empty")

et = self.expected_type
validate_fn = None

if isinstance(self.token, JWS):
if et != "JWS" and JWT_expect_type:
raise TypeError("Expected {}, got JWS".format(et))
validate_fn = self.token.verify
elif isinstance(self.token, JWE):
if et != "JWE" and JWT_expect_type:
print("algs: {}".format(self._algs))
raise TypeError("Expected {}, got JWE".format(et))
validate_fn = self.token.decrypt
else:
raise ValueError("Token format unrecognized")

try:
if isinstance(self.token, JWS):
self.token.verify(key)
elif isinstance(self.token, JWE):
self.token.decrypt(key)
else:
raise ValueError("Token format unrecognized")
validate_fn(key)
self.deserializelog.append("Success")
except Exception as e: # pylint: disable=broad-except
if isinstance(self.token, JWS):
Expand All @@ -520,7 +579,10 @@ def validate(self, key):
raise

self.header = self.token.jose_header
self.claims = self.token.payload.decode('utf-8')
payload = self.token.payload
if isinstance(payload, bytes):
payload = payload.decode('utf-8')
self.claims = payload
self._check_provided_claims()

def deserialize(self, jwt, key=None):
Expand Down
48 changes: 46 additions & 2 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jwcrypto.common import json_decode, json_encode

jwe_algs_and_rsa1_5 = jwe.default_allowed_algs + ['RSA1_5']
jws_algs_and_rsa1_5 = jws.default_allowed_algs + ['RSA1_5']

# RFC 7517 - A.1
PublicKeys = {"keys": [
Expand Down Expand Up @@ -1531,9 +1530,11 @@ def test_A2(self):
tinner = jwt.JWT(jwt=touter.claims, key=sigkey, check_claims=False)
self.assertEqual(A1_claims, json_decode(tinner.claims))

# Test Exception throwing when token is encrypted with
# algorithms not in the allowed set
with self.assertRaises(jwe.InvalidJWEData):
jwt.JWT(jwt=A2_token, key=E_A2_ex['key'],
algs=jws_algs_and_rsa1_5)
algs=['A192KW', 'A192CBC-HS384', 'RSA1_5'])

def test_decrypt_keyset(self):
key = jwk.JWK(kid='testkey', **E_A2_key)
Expand Down Expand Up @@ -1738,6 +1739,48 @@ def test_Issue_277(self):
jwt=sertok, check_claims={"aud": ["nomatch",
"failmatch"]})

def test_unexpected(self):
key = jwk.JWK(generate='oct', size=256)
claims = {"testclaim": "test"}
token = jwt.JWT(header={"alg": "HS256"}, claims=claims)
token.make_signed_token(key)
sertok = token.serialize()

token.validate(key)
token.expected_type = "JWS"
token.validate(key)
token.expected_type = "JWE"
with self.assertRaises(TypeError):
token.validate(key)

jwt.JWT(jwt=sertok, key=key)
jwt.JWT(jwt=sertok, key=key, expected_type='JWS')
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key, expected_type='JWE')

token = jwt.JWT(header={"alg": "A256KW", "enc": "A256GCM"},
claims=claims)
token.make_encrypted_token(key)
enctok = token.serialize()

# test workaroud for older applications
jwt.JWT_expect_type = False
jwt.JWT(jwt=enctok, key=key)
jwt.JWT_expect_type = True

token.validate(key)
token.expected_type = "JWE"
token.validate(key)
token.expected_type = "JWS"
with self.assertRaises(TypeError):
token.validate(key)

jwt.JWT(jwt=enctok, key=key, expected_type='JWE')
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key)
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key, expected_type='JWS')


class ConformanceTests(unittest.TestCase):

Expand Down Expand Up @@ -2107,6 +2150,7 @@ def test_jwt_equality(self):

ect = jwt.JWT.from_jose_token(ea.serialize())
self.assertNotEqual(ea, ect)
ect.expected_type = "JWE"
ect.validate(key)
self.assertEqual(ea, ect)

Expand Down