Skip to content

Commit 17ef9f8

Browse files
Fix uncaught exception with JWK (jazzband#600)
* Fix uncaught exception with JWK * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Allow tests to run on older JWT versions Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9b6f1de commit 17ef9f8

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

rest_framework_simplejwt/backends.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .utils import format_lazy
1111

1212
try:
13-
from jwt import PyJWKClient
13+
from jwt import PyJWKClient, PyJWKClientError
1414

1515
JWK_CLIENT_AVAILABLE = True
1616
except ImportError:
@@ -96,7 +96,10 @@ def get_verifying_key(self, token):
9696
return self.signing_key
9797

9898
if self.jwks_client:
99-
return self.jwks_client.get_signing_key_from_jwt(token).key
99+
try:
100+
return self.jwks_client.get_signing_key_from_jwt(token).key
101+
except PyJWKClientError as ex:
102+
raise TokenBackendError(_("Token is invalid or expired")) from ex
100103

101104
return self.verifying_key
102105

@@ -145,5 +148,5 @@ def decode(self, token, verify=True):
145148
)
146149
except InvalidAlgorithmError as ex:
147150
raise TokenBackendError(_("Invalid algorithm specified")) from ex
148-
except InvalidTokenError:
149-
raise TokenBackendError(_("Token is invalid or expired"))
151+
except InvalidTokenError as ex:
152+
raise TokenBackendError(_("Token is invalid or expired")) from ex

tests/test_backends.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,42 @@ def test_decode_rsa_aud_iss_jwk_success(self):
292292

293293
self.assertEqual(jwk_token_backend.decode(token), self.payload)
294294

295+
@pytest.mark.skipif(
296+
not JWK_CLIENT_AVAILABLE,
297+
reason="PyJWT 1.7.1 doesn't have JWK client",
298+
)
299+
def test_decode_jwk_missing_key_raises_tokenbackenderror(self):
300+
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
301+
self.payload["foo"] = "baz"
302+
self.payload["aud"] = AUDIENCE
303+
self.payload["iss"] = ISSUER
304+
305+
token = jwt.encode(
306+
self.payload,
307+
PRIVATE_KEY_2,
308+
algorithm="RS256",
309+
headers={"kid": "230498151c214b788dd97f22b85410a5"},
310+
)
311+
312+
mock_jwk_module = mock.MagicMock()
313+
with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module:
314+
mock_jwk_client = mock.MagicMock()
315+
316+
mock_jwk_module.return_value = mock_jwk_client
317+
mock_jwk_client.get_signing_key_from_jwt.side_effect = jwt.PyJWKClientError(
318+
"Unable to find a signing key that matches"
319+
)
320+
321+
# Note the PRIV,PUB care is intentially the original pairing
322+
jwk_token_backend = TokenBackend(
323+
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
324+
)
325+
326+
with self.assertRaisesRegex(
327+
TokenBackendError, "Token is invalid or expired"
328+
):
329+
jwk_token_backend.decode(token)
330+
295331
def test_decode_when_algorithm_not_available(self):
296332
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
297333
if IS_OLD_JWT:

0 commit comments

Comments
 (0)