Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Directly lookup local membership instead of getting all members in a room first (get_users_in_room mis-use) #13608

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13608.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function.
9 changes: 6 additions & 3 deletions synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def get_event(
"""Retrieve a single specified event.

Args:
user: The user requesting the event
user: The local user requesting the event
room_id: The expected room id. We'll return None if the
event's room does not match.
event_id: The event ID to obtain.
Expand All @@ -173,8 +173,11 @@ async def get_event(
if not event:
return None

users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users
is_user_in_room = await self.store.check_local_user_in_room(
user_id=user.to_string(), room_id=event.room_id
)
# The user is peeking if they aren't in the room already
is_peeking = not is_user_in_room

filtered = await filter_events_for_client(
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
Expand Down
6 changes: 4 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,10 @@ async def _is_exempt_from_privacy_policy(
async def _is_server_notices_room(self, room_id: str) -> bool:
if self.config.servernotices.server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
return self.config.servernotices.server_notices_mxid in user_ids
is_server_notices_room = await self.store.check_local_user_in_room(
user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
)
return is_server_notices_room
Comment on lines 761 to +767
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several of these is_server_notices_room kind of functions but this PR doesn't deduplicate them.


async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy
Expand Down
7 changes: 5 additions & 2 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,11 @@ async def get_event_context(
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit

users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
is_user_in_room = await self.store.check_local_user_in_room(
user_id=user.to_string(), room_id=room_id
)
# The user is peeking if they aren't in the room already
is_peeking = not is_user_in_room

async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge:
Expand Down
6 changes: 4 additions & 2 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,8 +1620,10 @@ async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
return self._server_notices_mxid in user_ids
is_server_notices_room = await self.store.check_local_user_in_room(
user_id=self._server_notices_mxid, room_id=room_id
)
return is_server_notices_room


class RoomMemberMasterHandler(RoomMemberHandler):
Expand Down
10 changes: 8 additions & 2 deletions synapse/server_notices/server_notices_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]:
Returns:
The room's ID, or None if no room could be found.
"""
# If there is no server notices MXID, then there is no server notices room
if self.server_notices_mxid is None:
return None

rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
Expand All @@ -111,8 +115,10 @@ async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]:
# be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
user_ids = await self._store.get_users_in_room(room.room_id)
if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
is_server_notices_room = await self._store.check_local_user_in_room(
user_id=self.server_notices_mxid, room_id=room.room_id
)
if is_server_notices_room:
# we found a room which our user shares with the system notice
# user
return room.room_id
Expand Down
26 changes: 26 additions & 0 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,32 @@ async def get_local_users_in_room(self, room_id: str) -> List[str]:
desc="get_local_users_in_room",
)

async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to check_user_in_room that is in synapse/api/auth.py but that one has some specific logic and can check for leave so I decided to leave it as-is.

async def check_user_in_room(
self,
room_id: str,
requester: Requester,
allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
requester: The user making the request, according to the access token.
current_state: Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
allow_departed_users: if True, accept users that were previously
members but have now departed.
Raises:
AuthError if the user is/was not in the room.
Returns:
The current membership of the user in the room and the
membership event ID of the user.
"""
user_id = requester.user.to_string()
(
membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN:
return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return membership, member_event_id
raise UnstableSpecAuthError(
403,
"User %s not in room %s" % (user_id, room_id),
errcode=Codes.NOT_JOINED,
)

"""
Check whether a given local user is currently joined to the given room.

Returns:
A boolean indicating whether the user is currently joined to the room

Raises:
Exeption when called with a non-local user to this homeserver
"""
if not self.hs.is_mine_id(user_id):
raise Exception(
"Cannot call 'check_local_user_in_room' on "
"non-local user %s" % (user_id,),
)

(
membership,
member_event_id,
) = await self.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)

return membership == Membership.JOIN

async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
Expand Down
12 changes: 6 additions & 6 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def assert_annotations(bundled_aggregations: JsonDict) -> None:
bundled_aggregations,
)

self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is the difference between cached calls of get_users_in_room


def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def assert_annotations(bundled_aggregations: JsonDict) -> None:
bundled_aggregations,
)

self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)

def test_thread(self) -> None:
"""
Expand Down Expand Up @@ -1080,21 +1080,21 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:

# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)

# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)

def test_thread_with_bundled_aggregations_for_latest(self) -> None:
Expand Down Expand Up @@ -1142,7 +1142,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
bundled_aggregations["latest_event"].get("unsigned"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)

def test_nested_thread(self) -> None:
"""
Expand Down