3030)
3131
3232import attr
33+
3334from twisted .internet .interfaces import IDelayedCall
3435
3536from synapse .api .constants import EventTypes
3839from synapse .logging .opentracing import set_tag
3940from synapse .metrics .background_process_metrics import run_as_background_process
4041from synapse .storage .databases .main .delayed_events import (
41- EventType ,
4242 Delay ,
4343 DelayID ,
44+ EventType ,
4445 StateKey ,
4546 Timestamp ,
4647 UserLocalpart ,
4748)
48- from synapse .types import JsonDict , Requester , RoomID , StateMap , UserID , create_requester
49+ from synapse .types import (
50+ JsonDict ,
51+ Requester ,
52+ RoomID ,
53+ StateMap ,
54+ UserID ,
55+ create_requester ,
56+ )
4957from synapse .util .async_helpers import Linearizer , ReadWriteLock
5058from synapse .util .stringutils import random_string
5159
@@ -89,7 +97,9 @@ def __init__(self, hs: "HomeServer"):
8997
9098 async def _schedule_db_events () -> None :
9199 # TODO: Sync all state first, so that affected delayed state events will be cancelled
92- events , remaining_timeout_delays = await self .store .process_all_delays (self ._get_current_ts ())
100+ events , remaining_timeout_delays = await self .store .process_all_delays (
101+ self ._get_current_ts ()
102+ )
93103 for args in events :
94104 await self ._send_event (* args )
95105
@@ -104,7 +114,9 @@ async def _schedule_db_events() -> None:
104114 "_schedule_db_events" , _schedule_db_events
105115 )
106116
107- async def on_new_event (self , event : EventBase , _state_events : StateMap [EventBase ]) -> None :
117+ async def on_new_event (
118+ self , event : EventBase , _state_events : StateMap [EventBase ]
119+ ) -> None :
108120 """
109121 Checks if a received event is a state event, and if so,
110122 cancels any delayed events that target the same state.
@@ -209,7 +221,7 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
209221 except ValueError :
210222 raise SynapseError (
211223 HTTPStatus .BAD_REQUEST ,
212- f"'action' is not one of { ', ' .join (map ( lambda m : m .value , _UpdateDelayedEventAction ) )} " ,
224+ f"'action' is not one of { ', ' .join (m .value for m in _UpdateDelayedEventAction )} " ,
213225 Codes .INVALID_PARAM ,
214226 )
215227
@@ -220,7 +232,9 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
220232
221233 async with self ._get_delay_context (delay_id , user_localpart ):
222234 if enum_action == _UpdateDelayedEventAction .CANCEL :
223- for removed_timeout_delay_id in await self .store .remove (delay_id , user_localpart ):
235+ for removed_timeout_delay_id in await self .store .remove (
236+ delay_id , user_localpart
237+ ):
224238 self ._unschedule (removed_timeout_delay_id , user_localpart )
225239
226240 elif enum_action == _UpdateDelayedEventAction .RESTART :
@@ -234,22 +248,29 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None
234248 self ._schedule (delay_id , user_localpart , delay )
235249
236250 elif enum_action == _UpdateDelayedEventAction .SEND :
237- args , removed_timeout_delay_ids = await self .store .pop_event (delay_id , user_localpart )
251+ args , removed_timeout_delay_ids = await self .store .pop_event (
252+ delay_id , user_localpart
253+ )
238254
239255 for timeout_delay_id in removed_timeout_delay_ids :
240256 self ._unschedule (timeout_delay_id , user_localpart )
241257 await self ._send_event (user_localpart , * args )
242258
243- async def _send_on_timeout (self , delay_id : DelayID , user_localpart : UserLocalpart ) -> None :
259+ async def _send_on_timeout (
260+ self , delay_id : DelayID , user_localpart : UserLocalpart
261+ ) -> None :
244262 del self ._delayed_calls [_DelayedCallKey (delay_id , user_localpart )]
245263
246264 async with self ._get_delay_context (delay_id , user_localpart ):
247265 try :
248- args , removed_timeout_delay_ids = await self .store .pop_event (delay_id , user_localpart )
266+ args , removed_timeout_delay_ids = await self .store .pop_event (
267+ delay_id , user_localpart
268+ )
249269 except NotFoundError :
250270 logger .debug (
251271 "delay_id %s for local user %s was removed after it timed out, but before it was sent on timeout" ,
252- delay_id , user_localpart ,
272+ delay_id ,
273+ user_localpart ,
253274 )
254275 return
255276
@@ -268,27 +289,36 @@ def _schedule(
268289 delay_sec = delay / 1000
269290
270291 logger .info (
271- "Scheduling delayed event %s for local user %s to be sent in %.3fs" , delay_id , user_localpart , delay_sec
272- )
273-
274- self ._delayed_calls [_DelayedCallKey (delay_id , user_localpart )] = self .clock .call_later (
275- delay_sec ,
276- run_as_background_process ,
277- "_send_on_timeout" ,
278- self ._send_on_timeout ,
292+ "Scheduling delayed event %s for local user %s to be sent in %.3fs" ,
279293 delay_id ,
280294 user_localpart ,
295+ delay_sec ,
296+ )
297+
298+ self ._delayed_calls [_DelayedCallKey (delay_id , user_localpart )] = (
299+ self .clock .call_later (
300+ delay_sec ,
301+ run_as_background_process ,
302+ "_send_on_timeout" ,
303+ self ._send_on_timeout ,
304+ delay_id ,
305+ user_localpart ,
306+ )
281307 )
282308
283309 def _unschedule (self , delay_id : DelayID , user_localpart : UserLocalpart ) -> None :
284- delayed_call = self ._delayed_calls .pop (_DelayedCallKey (delay_id , user_localpart ))
310+ delayed_call = self ._delayed_calls .pop (
311+ _DelayedCallKey (delay_id , user_localpart )
312+ )
285313 self .clock .cancel_call_later (delayed_call )
286314
287315 async def get_all_for_user (self , requester : Requester ) -> List [JsonDict ]:
288316 """Return all pending delayed events requested by the given user."""
289317 await self .request_ratelimiter .ratelimit (requester )
290318 await self ._initialized_from_db
291- return await self .store .get_all_for_user (UserLocalpart (requester .user .localpart ))
319+ return await self .store .get_all_for_user (
320+ UserLocalpart (requester .user .localpart )
321+ )
292322
293323 async def _send_event (
294324 self ,
@@ -350,13 +380,14 @@ def _get_current_ts(self) -> Timestamp:
350380 return Timestamp (self .clock .time_msec ())
351381
352382 @asynccontextmanager
353- async def _get_delay_context (self , delay_id : DelayID , user_localpart : UserLocalpart ) -> AsyncIterator [None ]:
383+ async def _get_delay_context (
384+ self , delay_id : DelayID , user_localpart : UserLocalpart
385+ ) -> AsyncIterator [None ]:
354386 await self ._initialized_from_db
355387 # TODO: Use parenthesized context manager once the minimum supported Python version is 3.10
356- async with \
357- self ._state_lock .read (_STATE_LOCK_KEY ),\
358- self ._linearizer .queue (_DelayedCallKey (delay_id , user_localpart ))\
359- :
388+ async with self ._state_lock .read (_STATE_LOCK_KEY ), self ._linearizer .queue (
389+ _DelayedCallKey (delay_id , user_localpart )
390+ ):
360391 yield
361392
362393 def _get_state_context (self ) -> AsyncContextManager :
0 commit comments