Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Recursively fetch the thread ID when calculating notifications.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Sep 22, 2022
1 parent 2f31cc7 commit 2183f0a
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
6 changes: 6 additions & 0 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,14 @@ async def action_for_event_by_user(
relation.parent_id,
itertools.chain(*(r.rules() for r in rules_by_user.values())),
)
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is parent of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
print(f"Found {thread_id} for {event.event_id}")

evaluator = PushRuleEvaluatorForEvent(
event,
Expand Down
34 changes: 34 additions & 0 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,40 @@ def _get_event_relations(
"get_event_relations", _get_event_relations
)

@cached()
async def get_thread_id(self, event_id: str) -> Optional[str]:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
of a thread. None, otherwise.
"""
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""

def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
return None

return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)


class RelationsStore(RelationsWorkerStore):
pass
102 changes: 102 additions & 0 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,108 @@ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
_rotate()
_assert_counts(0, 0, 0, 0)

def test_recursive_thread(self) -> None:
"""
Events related to events in a thread should still be considered part of
that thread.
"""

# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")

# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")

# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)

# Update the user's push rules to care about reaction events.
self.get_success(
self.store.add_push_rule(
user_id,
"related_events",
priority_class=5,
conditions=[
{"kind": "event_match", "key": "type", "pattern": "m.reaction"}
],
actions=["notify"],
)
)

def _create_event(type: str, content: JsonDict) -> str:
result = self.helper.send_event(
room_id, type=type, content=content, tok=other_token
)
return result["event_id"]

def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
self.assertEqual(
counts.threads,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=0,
highlight_count=0,
),
},
)
else:
self.assertEqual(counts.threads, {})

# Create a root event.
thread_id = _create_event(
"m.room.message", {"msgtype": "m.text", "body": "msg"}
)
print(f"Root: {thread_id}")
_assert_counts(1, 0)

# Reply, creating a thread.
reply_id = _create_event(
"m.room.message",
{
"msgtype": "m.text",
"body": "msg",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": thread_id,
},
},
)
print(f"Reply: {reply_id}")
_assert_counts(1, 1)

# Create an event related to a thread event, this should still appear in
# the thread.
_create_event(
type="m.reaction",
content={
"m.relates_to": {
"rel_type": "m.annotation",
"event_id": reply_id,
"key": "A",
}
},
)
_assert_counts(1, 2)

def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(
Expand Down

0 comments on commit 2183f0a

Please sign in to comment.