Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d0e78af

Browse files
authored
Add missing type hints to synapse.replication. (#11938)
1 parent 8c94b3a commit d0e78af

File tree

19 files changed

+209
-147
lines changed

19 files changed

+209
-147
lines changed

changelog.d/11938.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to replication code.

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ disallow_untyped_defs = True
169169
[mypy-synapse.push.*]
170170
disallow_untyped_defs = True
171171

172+
[mypy-synapse.replication.*]
173+
disallow_untyped_defs = True
174+
172175
[mypy-synapse.rest.*]
173176
disallow_untyped_defs = True
174177

synapse/replication/slave/storage/_slaved_id_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
for table, column in extra_tables:
4141
self.advance(None, _load_current_id(db_conn, table, column))
4242

43-
def advance(self, instance_name: Optional[str], new_id: int):
43+
def advance(self, instance_name: Optional[str], new_id: int) -> None:
4444
self._current = (max if self.step > 0 else min)(self._current, new_id)
4545

4646
def get_current_token(self) -> int:

synapse/replication/slave/storage/client_ips.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def __init__(
3737
cache_name="client_ip_last_seen", max_size=50000
3838
)
3939

40-
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
40+
async def insert_client_ip(
41+
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
42+
) -> None:
4143
now = int(self._clock.time_msec())
4244
key = (user_id, access_token, ip)
4345

synapse/replication/slave/storage/devices.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Iterable
1616

1717
from synapse.replication.slave.storage._base import BaseSlavedStore
1818
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -60,7 +60,9 @@ def __init__(
6060
def get_device_stream_token(self) -> int:
6161
return self._device_list_id_gen.get_current_token()
6262

63-
def process_replication_rows(self, stream_name, instance_name, token, rows):
63+
def process_replication_rows(
64+
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
65+
) -> None:
6466
if stream_name == DeviceListsStream.NAME:
6567
self._device_list_id_gen.advance(instance_name, token)
6668
self._invalidate_caches_for_devices(token, rows)
@@ -70,7 +72,9 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
7072
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
7173
return super().process_replication_rows(stream_name, instance_name, token, rows)
7274

73-
def _invalidate_caches_for_devices(self, token, rows):
75+
def _invalidate_caches_for_devices(
76+
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
77+
) -> None:
7478
for row in rows:
7579
# The entities are either user IDs (starting with '@') whose devices
7680
# have changed, or remote servers that we need to tell about

synapse/replication/slave/storage/groups.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Iterable
1616

1717
from synapse.replication.slave.storage._base import BaseSlavedStore
1818
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -44,10 +44,12 @@ def __init__(
4444
self._group_updates_id_gen.get_current_token(),
4545
)
4646

47-
def get_group_stream_token(self):
47+
def get_group_stream_token(self) -> int:
4848
return self._group_updates_id_gen.get_current_token()
4949

50-
def process_replication_rows(self, stream_name, instance_name, token, rows):
50+
def process_replication_rows(
51+
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
52+
) -> None:
5153
if stream_name == GroupServerStream.NAME:
5254
self._group_updates_id_gen.advance(instance_name, token)
5355
for row in rows:

synapse/replication/slave/storage/push_rule.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from typing import Any, Iterable
1516

1617
from synapse.replication.tcp.streams import PushRulesStream
1718
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -20,10 +21,12 @@
2021

2122

2223
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
23-
def get_max_push_rules_stream_id(self):
24+
def get_max_push_rules_stream_id(self) -> int:
2425
return self._push_rules_stream_id_gen.get_current_token()
2526

26-
def process_replication_rows(self, stream_name, instance_name, token, rows):
27+
def process_replication_rows(
28+
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
29+
) -> None:
2730
if stream_name == PushRulesStream.NAME:
2831
self._push_rules_stream_id_gen.advance(instance_name, token)
2932
for row in rows:

synapse/replication/slave/storage/pushers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Iterable
1616

1717
from synapse.replication.tcp.streams import PushersStream
1818
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@@ -41,8 +41,8 @@ def get_pushers_stream_token(self) -> int:
4141
return self._pushers_id_gen.get_current_token()
4242

4343
def process_replication_rows(
44-
self, stream_name: str, instance_name: str, token, rows
44+
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
4545
) -> None:
4646
if stream_name == PushersStream.NAME:
47-
self._pushers_id_gen.advance(instance_name, token) # type: ignore
47+
self._pushers_id_gen.advance(instance_name, token)
4848
return super().process_replication_rows(stream_name, instance_name, token, rows)

synapse/replication/tcp/client.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
"""A replication client for use by synapse workers.
1515
"""
1616
import logging
17-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
17+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
1818

1919
from twisted.internet.defer import Deferred
20+
from twisted.internet.interfaces import IAddress, IConnector
2021
from twisted.internet.protocol import ReconnectingClientFactory
22+
from twisted.python.failure import Failure
2123

2224
from synapse.api.constants import EventTypes
2325
from synapse.federation import send_queue
@@ -79,10 +81,10 @@ def __init__(
7981

8082
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
8183

82-
def startedConnecting(self, connector):
84+
def startedConnecting(self, connector: IConnector) -> None:
8385
logger.info("Connecting to replication: %r", connector.getDestination())
8486

85-
def buildProtocol(self, addr):
87+
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
8688
logger.info("Connected to replication: %r", addr)
8789
return ClientReplicationStreamProtocol(
8890
self.hs,
@@ -92,11 +94,11 @@ def buildProtocol(self, addr):
9294
self.command_handler,
9395
)
9496

95-
def clientConnectionLost(self, connector, reason):
97+
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
9698
logger.error("Lost replication conn: %r", reason)
9799
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
98100

99-
def clientConnectionFailed(self, connector, reason):
101+
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
100102
logger.error("Failed to connect to replication: %r", reason)
101103
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
102104

@@ -131,7 +133,7 @@ def __init__(self, hs: "HomeServer"):
131133

132134
async def on_rdata(
133135
self, stream_name: str, instance_name: str, token: int, rows: list
134-
):
136+
) -> None:
135137
"""Called to handle a batch of replication data with a given stream token.
136138
137139
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -252,14 +254,16 @@ async def on_rdata(
252254
# loop. (This maintains the order so no need to resort)
253255
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
254256

255-
async def on_position(self, stream_name: str, instance_name: str, token: int):
257+
async def on_position(
258+
self, stream_name: str, instance_name: str, token: int
259+
) -> None:
256260
await self.on_rdata(stream_name, instance_name, token, [])
257261

258262
# We poke the generic "replication" notifier to wake anything up that
259263
# may be streaming.
260264
self.notifier.notify_replication()
261265

262-
def on_remote_server_up(self, server: str):
266+
def on_remote_server_up(self, server: str) -> None:
263267
"""Called when get a new REMOTE_SERVER_UP command."""
264268

265269
# Let's wake up the transaction queue for the server in case we have
@@ -269,7 +273,7 @@ def on_remote_server_up(self, server: str):
269273

270274
async def wait_for_stream_position(
271275
self, instance_name: str, stream_name: str, position: int
272-
):
276+
) -> None:
273277
"""Wait until this instance has received updates up to and including
274278
the given stream position.
275279
"""
@@ -304,7 +308,7 @@ async def wait_for_stream_position(
304308
"Finished waiting for repl stream %r to reach %s", stream_name, position
305309
)
306310

307-
def stop_pusher(self, user_id, app_id, pushkey):
311+
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
308312
if not self._notify_pushers:
309313
return
310314

@@ -316,13 +320,13 @@ def stop_pusher(self, user_id, app_id, pushkey):
316320
logger.info("Stopping pusher %r / %r", user_id, key)
317321
pusher.on_stop()
318322

319-
async def start_pusher(self, user_id, app_id, pushkey):
323+
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
320324
if not self._notify_pushers:
321325
return
322326

323327
key = "%s:%s" % (app_id, pushkey)
324328
logger.info("Starting pusher %r / %r", user_id, key)
325-
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
329+
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
326330

327331

328332
class FederationSenderHandler:
@@ -353,10 +357,12 @@ def __init__(self, hs: "HomeServer"):
353357

354358
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
355359

356-
def wake_destination(self, server: str):
360+
def wake_destination(self, server: str) -> None:
357361
self.federation_sender.wake_destination(server)
358362

359-
async def process_replication_rows(self, stream_name, token, rows):
363+
async def process_replication_rows(
364+
self, stream_name: str, token: int, rows: list
365+
) -> None:
360366
# The federation stream contains things that we want to send out, e.g.
361367
# presence, typing, etc.
362368
if stream_name == "federation":
@@ -384,11 +390,12 @@ async def process_replication_rows(self, stream_name, token, rows):
384390
for host in hosts:
385391
self.federation_sender.send_device_messages(host)
386392

387-
async def _on_new_receipts(self, rows):
393+
async def _on_new_receipts(
394+
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
395+
) -> None:
388396
"""
389397
Args:
390-
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
391-
new receipts to be processed
398+
rows: new receipts to be processed
392399
"""
393400
for receipt in rows:
394401
# we only want to send on receipts for our own users
@@ -408,7 +415,7 @@ async def _on_new_receipts(self, rows):
408415
)
409416
await self.federation_sender.send_read_receipt(receipt_info)
410417

411-
async def update_token(self, token):
418+
async def update_token(self, token: int) -> None:
412419
"""Update the record of where we have processed to in the federation stream.
413420
414421
Called after we have processed a an update received over replication. Sends
@@ -428,7 +435,7 @@ async def update_token(self, token):
428435

429436
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
430437

431-
async def _save_and_send_ack(self):
438+
async def _save_and_send_ack(self) -> None:
432439
"""Save the current federation position in the database and send an ACK
433440
to master with where we're up to.
434441
"""

0 commit comments

Comments
 (0)