1616# limitations under the License.
1717import logging
1818import re
19- from typing import Any , Dict , List , Optional , Tuple
19+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
2020
2121from synapse .api .constants import UserTypes
2222from synapse .api .errors import Codes , StoreError , SynapseError , ThreepidValidationError
2323from synapse .metrics .background_process_metrics import wrap_as_background_process
24- from synapse .storage ._base import SQLBaseStore
2524from synapse .storage .database import DatabasePool
26- from synapse .storage .types import Cursor
25+ from synapse .storage .databases .main .cache import CacheInvalidationWorkerStore
26+ from synapse .storage .databases .main .stats import StatsStore
27+ from synapse .storage .types import Connection , Cursor
28+ from synapse .storage .util .id_generators import IdGenerator
2729from synapse .storage .util .sequence import build_sequence_generator
2830from synapse .types import UserID
2931from synapse .util .caches .descriptors import cached
3032
33+ if TYPE_CHECKING :
34+ from synapse .server import HomeServer
35+
3136THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
3237
3338logger = logging .getLogger (__name__ )
3439
3540
36- class RegistrationWorkerStore (SQLBaseStore ):
37- def __init__ (self , database : DatabasePool , db_conn , hs ):
41+ class RegistrationWorkerStore (CacheInvalidationWorkerStore ):
42+ def __init__ (self , database : DatabasePool , db_conn : Connection , hs : "HomeServer" ):
3843 super ().__init__ (database , db_conn , hs )
3944
4045 self .config = hs .config
41- self .clock = hs .get_clock ()
4246
4347 # Note: we don't check this sequence for consistency as we'd have to
4448 # call `find_max_generated_user_id_localpart` each time, which is
@@ -55,7 +59,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
5559
5660 # Create a background job for culling expired 3PID validity tokens
5761 if hs .config .run_background_tasks :
58- self .clock .looping_call (
62+ self ._clock .looping_call (
5963 self .cull_expired_threepid_validation_tokens , THIRTY_MINUTES_IN_MS
6064 )
6165
@@ -92,7 +96,7 @@ async def is_trial_user(self, user_id: str) -> bool:
9296 if not info :
9397 return False
9498
95- now = self .clock .time_msec ()
99+ now = self ._clock .time_msec ()
96100 trial_duration_ms = self .config .mau_trial_days * 24 * 60 * 60 * 1000
97101 is_trial = (now - info ["creation_ts" ] * 1000 ) < trial_duration_ms
98102 return is_trial
@@ -257,7 +261,7 @@ def select_users_txn(txn, now_ms, renew_at):
257261 return await self .db_pool .runInteraction (
258262 "get_users_expiring_soon" ,
259263 select_users_txn ,
260- self .clock .time_msec (),
264+ self ._clock .time_msec (),
261265 self .config .account_validity .renew_at ,
262266 )
263267
@@ -328,13 +332,17 @@ def set_server_admin_txn(txn):
328332 await self .db_pool .runInteraction ("set_server_admin" , set_server_admin_txn )
329333
330334 def _query_for_auth (self , txn , token ):
331- sql = (
332- "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
333- " access_tokens.device_id, access_tokens.valid_until_ms"
334- " FROM users"
335- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
336- " WHERE token = ?"
337- )
335+ sql = """
336+ SELECT users.name,
337+ users.is_guest,
338+ users.shadow_banned,
339+ access_tokens.id as token_id,
340+ access_tokens.device_id,
341+ access_tokens.valid_until_ms
342+ FROM users
343+ INNER JOIN access_tokens on users.name = access_tokens.user_id
344+ WHERE token = ?
345+ """
338346
339347 txn .execute (sql , (token ,))
340348 rows = self .db_pool .cursor_to_dict (txn )
@@ -803,7 +811,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts):
803811 await self .db_pool .runInteraction (
804812 "cull_expired_threepid_validation_tokens" ,
805813 cull_expired_threepid_validation_tokens_txn ,
806- self .clock .time_msec (),
814+ self ._clock .time_msec (),
807815 )
808816
809817 @wrap_as_background_process ("account_validity_set_expiration_dates" )
@@ -890,10 +898,10 @@ async def del_user_pending_deactivation(self, user_id: str) -> None:
890898
891899
892900class RegistrationBackgroundUpdateStore (RegistrationWorkerStore ):
893- def __init__ (self , database : DatabasePool , db_conn , hs ):
901+ def __init__ (self , database : DatabasePool , db_conn : Connection , hs : "HomeServer" ):
894902 super ().__init__ (database , db_conn , hs )
895903
896- self .clock = hs .get_clock ()
904+ self ._clock = hs .get_clock ()
897905 self .config = hs .config
898906
899907 self .db_pool .updates .register_background_index_update (
@@ -1016,13 +1024,56 @@ def _bg_user_threepids_grandfather_txn(txn):
10161024
10171025 return 1
10181026
1027+ async def set_user_deactivated_status (
1028+ self , user_id : str , deactivated : bool
1029+ ) -> None :
1030+ """Set the `deactivated` property for the provided user to the provided value.
1031+
1032+ Args:
1033+ user_id: The ID of the user to set the status for.
1034+ deactivated: The value to set for `deactivated`.
1035+ """
1036+
1037+ await self .db_pool .runInteraction (
1038+ "set_user_deactivated_status" ,
1039+ self .set_user_deactivated_status_txn ,
1040+ user_id ,
1041+ deactivated ,
1042+ )
1043+
1044+ def set_user_deactivated_status_txn (self , txn , user_id : str , deactivated : bool ):
1045+ self .db_pool .simple_update_one_txn (
1046+ txn = txn ,
1047+ table = "users" ,
1048+ keyvalues = {"name" : user_id },
1049+ updatevalues = {"deactivated" : 1 if deactivated else 0 },
1050+ )
1051+ self ._invalidate_cache_and_stream (
1052+ txn , self .get_user_deactivated_status , (user_id ,)
1053+ )
1054+ txn .call_after (self .is_guest .invalidate , (user_id ,))
1055+
1056+ @cached ()
1057+ async def is_guest (self , user_id : str ) -> bool :
1058+ res = await self .db_pool .simple_select_one_onecol (
1059+ table = "users" ,
1060+ keyvalues = {"name" : user_id },
1061+ retcol = "is_guest" ,
1062+ allow_none = True ,
1063+ desc = "is_guest" ,
1064+ )
1065+
1066+ return res if res else False
1067+
10191068
1020- class RegistrationStore (RegistrationBackgroundUpdateStore ):
1021- def __init__ (self , database : DatabasePool , db_conn , hs ):
1069+ class RegistrationStore (StatsStore , RegistrationBackgroundUpdateStore ):
1070+ def __init__ (self , database : DatabasePool , db_conn : Connection , hs : "HomeServer" ):
10221071 super ().__init__ (database , db_conn , hs )
10231072
10241073 self ._ignore_unknown_session_error = hs .config .request_token_inhibit_3pid_errors
10251074
1075+ self ._access_tokens_id_gen = IdGenerator (db_conn , "access_tokens" , "id" )
1076+
10261077 async def add_access_token_to_user (
10271078 self ,
10281079 user_id : str ,
@@ -1138,19 +1189,19 @@ async def register_user(
11381189 def _register_user (
11391190 self ,
11401191 txn ,
1141- user_id ,
1142- password_hash ,
1143- was_guest ,
1144- make_guest ,
1145- appservice_id ,
1146- create_profile_with_displayname ,
1147- admin ,
1148- user_type ,
1149- shadow_banned ,
1192+ user_id : str ,
1193+ password_hash : Optional [ str ] ,
1194+ was_guest : bool ,
1195+ make_guest : bool ,
1196+ appservice_id : Optional [ str ] ,
1197+ create_profile_with_displayname : Optional [ str ] ,
1198+ admin : bool ,
1199+ user_type : Optional [ str ] ,
1200+ shadow_banned : bool ,
11501201 ):
11511202 user_id_obj = UserID .from_string (user_id )
11521203
1153- now = int (self .clock .time ())
1204+ now = int (self ._clock .time ())
11541205
11551206 try :
11561207 if was_guest :
@@ -1374,18 +1425,6 @@ def f(txn):
13741425
13751426 await self .db_pool .runInteraction ("delete_access_token" , f )
13761427
1377- @cached ()
1378- async def is_guest (self , user_id : str ) -> bool :
1379- res = await self .db_pool .simple_select_one_onecol (
1380- table = "users" ,
1381- keyvalues = {"name" : user_id },
1382- retcol = "is_guest" ,
1383- allow_none = True ,
1384- desc = "is_guest" ,
1385- )
1386-
1387- return res if res else False
1388-
13891428 async def add_user_pending_deactivation (self , user_id : str ) -> None :
13901429 """
13911430 Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1479,7 +1518,7 @@ def validate_threepid_session_txn(txn):
14791518 txn ,
14801519 table = "threepid_validation_session" ,
14811520 keyvalues = {"session_id" : session_id },
1482- updatevalues = {"validated_at" : self .clock .time_msec ()},
1521+ updatevalues = {"validated_at" : self ._clock .time_msec ()},
14831522 )
14841523
14851524 return next_link
@@ -1547,35 +1586,6 @@ def start_or_continue_validation_session_txn(txn):
15471586 start_or_continue_validation_session_txn ,
15481587 )
15491588
1550- async def set_user_deactivated_status (
1551- self , user_id : str , deactivated : bool
1552- ) -> None :
1553- """Set the `deactivated` property for the provided user to the provided value.
1554-
1555- Args:
1556- user_id: The ID of the user to set the status for.
1557- deactivated: The value to set for `deactivated`.
1558- """
1559-
1560- await self .db_pool .runInteraction (
1561- "set_user_deactivated_status" ,
1562- self .set_user_deactivated_status_txn ,
1563- user_id ,
1564- deactivated ,
1565- )
1566-
1567- def set_user_deactivated_status_txn (self , txn , user_id , deactivated ):
1568- self .db_pool .simple_update_one_txn (
1569- txn = txn ,
1570- table = "users" ,
1571- keyvalues = {"name" : user_id },
1572- updatevalues = {"deactivated" : 1 if deactivated else 0 },
1573- )
1574- self ._invalidate_cache_and_stream (
1575- txn , self .get_user_deactivated_status , (user_id ,)
1576- )
1577- txn .call_after (self .is_guest .invalidate , (user_id ,))
1578-
15791589
15801590def find_max_generated_user_id_localpart (cur : Cursor ) -> int :
15811591 """
0 commit comments