|
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | 17 | import heapq |
18 | | -from typing import Iterable, Tuple, Type |
| 18 | +from collections import Iterable |
| 19 | +from typing import List, Tuple, Type |
19 | 20 |
|
20 | 21 | import attr |
21 | 22 |
|
22 | | -from ._base import Stream, Token, db_query_to_update_function |
| 23 | +from ._base import Stream, StreamUpdateResult, Token |
23 | 24 |
|
24 | 25 |
|
25 | 26 | """Handling of the 'events' replication stream |
@@ -117,30 +118,106 @@ class EventsStream(Stream): |
117 | 118 | def __init__(self, hs): |
118 | 119 | self._store = hs.get_datastore() |
119 | 120 | super().__init__( |
120 | | - self._store.get_current_events_token, |
121 | | - db_query_to_update_function(self._update_function), |
| 121 | + self._store.get_current_events_token, self._update_function, |
122 | 122 | ) |
123 | 123 |
|
124 | 124 | async def _update_function( |
125 | | - self, from_token: Token, current_token: Token, limit: int |
126 | | - ) -> Iterable[tuple]: |
| 125 | + self, from_token: Token, current_token: Token, target_row_count: int |
| 126 | + ) -> StreamUpdateResult: |
| 127 | + |
| 128 | + # the events stream merges together three separate sources: |
| 129 | + # * new events |
| 130 | + # * current_state changes |
| 131 | + # * events which were previously outliers, but have now been de-outliered. |
| 132 | + # |
| 133 | + # The merge operation is complicated by the fact that we only have a single |
| 134 | + # "stream token" which is supposed to indicate how far we have got through |
| 135 | + # all three streams. It's therefore no good to return rows 1-1000 from the |
| 136 | + # "new events" table if the state_deltas are limited to rows 1-100 by the |
| 137 | + # target_row_count. |
| 138 | + # |
| 139 | + # In other words: we must pick a new upper limit, and must return *all* rows |
| 140 | + # up to that point for each of the three sources. |
| 141 | + # |
| 142 | + # Start by trying to split the target_row_count up. We expect to have a |
| 143 | + # negligible number of ex-outliers, and a rough approximation based on recent |
| 144 | + # traffic on sw1v.org shows that there are approximately the same number of |
| 145 | + # event rows between a given pair of stream ids as there are state |
| 146 | + # updates, so let's split our target_row_count among those two types. The target |
| 147 | + # is only an approximation - it doesn't matter if we end up going a bit over it. |
| 148 | + |
| 149 | + target_row_count //= 2 |
| 150 | + |
| 151 | + # now we fetch up to that many rows from the events table |
| 152 | + |
127 | 153 | event_rows = await self._store.get_all_new_forward_event_rows( |
128 | | - from_token, current_token, limit |
129 | | - ) |
130 | | - event_updates = ( |
131 | | - (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows |
132 | | - ) |
| 154 | + from_token, current_token, target_row_count |
| 155 | + ) # type: List[Tuple] |
| 156 | + |
| 157 | + # we rely on get_all_new_forward_event_rows strictly honouring the limit, so |
| 158 | + # that we know it is safe to just take upper_limit = event_rows[-1][0]. |
| 159 | + assert ( |
| 160 | + len(event_rows) <= target_row_count |
| 161 | + ), "get_all_new_forward_event_rows did not honour row limit" |
| 162 | + |
| 163 | + # if we hit the limit on event_updates, there's no point in going beyond the |
| 164 | + # last stream_id in the batch for the other sources. |
| 165 | + |
| 166 | + if len(event_rows) == target_row_count: |
| 167 | + limited = True |
| 168 | + upper_limit = event_rows[-1][0] # type: int |
| 169 | + else: |
| 170 | + limited = False |
| 171 | + upper_limit = current_token |
| 172 | + |
| 173 | + # next up is the state delta table |
133 | 174 |
|
134 | 175 | state_rows = await self._store.get_all_updated_current_state_deltas( |
135 | | - from_token, current_token, limit |
136 | | - ) |
137 | | - state_updates = ( |
138 | | - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows |
139 | | - ) |
| 176 | + from_token, upper_limit, target_row_count |
| 177 | + ) # type: List[Tuple] |
| 178 | + |
| 179 | + # again, if we've hit the limit there, we'll need to limit the other sources |
| 180 | + assert len(state_rows) < target_row_count |
| 181 | + if len(state_rows) == target_row_count: |
| 182 | + assert state_rows[-1][0] <= upper_limit |
| 183 | + upper_limit = state_rows[-1][0] |
| 184 | + limited = True |
| 185 | + |
| 186 | + # FIXME: is it a given that there is only one row per stream_id in the |
| 187 | + # state_deltas table (so that we can be sure that we have got all of the |
| 188 | + # rows for upper_limit)? |
| 189 | + |
| 190 | + # finally, fetch the ex-outliers rows. We assume there are few enough of these |
| 191 | + # not to bother with the limit. |
140 | 192 |
|
141 | | - all_updates = heapq.merge(event_updates, state_updates) |
| 193 | + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( |
| 194 | + from_token, upper_limit |
| 195 | + ) # type: List[Tuple] |
142 | 196 |
|
143 | | - return all_updates |
| 197 | + # we now need to turn the raw database rows returned into tuples suitable |
| 198 | + # for the replication protocol (basically, we add an identifier to |
| 199 | + # distinguish the row type). At the same time, we can limit the event_rows |
| 200 | + # to the max stream_id from state_rows. |
| 201 | + |
| 202 | + event_updates = ( |
| 203 | + (stream_id, (EventsStreamEventRow.TypeId, rest)) |
| 204 | + for (stream_id, *rest) in event_rows |
| 205 | + if stream_id <= upper_limit |
| 206 | + ) # type: Iterable[Tuple[int, Tuple]] |
| 207 | + |
| 208 | + state_updates = ( |
| 209 | + (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) |
| 210 | + for (stream_id, *rest) in state_rows |
| 211 | + ) # type: Iterable[Tuple[int, Tuple]] |
| 212 | + |
| 213 | + ex_outliers_updates = ( |
| 214 | + (stream_id, (EventsStreamEventRow.TypeId, rest)) |
| 215 | + for (stream_id, *rest) in ex_outliers_rows |
| 216 | + ) # type: Iterable[Tuple[int, Tuple]] |
| 217 | + |
| 218 | + # we need to return a sorted list, so merge them together. |
| 219 | + updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) |
| 220 | + return updates, upper_limit, limited |
144 | 221 |
|
145 | 222 | @classmethod |
146 | 223 | def parse_row(cls, row): |
|
0 commit comments