Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 7c82da2

Browse files
authored
Add type hints to synapse/storage/databases/main (#11984)
1 parent 99f6d79 commit 7c82da2

File tree

7 files changed

+79
-53
lines changed

7 files changed

+79
-53
lines changed

changelog.d/11984.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to storage classes.

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@ exclude = (?x)
3131
|synapse/storage/databases/main/group_server.py
3232
|synapse/storage/databases/main/metrics.py
3333
|synapse/storage/databases/main/monthly_active_users.py
34-
|synapse/storage/databases/main/presence.py
35-
|synapse/storage/databases/main/purge_events.py
3634
|synapse/storage/databases/main/push_rule.py
3735
|synapse/storage/databases/main/receipts.py
3836
|synapse/storage/databases/main/roommember.py
3937
|synapse/storage/databases/main/search.py
4038
|synapse/storage/databases/main/state.py
41-
|synapse/storage/databases/main/user_directory.py
4239
|synapse/storage/schema/
4340

4441
|tests/api/test_auth.py

synapse/handlers/presence.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,25 +204,27 @@ async def current_state_for_users(
204204
Returns:
205205
dict: `user_id` -> `UserPresenceState`
206206
"""
207-
states = {
208-
user_id: self.user_to_current_state.get(user_id, None)
209-
for user_id in user_ids
210-
}
207+
states = {}
208+
missing = []
209+
for user_id in user_ids:
210+
state = self.user_to_current_state.get(user_id, None)
211+
if state:
212+
states[user_id] = state
213+
else:
214+
missing.append(user_id)
211215

212-
missing = [user_id for user_id, state in states.items() if not state]
213216
if missing:
214217
# There are things not in our in memory cache. Lets pull them out of
215218
# the database.
216219
res = await self.store.get_presence_for_users(missing)
217220
states.update(res)
218221

219-
missing = [user_id for user_id, state in states.items() if not state]
220-
if missing:
221-
new = {
222-
user_id: UserPresenceState.default(user_id) for user_id in missing
223-
}
224-
states.update(new)
225-
self.user_to_current_state.update(new)
222+
for user_id in missing:
223+
# if user has no state in database, create the state
224+
if not res.get(user_id, None):
225+
new_state = UserPresenceState.default(user_id)
226+
states[user_id] = new_state
227+
self.user_to_current_state[user_id] = new_state
226228

227229
return states
228230

synapse/storage/databases/main/presence.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
15+
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
1616

1717
from synapse.api.presence import PresenceState, UserPresenceState
1818
from synapse.replication.tcp.streams import PresenceStream
1919
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
20-
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
20+
from synapse.storage.database import (
21+
DatabasePool,
22+
LoggingDatabaseConnection,
23+
LoggingTransaction,
24+
)
2125
from synapse.storage.engines import PostgresEngine
2226
from synapse.storage.types import Connection
23-
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
27+
from synapse.storage.util.id_generators import (
28+
AbstractStreamIdGenerator,
29+
MultiWriterIdGenerator,
30+
StreamIdGenerator,
31+
)
2432
from synapse.util.caches.descriptors import cached, cachedList
2533
from synapse.util.caches.stream_change_cache import StreamChangeCache
2634
from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ def __init__(
3543
database: DatabasePool,
3644
db_conn: LoggingDatabaseConnection,
3745
hs: "HomeServer",
38-
):
46+
) -> None:
3947
super().__init__(database, db_conn, hs)
4048

4149
# Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ def __init__(
5462
database: DatabasePool,
5563
db_conn: LoggingDatabaseConnection,
5664
hs: "HomeServer",
57-
):
65+
) -> None:
5866
super().__init__(database, db_conn, hs)
5967

68+
self._instance_name = hs.get_instance_name()
69+
self._presence_id_gen: AbstractStreamIdGenerator
70+
6071
self._can_persist_presence = (
61-
hs.get_instance_name() in hs.config.worker.writers.presence
72+
self._instance_name in hs.config.worker.writers.presence
6273
)
6374

6475
if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ async def update_presence(self, presence_states) -> Tuple[int, int]:
109120

110121
return stream_orderings[-1], self._presence_id_gen.get_current_token()
111122

112-
def _update_presence_txn(self, txn, stream_orderings, presence_states):
123+
def _update_presence_txn(
124+
self, txn: LoggingTransaction, stream_orderings, presence_states
125+
) -> None:
113126
for stream_id, state in zip(stream_orderings, presence_states):
114127
txn.call_after(
115128
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ async def get_all_presence_updates(
183196
if last_id == current_id:
184197
return [], current_id, False
185198

186-
def get_all_presence_updates_txn(txn):
199+
def get_all_presence_updates_txn(
200+
txn: LoggingTransaction,
201+
) -> Tuple[List[Tuple[int, list]], int, bool]:
187202
sql = """
188203
SELECT stream_id, user_id, state, last_active_ts,
189204
last_federation_update_ts, last_user_sync_ts,
190-
status_msg,
191-
currently_active
205+
status_msg, currently_active
192206
FROM presence_stream
193207
WHERE ? < stream_id AND stream_id <= ?
194208
ORDER BY stream_id ASC
195209
LIMIT ?
196210
"""
197211
txn.execute(sql, (last_id, current_id, limit))
198-
updates = [(row[0], row[1:]) for row in txn]
212+
updates = cast(
213+
List[Tuple[int, list]],
214+
[(row[0], row[1:]) for row in txn],
215+
)
199216

200217
upper_bound = current_id
201218
limited = False
@@ -210,15 +227,17 @@ def get_all_presence_updates_txn(txn):
210227
)
211228

212229
@cached()
213-
def _get_presence_for_user(self, user_id):
230+
def _get_presence_for_user(self, user_id: str) -> None:
214231
raise NotImplementedError()
215232

216233
@cachedList(
217234
cached_method_name="_get_presence_for_user",
218235
list_name="user_ids",
219236
num_args=1,
220237
)
221-
async def get_presence_for_users(self, user_ids):
238+
async def get_presence_for_users(
239+
self, user_ids: Iterable[str]
240+
) -> Dict[str, UserPresenceState]:
222241
rows = await self.db_pool.simple_select_many_batch(
223242
table="presence_stream",
224243
column="user_id",
@@ -257,7 +276,9 @@ async def should_user_receive_full_presence_with_token(
257276
True if the user should have full presence sent to them, False otherwise.
258277
"""
259278

260-
def _should_user_receive_full_presence_with_token_txn(txn):
279+
def _should_user_receive_full_presence_with_token_txn(
280+
txn: LoggingTransaction,
281+
) -> bool:
261282
sql = """
262283
SELECT 1 FROM users_to_send_full_presence_to
263284
WHERE user_id = ?
@@ -271,7 +292,7 @@ def _should_user_receive_full_presence_with_token_txn(txn):
271292
_should_user_receive_full_presence_with_token_txn,
272293
)
273294

274-
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
295+
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
275296
"""Adds to the list of users who should receive a full snapshot of presence
276297
upon their next sync.
277298
@@ -353,10 +374,10 @@ async def get_presence_for_all_users(
353374

354375
return users_to_state
355376

356-
def get_current_presence_token(self):
377+
def get_current_presence_token(self) -> int:
357378
return self._presence_id_gen.get_current_token()
358379

359-
def _get_active_presence(self, db_conn: Connection):
380+
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
360381
"""Fetch non-offline presence from the database so that we can register
361382
the appropriate time outs.
362383
"""
@@ -379,12 +400,12 @@ def _get_active_presence(self, db_conn: Connection):
379400

380401
return [UserPresenceState(**row) for row in rows]
381402

382-
def take_presence_startup_info(self):
403+
def take_presence_startup_info(self) -> List[UserPresenceState]:
383404
active_on_startup = self._presence_on_startup
384-
self._presence_on_startup = None
405+
self._presence_on_startup = []
385406
return active_on_startup
386407

387-
def process_replication_rows(self, stream_name, instance_name, token, rows):
408+
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
388409
if stream_name == PresenceStream.NAME:
389410
self._presence_id_gen.advance(instance_name, token)
390411
for row in rows:

synapse/storage/databases/main/purge_events.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Any, List, Set, Tuple
16+
from typing import Any, List, Set, Tuple, cast
1717

1818
from synapse.api.errors import SynapseError
19+
from synapse.storage.database import LoggingTransaction
1920
from synapse.storage.databases.main import CacheInvalidationWorkerStore
2021
from synapse.storage.databases.main.state import StateGroupWorkerStore
2122
from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ async def purge_history(
5556
)
5657

5758
def _purge_history_txn(
58-
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
59+
self,
60+
txn: LoggingTransaction,
61+
room_id: str,
62+
token: RoomStreamToken,
63+
delete_local_events: bool,
5964
) -> Set[int]:
6065
# Tables that should be pruned:
6166
# event_auth
@@ -273,7 +278,7 @@ def _purge_history_txn(
273278
""",
274279
(room_id,),
275280
)
276-
(min_depth,) = txn.fetchone()
281+
(min_depth,) = cast(Tuple[int], txn.fetchone())
277282

278283
logger.info("[purge] updating room_depth to %d", min_depth)
279284

@@ -318,7 +323,7 @@ async def purge_room(self, room_id: str) -> List[int]:
318323
"purge_room", self._purge_room_txn, room_id
319324
)
320325

321-
def _purge_room_txn(self, txn, room_id: str) -> List[int]:
326+
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
322327
# First we fetch all the state groups that should be deleted, before
323328
# we delete that information.
324329
txn.execute(

synapse/storage/databases/main/user_directory.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
database: DatabasePool,
5959
db_conn: LoggingDatabaseConnection,
6060
hs: "HomeServer",
61-
):
61+
) -> None:
6262
super().__init__(database, db_conn, hs)
6363

6464
self.server_name = hs.hostname
@@ -234,10 +234,10 @@ def _get_next_batch(
234234
processed_event_count = 0
235235

236236
for room_id, event_count in rooms_to_work_on:
237-
is_in_room = await self.is_host_joined(room_id, self.server_name)
237+
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
238238

239239
if is_in_room:
240-
users_with_profile = await self.get_users_in_room_with_profiles(room_id)
240+
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
241241
# Throw away users excluded from the directory.
242242
users_with_profile = {
243243
user_id: profile
@@ -368,7 +368,7 @@ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
368368

369369
for user_id in users_to_work_on:
370370
if await self.should_include_local_user_in_dir(user_id):
371-
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
371+
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
372372
await self.update_profile_in_user_dir(
373373
user_id, profile.display_name, profile.avatar_url
374374
)
@@ -397,25 +397,25 @@ async def should_include_local_user_in_dir(self, user: str) -> bool:
397397
# technically it could be DM-able. In the future, this could potentially
398398
# be configurable per-appservice whether the appservice sender can be
399399
# contacted.
400-
if self.get_app_service_by_user_id(user) is not None:
400+
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
401401
return False
402402

403403
# We're opting to exclude appservice users (anyone matching the user
404404
# namespace regex in the appservice registration) even though technically
405405
# they could be DM-able. In the future, this could potentially
406406
# be configurable per-appservice whether the appservice users can be
407407
# contacted.
408-
if self.get_if_app_services_interested_in_user(user):
408+
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
409409
# TODO we might want to make this configurable for each app service
410410
return False
411411

412412
# Support users are for diagnostics and should not appear in the user directory.
413-
if await self.is_support_user(user):
413+
if await self.is_support_user(user): # type: ignore[attr-defined]
414414
return False
415415

416416
# Deactivated users aren't contactable, so should not appear in the user directory.
417417
try:
418-
if await self.get_user_deactivated_status(user):
418+
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
419419
return False
420420
except StoreError:
421421
# No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> boo
433433
(EventTypes.RoomHistoryVisibility, ""),
434434
)
435435

436-
current_state_ids = await self.get_filtered_current_state_ids(
436+
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
437437
room_id, StateFilter.from_types(types_to_filter)
438438
)
439439

440440
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
441441
if join_rules_id:
442-
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
442+
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
443443
if join_rule_ev:
444444
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
445445
return True
446446

447447
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
448448
if hist_vis_id:
449-
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
449+
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
450450
if hist_vis_ev:
451451
if (
452452
hist_vis_ev.content.get("history_visibility")

synapse/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
if TYPE_CHECKING:
5353
from synapse.appservice.api import ApplicationService
54-
from synapse.storage.databases.main import DataStore
54+
from synapse.storage.databases.main import DataStore, PurgeEventsStore
5555

5656
# Define a state map type from type/state_key to T (usually an event ID or
5757
# event)
@@ -485,7 +485,7 @@ def __attrs_post_init__(self) -> None:
485485
)
486486

487487
@classmethod
488-
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
488+
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
489489
try:
490490
if string[0] == "s":
491491
return cls(topological=None, stream=int(string[1:]))
@@ -502,7 +502,7 @@ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
502502
instance_id = int(key)
503503
pos = int(value)
504504

505-
instance_name = await store.get_name_from_instance_id(instance_id)
505+
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
506506
instance_map[instance_name] = pos
507507

508508
return cls(

0 commit comments

Comments
 (0)