Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 62c75e6

Browse files
committed
Improve get auth chain difference algorithm. (#7095)
* commit '4a17a647a': Improve get auth chain difference algorithm. (#7095)
2 parents 6ed566e + 4a17a64 commit 62c75e6

File tree

6 files changed

+310
-71
lines changed

6 files changed

+310
-71
lines changed

changelog.d/7095.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Attempt to improve performance of state res v2 algorithm.

synapse/state/__init__.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -662,28 +662,16 @@ def get_events(self, event_ids, allow_rejected=False):
662662
allow_rejected=allow_rejected,
663663
)
664664

665-
def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
666-
"""Gets the full auth chain for a set of events (including rejected
667-
events).
668-
669-
Includes the given event IDs in the result.
670-
671-
Note that:
672-
1. All events must be state events.
673-
2. For v1 rooms this may not have the full auth chain in the
674-
presence of rejected events
675-
676-
Args:
677-
event_ids: The event IDs of the events to fetch the auth chain for.
678-
Must be state events.
679-
ignore_events: Set of events to exclude from the returned auth
680-
chain.
665+
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
666+
"""Given sets of state events figure out the auth chain difference (as
667+
per state res v2 algorithm).
681668
669+
This equivalent to fetching the full auth chain for each set of state
670+
and returning the events that don't appear in each and every auth
671+
chain.
682672
683673
Returns:
684-
Deferred[list[str]]: List of event IDs of the auth chain.
674+
Deferred[Set[str]]: Set of event IDs.
685675
"""
686676

687-
return self.store.get_auth_chain_ids(
688-
event_ids, include_given=True, ignore_events=ignore_events,
689-
)
677+
return self.store.get_auth_chain_difference(state_sets)

synapse/state/v2.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
227227
Returns:
228228
Deferred[set[str]]: Set of event IDs
229229
"""
230-
common = set(itervalues(state_sets[0])).intersection(
231-
*(itervalues(s) for s in state_sets[1:])
232-
)
233-
234-
auth_sets = []
235-
for state_set in state_sets:
236-
auth_ids = {
237-
eid
238-
for key, eid in iteritems(state_set)
239-
if (
240-
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
241-
or key
242-
in (
243-
(EventTypes.PowerLevels, ""),
244-
(EventTypes.Create, ""),
245-
(EventTypes.JoinRules, ""),
246-
)
247-
)
248-
and eid not in common
249-
}
250230

251-
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
252-
auth_ids.update(auth_chain)
253-
254-
auth_sets.append(auth_ids)
255-
256-
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
257-
union = set().union(*auth_sets)
231+
difference = yield state_res_store.get_auth_chain_difference(
232+
[set(state_set.values()) for state_set in state_sets]
233+
)
258234

259-
return union - intersection
235+
return difference
260236

261237

262238
def _seperate(state_sets):

synapse/storage/data_stores/main/event_federation.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import itertools
1616
import logging
17-
from typing import List, Optional, Set
17+
from typing import Dict, List, Optional, Set, Tuple
1818

1919
from six.moves.queue import Empty, PriorityQueue
2020

@@ -103,6 +103,154 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
103103

104104
return list(results)
105105

106+
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
107+
"""Given sets of state events figure out the auth chain difference (as
108+
per state res v2 algorithm).
109+
110+
This equivalent to fetching the full auth chain for each set of state
111+
and returning the events that don't appear in each and every auth
112+
chain.
113+
114+
Returns:
115+
Deferred[Set[str]]
116+
"""
117+
118+
return self.db.runInteraction(
119+
"get_auth_chain_difference",
120+
self._get_auth_chain_difference_txn,
121+
state_sets,
122+
)
123+
124+
def _get_auth_chain_difference_txn(
125+
self, txn, state_sets: List[Set[str]]
126+
) -> Set[str]:
127+
128+
# Algorithm Description
129+
# ~~~~~~~~~~~~~~~~~~~~~
130+
#
131+
# The idea here is to basically walk the auth graph of each state set in
132+
# tandem, keeping track of which auth events are reachable by each state
133+
# set. If we reach an auth event we've already visited (via a different
134+
# state set) then we mark that auth event and all ancestors as reachable
135+
# by the state set. This requires that we keep track of the auth chains
136+
# in memory.
137+
#
138+
# Doing it in a such a way means that we can stop early if all auth
139+
# events we're currently walking are reachable by all state sets.
140+
#
141+
# *Note*: We can't stop walking an event's auth chain if it is reachable
142+
# by all state sets. This is because other auth chains we're walking
143+
# might be reachable only via the original auth chain. For example,
144+
# given the following auth chain:
145+
#
146+
# A -> C -> D -> E
147+
# / /
148+
# B -´---------´
149+
#
150+
# and state sets {A} and {B} then walking the auth chains of A and B
151+
# would immediately show that C is reachable by both. However, if we
152+
# stopped at C then we'd only reach E via the auth chain of B and so E
153+
# would errornously get included in the returned difference.
154+
#
155+
# The other thing that we do is limit the number of auth chains we walk
156+
# at once, due to practical limits (i.e. we can only query the database
157+
# with a limited set of parameters). We pick the auth chains we walk
158+
# each iteration based on their depth, in the hope that events with a
159+
# lower depth are likely reachable by those with higher depths.
160+
#
161+
# We could use any ordering that we believe would give a rough
162+
# topological ordering, e.g. origin server timestamp. If the ordering
163+
# chosen is not topological then the algorithm still produces the right
164+
# result, but perhaps a bit more inefficiently. This is why it is safe
165+
# to use "depth" here.
166+
167+
initial_events = set(state_sets[0]).union(*state_sets[1:])
168+
169+
# Dict from events in auth chains to which sets *cannot* reach them.
170+
# I.e. if the set is empty then all sets can reach the event.
171+
event_to_missing_sets = {
172+
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
173+
for event_id in initial_events
174+
}
175+
176+
# We need to get the depth of the initial events for sorting purposes.
177+
sql = """
178+
SELECT depth, event_id FROM events
179+
WHERE %s
180+
ORDER BY depth ASC
181+
"""
182+
clause, args = make_in_list_sql_clause(
183+
txn.database_engine, "event_id", initial_events
184+
)
185+
txn.execute(sql % (clause,), args)
186+
187+
# The sorted list of events whose auth chains we should walk.
188+
search = txn.fetchall() # type: List[Tuple[int, str]]
189+
190+
# Map from event to its auth events
191+
event_to_auth_events = {} # type: Dict[str, Set[str]]
192+
193+
base_sql = """
194+
SELECT a.event_id, auth_id, depth
195+
FROM event_auth AS a
196+
INNER JOIN events AS e ON (e.event_id = a.auth_id)
197+
WHERE
198+
"""
199+
200+
while search:
201+
# Check whether all our current walks are reachable by all state
202+
# sets. If so we can bail.
203+
if all(not event_to_missing_sets[eid] for _, eid in search):
204+
break
205+
206+
# Fetch the auth events and their depths of the N last events we're
207+
# currently walking
208+
search, chunk = search[:-100], search[-100:]
209+
clause, args = make_in_list_sql_clause(
210+
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
211+
)
212+
txn.execute(base_sql + clause, args)
213+
214+
for event_id, auth_event_id, auth_event_depth in txn:
215+
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
216+
217+
sets = event_to_missing_sets.get(auth_event_id)
218+
if sets is None:
219+
# First time we're seeing this event, so we add it to the
220+
# queue of things to fetch.
221+
search.append((auth_event_depth, auth_event_id))
222+
223+
# Assume that this event is unreachable from any of the
224+
# state sets until proven otherwise
225+
sets = event_to_missing_sets[auth_event_id] = set(
226+
range(len(state_sets))
227+
)
228+
else:
229+
# We've previously seen this event, so look up its auth
230+
# events and recursively mark all ancestors as reachable
231+
# by the current event's state set.
232+
a_ids = event_to_auth_events.get(auth_event_id)
233+
while a_ids:
234+
new_aids = set()
235+
for a_id in a_ids:
236+
event_to_missing_sets[a_id].intersection_update(
237+
event_to_missing_sets[event_id]
238+
)
239+
240+
b = event_to_auth_events.get(a_id)
241+
if b:
242+
new_aids.update(b)
243+
244+
a_ids = new_aids
245+
246+
# Mark that the auth event is reachable by the approriate sets.
247+
sets.intersection_update(event_to_missing_sets[event_id])
248+
249+
search.sort()
250+
251+
# Return all events where not all sets can reach them.
252+
return {eid for eid, n in event_to_missing_sets.items() if n}
253+
106254
def get_oldest_events_in_room(self, room_id):
107255
return self.db.runInteraction(
108256
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id

tests/state/test_v2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def get_events(self, event_ids, allow_rejected=False):
603603

604604
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
605605

606-
def get_auth_chain(self, event_ids, ignore_events):
606+
def _get_auth_chain(self, event_ids):
607607
"""Gets the full auth chain for a set of events (including rejected
608608
events).
609609
@@ -617,9 +617,6 @@ def get_auth_chain(self, event_ids, ignore_events):
617617
Args:
618618
event_ids (list): The event IDs of the events to fetch the auth
619619
chain for. Must be state events.
620-
ignore_events: Set of events to exclude from the returned auth
621-
chain.
622-
623620
Returns:
624621
Deferred[list[str]]: List of event IDs of the auth chain.
625622
"""
@@ -629,7 +626,7 @@ def get_auth_chain(self, event_ids, ignore_events):
629626
stack = list(event_ids)
630627
while stack:
631628
event_id = stack.pop()
632-
if event_id in result or event_id in ignore_events:
629+
if event_id in result:
633630
continue
634631

635632
result.add(event_id)
@@ -639,3 +636,9 @@ def get_auth_chain(self, event_ids, ignore_events):
639636
stack.append(aid)
640637

641638
return list(result)
639+
640+
def get_auth_chain_difference(self, auth_sets):
641+
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
642+
643+
common = set(chains[0]).intersection(*chains[1:])
644+
return set(chains[0]).union(*chains[1:]) - common

0 commit comments

Comments
 (0)