Skip to content

Commit 03937a1

Browse files
Sliding Sync: Return room tags in account data extension (#17707)
The account data extension was also updated to avoid copies when we pull the data out of the cache. Fix #17694
1 parent 285de43 commit 03937a1

File tree

5 files changed

+226
-65
lines changed

5 files changed

+226
-65
lines changed

changelog.d/17707.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Return room tags in Sliding Sync account data extension.

synapse/handlers/sliding_sync/extensions.py

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414

1515
import itertools
1616
import logging
17-
from typing import TYPE_CHECKING, AbstractSet, Dict, Mapping, Optional, Sequence, Set
17+
from typing import (
18+
TYPE_CHECKING,
19+
AbstractSet,
20+
ChainMap,
21+
Dict,
22+
List,
23+
Mapping,
24+
MutableMapping,
25+
Optional,
26+
Sequence,
27+
Set,
28+
cast,
29+
)
1830

1931
from typing_extensions import assert_never
2032

@@ -381,29 +393,47 @@ async def get_account_data_extension_response(
381393
)
382394
)
383395

396+
# TODO: This should take into account the `from_token` and `to_token`
384397
have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
385398
user_id, from_token.stream_token.push_rules_key
386399
)
387400
if have_push_rules_changed:
388-
global_account_data_map = dict(global_account_data_map)
389401
# TODO: This should take into account the `from_token` and `to_token`
390402
global_account_data_map[
391403
AccountDataTypes.PUSH_RULES
392404
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
393405
else:
394406
# TODO: This should take into account the `to_token`
395-
all_global_account_data = await self.store.get_global_account_data_for_user(
396-
user_id
407+
immutable_global_account_data_map = (
408+
await self.store.get_global_account_data_for_user(user_id)
397409
)
398410

399-
global_account_data_map = dict(all_global_account_data)
400-
# TODO: This should take into account the `to_token`
401-
global_account_data_map[
402-
AccountDataTypes.PUSH_RULES
403-
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
411+
# Use a `ChainMap` to avoid copying the immutable data from the cache
412+
global_account_data_map = ChainMap(
413+
{
414+
# TODO: This should take into account the `to_token`
415+
AccountDataTypes.PUSH_RULES: await self.push_rules_handler.push_rules_for_user(
416+
sync_config.user
417+
)
418+
},
419+
# Cast is safe because `ChainMap` only mutates the top-most map,
420+
# see https://github.com/python/typeshed/issues/8430
421+
cast(
422+
MutableMapping[str, JsonMapping], immutable_global_account_data_map
423+
),
424+
)
404425

405426
# Fetch room account data
406-
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
427+
#
428+
# List of -> Mapping from room_id to mapping of `type` to `content` of room
429+
# account data events.
430+
#
431+
# This is is a list so we can avoid making copies of immutable data and instead
432+
# just provide multiple maps that need to be combined. Normally, we could
433+
# reach for `ChainMap` in this scenario, but this is a nested map and accessing
434+
# the ChainMap by room_id won't combine the two maps for that room (we would
435+
# need a new `NestedChainMap` type class).
436+
account_data_by_room_maps: List[Mapping[str, Mapping[str, JsonMapping]]] = []
407437
relevant_room_ids = self.find_relevant_room_ids_for_extension(
408438
requested_lists=account_data_request.lists,
409439
requested_room_ids=account_data_request.rooms,
@@ -418,22 +448,66 @@ async def get_account_data_extension_response(
418448
user_id, from_token.stream_token.account_data_key
419449
)
420450
)
451+
452+
# Add room tags
453+
#
454+
# TODO: This should take into account the `from_token` and `to_token`
455+
tags_by_room = await self.store.get_updated_tags(
456+
user_id, from_token.stream_token.account_data_key
457+
)
458+
for room_id, tags in tags_by_room.items():
459+
account_data_by_room_map.setdefault(room_id, {})[
460+
AccountDataTypes.TAG
461+
] = {"tags": tags}
462+
463+
account_data_by_room_maps.append(account_data_by_room_map)
421464
else:
422465
# TODO: This should take into account the `to_token`
423-
account_data_by_room_map = (
466+
immutable_account_data_by_room_map = (
424467
await self.store.get_room_account_data_for_user(user_id)
425468
)
469+
account_data_by_room_maps.append(immutable_account_data_by_room_map)
426470

427-
# Filter down to the relevant rooms
428-
account_data_by_room_map = {
429-
room_id: account_data_map
430-
for room_id, account_data_map in account_data_by_room_map.items()
431-
if room_id in relevant_room_ids
432-
}
471+
# Add room tags
472+
#
473+
# TODO: This should take into account the `to_token`
474+
tags_by_room = await self.store.get_tags_for_user(user_id)
475+
account_data_by_room_maps.append(
476+
{
477+
room_id: {AccountDataTypes.TAG: {"tags": tags}}
478+
for room_id, tags in tags_by_room.items()
479+
}
480+
)
481+
482+
# Filter down to the relevant rooms ... and combine the maps
483+
relevant_account_data_by_room_map: MutableMapping[
484+
str, Mapping[str, JsonMapping]
485+
] = {}
486+
for room_id in relevant_room_ids:
487+
# We want to avoid adding empty maps for relevant rooms that have no room
488+
# account data so do a quick check to see if it's in any of the maps.
489+
is_room_in_maps = False
490+
for room_map in account_data_by_room_maps:
491+
if room_id in room_map:
492+
is_room_in_maps = True
493+
break
494+
495+
# If we found the room in any of the maps, combine the maps for that room
496+
if is_room_in_maps:
497+
relevant_account_data_by_room_map[room_id] = ChainMap(
498+
{},
499+
*(
500+
# Cast is safe because `ChainMap` only mutates the top-most map,
501+
# see https://github.com/python/typeshed/issues/8430
502+
cast(MutableMapping[str, JsonMapping], room_map[room_id])
503+
for room_map in account_data_by_room_maps
504+
if room_map.get(room_id)
505+
),
506+
)
433507

434508
return SlidingSyncResult.Extensions.AccountDataExtension(
435509
global_account_data_map=global_account_data_map,
436-
account_data_by_room_map=account_data_by_room_map,
510+
account_data_by_room_map=relevant_account_data_by_room_map,
437511
)
438512

439513
@trace

synapse/storage/databases/main/account_data.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ async def get_room_account_data_for_user(
177177

178178
def get_room_account_data_for_user_txn(
179179
txn: LoggingTransaction,
180-
) -> Dict[str, Dict[str, JsonDict]]:
180+
) -> Dict[str, Dict[str, JsonMapping]]:
181181
# The 'content != '{}' condition below prevents us from using
182182
# `simple_select_list_txn` here, as it doesn't support conditions
183183
# other than 'equals'.
@@ -194,7 +194,7 @@ def get_room_account_data_for_user_txn(
194194

195195
txn.execute(sql, (user_id,))
196196

197-
by_room: Dict[str, Dict[str, JsonDict]] = {}
197+
by_room: Dict[str, Dict[str, JsonMapping]] = {}
198198
for room_id, account_data_type, content in txn:
199199
room_data = by_room.setdefault(room_id, {})
200200

@@ -394,7 +394,7 @@ def get_updated_room_account_data_txn(
394394

395395
async def get_updated_global_account_data_for_user(
396396
self, user_id: str, stream_id: int
397-
) -> Mapping[str, JsonMapping]:
397+
) -> Dict[str, JsonMapping]:
398398
"""Get all the global account_data that's changed for a user.
399399
400400
Args:
@@ -407,7 +407,7 @@ async def get_updated_global_account_data_for_user(
407407

408408
def get_updated_global_account_data_for_user(
409409
txn: LoggingTransaction,
410-
) -> Dict[str, JsonDict]:
410+
) -> Dict[str, JsonMapping]:
411411
sql = """
412412
SELECT account_data_type, content FROM account_data
413413
WHERE user_id = ? AND stream_id > ?
@@ -429,7 +429,7 @@ def get_updated_global_account_data_for_user(
429429

430430
async def get_updated_room_account_data_for_user(
431431
self, user_id: str, stream_id: int
432-
) -> Dict[str, Dict[str, JsonDict]]:
432+
) -> Dict[str, Dict[str, JsonMapping]]:
433433
"""Get all the room account_data that's changed for a user.
434434
435435
Args:
@@ -442,14 +442,14 @@ async def get_updated_room_account_data_for_user(
442442

443443
def get_updated_room_account_data_for_user_txn(
444444
txn: LoggingTransaction,
445-
) -> Dict[str, Dict[str, JsonDict]]:
445+
) -> Dict[str, Dict[str, JsonMapping]]:
446446
sql = """
447447
SELECT room_id, account_data_type, content FROM room_account_data
448448
WHERE user_id = ? AND stream_id > ?
449449
"""
450450
txn.execute(sql, (user_id, stream_id))
451451

452-
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
452+
account_data_by_room: Dict[str, Dict[str, JsonMapping]] = {}
453453
for row in txn:
454454
room_account_data = account_data_by_room.setdefault(row[0], {})
455455
room_account_data[row[1]] = db_to_json(row[2])

synapse/types/handlers/sliding_sync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ class AccountDataExtension:
314314
"""The Account Data extension (MSC3959)
315315
316316
Attributes:
317-
global_account_data_map: Mapping from `type` to `content` of global account
318-
data events.
317+
global_account_data_map: Mapping from `type` to `content` of global
318+
account data events.
319319
account_data_by_room_map: Mapping from room_id to mapping of `type` to
320320
`content` of room account data events.
321321
"""

0 commit comments

Comments
 (0)