|
14 | 14 | # limitations under the License. |
15 | 15 | import itertools |
16 | 16 | import logging |
17 | | -from typing import List, Optional, Set |
| 17 | +from typing import Dict, List, Optional, Set, Tuple |
18 | 18 |
|
19 | 19 | from six.moves.queue import Empty, PriorityQueue |
20 | 20 |
|
@@ -103,6 +103,154 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): |
103 | 103 |
|
104 | 104 | return list(results) |
105 | 105 |
|
| 106 | + def get_auth_chain_difference(self, state_sets: List[Set[str]]): |
| 107 | + """Given sets of state events figure out the auth chain difference (as |
| 108 | + per state res v2 algorithm). |
| 109 | +
|
| 110 | + This equivalent to fetching the full auth chain for each set of state |
| 111 | + and returning the events that don't appear in each and every auth |
| 112 | + chain. |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Deferred[Set[str]] |
| 116 | + """ |
| 117 | + |
| 118 | + return self.db.runInteraction( |
| 119 | + "get_auth_chain_difference", |
| 120 | + self._get_auth_chain_difference_txn, |
| 121 | + state_sets, |
| 122 | + ) |
| 123 | + |
| 124 | + def _get_auth_chain_difference_txn( |
| 125 | + self, txn, state_sets: List[Set[str]] |
| 126 | + ) -> Set[str]: |
| 127 | + |
| 128 | + # Algorithm Description |
| 129 | + # ~~~~~~~~~~~~~~~~~~~~~ |
| 130 | + # |
| 131 | + # The idea here is to basically walk the auth graph of each state set in |
| 132 | + # tandem, keeping track of which auth events are reachable by each state |
| 133 | + # set. If we reach an auth event we've already visited (via a different |
| 134 | + # state set) then we mark that auth event and all ancestors as reachable |
| 135 | + # by the state set. This requires that we keep track of the auth chains |
| 136 | + # in memory. |
| 137 | + # |
| 138 | + # Doing it in a such a way means that we can stop early if all auth |
| 139 | + # events we're currently walking are reachable by all state sets. |
| 140 | + # |
| 141 | + # *Note*: We can't stop walking an event's auth chain if it is reachable |
| 142 | + # by all state sets. This is because other auth chains we're walking |
| 143 | + # might be reachable only via the original auth chain. For example, |
| 144 | + # given the following auth chain: |
| 145 | + # |
| 146 | + # A -> C -> D -> E |
| 147 | + # / / |
| 148 | + # B -´---------´ |
| 149 | + # |
| 150 | + # and state sets {A} and {B} then walking the auth chains of A and B |
| 151 | + # would immediately show that C is reachable by both. However, if we |
| 152 | + # stopped at C then we'd only reach E via the auth chain of B and so E |
| 153 | + # would errornously get included in the returned difference. |
| 154 | + # |
| 155 | + # The other thing that we do is limit the number of auth chains we walk |
| 156 | + # at once, due to practical limits (i.e. we can only query the database |
| 157 | + # with a limited set of parameters). We pick the auth chains we walk |
| 158 | + # each iteration based on their depth, in the hope that events with a |
| 159 | + # lower depth are likely reachable by those with higher depths. |
| 160 | + # |
| 161 | + # We could use any ordering that we believe would give a rough |
| 162 | + # topological ordering, e.g. origin server timestamp. If the ordering |
| 163 | + # chosen is not topological then the algorithm still produces the right |
| 164 | + # result, but perhaps a bit more inefficiently. This is why it is safe |
| 165 | + # to use "depth" here. |
| 166 | + |
| 167 | + initial_events = set(state_sets[0]).union(*state_sets[1:]) |
| 168 | + |
| 169 | + # Dict from events in auth chains to which sets *cannot* reach them. |
| 170 | + # I.e. if the set is empty then all sets can reach the event. |
| 171 | + event_to_missing_sets = { |
| 172 | + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} |
| 173 | + for event_id in initial_events |
| 174 | + } |
| 175 | + |
| 176 | + # We need to get the depth of the initial events for sorting purposes. |
| 177 | + sql = """ |
| 178 | + SELECT depth, event_id FROM events |
| 179 | + WHERE %s |
| 180 | + ORDER BY depth ASC |
| 181 | + """ |
| 182 | + clause, args = make_in_list_sql_clause( |
| 183 | + txn.database_engine, "event_id", initial_events |
| 184 | + ) |
| 185 | + txn.execute(sql % (clause,), args) |
| 186 | + |
| 187 | + # The sorted list of events whose auth chains we should walk. |
| 188 | + search = txn.fetchall() # type: List[Tuple[int, str]] |
| 189 | + |
| 190 | + # Map from event to its auth events |
| 191 | + event_to_auth_events = {} # type: Dict[str, Set[str]] |
| 192 | + |
| 193 | + base_sql = """ |
| 194 | + SELECT a.event_id, auth_id, depth |
| 195 | + FROM event_auth AS a |
| 196 | + INNER JOIN events AS e ON (e.event_id = a.auth_id) |
| 197 | + WHERE |
| 198 | + """ |
| 199 | + |
| 200 | + while search: |
| 201 | + # Check whether all our current walks are reachable by all state |
| 202 | + # sets. If so we can bail. |
| 203 | + if all(not event_to_missing_sets[eid] for _, eid in search): |
| 204 | + break |
| 205 | + |
| 206 | + # Fetch the auth events and their depths of the N last events we're |
| 207 | + # currently walking |
| 208 | + search, chunk = search[:-100], search[-100:] |
| 209 | + clause, args = make_in_list_sql_clause( |
| 210 | + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] |
| 211 | + ) |
| 212 | + txn.execute(base_sql + clause, args) |
| 213 | + |
| 214 | + for event_id, auth_event_id, auth_event_depth in txn: |
| 215 | + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) |
| 216 | + |
| 217 | + sets = event_to_missing_sets.get(auth_event_id) |
| 218 | + if sets is None: |
| 219 | + # First time we're seeing this event, so we add it to the |
| 220 | + # queue of things to fetch. |
| 221 | + search.append((auth_event_depth, auth_event_id)) |
| 222 | + |
| 223 | + # Assume that this event is unreachable from any of the |
| 224 | + # state sets until proven otherwise |
| 225 | + sets = event_to_missing_sets[auth_event_id] = set( |
| 226 | + range(len(state_sets)) |
| 227 | + ) |
| 228 | + else: |
| 229 | + # We've previously seen this event, so look up its auth |
| 230 | + # events and recursively mark all ancestors as reachable |
| 231 | + # by the current event's state set. |
| 232 | + a_ids = event_to_auth_events.get(auth_event_id) |
| 233 | + while a_ids: |
| 234 | + new_aids = set() |
| 235 | + for a_id in a_ids: |
| 236 | + event_to_missing_sets[a_id].intersection_update( |
| 237 | + event_to_missing_sets[event_id] |
| 238 | + ) |
| 239 | + |
| 240 | + b = event_to_auth_events.get(a_id) |
| 241 | + if b: |
| 242 | + new_aids.update(b) |
| 243 | + |
| 244 | + a_ids = new_aids |
| 245 | + |
| 246 | + # Mark that the auth event is reachable by the approriate sets. |
| 247 | + sets.intersection_update(event_to_missing_sets[event_id]) |
| 248 | + |
| 249 | + search.sort() |
| 250 | + |
| 251 | + # Return all events where not all sets can reach them. |
| 252 | + return {eid for eid, n in event_to_missing_sets.items() if n} |
| 253 | + |
106 | 254 | def get_oldest_events_in_room(self, room_id): |
107 | 255 | return self.db.runInteraction( |
108 | 256 | "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id |
|
0 commit comments