From 4ebb03841ce2d69e8012acd6d67d7b0160c512a4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 22 Sep 2022 14:41:06 -0400 Subject: [PATCH] Return a dict from _get_receipts_by_room_txn. --- .../databases/main/event_push_actions.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 6b8668d2dcfe..dc28c725754e 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -559,7 +559,18 @@ def f(txn: LoggingTransaction) -> List[str]: def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> List[Tuple[str, int]]: + ) -> Dict[str, int]: + """ + Generate a map of room ID to the latest stream ordering that has been + read by the given user. + + Args: + txn: + user_id: The user to fetch receipts for. + + Returns: + A map of room ID to stream ordering for all rooms the user has a receipt in. + """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", @@ -580,7 +591,7 @@ def _get_receipts_by_room_txn( args.extend((user_id,)) txn.execute(sql, args) - return cast(List[Tuple[str, int]], txn.fetchall()) + return dict(cast(List[Tuple[str, int]], txn.fetchall())) async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -605,12 +616,10 @@ async def get_unread_push_actions_for_user_in_range_for_http( The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( @@ -679,12 +688,10 @@ async def get_unread_push_actions_for_user_in_range_for_email( The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn(