Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit b7672ff

Browse files
committed
Merge commit 'aec708517' into anoa/dinsic_release_1_21_x
* commit 'aec708517': Convert state and stream stores and related code to async (#8194) Ensure that the OpenID Connect remote ID is a string. (#8190)
2 parents 41ac123 + aec7085 commit b7672ff

File tree

10 files changed

+94
-47
lines changed

10 files changed

+94
-47
lines changed

changelog.d/8190.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix logging in via OpenID Connect with a provider that uses integer user IDs.

changelog.d/8194.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Convert various parts of the codebase to async/await.

synapse/handlers/oidc_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,9 @@ async def _map_userinfo_to_user(
869869
raise MappingException(
870870
"Failed to extract subject from OIDC response: %s" % (e,)
871871
)
872+
# Some OIDC providers use integer IDs, but Synapse expects external IDs
873+
# to be strings.
874+
remote_user_id = str(remote_user_id)
872875

873876
logger.info(
874877
"Looking for existing mapping for user %s:%s",

synapse/handlers/room.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ async def clone_existing_room(
463463
old_room_member_state_events = await self.store.get_events(
464464
old_room_member_state_ids.values()
465465
)
466-
for k, old_event in old_room_member_state_events.items():
466+
for old_event in old_room_member_state_events.values():
467467
# Only transfer ban events
468468
if (
469469
"membership" in old_event.content

synapse/storage/databases/main/state.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from synapse.storage.databases.main.events_worker import EventsWorkerStore
2828
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
2929
from synapse.storage.state import StateFilter
30+
from synapse.types import StateMap
3031
from synapse.util.caches import intern_string
3132
from synapse.util.caches.descriptors import cached, cachedList
3233

@@ -163,15 +164,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase:
163164
return create_event
164165

165166
@cached(max_entries=100000, iterable=True)
166-
def get_current_state_ids(self, room_id):
167+
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
167168
"""Get the current state event ids for a room based on the
168169
current_state_events table.
169170
170171
Args:
171-
room_id (str)
172+
room_id: The room to get the state IDs of.
172173
173174
Returns:
174-
deferred: dict of (type, state_key) -> event_id
175+
The current state of the room.
175176
"""
176177

177178
def _get_current_state_ids_txn(txn):
@@ -184,14 +185,14 @@ def _get_current_state_ids_txn(txn):
184185

185186
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
186187

187-
return self.db_pool.runInteraction(
188+
return await self.db_pool.runInteraction(
188189
"get_current_state_ids", _get_current_state_ids_txn
189190
)
190191

191192
# FIXME: how should this be cached?
192-
def get_filtered_current_state_ids(
193+
async def get_filtered_current_state_ids(
193194
self, room_id: str, state_filter: StateFilter = StateFilter.all()
194-
):
195+
) -> StateMap[str]:
195196
"""Get the current state event of a given type for a room based on the
196197
current_state_events table. This may not be as up-to-date as the result
197198
of doing a fresh state resolution as per state_handler.get_current_state
@@ -202,14 +203,14 @@ def get_filtered_current_state_ids(
202203
from the database.
203204
204205
Returns:
205-
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
206+
Map from type/state_key to event ID.
206207
"""
207208

208209
where_clause, where_args = state_filter.make_sql_filter_clause()
209210

210211
if not where_clause:
211212
# We delegate to the cached version
212-
return self.get_current_state_ids(room_id)
213+
return await self.get_current_state_ids(room_id)
213214

214215
def _get_filtered_current_state_ids_txn(txn):
215216
results = {}
@@ -231,7 +232,7 @@ def _get_filtered_current_state_ids_txn(txn):
231232

232233
return results
233234

234-
return self.db_pool.runInteraction(
235+
return await self.db_pool.runInteraction(
235236
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
236237
)
237238

synapse/storage/databases/main/state_deltas.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
# limitations under the License.
1515

1616
import logging
17-
18-
from twisted.internet import defer
17+
from typing import Any, Dict, List, Tuple
1918

2019
from synapse.storage._base import SQLBaseStore
2120

2221
logger = logging.getLogger(__name__)
2322

2423

2524
class StateDeltasStore(SQLBaseStore):
26-
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
25+
async def get_current_state_deltas(
26+
self, prev_stream_id: int, max_stream_id: int
27+
) -> Tuple[int, List[Dict[str, Any]]]:
2728
"""Fetch a list of room state changes since the given stream id
2829
2930
Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
3738
if it's new state.
3839
3940
Args:
40-
prev_stream_id (int): point to get changes since (exclusive)
41-
max_stream_id (int): the point that we know has been correctly persisted
41+
prev_stream_id: point to get changes since (exclusive)
42+
max_stream_id: the point that we know has been correctly persisted
4243
- ie, an upper limit to return changes from.
4344
4445
Returns:
45-
Deferred[tuple[int, list[dict]]: A tuple consisting of:
46+
A tuple consisting of:
4647
- the stream id which these results go up to
4748
- list of current_state_delta_stream rows. If it is empty, we are
4849
up to date.
@@ -58,7 +59,7 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
5859
# if the CSDs haven't changed between prev_stream_id and now, we
5960
# know for certain that they haven't changed between prev_stream_id and
6061
# max_stream_id.
61-
return defer.succeed((max_stream_id, []))
62+
return (max_stream_id, [])
6263

6364
def get_current_state_deltas_txn(txn):
6465
# First we calculate the max stream id that will give us less than
@@ -102,7 +103,7 @@ def get_current_state_deltas_txn(txn):
102103
txn.execute(sql, (prev_stream_id, clipped_stream_id))
103104
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
104105

105-
return self.db_pool.runInteraction(
106+
return await self.db_pool.runInteraction(
106107
"get_current_state_deltas", get_current_state_deltas_txn
107108
)
108109

@@ -114,8 +115,8 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
114115
retcol="COALESCE(MAX(stream_id), -1)",
115116
)
116117

117-
def get_max_stream_id_in_current_state_deltas(self):
118-
return self.db_pool.runInteraction(
118+
async def get_max_stream_id_in_current_state_deltas(self):
119+
return await self.db_pool.runInteraction(
119120
"get_max_stream_id_in_current_state_deltas",
120121
self._get_max_stream_id_in_current_state_deltas_txn,
121122
)

synapse/storage/databases/main/stream.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,16 +539,17 @@ async def get_recent_event_ids_for_room(
539539

540540
return rows, token
541541

542-
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
542+
async def get_room_event_before_stream_ordering(
543+
self, room_id: str, stream_ordering: int
544+
) -> Tuple[int, int, str]:
543545
"""Gets details of the first event in a room at or before a stream ordering
544546
545547
Args:
546548
room_id:
547549
stream_ordering:
548550
549551
Returns:
550-
Deferred[(int, int, str)]:
551-
(stream ordering, topological ordering, event_id)
552+
A tuple of (stream ordering, topological ordering, event_id)
552553
"""
553554

554555
def _f(txn):
@@ -563,7 +564,9 @@ def _f(txn):
563564
txn.execute(sql, (room_id, stream_ordering))
564565
return txn.fetchone()
565566

566-
return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
567+
return await self.db_pool.runInteraction(
568+
"get_room_event_before_stream_ordering", _f
569+
)
567570

568571
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
569572
"""Returns the current token for rooms stream.

synapse/storage/databases/state/store.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from collections import namedtuple
1818
from typing import Dict, Iterable, List, Set, Tuple
1919

20-
from twisted.internet import defer
21-
2220
from synapse.api.constants import EventTypes
2321
from synapse.storage._base import SQLBaseStore
2422
from synapse.storage.database import DatabasePool
@@ -103,7 +101,7 @@ def get_max_state_group_txn(txn: Cursor):
103101
)
104102

105103
@cached(max_entries=10000, iterable=True)
106-
def get_state_group_delta(self, state_group):
104+
async def get_state_group_delta(self, state_group):
107105
"""Given a state group try to return a previous group and a delta between
108106
the old and the new.
109107
@@ -135,7 +133,7 @@ def _get_state_group_delta_txn(txn):
135133
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
136134
)
137135

138-
return self.db_pool.runInteraction(
136+
return await self.db_pool.runInteraction(
139137
"get_state_group_delta", _get_state_group_delta_txn
140138
)
141139

@@ -367,9 +365,9 @@ def _insert_into_cache(
367365
fetched_keys=non_member_types,
368366
)
369367

370-
def store_state_group(
368+
async def store_state_group(
371369
self, event_id, room_id, prev_group, delta_ids, current_state_ids
372-
):
370+
) -> int:
373371
"""Store a new set of state, returning a newly assigned state group.
374372
375373
Args:
@@ -383,7 +381,7 @@ def store_state_group(
383381
to event_id.
384382
385383
Returns:
386-
Deferred[int]: The state group ID
384+
The state group ID
387385
"""
388386

389387
def _store_state_group_txn(txn):
@@ -484,11 +482,13 @@ def _store_state_group_txn(txn):
484482

485483
return state_group
486484

487-
return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
485+
return await self.db_pool.runInteraction(
486+
"store_state_group", _store_state_group_txn
487+
)
488488

489-
def purge_unreferenced_state_groups(
489+
async def purge_unreferenced_state_groups(
490490
self, room_id: str, state_groups_to_delete
491-
) -> defer.Deferred:
491+
) -> None:
492492
"""Deletes no longer referenced state groups and de-deltas any state
493493
groups that reference them.
494494
@@ -499,7 +499,7 @@ def purge_unreferenced_state_groups(
499499
to delete.
500500
"""
501501

502-
return self.db_pool.runInteraction(
502+
await self.db_pool.runInteraction(
503503
"purge_unreferenced_state_groups",
504504
self._purge_unreferenced_state_groups,
505505
room_id,
@@ -594,15 +594,15 @@ async def get_previous_state_groups(
594594

595595
return {row["state_group"]: row["prev_state_group"] for row in rows}
596596

597-
def purge_room_state(self, room_id, state_groups_to_delete):
597+
async def purge_room_state(self, room_id, state_groups_to_delete):
598598
"""Deletes all record of a room from state tables
599599
600600
Args:
601601
room_id (str):
602602
state_groups_to_delete (list[int]): State groups to delete
603603
"""
604604

605-
return self.db_pool.runInteraction(
605+
await self.db_pool.runInteraction(
606606
"purge_room_state",
607607
self._purge_room_state_txn,
608608
room_id,

synapse/storage/state.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,19 +333,19 @@ class StateGroupStorage(object):
333333
def __init__(self, hs, stores):
334334
self.stores = stores
335335

336-
def get_state_group_delta(self, state_group: int):
336+
async def get_state_group_delta(self, state_group: int):
337337
"""Given a state group try to return a previous group and a delta between
338338
the old and the new.
339339
340340
Args:
341341
state_group: The state group used to retrieve state deltas.
342342
343343
Returns:
344-
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
344+
Tuple[Optional[int], Optional[StateMap[str]]]:
345345
(prev_group, delta_ids)
346346
"""
347347

348-
return self.stores.state.get_state_group_delta(state_group)
348+
return await self.stores.state.get_state_group_delta(state_group)
349349

350350
async def get_state_groups_ids(
351351
self, _room_id: str, event_ids: Iterable[str]
@@ -525,7 +525,7 @@ async def get_state_ids_for_event(
525525
state_filter: The state filter used to fetch state from the database.
526526
527527
Returns:
528-
A deferred dict from (type, state_key) -> state_event
528+
A dict from (type, state_key) -> state_event
529529
"""
530530
state_map = await self.get_state_ids_for_events([event_id], state_filter)
531531
return state_map[event_id]
@@ -546,14 +546,14 @@ def _get_state_for_groups(
546546
"""
547547
return self.stores.state._get_state_for_groups(groups, state_filter)
548548

549-
def store_state_group(
549+
async def store_state_group(
550550
self,
551551
event_id: str,
552552
room_id: str,
553553
prev_group: Optional[int],
554554
delta_ids: Optional[dict],
555555
current_state_ids: dict,
556-
):
556+
) -> int:
557557
"""Store a new set of state, returning a newly assigned state group.
558558
559559
Args:
@@ -567,8 +567,8 @@ def store_state_group(
567567
to event_id.
568568
569569
Returns:
570-
Deferred[int]: The state group ID
570+
The state group ID
571571
"""
572-
return self.stores.state.store_state_group(
572+
return await self.stores.state.store_state_group(
573573
event_id, room_id, prev_group, delta_ids, current_state_ids
574574
)

0 commit comments

Comments
 (0)