Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit c2e1a21

Browse files
authored
Fix limit logic for EventsStream (#7358)
* Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests.
1 parent eeef963 commit c2e1a21

File tree

14 files changed

+658
-67
lines changed

14 files changed

+658
-67
lines changed

changelog.d/7358.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

synapse/replication/tcp/handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def __init__(self, hs):
8787
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
8888
} # type: Dict[str, Stream]
8989

90-
self._position_linearizer = Linearizer("replication_position")
90+
self._position_linearizer = Linearizer(
91+
"replication_position", clock=self._clock
92+
)
9193

9294
# Map of stream to batched updates. See RdataCommand for info on how
9395
# batching works.

synapse/replication/tcp/streams/events.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,16 @@ async def _update_function(
170170
limited = False
171171
upper_limit = current_token
172172

173-
# next up is the state delta table
174-
175-
state_rows = await self._store.get_all_updated_current_state_deltas(
173+
# next up is the state delta table.
174+
(
175+
state_rows,
176+
upper_limit,
177+
state_rows_limited,
178+
) = await self._store.get_all_updated_current_state_deltas(
176179
from_token, upper_limit, target_row_count
177-
) # type: List[Tuple]
178-
179-
# again, if we've hit the limit there, we'll need to limit the other sources
180-
assert len(state_rows) < target_row_count
181-
if len(state_rows) == target_row_count:
182-
assert state_rows[-1][0] <= upper_limit
183-
upper_limit = state_rows[-1][0]
184-
limited = True
180+
)
185181

186-
# FIXME: is it a given that there is only one row per stream_id in the
187-
# state_deltas table (so that we can be sure that we have got all of the
188-
# rows for upper_limit)?
182+
limited = limited or state_rows_limited
189183

190184
# finally, fetch the ex-outliers rows. We assume there are few enough of these
191185
# not to bother with the limit.

synapse/server.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
2525
import synapse.server_notices.server_notices_sender
2626
import synapse.state
2727
import synapse.storage
28+
from synapse.events.builder import EventBuilderFactory
2829

2930
class HomeServer(object):
3031
@property
@@ -121,3 +122,7 @@ class HomeServer(object):
121122
pass
122123
def get_instance_id(self) -> str:
123124
pass
125+
def get_event_builder_factory(self) -> EventBuilderFactory:
126+
pass
127+
def get_storage(self) -> synapse.storage.Storage:
128+
pass

synapse/storage/data_stores/main/events_worker.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import threading
2121
from collections import namedtuple
22-
from typing import List, Optional
22+
from typing import List, Optional, Tuple
2323

2424
from canonicaljson import json
2525
from constantly import NamedConstant, Names
@@ -1084,18 +1084,74 @@ def get_all_new_backfill_event_rows(txn):
10841084
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
10851085
)
10861086

1087-
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
1087+
async def get_all_updated_current_state_deltas(
1088+
self, from_token: int, to_token: int, target_row_count: int
1089+
) -> Tuple[List[Tuple], int, bool]:
1090+
"""Fetch updates from current_state_delta_stream
1091+
1092+
Args:
1093+
from_token: The previous stream token. Updates from this stream id will
1094+
be excluded.
1095+
1096+
to_token: The current stream token (ie the upper limit). Updates up to this
1097+
stream id will be included (modulo the 'limit' param)
1098+
1099+
target_row_count: The number of rows to try to return. If more rows are
1100+
available, we will set 'limited' in the result. In the event of a large
1101+
batch, we may return more rows than this.
1102+
Returns:
1103+
A triplet `(updates, new_last_token, limited)`, where:
1104+
* `updates` is a list of database tuples.
1105+
* `new_last_token` is the new position in stream.
1106+
* `limited` is whether there are more updates to fetch.
1107+
"""
1108+
10881109
def get_all_updated_current_state_deltas_txn(txn):
10891110
sql = """
10901111
SELECT stream_id, room_id, type, state_key, event_id
10911112
FROM current_state_delta_stream
10921113
WHERE ? < stream_id AND stream_id <= ?
10931114
ORDER BY stream_id ASC LIMIT ?
10941115
"""
1095-
txn.execute(sql, (from_token, to_token, limit))
1116+
txn.execute(sql, (from_token, to_token, target_row_count))
10961117
return txn.fetchall()
10971118

