Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ If `notify_on_registration` is set then `notify_on_registration.url` will be cal
| `required_scopes` | Space separated string or a list of strings (optional) |
| `jwk_set` | [JWKSet](https://datatracker.ietf.org/doc/html/rfc7517#section-5) or [JWK](https://datatracker.ietf.org/doc/html/rfc7517#section-4) (optional) |
| `jwk_file` | String (optional) |
| `jwks_endpoint` | String (optional) |

Either `jwk_set` or `jwk_file` must be specified.
Either `jwk_set` or `jwk_file` or `jwks_endpoint` must be specified.


### IntrospectionValidationConfig
Expand Down
3 changes: 2 additions & 1 deletion synapse_token_authenticator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class JwtValidationConfig:
required_scopes: str | List[str] | None = None
jwk_set: JWKSet | JWK | None = None
jwk_file: str | None = None
jwks_endpoint: str | None = None

def __post_init__(self):
if not isinstance(self.validator, Exist):
Expand All @@ -82,7 +83,7 @@ def __post_init__(self):
elif self.jwk_file:
with open(self.jwk_file) as f:
self.jwk_set = JWK.from_pem(f.read())
else:
elif not self.jwks_endpoint:
raise Exception("No JWK")

@dataclass
Expand Down
6 changes: 6 additions & 0 deletions synapse_token_authenticator/token_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import synapse
from jwcrypto import jwk, jwt
from jwcrypto.jwk import JWKSet
from jwcrypto.common import JWException, json_decode
from synapse.api.errors import HttpResponseException
from synapse.module_api import ModuleApi
Expand Down Expand Up @@ -317,6 +318,11 @@ async def check_oauth(
check_claims: dict = {}
if config.jwt_validation.require_expiry:
check_claims["exp"] = None
if config.jwt_validation.jwks_endpoint:
jwks_json = await client.get_raw(
config.jwt_validation.jwks_endpoint,
)
config.jwt_validation.jwk_set = JWKSet.from_json(jwks_json)
try:
token = jwt.JWT(
jwt=token,
Expand Down
15 changes: 11 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,24 @@ def default_config(self) -> dict[str, Any]:
return conf


def get_jwk(secret="foxies"):
def get_jwk(secret="foxies", id="123456"):
return jwk.JWK(
k=base64.urlsafe_b64encode(secret.encode("utf-8")).decode("utf-8"),
kty="oct",
kid=id,
)


def get_jwt_token(
username, exp_in=None, secret="foxies", algorithm="HS512", admin=None, claims=None
username,
exp_in=None,
secret="foxies",
algorithm="HS512",
admin=None,
claims=None,
id="123456",
):
key = get_jwk(secret)
key = get_jwk(secret, id)
if claims is None:
claims = {}
claims["sub"] = username
Expand All @@ -157,7 +164,7 @@ def get_jwt_token(
claims["exp"] = int(time.time()) + 120
else:
claims["exp"] = int(time.time()) + exp_in
token = jwt.JWT(header={"alg": algorithm}, claims=claims)
token = jwt.JWT(header={"alg": algorithm, "kid": id}, claims=claims)
token.make_signed_token(key)
return token.serialize()

Expand Down
57 changes: 40 additions & 17 deletions tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import ModuleApiTestCase, get_jwt_token, get_jwk, mock_for_oauth
from copy import deepcopy
from jwcrypto.jwk import JWKSet

default_claims = {
"urn:messaging:matrix:localpart": "alice",
Expand Down Expand Up @@ -94,25 +95,25 @@ async def test_token_claims_username_mismatch(self):
)
self.assertEqual(result, None)

@synapsetest.override_config(
{
"modules": [
{
"module": "synapse_token_authenticator.TokenAuthenticator",
"config": {
"oauth": {
"jwt_validation": {
"validator": ["exist"],
"require_expiry": False,
"jwk_set": get_jwk(),
},
"username_type": "user_id",
config_for_jwt = {
"modules": [
{
"module": "synapse_token_authenticator.TokenAuthenticator",
"config": {
"oauth": {
"jwt_validation": {
"validator": ["exist"],
"require_expiry": False,
"jwk_set": get_jwk(),
},
"username_type": "user_id",
},
}
]
}
)
},
}
]
}

@synapsetest.override_config(config_for_jwt)
async def test_token_no_expiry_with_config(self, *args):
token = get_jwt_token("aliceid", exp_in=-1, claims=default_claims)
result = await self.hs.mockmod.check_oauth(
Expand Down Expand Up @@ -147,6 +148,28 @@ async def test_invalid_scope(self):
)
self.assertEqual(result, None)

config_for_jwt_jwks_url = deepcopy(config_for_jwt)
config_for_jwt_jwks_url["modules"][0]["config"]["oauth"]["jwt_validation"].pop(
"jwk_set"
)
config_for_jwt_jwks_url["modules"][0]["config"]["oauth"]["jwt_validation"][
"jwks_endpoint"
] = "https://my_idp.com/oauth/v2/keys"
jwks = JWKSet()
jwks.add(get_jwk())

@synapsetest.override_config(config_for_jwt_jwks_url)
@mock.patch(
"synapse.http.client.SimpleHttpClient.get_raw", return_value=jwks.export()
)
async def test_fetch_jwks(self, *args):
token = get_jwt_token("aliceid", claims=default_claims)
result = await self.hs.mockmod.check_oauth(
"alice", "com.famedly.login.token.oauth", {"token": token}
)
print(f"result:\n{result}")
self.assertEqual(result[0], "@alice:example.test")

config_for_introspection = {
"modules": [
{
Expand Down