@@ -384,8 +384,7 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
384384 async def get_thread_summary (
385385 self , event_id : str , room_id : str
386386 ) -> Tuple [int , Optional [EventBase ]]:
387- """Get the number of threaded replies, the senders of those replies, and
388- the latest reply (if any) for the given event.
387+ """Get the number of threaded replies and the latest reply (if any) for the given event.
389388
390389 Args:
391390 event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ async def get_thread_summary(
398397 def _get_thread_summary_txn (
399398 txn : LoggingTransaction ,
400399 ) -> Tuple [int , Optional [str ]]:
401- # Fetch the count of threaded events and the latest event ID .
400+ # Fetch the latest event ID in the thread .
402401 # TODO Should this only allow m.room.message events.
403402 sql = """
404403 SELECT event_id
@@ -419,6 +418,7 @@ def _get_thread_summary_txn(
419418
420419 latest_event_id = row [0 ]
421420
421+ # Fetch the number of threaded replies.
422422 sql = """
423423 SELECT COUNT(event_id)
424424 FROM event_relations
@@ -443,6 +443,44 @@ def _get_thread_summary_txn(
443443
444444 return count , latest_event
445445
446+ @cached ()
447+ async def get_thread_participated (
448+ self , event_id : str , room_id : str , user_id : str
449+ ) -> bool :
450+ """Get whether the requesting user participated in a thread.
451+
452+ This is separate from get_thread_summary since that can be cached across
453+ all users while this value is specific to the requeser.
454+
455+ Args:
456+ event_id: The thread related to this event ID.
457+ room_id: The room the event belongs to.
458+ user_id: The user requesting the summary.
459+
460+ Returns:
461+ True if the requesting user participated in the thread, otherwise false.
462+ """
463+
464+ def _get_thread_summary_txn (txn : LoggingTransaction ) -> bool :
465+ # Fetch whether the requester has participated or not.
466+ sql = """
467+ SELECT 1
468+ FROM event_relations
469+ INNER JOIN events USING (event_id)
470+ WHERE
471+ relates_to_id = ?
472+ AND room_id = ?
473+ AND relation_type = ?
474+ AND sender = ?
475+ """
476+
477+ txn .execute (sql , (event_id , room_id , RelationTypes .THREAD , user_id ))
478+ return bool (txn .fetchone ())
479+
480+ return await self .db_pool .runInteraction (
481+ "get_thread_summary" , _get_thread_summary_txn
482+ )
483+
446484 async def events_have_relations (
447485 self ,
448486 parent_ids : List [str ],
@@ -546,14 +584,15 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
546584 )
547585
548586 async def _get_bundled_aggregation_for_event (
549- self , event : EventBase
587+ self , event : EventBase , user_id : str
550588 ) -> Optional [Dict [str , Any ]]:
551589 """Generate bundled aggregations for an event.
552590
553591 Note that this does not use a cache, but depends on cached methods.
554592
555593 Args:
556594 event: The event to calculate bundled aggregations for.
595+ user_id: The user requesting the bundled aggregations.
557596
558597 Returns:
559598 The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ async def _get_bundled_aggregation_for_event(
598637
599638 # If this event is the start of a thread, include a summary of the replies.
600639 if self ._msc3440_enabled :
601- (
602- thread_count ,
603- latest_thread_event ,
604- ) = await self .get_thread_summary (event_id , room_id )
640+ thread_count , latest_thread_event = await self .get_thread_summary (
641+ event_id , room_id
642+ )
643+ participated = await self .get_thread_participated (
644+ event_id , room_id , user_id
645+ )
605646 if latest_thread_event :
606647 aggregations [RelationTypes .THREAD ] = {
607- # Don't bundle aggregations as this could recurse forever.
608648 "latest_event" : latest_thread_event ,
609649 "count" : thread_count ,
650+ "current_user_participated" : participated ,
610651 }
611652
612653 # Store the bundled aggregations in the event metadata for later use.
613654 return aggregations
614655
615656 async def get_bundled_aggregations (
616- self , events : Iterable [EventBase ]
657+ self ,
658+ events : Iterable [EventBase ],
659+ user_id : str ,
617660 ) -> Dict [str , Dict [str , Any ]]:
618661 """Generate bundled aggregations for events.
619662
620663 Args:
621664 events: The iterable of events to calculate bundled aggregations for.
665+ user_id: The user requesting the bundled aggregations.
622666
623667 Returns:
624668 A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ async def get_bundled_aggregations(
631675 # TODO Parallelize.
632676 results = {}
633677 for event in events :
634- event_result = await self ._get_bundled_aggregation_for_event (event )
678+ event_result = await self ._get_bundled_aggregation_for_event (event , user_id )
635679 if event_result is not None :
636680 results [event .event_id ] = event_result
637681
0 commit comments