Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit a9f90fa

Browse files
authored
Type hints for RegistrationStore (#8615)
1 parent 2ac908f commit a9f90fa

File tree

4 files changed

+85
-74
lines changed

4 files changed

+85
-74
lines changed

changelog.d/8615.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Type hints for `RegistrationStore`.

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ files =
5757
synapse/spam_checker_api,
5858
synapse/state,
5959
synapse/storage/databases/main/events.py,
60+
synapse/storage/databases/main/registration.py,
6061
synapse/storage/databases/main/stream.py,
6162
synapse/storage/databases/main/ui_auth.py,
6263
synapse/storage/database.py,

synapse/storage/databases/main/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def __init__(self, database: DatabasePool, db_conn, hs):
146146
db_conn, "e2e_cross_signing_keys", "stream_id"
147147
)
148148

149-
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
150149
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
151150
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
152151
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")

synapse/storage/databases/main/registration.py

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,33 @@
1616
# limitations under the License.
1717
import logging
1818
import re
19-
from typing import Any, Dict, List, Optional, Tuple
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
2020

2121
from synapse.api.constants import UserTypes
2222
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
2323
from synapse.metrics.background_process_metrics import wrap_as_background_process
24-
from synapse.storage._base import SQLBaseStore
2524
from synapse.storage.database import DatabasePool
26-
from synapse.storage.types import Cursor
25+
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
26+
from synapse.storage.databases.main.stats import StatsStore
27+
from synapse.storage.types import Connection, Cursor
28+
from synapse.storage.util.id_generators import IdGenerator
2729
from synapse.storage.util.sequence import build_sequence_generator
2830
from synapse.types import UserID
2931
from synapse.util.caches.descriptors import cached
3032

33+
if TYPE_CHECKING:
34+
from synapse.server import HomeServer
35+
3136
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
3237

3338
logger = logging.getLogger(__name__)
3439

3540

36-
class RegistrationWorkerStore(SQLBaseStore):
37-
def __init__(self, database: DatabasePool, db_conn, hs):
41+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
42+
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
3843
super().__init__(database, db_conn, hs)
3944

4045
self.config = hs.config
41-
self.clock = hs.get_clock()
4246

4347
# Note: we don't check this sequence for consistency as we'd have to
4448
# call `find_max_generated_user_id_localpart` each time, which is
@@ -55,7 +59,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
5559

5660
# Create a background job for culling expired 3PID validity tokens
5761
if hs.config.run_background_tasks:
58-
self.clock.looping_call(
62+
self._clock.looping_call(
5963
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
6064
)
6165

@@ -92,7 +96,7 @@ async def is_trial_user(self, user_id: str) -> bool:
9296
if not info:
9397
return False
9498

95-
now = self.clock.time_msec()
99+
now = self._clock.time_msec()
96100
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
97101
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
98102
return is_trial
@@ -257,7 +261,7 @@ def select_users_txn(txn, now_ms, renew_at):
257261
return await self.db_pool.runInteraction(
258262
"get_users_expiring_soon",
259263
select_users_txn,
260-
self.clock.time_msec(),
264+
self._clock.time_msec(),
261265
self.config.account_validity.renew_at,
262266
)
263267

@@ -328,13 +332,17 @@ def set_server_admin_txn(txn):
328332
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
329333

330334
def _query_for_auth(self, txn, token):
331-
sql = (
332-
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
333-
" access_tokens.device_id, access_tokens.valid_until_ms"
334-
" FROM users"
335-
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
336-
" WHERE token = ?"
337-
)
335+
sql = """
336+
SELECT users.name,
337+
users.is_guest,
338+
users.shadow_banned,
339+
access_tokens.id as token_id,
340+
access_tokens.device_id,
341+
access_tokens.valid_until_ms
342+
FROM users
343+
INNER JOIN access_tokens on users.name = access_tokens.user_id
344+
WHERE token = ?
345+
"""
338346

339347
txn.execute(sql, (token,))
340348
rows = self.db_pool.cursor_to_dict(txn)
@@ -803,7 +811,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts):
803811
await self.db_pool.runInteraction(
804812
"cull_expired_threepid_validation_tokens",
805813
cull_expired_threepid_validation_tokens_txn,
806-
self.clock.time_msec(),
814+
self._clock.time_msec(),
807815
)
808816

809817
@wrap_as_background_process("account_validity_set_expiration_dates")
@@ -890,10 +898,10 @@ async def del_user_pending_deactivation(self, user_id: str) -> None:
890898

891899

892900
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
893-
def __init__(self, database: DatabasePool, db_conn, hs):
901+
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
894902
super().__init__(database, db_conn, hs)
895903

896-
self.clock = hs.get_clock()
904+
self._clock = hs.get_clock()
897905
self.config = hs.config
898906

