|
15 | 15 |
|
16 | 16 | from mock import Mock
|
17 | 17 |
|
| 18 | +from synapse.app.generic_worker import GenericWorkerServer |
| 19 | +from synapse.replication.tcp.client import ReplicationDataHandler |
18 | 20 | from synapse.replication.tcp.handler import ReplicationCommandHandler
|
19 | 21 | from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
20 | 22 | from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
|
26 | 28 | class BaseStreamTestCase(unittest.HomeserverTestCase):
|
27 | 29 | """Base class for tests of the replication streams"""
|
28 | 30 |
|
29 |
| - def make_homeserver(self, reactor, clock): |
30 |
| - self.test_handler = Mock(wraps=TestReplicationDataHandler()) |
31 |
| - return self.setup_test_homeserver(replication_data_handler=self.test_handler) |
32 |
| - |
33 | 31 | def prepare(self, reactor, clock, hs):
|
34 | 32 | # build a replication server
|
35 | 33 | server_factory = ReplicationStreamProtocolFactory(hs)
|
36 | 34 | self.streamer = hs.get_replication_streamer()
|
37 | 35 | self.server = server_factory.buildProtocol(None)
|
38 | 36 |
|
39 |
| - repl_handler = ReplicationCommandHandler(hs) |
40 |
| - repl_handler.handler = self.test_handler |
| 37 | + # Make a new HomeServer object for the worker |
| 38 | + config = self.default_config() |
| 39 | + config["worker_app"] = "synapse.app.generic_worker" |
| 40 | + |
| 41 | + self.worker_hs = self.setup_test_homeserver( |
| 42 | + http_client=None, |
| 43 | + homeserverToUse=GenericWorkerServer, |
| 44 | + config=config, |
| 45 | + reactor=self.reactor, |
| 46 | + ) |
| 47 | + |
| 48 | + self.test_handler = Mock( |
| 49 | + wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) |
| 50 | + ) |
| 51 | + self.worker_hs.replication_data_handler = self.test_handler |
| 52 | + |
| 53 | + # Since we use sqlite in memory databases we need to make sure the |
| 54 | + # databases objects are the same. |
| 55 | + self.worker_hs.get_datastore().db = hs.get_datastore().db |
| 56 | + |
| 57 | + repl_handler = ReplicationCommandHandler(self.worker_hs) |
| 58 | + |
41 | 59 | self.client = ClientReplicationStreamProtocol(
|
42 | 60 | hs, "client", "test", clock, repl_handler,
|
43 | 61 | )
|
@@ -75,16 +93,15 @@ def replicate(self):
|
75 | 93 | self.pump(0.1)
|
76 | 94 |
|
77 | 95 |
|
78 |
| -class TestReplicationDataHandler: |
| 96 | +class TestReplicationDataHandler(ReplicationDataHandler): |
79 | 97 | """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
80 | 98 |
|
81 |
| - def __init__(self): |
| 99 | + def __init__(self, hs): |
| 100 | + super().__init__(hs) |
82 | 101 | self.streams = set()
|
83 | 102 | self._received_rdata_rows = []
|
84 | 103 |
|
85 | 104 | async def on_rdata(self, stream_name, token, rows):
|
| 105 | + await super().on_rdata(stream_name, token, rows) |
86 | 106 | for r in rows:
|
87 | 107 | self._received_rdata_rows.append((stream_name, token, r))
|
88 |
| - |
89 |
| - async def on_position(self, stream_name, token): |
90 |
| - pass |
|
0 commit comments