Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3816, consider the root event for thread participation. (#…
Browse files Browse the repository at this point in the history
…12766)

As opposed to only considering a user to have "participated" if they
replied to the thread.
  • Loading branch information
clokep authored Jun 6, 2022
1 parent fcd8703 commit 1acc897
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/12766.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
58 changes: 37 additions & 21 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple

import attr

Expand Down Expand Up @@ -256,13 +247,19 @@ async def get_annotations_for_event(

return filtered_results

async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
relations_by_id: Dict[str, str],
user_id: str,
ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
event_ids: Events to get aggregations for threads.
events_by_id: A map of event_id to events to get aggregations for threads.
relations_by_id: A map of event_id to the relation type, if one exists
for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Expand All @@ -273,16 +270,34 @@ async def get_threads_for_events(
"""
user = UserID.from_string(user_id)

# It is not valid to start a thread on an event which itself relates to another event.
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]

# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)

# Only fetch participated for a limited selection based on what had
# summaries.
# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id

# Pre-seed thread participation with whether the requester sent the event.
participated = {
event_id: events_by_id[event_id].sender == user_id
for event_id in thread_event_ids
}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[
event_id
for event_id in thread_event_ids
if not participated[event_id]
],
user_id,
)
)

# Then subtract off the results for any ignored users.
Expand Down Expand Up @@ -343,7 +358,8 @@ async def get_threads_for_events(
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
current_user_participated=events_by_id[event_id].sender == user_id
or participated[event_id],
)

return results
Expand Down Expand Up @@ -401,9 +417,9 @@ async def get_bundled_aggregations(
# events to be fetched. Thus, we check those first!

# Fetch thread summaries (but only for the directly requested events).
threads = await self.get_threads_for_events(
# It is not valid to start a thread on an event which itself relates to another event.
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
threads = await self._get_threads_for_events(
events_by_id,
relations_by_id,
user_id,
ignored_users,
)
Expand Down
85 changes: 59 additions & 26 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def _test_bundled_aggregations(
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
Expand All @@ -907,7 +908,9 @@ def _test_bundled_aggregations(
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
access_token: The access token to user, defaults to self.user_token.
"""
access_token = access_token or self.user_token

def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
Expand All @@ -921,7 +924,7 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
Expand All @@ -932,7 +935,7 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
Expand All @@ -941,15 +944,15 @@ def assert_bundle(event_json: JsonDict) -> None:
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])

# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
"GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
Expand All @@ -962,7 +965,7 @@ def assert_bundle(event_json: JsonDict) -> None:
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
Expand Down Expand Up @@ -1037,30 +1040,60 @@ def test_thread(self) -> None:
"""
Test that threads get correctly bundled.
"""
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
# The root message is from "user", send replies as "user2".
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
channel = self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
thread_2 = channel.json_body["event_id"]

def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
# This needs two assertion functions which are identical except for whether
# the current_user_participated flag is True, create a factory for the
# two versions.
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertEqual(
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
},
"event_id": thread_2,
"sender": self.user2_id,
"type": "m.room.test",
},
"event_id": thread_2,
"sender": self.user_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
)
bundled_aggregations.get("latest_event"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
return assert_thread

# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
)

# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
)

def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
Expand Down Expand Up @@ -1106,7 +1139,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
bundled_aggregations["latest_event"].get("unsigned"),
)

self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)

def test_nested_thread(self) -> None:
"""
Expand Down

0 comments on commit 1acc897

Please sign in to comment.