Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Replace pyjwt with authlib in org.matrix.login.jwt #13011

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,7 @@ sso:
# The algorithm used to sign the JSON web token.
#
# Supported algorithms are listed at
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
# https://docs.authlib.org/en/latest/specs/rfc7518.html
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
#
# Required if 'enabled' is true.
#
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/configuration/config_documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -2947,7 +2947,7 @@ Additional sub-options for this setting include:
* `secret`: This is either the private shared secret or the public key used to
decode the contents of the JSON web token. Required if `enabled` is set to true.
* `algorithm`: The algorithm used to sign the JSON web token. Supported algorithms are listed at
https://pyjwt.readthedocs.io/en/latest/algorithms.html Required if `enabled` is set to true.
https://docs.authlib.org/en/latest/specs/rfc7518.html Required if `enabled` is set to true.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
* `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
Expand Down
17 changes: 2 additions & 15 deletions synapse/config/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@

from synapse.types import JsonDict

from ._base import Config, ConfigError

MISSING_JWT = """Missing jwt library. This is required for jwt login.

Install by running:
pip install pyjwt
"""
from ._base import Config


class JWTConfig(Config):
Expand All @@ -41,13 +35,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# that the claims exist on the JWT.
self.jwt_issuer = jwt_config.get("issuer")
self.jwt_audiences = jwt_config.get("audiences")

try:
import jwt

jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want a replacement for this, to warn users who've configured JWT but haven't installed authlib.

check_dependencies("jwt") should suffice after the poetry-related changes are made. For an existing example elsewhere, see:

check_requirements("redis")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is still outstanding.

else:
self.jwt_enabled = False
self.jwt_secret = None
Expand Down Expand Up @@ -89,7 +76,7 @@ def generate_config_section(self, **kwargs: Any) -> str:
# The algorithm used to sign the JSON web token.
#
# Supported algorithms are listed at
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
# https://docs.authlib.org/en/latest/specs/rfc7518.html
#
# Required if 'enabled' is true.
#
Expand Down
44 changes: 36 additions & 8 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Union,
)

from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
from typing_extensions import TypedDict

from synapse.api.errors import Codes, LoginError, SynapseError
Expand Down Expand Up @@ -420,25 +422,51 @@ async def _do_jwt_login(
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)

import jwt
jwt = JsonWebToken([self.jwt_algorithm])
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
claim_options = {}
if self.jwt_issuer is not None:
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
if self.jwt_audiences is not None:
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

try:
payload = jwt.decode(
claims = jwt.decode(
token,
self.jwt_secret,
algorithms=[self.jwt_algorithm],
issuer=self.jwt_issuer,
audience=self.jwt_audiences,
key=self.jwt_secret,
claims_cls=JWTClaims,
claims_options=claim_options,
)
except jwt.PyJWTError as e:
except BadSignatureError:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
raise LoginError(
403,
"JWT validation failed: Signature verification failed",
errcode=Codes.FORBIDDEN,
)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sandhose pointed out that this is a pretty broad exception clause. Can you narrow this to a JOSE-specific exception?

Perhaps JoseError from here? https://github.com/lepture/authlib/blob/master/authlib/jose/errors.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(That will mean that any other exception gets properly flagged as an application error rather than causing a 403 to the requester.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and pushed it

# A JWT error occurred, return some info back to the client.
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

user = payload.get(self.jwt_subject_claim, None)
try:
claims.validate(leeway=120) # allows 2 min of clock skew

# Enforce the old behavior which is rolled out in productive
# servers: if the JWT contains an 'aud' claim but none is
# configured, the login attempt will fail
if claims.get("aud") is not None:
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
raise InvalidClaimError("aud")
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
except JoseError as e:
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

user = claims.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

Expand Down
26 changes: 12 additions & 14 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import time
import urllib.parse
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode

Expand All @@ -41,7 +41,7 @@
from tests.unittest import HomeserverTestCase, override_config, skip_unless

try:
import jwt
from authlib.jose import jwt, jwk

HAS_JWT = True
except ImportError:
Expand Down Expand Up @@ -841,7 +841,7 @@ def test_deactivated_user(self) -> None:
self.assertIn(b"SSO account deactivated", channel.result["body"])


@skip_unless(HAS_JWT, "requires jwt")
@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
Expand All @@ -866,11 +866,9 @@ def default_config(self) -> Dict[str, Any]:
return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
header = {"alg": self.jwt_algorithm}
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")

def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
Expand Down Expand Up @@ -1010,7 +1008,7 @@ def test_login_no_token(self) -> None:
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
@skip_unless(HAS_JWT, "requires jwt")
@skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
Expand Down Expand Up @@ -1071,11 +1069,11 @@ def default_config(self) -> Dict[str, Any]:
return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
header = {"alg": "RS256"}
if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
secret = jwk.dumps(secret, kty="RSA")
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")

def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
Expand Down