899907
self.db_pool.updates.register_background_index_update(
@@ -1016,13 +1024,56 @@ def _bg_user_threepids_grandfather_txn(txn):
10161024

10171025
return 1
10181026

1027+
async def set_user_deactivated_status(
1028+
self, user_id: str, deactivated: bool
1029+
) -> None:
1030+
"""Set the `deactivated` property for the provided user to the provided value.
1031+
1032+
Args:
1033+
user_id: The ID of the user to set the status for.
1034+
deactivated: The value to set for `deactivated`.
1035+
"""
1036+
1037+
await self.db_pool.runInteraction(
1038+
"set_user_deactivated_status",
1039+
self.set_user_deactivated_status_txn,
1040+
user_id,
1041+
deactivated,
1042+
)
1043+
1044+
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
1045+
self.db_pool.simple_update_one_txn(
1046+
txn=txn,
1047+
table="users",
1048+
keyvalues={"name": user_id},
1049+
updatevalues={"deactivated": 1 if deactivated else 0},
1050+
)
1051+
self._invalidate_cache_and_stream(
1052+
txn, self.get_user_deactivated_status, (user_id,)
1053+
)
1054+
txn.call_after(self.is_guest.invalidate, (user_id,))
1055+
1056+
@cached()
1057+
async def is_guest(self, user_id: str) -> bool:
1058+
res = await self.db_pool.simple_select_one_onecol(
1059+
table="users",
1060+
keyvalues={"name": user_id},
1061+
retcol="is_guest",
1062+
allow_none=True,
1063+
desc="is_guest",
1064+
)
1065+
1066+
return res if res else False
1067+
10191068

1020-
class RegistrationStore(RegistrationBackgroundUpdateStore):
1021-
def __init__(self, database: DatabasePool, db_conn, hs):
1069+
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
1070+
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
10221071
super().__init__(database, db_conn, hs)
10231072

10241073
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
10251074

1075+
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
1076+
10261077
async def add_access_token_to_user(
10271078
self,
10281079
user_id: str,
@@ -1138,19 +1189,19 @@ async def register_user(
11381189
def _register_user(
11391190
self,
11401191
txn,
1141-
user_id,
1142-
password_hash,
1143-
was_guest,
1144-
make_guest,
1145-
appservice_id,
1146-
create_profile_with_displayname,
1147-
admin,
1148-
user_type,
1149-
shadow_banned,
1192+
user_id: str,
1193+
password_hash: Optional[str],
1194+
was_guest: bool,
1195+
make_guest: bool,
1196+
appservice_id: Optional[str],
1197+
create_profile_with_displayname: Optional[str],
1198+
admin: bool,
1199+
user_type: Optional[str],
1200+
shadow_banned: bool,
11501201
):
11511202
user_id_obj = UserID.from_string(user_id)
11521203

1153-
now = int(self.clock.time())
1204+
now = int(self._clock.time())
11541205

11551206
try:
11561207
if was_guest:
@@ -1374,18 +1425,6 @@ def f(txn):
13741425

13751426
await self.db_pool.runInteraction("delete_access_token", f)
13761427

1377-
@cached()
1378-
async def is_guest(self, user_id: str) -> bool:
1379-
res = await self.db_pool.simple_select_one_onecol(
1380-
table="users",
1381-
keyvalues={"name": user_id},
1382-
retcol="is_guest",
1383-
allow_none=True,
1384-
desc="is_guest",
1385-
)
1386-
1387-
return res if res else False
1388-
13891428
async def add_user_pending_deactivation(self, user_id: str) -> None:
13901429
"""
13911430
Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1479,7 +1518,7 @@ def validate_threepid_session_txn(txn):
14791518
txn,
14801519
table="threepid_validation_session",
14811520
keyvalues={"session_id": session_id},
1482-
updatevalues={"validated_at": self.clock.time_msec()},
1521+
updatevalues={"validated_at": self._clock.time_msec()},
14831522
)
14841523

14851524
return next_link
@@ -1547,35 +1586,6 @@ def start_or_continue_validation_session_txn(txn):
15471586
start_or_continue_validation_session_txn,
15481587
)
15491588

1550-
async def set_user_deactivated_status(
1551-
self, user_id: str, deactivated: bool
1552-
) -> None:
1553-
"""Set the `deactivated` property for the provided user to the provided value.
1554-
1555-
Args:
1556-
user_id: The ID of the user to set the status for.
1557-
deactivated: The value to set for `deactivated`.
1558-
"""
1559-
1560-
await self.db_pool.runInteraction(
1561-
"set_user_deactivated_status",
1562-
self.set_user_deactivated_status_txn,
1563-
user_id,
1564-
deactivated,
1565-
)
1566-
1567-
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
1568-
self.db_pool.simple_update_one_txn(
1569-
txn=txn,
1570-
table="users",
1571-
keyvalues={"name": user_id},
1572-
updatevalues={"deactivated": 1 if deactivated else 0},
1573-
)
1574-
self._invalidate_cache_and_stream(
1575-
txn, self.get_user_deactivated_status, (user_id,)
1576-
)
1577-
txn.call_after(self.is_guest.invalidate, (user_id,))
1578-
15791589

15801590
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
15811591
"""

0 commit comments

Comments
 (0)