Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1eaff08

Browse files
committed
Add storage support for checking and updating a user's approval status
1 parent f7c9743 commit 1eaff08

File tree

3 files changed

+255
-1
lines changed

3 files changed

+255
-1
lines changed

synapse/storage/databases/main/registration.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
181181
"user_type",
182182
"deactivated",
183183
"shadow_banned",
184+
"approved",
184185
],
185186
allow_none=True,
186187
desc="get_user_by_id",
@@ -1778,6 +1779,26 @@ async def is_guest(self, user_id: str) -> bool:
17781779

17791780
return res if res else False
17801781

1782+
@cached()
1783+
async def is_user_approved(self, user_id: str) -> bool:
1784+
"""Checks if a user is approved and therefore can be allowed to log in.
1785+
1786+
Args:
1787+
user_id: the user to check the approval status of.
1788+
1789+
Returns:
1790+
A boolean that is True if the user is approved, False otherwise.
1791+
"""
1792+
ret = await self.db_pool.simple_select_one_onecol(
1793+
table="users",
1794+
keyvalues={"name": user_id},
1795+
retcol="approved",
1796+
allow_none=True,
1797+
desc="is_user_pending_approval",
1798+
)
1799+
1800+
return bool(ret)
1801+
17811802

17821803
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
17831804
def __init__(
@@ -1817,6 +1838,10 @@ def __init__(
18171838
unique=False,
18181839
)
18191840

1841+
self.db_pool.updates.register_background_update_handler(
1842+
"users_set_approved_flag", self._background_update_set_approved_flag
1843+
)
1844+
18201845
async def _background_update_set_deactivated_flag(
18211846
self, progress: JsonDict, batch_size: int
18221847
) -> int:
@@ -1915,6 +1940,75 @@ def set_user_deactivated_status_txn(
19151940
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
19161941
txn.call_after(self.is_guest.invalidate, (user_id,))
19171942

1943+
async def _background_update_set_approved_flag(
1944+
self, progress: JsonDict, batch_size: int
1945+
) -> int:
1946+
"""Set the 'approved' flag for all already existing users. We want to set it to
1947+
true systematically because we don't want to suddenly prevent already existing
1948+
users from logging in if the option to block registration on approval is turned
1949+
on.
1950+
"""
1951+
last_user = progress.get("user_id", "")
1952+
1953+
def _background_update_set_approved_flag_txn(txn: LoggingTransaction) -> int:
1954+
txn.execute(
1955+
"""
1956+
SELECT name
1957+
FROM users
1958+
WHERE
1959+
approved IS NULL
1960+
AND name > ?
1961+
ORDER BY name ASC
1962+
LIMIT ?
1963+
""",
1964+
(last_user, batch_size),
1965+
)
1966+
rows = self.db_pool.cursor_to_dict(txn)
1967+
1968+
if len(rows) == 0:
1969+
return 0
1970+
1971+
for user in rows:
1972+
self.update_user_approval_status_txn(txn, user["name"], True)
1973+
1974+
self.db_pool.updates._background_update_progress_txn(
1975+
txn, "users_set_approved_flag", {"user_id": rows[-1]["name"]}
1976+
)
1977+
1978+
return len(rows)
1979+
1980+
nb_processed = await self.db_pool.runInteraction(
1981+
"users_set_approved_flag", _background_update_set_approved_flag_txn
1982+
)
1983+
1984+
if nb_processed < batch_size:
1985+
await self.db_pool.updates._end_background_update("users_set_approved_flag")
1986+
1987+
return nb_processed
1988+
1989+
def update_user_approval_status_txn(
1990+
self, txn: LoggingTransaction, user_id: str, approved: bool
1991+
) -> None:
1992+
"""Set the user's 'approved' flag to the given value.
1993+
1994+
The boolean is turned into an int because the column is a smallint.
1995+
1996+
Args:
1997+
txn: the current database transaction.
1998+
user_id: the user to update the flag for.
1999+
approved: the value to set the flag to.
2000+
"""
2001+
self.db_pool.simple_update_one_txn(
2002+
txn=txn,
2003+
table="users",
2004+
keyvalues={"name": user_id},
2005+
updatevalues={"approved": int(approved)},
2006+
)
2007+
2008+
# Invalidate the caches of methods that read the value of the 'approved' flag.
2009+
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
2010+
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
2011+
19182012

19192013
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
19202014
def __init__(
@@ -1932,6 +2026,13 @@ def __init__(
19322026
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
19332027
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
19342028

2029+
# If support for MSC3866 is enabled and configured to require approval for new
2030+
# account, we will create new users with an 'approved' flag set to false.
2031+
self._require_approval = (
2032+
hs.config.experimental.msc3866.enabled
2033+
and hs.config.experimental.msc3866.require_approval_for_new_accounts
2034+
)
2035+
19352036
async def add_access_token_to_user(
19362037
self,
19372038
user_id: str,
@@ -2064,6 +2165,7 @@ async def register_user(
20642165
admin: bool = False,
20652166
user_type: Optional[str] = None,
20662167
shadow_banned: bool = False,
2168+
approved: bool = False,
20672169
) -> None:
20682170
"""Attempts to register an account.
20692171
@@ -2082,6 +2184,8 @@ async def register_user(
20822184
or None for a normal user.
20832185
shadow_banned: Whether the user is shadow-banned, i.e. they may be
20842186
told their requests succeeded but we ignore them.
2187+
approved: Whether to consider the user has already been approved by an
2188+
administrator.
20852189
20862190
Raises:
20872191
StoreError if the user_id could not be registered.
@@ -2098,6 +2202,7 @@ async def register_user(
20982202
admin,
20992203
user_type,
21002204
shadow_banned,
2205+
approved,
21012206
)
21022207

21032208
def _register_user(
@@ -2112,11 +2217,14 @@ def _register_user(
21122217
admin: bool,
21132218
user_type: Optional[str],
21142219
shadow_banned: bool,
2220+
approved: bool,
21152221
) -> None:
21162222
user_id_obj = UserID.from_string(user_id)
21172223

21182224
now = int(self._clock.time())
21192225

2226+
pending_approval = self._require_approval and not approved
2227+
21202228
try:
21212229
if was_guest:
21222230
# Ensure that the guest user actually exists
@@ -2142,6 +2250,7 @@ def _register_user(
21422250
"admin": 1 if admin else 0,
21432251
"user_type": user_type,
21442252
"shadow_banned": shadow_banned,
2253+
"approved": 0 if pending_approval else 1,
21452254
},
21462255
)
21472256
else:
@@ -2157,6 +2266,7 @@ def _register_user(
21572266
"admin": 1 if admin else 0,
21582267
"user_type": user_type,
21592268
"shadow_banned": shadow_banned,
2269+
"approved": 0 if pending_approval else 1,
21602270
},
21612271
)
21622272

@@ -2499,6 +2609,25 @@ def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
24992609
start_or_continue_validation_session_txn,
25002610
)
25012611

2612+
async def update_user_approval_status(
2613+
self, user_id: UserID, approved: bool
2614+
) -> None:
2615+
"""Set the user's 'approved' flag to the given value.
2616+
2617+
The boolean will be turned into an int (in update_user_approval_status_txn)
2618+
because the column is a smallint.
2619+
2620+
Args:
2621+
user_id: the user to update the flag for.
2622+
approved: the value to set the flag to.
2623+
"""
2624+
await self.db_pool.runInteraction(
2625+
"update_user_approval_status",
2626+
self.update_user_approval_status_txn,
2627+
user_id.to_string(),
2628+
approved,
2629+
)
2630+
25022631

25032632
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
25042633
"""
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright 2022 The Matrix.org Foundation C.I.C
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
-- Add a column to the users table to track whether the user needs to be approved by an
17+
-- administrator.
18+
ALTER TABLE users ADD COLUMN approved SMALLINT;
19+
20+
-- Run a background update to set the approved flag on already existing users.
21+
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
22+
(7204, 'users_set_approved_flag', '{}');

tests/storage/test_registration.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from twisted.test.proto_helpers import MemoryReactor
1415

1516
from synapse.api.constants import UserTypes
1617
from synapse.api.errors import ThreepidValidationError
18+
from synapse.server import HomeServer
19+
from synapse.types import JsonDict, UserID
20+
from synapse.util import Clock
1721

18-
from tests.unittest import HomeserverTestCase
22+
from tests.unittest import HomeserverTestCase, override_config
1923

2024

2125
class RegistrationStoreTestCase(HomeserverTestCase):
@@ -44,6 +48,7 @@ def test_register(self):
4448
"user_type": None,
4549
"deactivated": 0,
4650
"shadow_banned": 0,
51+
"approved": 1,
4752
},
4853
(self.get_success(self.store.get_user_by_id(self.user_id))),
4954
)
@@ -147,3 +152,101 @@ def test_3pid_inhibit_invalid_validation_session_error(self):
147152
ThreepidValidationError,
148153
)
149154
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
155+
156+
157+
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
158+
def default_config(self) -> JsonDict:
159+
config = super().default_config()
160+
161+
# If there's already some config for this feature in the default config, it
162+
# means we're overriding it with @override_config. In this case we don't want
163+
# to do anything more with it.
164+
msc3866_config = config.get("experimental_features", {}).get("msc3866")
165+
if msc3866_config is not None:
166+
return config
167+
168+
# Require approval for all new accounts.
169+
config["experimental_features"] = {
170+
"msc3866": {
171+
"enabled": True,
172+
"require_approval_for_new_accounts": True,
173+
}
174+
}
175+
return config
176+
177+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
178+
self.store = hs.get_datastores().main
179+
self.user_id = "@my-user:test"
180+
self.pwhash = "{xx1}123456789"
181+
182+
@override_config(
183+
{
184+
"experimental_features": {
185+
"msc3866": {
186+
"enabled": True,
187+
"require_approval_for_new_accounts": False,
188+
}
189+
}
190+
}
191+
)
192+
def test_approval_not_required(self) -> None:
193+
"""Tests that if we don't require approval for new accounts, newly created
194+
accounts are automatically marked as approved.
195+
"""
196+
self.get_success(self.store.register_user(self.user_id, self.pwhash))
197+
198+
user = self.get_success(self.store.get_user_by_id(self.user_id))
199+
assert user is not None
200+
self.assertEqual(user["approved"], 1)
201+
202+
approved = self.get_success(self.store.is_user_approved(self.user_id))
203+
self.assertTrue(approved)
204+
205+
def test_approval_required(self) -> None:
206+
"""Tests that if we require approval for new accounts, newly created accounts
207+
are not automatically marked as approved.
208+
"""
209+
self.get_success(self.store.register_user(self.user_id, self.pwhash))
210+
211+
user = self.get_success(self.store.get_user_by_id(self.user_id))
212+
assert user is not None
213+
self.assertEqual(user["approved"], 0)
214+
215+
approved = self.get_success(self.store.is_user_approved(self.user_id))
216+
self.assertFalse(approved)
217+
218+
def test_override(self) -> None:
219+
"""Tests that if we require approval for new accounts, but we explicitly say the
220+
new user should be considered approved, they're marked as approved.
221+
"""
222+
self.get_success(
223+
self.store.register_user(
224+
self.user_id,
225+
self.pwhash,
226+
approved=True,
227+
)
228+
)
229+
230+
user = self.get_success(self.store.get_user_by_id(self.user_id))
231+
self.assertIsNotNone(user)
232+
assert user is not None
233+
self.assertEqual(user["approved"], 1)
234+
235+
approved = self.get_success(self.store.is_user_approved(self.user_id))
236+
self.assertTrue(approved)
237+
238+
def test_approve_user(self) -> None:
239+
"""Tests that approving the user updates their approval status."""
240+
self.get_success(self.store.register_user(self.user_id, self.pwhash))
241+
242+
approved = self.get_success(self.store.is_user_approved(self.user_id))
243+
self.assertFalse(approved)
244+
245+
self.get_success(
246+
self.store.update_user_approval_status(
247+
UserID.from_string(self.user_id), True
248+
)
249+
)
250+
251+
approved = self.get_success(self.store.is_user_approved(self.user_id))
252+
self.assertTrue(approved)

0 commit comments

Comments
 (0)