Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 88cd6f9

Browse files
authored
Allow retrieving the relations of a redacted event. (#12130)
This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content).
1 parent 3e4af36 commit 88cd6f9

File tree

8 files changed

+122
-83
lines changed

8 files changed

+122
-83
lines changed

changelog.d/12130.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug when redacting events with relations.

changelog.d/12189.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug when redacting events with relations.

changelog.d/12189.misc

Lines changed: 0 additions & 1 deletion
This file was deleted.

synapse/rest/client/relations.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from synapse.http.servlet import RestServlet, parse_integer, parse_string
2828
from synapse.http.site import SynapseRequest
2929
from synapse.rest.client._base import client_patterns
30-
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
30+
from synapse.storage.relations import AggregationPaginationToken
3131
from synapse.types import JsonDict, StreamToken
3232

3333
if TYPE_CHECKING:
@@ -82,28 +82,25 @@ async def on_GET(
8282
from_token_str = parse_string(request, "from")
8383
to_token_str = parse_string(request, "to")
8484

85-
if event.internal_metadata.is_redacted():
86-
# If the event is redacted, return an empty list of relations
87-
pagination_chunk = PaginationChunk(chunk=[])
88-
else:
89-
# Return the relations
90-
from_token = None
91-
if from_token_str:
92-
from_token = await StreamToken.from_string(self.store, from_token_str)
93-
to_token = None
94-
if to_token_str:
95-
to_token = await StreamToken.from_string(self.store, to_token_str)
96-
97-
pagination_chunk = await self.store.get_relations_for_event(
98-
event_id=parent_id,
99-
room_id=room_id,
100-
relation_type=relation_type,
101-
event_type=event_type,
102-
limit=limit,
103-
direction=direction,
104-
from_token=from_token,
105-
to_token=to_token,
106-
)
85+
# Return the relations
86+
from_token = None
87+
if from_token_str:
88+
from_token = await StreamToken.from_string(self.store, from_token_str)
89+
to_token = None
90+
if to_token_str:
91+
to_token = await StreamToken.from_string(self.store, to_token_str)
92+
93+
pagination_chunk = await self.store.get_relations_for_event(
94+
event_id=parent_id,
95+
event=event,
96+
room_id=room_id,
97+
relation_type=relation_type,
98+
event_type=event_type,
99+
limit=limit,
100+
direction=direction,
101+
from_token=from_token,
102+
to_token=to_token,
103+
)
107104

108105
events = await self.store.get_events_as_list(
109106
[c["event_id"] for c in pagination_chunk.chunk]
@@ -193,27 +190,23 @@ async def on_GET(
193190
from_token_str = parse_string(request, "from")
194191
to_token_str = parse_string(request, "to")
195192

196-
if event.internal_metadata.is_redacted():
197-
# If the event is redacted, return an empty list of relations
198-
pagination_chunk = PaginationChunk(chunk=[])
199-
else:
200-
# Return the relations
201-
from_token = None
202-
if from_token_str:
203-
from_token = AggregationPaginationToken.from_string(from_token_str)
204-
205-
to_token = None
206-
if to_token_str:
207-
to_token = AggregationPaginationToken.from_string(to_token_str)
208-
209-
pagination_chunk = await self.store.get_aggregation_groups_for_event(
210-
event_id=parent_id,
211-
room_id=room_id,
212-
event_type=event_type,
213-
limit=limit,
214-
from_token=from_token,
215-
to_token=to_token,
216-
)
193+
# Return the relations
194+
from_token = None
195+
if from_token_str:
196+
from_token = AggregationPaginationToken.from_string(from_token_str)
197+
198+
to_token = None
199+
if to_token_str:
200+
to_token = AggregationPaginationToken.from_string(to_token_str)
201+
202+
pagination_chunk = await self.store.get_aggregation_groups_for_event(
203+
event_id=parent_id,
204+
room_id=room_id,
205+
event_type=event_type,
206+
limit=limit,
207+
from_token=from_token,
208+
to_token=to_token,
209+
)
217210

218211
return 200, await pagination_chunk.to_dict(self.store)
219212

@@ -295,6 +288,7 @@ async def on_GET(
295288

296289
result = await self.store.get_relations_for_event(
297290
event_id=parent_id,
291+
event=event,
298292
room_id=room_id,
299293
relation_type=relation_type,
300294
event_type=event_type,

synapse/storage/databases/main/cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def _invalidate_caches_for_event(
191191

192192
if redacts:
193193
self._invalidate_get_event_cache(redacts)
194+
# Caches which might leak edits must be invalidated for the event being
195+
# redacted.
196+
self.get_relations_for_event.invalidate((redacts,))
197+
self.get_applicable_edit.invalidate((redacts,))
194198

195199
if etype == EventTypes.Member:
196200
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)

synapse/storage/databases/main/events.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,9 +1619,12 @@ def prefill():
16191619

16201620
txn.call_after(prefill)
16211621

1622-
def _store_redaction(self, txn, event):
1623-
# invalidate the cache for the redacted event
1622+
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
1623+
# Invalidate the caches for the redacted event, note that these caches
1624+
# are also cleared as part of event replication in _invalidate_caches_for_event.
16241625
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
1626+
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
1627+
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
16251628

16261629
self.db_pool.simple_upsert_txn(
16271630
txn,
@@ -1812,9 +1815,7 @@ def _handle_event_relations(
18121815
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
18131816

18141817
if rel_type == RelationTypes.THREAD:
1815-
txn.call_after(
1816-
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
1817-
)
1818+
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
18181819
# It should be safe to only invalidate the cache if the user has not
18191820
# previously participated in the thread, but that's difficult (and
18201821
# potentially error-prone) so it is always invalidated.

synapse/storage/databases/main/relations.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ def __init__(
9191

9292
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
9393

94-
@cached(tree=True)
94+
@cached(uncached_args=("event",), tree=True)
9595
async def get_relations_for_event(
9696
self,
9797
event_id: str,
98+
event: EventBase,
9899
room_id: str,
99100
relation_type: Optional[str] = None,
100101
event_type: Optional[str] = None,
@@ -108,6 +109,7 @@ async def get_relations_for_event(
108109
109110
Args:
110111
event_id: Fetch events that relate to this event ID.
112+
event: The matching EventBase to event_id.
111113
room_id: The room the event belongs to.
112114
relation_type: Only fetch events with this relation type, if given.
113115
event_type: Only fetch events with this event type, if given.
@@ -122,9 +124,13 @@ async def get_relations_for_event(
122124
List of event IDs that match relations requested. The rows are of
123125
the form `{"event_id": "..."}`.
124126
"""
127+
# We don't use `event_id`, it's there so that we can cache based on
128+
# it. The `event_id` must match the `event.event_id`.
129+
assert event.event_id == event_id
125130

126131
where_clause = ["relates_to_id = ?", "room_id = ?"]
127-
where_args: List[Union[str, int]] = [event_id, room_id]
132+
where_args: List[Union[str, int]] = [event.event_id, room_id]
133+
is_redacted = event.internal_metadata.is_redacted()
128134

129135
if relation_type is not None:
130136
where_clause.append("relation_type = ?")
@@ -157,7 +163,7 @@ async def get_relations_for_event(
157163
order = "ASC"
158164

159165
sql = """
160-
SELECT event_id, topological_ordering, stream_ordering
166+
SELECT event_id, relation_type, topological_ordering, stream_ordering
161167
FROM event_relations
162168
INNER JOIN events USING (event_id)
163169
WHERE %s
@@ -178,9 +184,12 @@ def _get_recent_references_for_event_txn(
178184
last_stream_id = None
179185
events = []
180186
for row in txn:
181-
events.append({"event_id": row[0]})
182-
last_topo_id = row[1]
183-
last_stream_id = row[2]
187+
# Do not include edits for redacted events as they leak event
188+
# content.
189+
if not is_redacted or row[1] != RelationTypes.REPLACE:
190+
events.append({"event_id": row[0]})
191+
last_topo_id = row[2]
192+
last_stream_id = row[3]
184193

185194
# If there are more events, generate the next pagination key.
186195
next_token = None
@@ -776,7 +785,7 @@ async def _get_bundled_aggregation_for_event(
776785
)
777786

778787
references = await self.get_relations_for_event(
779-
event_id, room_id, RelationTypes.REFERENCE, direction="f"
788+
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
780789
)
781790
if references.chunk:
782791
aggregations.references = await references.to_dict(cast("DataStore", self))
@@ -797,41 +806,36 @@ async def get_bundled_aggregations(
797806
A map of event ID to the bundled aggregation for the event. Not all
798807
events may have bundled aggregations in the results.
799808
"""
800-
# The already processed event IDs. Tracked separately from the result
801-
# since the result omits events which do not have bundled aggregations.
802-
seen_event_ids = set()
803-
804-
# State events and redacted events do not get bundled aggregations.
805-
events = [
806-
event
807-
for event in events
808-
if not event.is_state() and not event.internal_metadata.is_redacted()
809-
]
809+
# De-duplicate events by ID to handle the same event requested multiple times.
810+
#
811+
# State events do not get bundled aggregations.
812+
events_by_id = {
813+
event.event_id: event for event in events if not event.is_state()
814+
}
810815

811816
# event ID -> bundled aggregation in non-serialized form.
812817
results: Dict[str, BundledAggregations] = {}
813818

814819
# Fetch other relations per event.
815-
for event in events:
816-
# De-duplicate events by ID to handle the same event requested multiple
817-
# times. The caches that _get_bundled_aggregation_for_event use should
818-
# capture this, but best to reduce work.
819-
if event.event_id in seen_event_ids:
820-
continue
821-
seen_event_ids.add(event.event_id)
822-
820+
for event in events_by_id.values():
823821
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
824822
if event_result:
825823
results[event.event_id] = event_result
826824

827-
# Fetch any edits.
828-
edits = await self._get_applicable_edits(seen_event_ids)
825+
# Fetch any edits (but not for redacted events).
826+
edits = await self._get_applicable_edits(
827+
[
828+
event_id
829+
for event_id, event in events_by_id.items()
830+
if not event.internal_metadata.is_redacted()
831+
]
832+
)
829833
for event_id, edit in edits.items():
830834
results.setdefault(event_id, BundledAggregations()).replace = edit
831835

832836
# Fetch thread summaries.
833837
if self._msc3440_enabled:
834-
summaries = await self._get_thread_summaries(seen_event_ids)
838+
summaries = await self._get_thread_summaries(events_by_id.keys())
835839
# Only fetch participated for a limited selection based on what had
836840
# summaries.
837841
participated = await self._get_threads_participated(

tests/rest/client/test_relations.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,12 +1475,13 @@ def test_redact_parent_edit(self) -> None:
14751475
self.assertEqual(relations, {})
14761476

14771477
def test_redact_parent_annotation(self) -> None:
1478-
"""Test that annotations of an event are redacted when the original event
1478+
"""Test that annotations of an event are viewable when the original event
14791479
is redacted.
14801480
"""
14811481
# Add a relation
14821482
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
14831483
self.assertEqual(200, channel.code, channel.json_body)
1484+
related_event_id = channel.json_body["event_id"]
14841485

14851486
# The relations should exist.
14861487
event_ids, relations = self._make_relation_requests()
@@ -1494,11 +1495,45 @@ def test_redact_parent_annotation(self) -> None:
14941495
# Redact the original event.
14951496
self._redact(self.parent_id)
14961497

1497-
# The relations are not returned.
1498+
# The relations are returned.
14981499
event_ids, relations = self._make_relation_requests()
1499-
self.assertEqual(event_ids, [])
1500-
self.assertEqual(relations, {})
1500+
self.assertEquals(event_ids, [related_event_id])
1501+
self.assertEquals(
1502+
relations["m.annotation"],
1503+
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
1504+
)
15011505

15021506
# There's nothing to aggregate.
15031507
chunk = self._get_aggregations()
1504-
self.assertEqual(chunk, [])
1508+
self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
1509+
1510+
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
1511+
def test_redact_parent_thread(self) -> None:
1512+
"""
1513+
Test that thread replies are still available when the root event is redacted.
1514+
"""
1515+
channel = self._send_relation(
1516+
RelationTypes.THREAD,
1517+
EventTypes.Message,
1518+
content={"body": "reply 1", "msgtype": "m.text"},
1519+
)
1520+
self.assertEqual(200, channel.code, channel.json_body)
1521+
related_event_id = channel.json_body["event_id"]
1522+
1523+
# Redact one of the reactions.
1524+
self._redact(self.parent_id)
1525+
1526+
# The unredacted relation should still exist.
1527+
event_ids, relations = self._make_relation_requests()
1528+
self.assertEquals(len(event_ids), 1)
1529+
self.assertDictContainsSubset(
1530+
{
1531+
"count": 1,
1532+
"current_user_participated": True,
1533+
},
1534+
relations[RelationTypes.THREAD],
1535+
)
1536+
self.assertEqual(
1537+
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
1538+
related_event_id,
1539+
)

0 commit comments

Comments
 (0)