@@ -181,6 +181,7 @@ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
181181 "user_type" ,
182182 "deactivated" ,
183183 "shadow_banned" ,
184+ "approved" ,
184185 ],
185186 allow_none = True ,
186187 desc = "get_user_by_id" ,
@@ -1778,6 +1779,26 @@ async def is_guest(self, user_id: str) -> bool:
17781779
17791780 return res if res else False
17801781
1782+ @cached ()
1783+ async def is_user_approved (self , user_id : str ) -> bool :
1784+ """Checks if a user is approved and therefore can be allowed to log in.
1785+
1786+ Args:
1787+ user_id: the user to check the approval status of.
1788+
1789+ Returns:
1790+ A boolean that is True if the user is approved, False otherwise.
1791+ """
1792+ ret = await self .db_pool .simple_select_one_onecol (
1793+ table = "users" ,
1794+ keyvalues = {"name" : user_id },
1795+ retcol = "approved" ,
1796+ allow_none = True ,
1797+ desc = "is_user_pending_approval" ,
1798+ )
1799+
1800+ return bool (ret )
1801+
17811802
17821803class RegistrationBackgroundUpdateStore (RegistrationWorkerStore ):
17831804 def __init__ (
@@ -1817,6 +1838,10 @@ def __init__(
18171838 unique = False ,
18181839 )
18191840
1841+ self .db_pool .updates .register_background_update_handler (
1842+ "users_set_approved_flag" , self ._background_update_set_approved_flag
1843+ )
1844+
18201845 async def _background_update_set_deactivated_flag (
18211846 self , progress : JsonDict , batch_size : int
18221847 ) -> int :
@@ -1915,6 +1940,75 @@ def set_user_deactivated_status_txn(
19151940 self ._invalidate_cache_and_stream (txn , self .get_user_by_id , (user_id ,))
19161941 txn .call_after (self .is_guest .invalidate , (user_id ,))
19171942
1943+ async def _background_update_set_approved_flag (
1944+ self , progress : JsonDict , batch_size : int
1945+ ) -> int :
1946+ """Set the 'approved' flag for all already existing users. We want to set it to
1947+ true systematically because we don't want to suddenly prevent already existing
1948+ users from logging in if the option to block registration on approval is turned
1949+ on.
1950+ """
1951+ last_user = progress .get ("user_id" , "" )
1952+
1953+ def _background_update_set_approved_flag_txn (txn : LoggingTransaction ) -> int :
1954+ txn .execute (
1955+ """
1956+ SELECT name
1957+ FROM users
1958+ WHERE
1959+ approved IS NULL
1960+ AND name > ?
1961+ ORDER BY name ASC
1962+ LIMIT ?
1963+ """ ,
1964+ (last_user , batch_size ),
1965+ )
1966+ rows = self .db_pool .cursor_to_dict (txn )
1967+
1968+ if len (rows ) == 0 :
1969+ return 0
1970+
1971+ for user in rows :
1972+ self .update_user_approval_status_txn (txn , user ["name" ], True )
1973+
1974+ self .db_pool .updates ._background_update_progress_txn (
1975+ txn , "users_set_approved_flag" , {"user_id" : rows [- 1 ]["name" ]}
1976+ )
1977+
1978+ return len (rows )
1979+
1980+ nb_processed = await self .db_pool .runInteraction (
1981+ "users_set_approved_flag" , _background_update_set_approved_flag_txn
1982+ )
1983+
1984+ if nb_processed < batch_size :
1985+ await self .db_pool .updates ._end_background_update ("users_set_approved_flag" )
1986+
1987+ return nb_processed
1988+
1989+ def update_user_approval_status_txn (
1990+ self , txn : LoggingTransaction , user_id : str , approved : bool
1991+ ) -> None :
1992+ """Set the user's 'approved' flag to the given value.
1993+
1994+ The boolean is turned into an int because the column is a smallint.
1995+
1996+ Args:
1997+ txn: the current database transaction.
1998+ user_id: the user to update the flag for.
1999+ approved: the value to set the flag to.
2000+ """
2001+ self .db_pool .simple_update_one_txn (
2002+ txn = txn ,
2003+ table = "users" ,
2004+ keyvalues = {"name" : user_id },
2005+ updatevalues = {"approved" : int (approved )},
2006+ )
2007+
2008+ # Invalidate the caches of methods that read the value of the 'approved' flag.
2009+ self ._invalidate_cache_and_stream (txn , self .get_user_by_id , (user_id ,))
2010+ self ._invalidate_cache_and_stream (txn , self .is_user_approved , (user_id ,))
2011+
19182012
19192013class RegistrationStore (StatsStore , RegistrationBackgroundUpdateStore ):
19202014 def __init__ (
@@ -1932,6 +2026,13 @@ def __init__(
19322026 self ._access_tokens_id_gen = IdGenerator (db_conn , "access_tokens" , "id" )
19332027 self ._refresh_tokens_id_gen = IdGenerator (db_conn , "refresh_tokens" , "id" )
19342028
2029+ # If support for MSC3866 is enabled and configured to require approval for new
2030+ # account, we will create new users with an 'approved' flag set to false.
2031+ self ._require_approval = (
2032+ hs .config .experimental .msc3866 .enabled
2033+ and hs .config .experimental .msc3866 .require_approval_for_new_accounts
2034+ )
2035+
19352036 async def add_access_token_to_user (
19362037 self ,
19372038 user_id : str ,
@@ -2064,6 +2165,7 @@ async def register_user(
20642165 admin : bool = False ,
20652166 user_type : Optional [str ] = None ,
20662167 shadow_banned : bool = False ,
2168+ approved : bool = False ,
20672169 ) -> None :
20682170 """Attempts to register an account.
20692171
@@ -2082,6 +2184,8 @@ async def register_user(
20822184 or None for a normal user.
20832185 shadow_banned: Whether the user is shadow-banned, i.e. they may be
20842186 told their requests succeeded but we ignore them.
2187+ approved: Whether to consider the user has already been approved by an
2188+ administrator.
20852189
20862190 Raises:
20872191 StoreError if the user_id could not be registered.
@@ -2098,6 +2202,7 @@ async def register_user(
20982202 admin ,
20992203 user_type ,
21002204 shadow_banned ,
2205+ approved ,
21012206 )
21022207
21032208 def _register_user (
@@ -2112,11 +2217,14 @@ def _register_user(
21122217 admin : bool ,
21132218 user_type : Optional [str ],
21142219 shadow_banned : bool ,
2220+ approved : bool ,
21152221 ) -> None :
21162222 user_id_obj = UserID .from_string (user_id )
21172223
21182224 now = int (self ._clock .time ())
21192225
2226+ pending_approval = self ._require_approval and not approved
2227+
21202228 try :
21212229 if was_guest :
21222230 # Ensure that the guest user actually exists
@@ -2142,6 +2250,7 @@ def _register_user(
21422250 "admin" : 1 if admin else 0 ,
21432251 "user_type" : user_type ,
21442252 "shadow_banned" : shadow_banned ,
2253+ "approved" : 0 if pending_approval else 1 ,
21452254 },
21462255 )
21472256 else :
@@ -2157,6 +2266,7 @@ def _register_user(
21572266 "admin" : 1 if admin else 0 ,
21582267 "user_type" : user_type ,
21592268 "shadow_banned" : shadow_banned ,
2269+ "approved" : 0 if pending_approval else 1 ,
21602270 },
21612271 )
21622272
@@ -2499,6 +2609,25 @@ def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
24992609 start_or_continue_validation_session_txn ,
25002610 )
25012611
2612+ async def update_user_approval_status (
2613+ self , user_id : UserID , approved : bool
2614+ ) -> None :
2615+ """Set the user's 'approved' flag to the given value.
2616+
2617+ The boolean will be turned into an int (in update_user_approval_status_txn)
2618+ because the column is a smallint.
2619+
2620+ Args:
2621+ user_id: the user to update the flag for.
2622+ approved: the value to set the flag to.
2623+ """
2624+ await self .db_pool .runInteraction (
2625+ "update_user_approval_status" ,
2626+ self .update_user_approval_status_txn ,
2627+ user_id .to_string (),
2628+ approved ,
2629+ )
2630+
25022631
25032632def find_max_generated_user_id_localpart (cur : Cursor ) -> int :
25042633 """
0 commit comments