1414
1515import itertools
1616import logging
17- from typing import TYPE_CHECKING , AbstractSet , Dict , Mapping , Optional , Sequence , Set
17+ from typing import (
18+ TYPE_CHECKING ,
19+ AbstractSet ,
20+ ChainMap ,
21+ Dict ,
22+ List ,
23+ Mapping ,
24+ MutableMapping ,
25+ Optional ,
26+ Sequence ,
27+ Set ,
28+ cast ,
29+ )
1830
1931from typing_extensions import assert_never
2032
@@ -381,29 +393,47 @@ async def get_account_data_extension_response(
381393 )
382394 )
383395
396+ # TODO: This should take into account the `from_token` and `to_token`
384397 have_push_rules_changed = await self .store .have_push_rules_changed_for_user (
385398 user_id , from_token .stream_token .push_rules_key
386399 )
387400 if have_push_rules_changed :
388- global_account_data_map = dict (global_account_data_map )
389401 # TODO: This should take into account the `from_token` and `to_token`
390402 global_account_data_map [
391403 AccountDataTypes .PUSH_RULES
392404 ] = await self .push_rules_handler .push_rules_for_user (sync_config .user )
393405 else :
394406 # TODO: This should take into account the `to_token`
395- all_global_account_data = await self . store . get_global_account_data_for_user (
396- user_id
407+ immutable_global_account_data_map = (
408+ await self . store . get_global_account_data_for_user ( user_id )
397409 )
398410
399- global_account_data_map = dict (all_global_account_data )
400- # TODO: This should take into account the `to_token`
401- global_account_data_map [
402- AccountDataTypes .PUSH_RULES
403- ] = await self .push_rules_handler .push_rules_for_user (sync_config .user )
411+ # Use a `ChainMap` to avoid copying the immutable data from the cache
412+ global_account_data_map = ChainMap (
413+ {
414+ # TODO: This should take into account the `to_token`
415+ AccountDataTypes .PUSH_RULES : await self .push_rules_handler .push_rules_for_user (
416+ sync_config .user
417+ )
418+ },
419+ # Cast is safe because `ChainMap` only mutates the top-most map,
420+ # see https://github.com/python/typeshed/issues/8430
421+ cast (
422+ MutableMapping [str , JsonMapping ], immutable_global_account_data_map
423+ ),
424+ )
404425
405426 # Fetch room account data
406- account_data_by_room_map : Mapping [str , Mapping [str , JsonMapping ]] = {}
427+ #
428+ # List of -> Mapping from room_id to mapping of `type` to `content` of room
429+ # account data events.
430+ #
431+ # This is is a list so we can avoid making copies of immutable data and instead
432+ # just provide multiple maps that need to be combined. Normally, we could
433+ # reach for `ChainMap` in this scenario, but this is a nested map and accessing
434+ # the ChainMap by room_id won't combine the two maps for that room (we would
435+ # need a new `NestedChainMap` type class).
436+ account_data_by_room_maps : List [Mapping [str , Mapping [str , JsonMapping ]]] = []
407437 relevant_room_ids = self .find_relevant_room_ids_for_extension (
408438 requested_lists = account_data_request .lists ,
409439 requested_room_ids = account_data_request .rooms ,
@@ -418,22 +448,66 @@ async def get_account_data_extension_response(
418448 user_id , from_token .stream_token .account_data_key
419449 )
420450 )
451+
452+ # Add room tags
453+ #
454+ # TODO: This should take into account the `from_token` and `to_token`
455+ tags_by_room = await self .store .get_updated_tags (
456+ user_id , from_token .stream_token .account_data_key
457+ )
458+ for room_id , tags in tags_by_room .items ():
459+ account_data_by_room_map .setdefault (room_id , {})[
460+ AccountDataTypes .TAG
461+ ] = {"tags" : tags }
462+
463+ account_data_by_room_maps .append (account_data_by_room_map )
421464 else :
422465 # TODO: This should take into account the `to_token`
423- account_data_by_room_map = (
466+ immutable_account_data_by_room_map = (
424467 await self .store .get_room_account_data_for_user (user_id )
425468 )
469+ account_data_by_room_maps .append (immutable_account_data_by_room_map )
426470
427- # Filter down to the relevant rooms
428- account_data_by_room_map = {
429- room_id : account_data_map
430- for room_id , account_data_map in account_data_by_room_map .items ()
431- if room_id in relevant_room_ids
432- }
471+ # Add room tags
472+ #
473+ # TODO: This should take into account the `to_token`
474+ tags_by_room = await self .store .get_tags_for_user (user_id )
475+ account_data_by_room_maps .append (
476+ {
477+ room_id : {AccountDataTypes .TAG : {"tags" : tags }}
478+ for room_id , tags in tags_by_room .items ()
479+ }
480+ )
481+
482+ # Filter down to the relevant rooms ... and combine the maps
483+ relevant_account_data_by_room_map : MutableMapping [
484+ str , Mapping [str , JsonMapping ]
485+ ] = {}
486+ for room_id in relevant_room_ids :
487+ # We want to avoid adding empty maps for relevant rooms that have no room
488+ # account data so do a quick check to see if it's in any of the maps.
489+ is_room_in_maps = False
490+ for room_map in account_data_by_room_maps :
491+ if room_id in room_map :
492+ is_room_in_maps = True
493+ break
494+
495+ # If we found the room in any of the maps, combine the maps for that room
496+ if is_room_in_maps :
497+ relevant_account_data_by_room_map [room_id ] = ChainMap (
498+ {},
499+ * (
500+ # Cast is safe because `ChainMap` only mutates the top-most map,
501+ # see https://github.com/python/typeshed/issues/8430
502+ cast (MutableMapping [str , JsonMapping ], room_map [room_id ])
503+ for room_map in account_data_by_room_maps
504+ if room_map .get (room_id )
505+ ),
506+ )
433507
434508 return SlidingSyncResult .Extensions .AccountDataExtension (
435509 global_account_data_map = global_account_data_map ,
436- account_data_by_room_map = account_data_by_room_map ,
510+ account_data_by_room_map = relevant_account_data_by_room_map ,
437511 )
438512
439513 @trace
0 commit comments