1098-
return self.db.runInteraction(
1119+
def get_deltas_for_stream_id_txn(txn, stream_id):
1120+
sql = """
1121+
SELECT stream_id, room_id, type, state_key, event_id
1122+
FROM current_state_delta_stream
1123+
WHERE stream_id = ?
1124+
"""
1125+
txn.execute(sql, [stream_id])
1126+
return txn.fetchall()
1127+
1128+
# we need to make sure that, for every stream id in the results, we get *all*
1129+
# the rows with that stream id.
1130+
1131+
rows = await self.db.runInteraction(
10991132
"get_all_updated_current_state_deltas",
11001133
get_all_updated_current_state_deltas_txn,
1134+
) # type: List[Tuple]
1135+
1136+
# if we've got fewer rows than the limit, we're good
1137+
if len(rows) < target_row_count:
1138+
return rows, to_token, False
1139+
1140+
# we hit the limit, so reduce the upper limit so that we exclude the stream id
1141+
# of the last row in the result.
1142+
assert rows[-1][0] <= to_token
1143+
to_token = rows[-1][0] - 1
1144+
1145+
# search backwards through the list for the point to truncate
1146+
for idx in range(len(rows) - 1, 0, -1):
1147+
if rows[idx - 1][0] <= to_token:
1148+
return rows[:idx], to_token, True
1149+
1150+
# bother. We didn't get a full set of changes for even a single
1151+
# stream id. let's run the query again, without a row limit, but for
1152+
# just one stream id.
1153+
to_token += 1
1154+
rows = await self.db.runInteraction(
1155+
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
11011156
)
1157+
return rows, to_token, True

tests/replication/tcp/streams/_base.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import logging
16-
from typing import Optional
1715

18-
from mock import Mock
16+
import logging
17+
from typing import Any, Dict, List, Optional, Tuple
1918

2019
import attr
2120

@@ -25,6 +24,7 @@
2524

2625
from synapse.app.generic_worker import GenericWorkerServer
2726
from synapse.http.site import SynapseRequest
27+
from synapse.replication.slave.storage._base import BaseSlavedStore
2828
from synapse.replication.tcp.client import ReplicationDataHandler
2929
from synapse.replication.tcp.handler import ReplicationCommandHandler
3030
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -65,9 +65,7 @@ def prepare(self, reactor, clock, hs):
6565
# databases objects are the same.
6666
self.worker_hs.get_datastore().db = hs.get_datastore().db
6767

68-
self.test_handler = Mock(
69-
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
70-
)
68+
self.test_handler = self._build_replication_data_handler()
7169
self.worker_hs.replication_data_handler = self.test_handler
7270

7371
repl_handler = ReplicationCommandHandler(self.worker_hs)
@@ -78,6 +76,9 @@ def prepare(self, reactor, clock, hs):
7876
self._client_transport = None
7977
self._server_transport = None
8078

79+
def _build_replication_data_handler(self):
80+
return TestReplicationDataHandler(self.worker_hs.get_datastore())
81+
8182
def reconnect(self):
8283
if self._client_transport:
8384
self.client.close()
@@ -174,22 +175,28 @@ def assert_request_is_get_repl_stream_updates(
174175
class TestReplicationDataHandler(ReplicationDataHandler):
175176
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
176177

177-
def __init__(self, hs):
178-
super().__init__(hs)
179-
self.streams = set()
180-
self._received_rdata_rows = []
178+
def __init__(self, store: BaseSlavedStore):
179+
super().__init__(store)
180+
181+
# streams to subscribe to: map from stream id to position
182+
self.stream_positions = {} # type: Dict[str, int]
183+
184+
# list of received (stream_name, token, row) tuples
185+
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
181186

182187
def get_streams_to_replicate(self):
183-
positions = {s: 0 for s in self.streams}
184-
for stream, token, _ in self._received_rdata_rows:
185-
if stream in self.streams:
186-
positions[stream] = max(token, positions.get(stream, 0))
187-
return positions
188+
return self.stream_positions
188189

189190
async def on_rdata(self, stream_name, token, rows):
190191
await super().on_rdata(stream_name, token, rows)
191192
for r in rows:
192-
self._received_rdata_rows.append((stream_name, token, r))
193+
self.received_rdata_rows.append((stream_name, token, r))
194+
195+
if (
196+
stream_name in self.stream_positions
197+
and token > self.stream_positions[stream_name]
198+
):
199+
self.stream_positions[stream_name] = token
193200

194201

195202
@attr.s()
@@ -221,7 +228,7 @@ def __init__(self, reactor: IReactorTime):
221228
super().__init__()
222229
self.reactor = reactor
223230

224-
self._pull_to_push_producer = None
231+
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
225232

226233
def registerProducer(self, producer, streaming):
227234
# Convert pull producers to push producer.

0 commit comments

Comments
 (0)