From 73b4791af84597a8064c85b45d37ce9b588345ba Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 18 Aug 2022 17:07:27 +0100 Subject: [PATCH] Block login if a user requires approval and the server is configured to do so --- synapse/handlers/auth.py | 11 +++++++++++ synapse/rest/client/login.py | 15 ++++++++++++++ tests/rest/client/test_login.py | 35 +++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bfa553504442..942f221bb8bf 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1010,6 +1010,17 @@ async def check_user_exists(self, user_id: str) -> Optional[str]: return res[0] return None + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + return await self.store.is_user_approved(user_id) + async def _find_user_id_and_pwd_hash( self, user_id: str ) -> Optional[Tuple[str, str]]: diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0437c87d8d6d..b90015ecc578 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -92,6 +92,12 @@ def __init__(self, hs: "HomeServer"): hs.config.registration.refreshable_access_token_lifetime is not None ) + # Whether we need to check if the user has been approved or not. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -220,6 +226,15 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: except KeyError: raise SynapseError(400, "Missing JSON keys.") + if self._require_approval: + approved = await self.auth_handler.is_user_approved(result["user_id"]) + if not approved: + raise SynapseError( + code=403, + errcode=Codes.USER_AWAITING_APPROVAL, + msg="This account is pending approval by a server administrator.", + ) + well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d982755a..8a65d6638b3f 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -23,6 +23,8 @@ from twisted.web.resource import Resource import synapse.rest.admin +from synapse.api.constants import LoginType +from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet @@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + register.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -406,6 +409,38 @@ def test_login_with_overly_long_device_id_fails(self) -> None: self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + + params = { + "type": LoginType.PASSWORD, + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request("POST", LOGIN_URL, params) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase):