Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 68acb0a

Browse files
authored
Include whether the requesting user has participated in a thread. (#11577)
Per updates to MSC3440. This is implement as a separate method since it needs to be cached on a per-user basis, instead of a per-thread basis.
1 parent 251b556 commit 68acb0a

File tree

9 files changed

+85
-18
lines changed

9 files changed

+85
-18
lines changed

changelog.d/11577.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).

synapse/handlers/pagination.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ async def get_messages(
537537
state_dict = await self.store.get_events(list(state_ids.values()))
538538
state = state_dict.values()
539539

540-
aggregations = await self.store.get_bundled_aggregations(events)
540+
aggregations = await self.store.get_bundled_aggregations(events, user_id)
541541

542542
time_now = self.clock.time_msec()
543543

synapse/handlers/room.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,12 +1182,18 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
11821182
results["event"] = filtered[0]
11831183

11841184
# Fetch the aggregations.
1185-
aggregations = await self.store.get_bundled_aggregations([results["event"]])
1185+
aggregations = await self.store.get_bundled_aggregations(
1186+
[results["event"]], user.to_string()
1187+
)
11861188
aggregations.update(
1187-
await self.store.get_bundled_aggregations(results["events_before"])
1189+
await self.store.get_bundled_aggregations(
1190+
results["events_before"], user.to_string()
1191+
)
11881192
)
11891193
aggregations.update(
1190-
await self.store.get_bundled_aggregations(results["events_after"])
1194+
await self.store.get_bundled_aggregations(
1195+
results["events_after"], user.to_string()
1196+
)
11911197
)
11921198
results["aggregations"] = aggregations
11931199

synapse/handlers/sync.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,9 @@ async def _load_filtered_recents(
637637
# as clients will have all the necessary information.
638638
bundled_aggregations = None
639639
if limited or newly_joined_room:
640-
bundled_aggregations = await self.store.get_bundled_aggregations(recents)
640+
bundled_aggregations = await self.store.get_bundled_aggregations(
641+
recents, sync_config.user.to_string()
642+
)
641643

642644
return TimelineBatch(
643645
events=recents,

synapse/rest/client/relations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ async def on_GET(
118118
)
119119
# The relations returned for the requested event do include their
120120
# bundled aggregations.
121-
aggregations = await self.store.get_bundled_aggregations(events)
121+
aggregations = await self.store.get_bundled_aggregations(
122+
events, requester.user.to_string()
123+
)
122124
serialized_events = self._event_serializer.serialize_events(
123125
events, now, bundle_aggregations=aggregations
124126
)

synapse/rest/client/room.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,9 @@ async def on_GET(
663663

664664
if event:
665665
# Ensure there are bundled aggregations available.
666-
aggregations = await self._store.get_bundled_aggregations([event])
666+
aggregations = await self._store.get_bundled_aggregations(
667+
[event], requester.user.to_string()
668+
)
667669

668670
time_now = self.clock.time_msec()
669671
event_dict = self._event_serializer.serialize_event(

synapse/storage/databases/main/events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,13 @@ def _handle_event_relations(
17931793
txn.call_after(
17941794
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
17951795
)
1796+
# It should be safe to only invalidate the cache if the user has not
1797+
# previously participated in the thread, but that's difficult (and
1798+
# potentially error-prone) so it is always invalidated.
1799+
txn.call_after(
1800+
self.store.get_thread_participated.invalidate,
1801+
(parent_id, event.room_id, event.sender),
1802+
)
17961803

17971804
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
17981805
"""Handles keeping track of insertion events and edges/connections.

synapse/storage/databases/main/relations.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,7 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
384384
async def get_thread_summary(
385385
self, event_id: str, room_id: str
386386
) -> Tuple[int, Optional[EventBase]]:
387-
"""Get the number of threaded replies, the senders of those replies, and
388-
the latest reply (if any) for the given event.
387+
"""Get the number of threaded replies and the latest reply (if any) for the given event.
389388
390389
Args:
391390
event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ async def get_thread_summary(
398397
def _get_thread_summary_txn(
399398
txn: LoggingTransaction,
400399
) -> Tuple[int, Optional[str]]:
401-
# Fetch the count of threaded events and the latest event ID.
400+
# Fetch the latest event ID in the thread.
402401
# TODO Should this only allow m.room.message events.
403402
sql = """
404403
SELECT event_id
@@ -419,6 +418,7 @@ def _get_thread_summary_txn(
419418

420419
latest_event_id = row[0]
421420

421+
# Fetch the number of threaded replies.
422422
sql = """
423423
SELECT COUNT(event_id)
424424
FROM event_relations
@@ -443,6 +443,44 @@ def _get_thread_summary_txn(
443443

444444
return count, latest_event
445445

446+
@cached()
447+
async def get_thread_participated(
448+
self, event_id: str, room_id: str, user_id: str
449+
) -> bool:
450+
"""Get whether the requesting user participated in a thread.
451+
452+
This is separate from get_thread_summary since that can be cached across
453+
all users while this value is specific to the requeser.
454+
455+
Args:
456+
event_id: The thread related to this event ID.
457+
room_id: The room the event belongs to.
458+
user_id: The user requesting the summary.
459+
460+
Returns:
461+
True if the requesting user participated in the thread, otherwise false.
462+
"""
463+
464+
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
465+
# Fetch whether the requester has participated or not.
466+
sql = """
467+
SELECT 1
468+
FROM event_relations
469+
INNER JOIN events USING (event_id)
470+
WHERE
471+
relates_to_id = ?
472+
AND room_id = ?
473+
AND relation_type = ?
474+
AND sender = ?
475+
"""
476+
477+
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
478+
return bool(txn.fetchone())
479+
480+
return await self.db_pool.runInteraction(
481+
"get_thread_summary", _get_thread_summary_txn
482+
)
483+
446484
async def events_have_relations(
447485
self,
448486
parent_ids: List[str],
@@ -546,14 +584,15 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
546584
)
547585

548586
async def _get_bundled_aggregation_for_event(
549-
self, event: EventBase
587+
self, event: EventBase, user_id: str
550588
) -> Optional[Dict[str, Any]]:
551589
"""Generate bundled aggregations for an event.
552590
553591
Note that this does not use a cache, but depends on cached methods.
554592
555593
Args:
556594
event: The event to calculate bundled aggregations for.
595+
user_id: The user requesting the bundled aggregations.
557596
558597
Returns:
559598
The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ async def _get_bundled_aggregation_for_event(
598637

599638
# If this event is the start of a thread, include a summary of the replies.
600639
if self._msc3440_enabled:
601-
(
602-
thread_count,
603-
latest_thread_event,
604-
) = await self.get_thread_summary(event_id, room_id)
640+
thread_count, latest_thread_event = await self.get_thread_summary(
641+
event_id, room_id
642+
)
643+
participated = await self.get_thread_participated(
644+
event_id, room_id, user_id
645+
)
605646
if latest_thread_event:
606647
aggregations[RelationTypes.THREAD] = {
607-
# Don't bundle aggregations as this could recurse forever.
608648
"latest_event": latest_thread_event,
609649
"count": thread_count,
650+
"current_user_participated": participated,
610651
}
611652

612653
# Store the bundled aggregations in the event metadata for later use.
613654
return aggregations
614655

615656
async def get_bundled_aggregations(
616-
self, events: Iterable[EventBase]
657+
self,
658+
events: Iterable[EventBase],
659+
user_id: str,
617660
) -> Dict[str, Dict[str, Any]]:
618661
"""Generate bundled aggregations for events.
619662
620663
Args:
621664
events: The iterable of events to calculate bundled aggregations for.
665+
user_id: The user requesting the bundled aggregations.
622666
623667
Returns:
624668
A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ async def get_bundled_aggregations(
631675
# TODO Parallelize.
632676
results = {}
633677
for event in events:
634-
event_result = await self._get_bundled_aggregation_for_event(event)
678+
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
635679
if event_result is not None:
636680
results[event.event_id] = event_result
637681

tests/rest/client/test_relations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@ def assert_bundle(actual):
515515
2,
516516
actual[RelationTypes.THREAD].get("count"),
517517
)
518+
self.assertTrue(
519+
actual[RelationTypes.THREAD].get("current_user_participated")
520+
)
518521
# The latest thread event has some fields that don't matter.
519522
self.assert_dict(
520523
{

0 commit comments

Comments
 (0)