|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | import logging
|
| 16 | +from collections import defaultdict |
16 | 17 | from typing import (
|
17 | 18 | TYPE_CHECKING,
|
18 | 19 | Collection,
|
@@ -681,38 +682,27 @@ async def get_rooms_for_users(
|
681 | 682 | Returns:
|
682 | 683 | Map from user_id to set of rooms that is currently in.
|
683 | 684 | """
|
684 |
| - return await self.db_pool.runInteraction( |
685 |
| - "get_rooms_for_users", |
686 |
| - self._get_rooms_for_users_txn, |
687 |
| - user_ids, |
688 |
| - ) |
689 |
| - |
690 |
| - def _get_rooms_for_users_txn( |
691 |
| - self, txn: LoggingTransaction, user_ids: Collection[str] |
692 |
| - ) -> Dict[str, FrozenSet[str]]: |
693 | 685 |
|
694 |
| - clause, args = make_in_list_sql_clause( |
695 |
| - self.database_engine, |
696 |
| - "c.state_key", |
697 |
| - user_ids, |
| 686 | + rows = await self.db_pool.simple_select_many_batch( |
| 687 | + table="current_state_events", |
| 688 | + column="state_key", |
| 689 | + iterable=user_ids, |
| 690 | + retcols=( |
| 691 | + "user_id", |
| 692 | + "room_id", |
| 693 | + ), |
| 694 | + keyvalues={ |
| 695 | + "type": EventTypes.Member, |
| 696 | + "membership": Membership.JOIN, |
| 697 | + }, |
| 698 | + desc="get_rooms_for_users", |
698 | 699 | )
|
699 | 700 |
|
700 |
| - sql = f""" |
701 |
| - SELECT c.state_key, room_id |
702 |
| - FROM current_state_events AS c |
703 |
| - WHERE |
704 |
| - c.type = 'm.room.member' |
705 |
| - AND c.membership = ? |
706 |
| - AND {clause} |
707 |
| - """ |
708 |
| - |
709 |
| - txn.execute(sql, [Membership.JOIN] + args) |
710 |
| - |
711 |
| - result: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} |
712 |
| - for user_id, room_id in txn: |
713 |
| - result[user_id].add(room_id) |
| 701 | + user_rooms: Dict[str, Set[str]] = defaultdict(set) |
| 702 | + for row in rows: |
| 703 | + user_rooms[row["user_id"]].add(row["room_id"]) |
714 | 704 |
|
715 |
| - return {user_id: frozenset(v) for user_id, v in result.items()} |
| 705 | + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} |
716 | 706 |
|
717 | 707 | @cached(max_entries=10000)
|
718 | 708 | async def does_pair_of_users_share_a_room(
|
|
0 commit comments