12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import TYPE_CHECKING , Dict , Iterable , List , Tuple
15
+ from typing import TYPE_CHECKING , Dict , Iterable , List , Tuple , cast
16
16
17
17
from synapse .api .presence import PresenceState , UserPresenceState
18
18
from synapse .replication .tcp .streams import PresenceStream
19
19
from synapse .storage ._base import SQLBaseStore , make_in_list_sql_clause
20
- from synapse .storage .database import DatabasePool , LoggingDatabaseConnection
20
+ from synapse .storage .database import (
21
+ DatabasePool ,
22
+ LoggingDatabaseConnection ,
23
+ LoggingTransaction ,
24
+ )
21
25
from synapse .storage .engines import PostgresEngine
22
26
from synapse .storage .types import Connection
23
- from synapse .storage .util .id_generators import MultiWriterIdGenerator , StreamIdGenerator
27
+ from synapse .storage .util .id_generators import (
28
+ AbstractStreamIdGenerator ,
29
+ MultiWriterIdGenerator ,
30
+ StreamIdGenerator ,
31
+ )
24
32
from synapse .util .caches .descriptors import cached , cachedList
25
33
from synapse .util .caches .stream_change_cache import StreamChangeCache
26
34
from synapse .util .iterutils import batch_iter
@@ -35,7 +43,7 @@ def __init__(
35
43
database : DatabasePool ,
36
44
db_conn : LoggingDatabaseConnection ,
37
45
hs : "HomeServer" ,
38
- ):
46
+ ) -> None :
39
47
super ().__init__ (database , db_conn , hs )
40
48
41
49
# Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ def __init__(
54
62
database : DatabasePool ,
55
63
db_conn : LoggingDatabaseConnection ,
56
64
hs : "HomeServer" ,
57
- ):
65
+ ) -> None :
58
66
super ().__init__ (database , db_conn , hs )
59
67
68
+ self ._instance_name = hs .get_instance_name ()
69
+ self ._presence_id_gen : AbstractStreamIdGenerator
70
+
60
71
self ._can_persist_presence = (
61
- hs . get_instance_name () in hs .config .worker .writers .presence
72
+ self . _instance_name in hs .config .worker .writers .presence
62
73
)
63
74
64
75
if isinstance (database .engine , PostgresEngine ):
@@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]:
109
120
110
121
return stream_orderings [- 1 ], self ._presence_id_gen .get_current_token ()
111
122
112
- def _update_presence_txn (self , txn , stream_orderings , presence_states ):
123
+ def _update_presence_txn (
124
+ self , txn : LoggingTransaction , stream_orderings , presence_states
125
+ ) -> None :
113
126
for stream_id , state in zip (stream_orderings , presence_states ):
114
127
txn .call_after (
115
128
self .presence_stream_cache .entity_has_changed , state .user_id , stream_id
@@ -183,19 +196,23 @@ async def get_all_presence_updates(
183
196
if last_id == current_id :
184
197
return [], current_id , False
185
198
186
- def get_all_presence_updates_txn (txn ):
199
+ def get_all_presence_updates_txn (
200
+ txn : LoggingTransaction ,
201
+ ) -> Tuple [List [Tuple [int , list ]], int , bool ]:
187
202
sql = """
188
203
SELECT stream_id, user_id, state, last_active_ts,
189
204
last_federation_update_ts, last_user_sync_ts,
190
- status_msg,
191
- currently_active
205
+ status_msg, currently_active
192
206
FROM presence_stream
193
207
WHERE ? < stream_id AND stream_id <= ?
194
208
ORDER BY stream_id ASC
195
209
LIMIT ?
196
210
"""
197
211
txn .execute (sql , (last_id , current_id , limit ))
198
- updates = [(row [0 ], row [1 :]) for row in txn ]
212
+ updates = cast (
213
+ List [Tuple [int , list ]],
214
+ [(row [0 ], row [1 :]) for row in txn ],
215
+ )
199
216
200
217
upper_bound = current_id
201
218
limited = False
@@ -210,15 +227,17 @@ def get_all_presence_updates_txn(txn):
210
227
)
211
228
212
229
@cached ()
213
- def _get_presence_for_user (self , user_id ) :
230
+ def _get_presence_for_user (self , user_id : str ) -> None :
214
231
raise NotImplementedError ()
215
232
216
233
@cachedList (
217
234
cached_method_name = "_get_presence_for_user" ,
218
235
list_name = "user_ids" ,
219
236
num_args = 1 ,
220
237
)
221
- async def get_presence_for_users (self , user_ids ):
238
+ async def get_presence_for_users (
239
+ self , user_ids : Iterable [str ]
240
+ ) -> Dict [str , UserPresenceState ]:
222
241
rows = await self .db_pool .simple_select_many_batch (
223
242
table = "presence_stream" ,
224
243
column = "user_id" ,
@@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token(
257
276
True if the user should have full presence sent to them, False otherwise.
258
277
"""
259
278
260
- def _should_user_receive_full_presence_with_token_txn (txn ):
279
+ def _should_user_receive_full_presence_with_token_txn (
280
+ txn : LoggingTransaction ,
281
+ ) -> bool :
261
282
sql = """
262
283
SELECT 1 FROM users_to_send_full_presence_to
263
284
WHERE user_id = ?
@@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn):
271
292
_should_user_receive_full_presence_with_token_txn ,
272
293
)
273
294
274
- async def add_users_to_send_full_presence_to (self , user_ids : Iterable [str ]):
295
+ async def add_users_to_send_full_presence_to (self , user_ids : Iterable [str ]) -> None :
275
296
"""Adds to the list of users who should receive a full snapshot of presence
276
297
upon their next sync.
277
298
@@ -353,10 +374,10 @@ async def get_presence_for_all_users(
353
374
354
375
return users_to_state
355
376
356
- def get_current_presence_token (self ):
377
+ def get_current_presence_token (self ) -> int :
357
378
return self ._presence_id_gen .get_current_token ()
358
379
359
- def _get_active_presence (self , db_conn : Connection ):
380
+ def _get_active_presence (self , db_conn : Connection ) -> List [ UserPresenceState ] :
360
381
"""Fetch non-offline presence from the database so that we can register
361
382
the appropriate time outs.
362
383
"""
@@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection):
379
400
380
401
return [UserPresenceState (** row ) for row in rows ]
381
402
382
- def take_presence_startup_info (self ):
403
+ def take_presence_startup_info (self ) -> List [ UserPresenceState ] :
383
404
active_on_startup = self ._presence_on_startup
384
- self ._presence_on_startup = None
405
+ self ._presence_on_startup = []
385
406
return active_on_startup
386
407
387
- def process_replication_rows (self , stream_name , instance_name , token , rows ):
408
+ def process_replication_rows (self , stream_name , instance_name , token , rows ) -> None :
388
409
if stream_name == PresenceStream .NAME :
389
410
self ._presence_id_gen .advance (instance_name , token )
390
411
for row in rows :
0 commit comments