Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 55da8df

Browse files
authored
Fix additional type hints from Twisted 21.2.0. (#9591)
1 parent 1e67bff commit 55da8df

File tree

18 files changed

+187
-119
lines changed

18 files changed

+187
-119
lines changed

changelog.d/9591.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix incorrect type hints.

synapse/api/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def get_public_keys(self, invite_event):
164164

165165
async def get_user_by_req(
166166
self,
167-
request: Request,
167+
request: SynapseRequest,
168168
allow_guest: bool = False,
169169
rights: str = "access",
170170
allow_expired: bool = False,

synapse/federation/federation_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,9 @@ def __init__(self, hs: "HomeServer"):
880880
self.edu_handlers = (
881881
{}
882882
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
883-
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
883+
self.query_handlers = (
884+
{}
885+
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
884886

885887
# Map from type to instance names that we should route EDU handling to.
886888
# We randomly choose one instance from the list to route to for each new
@@ -914,7 +916,7 @@ def register_edu_handler(
914916
self.edu_handlers[edu_type] = handler
915917

916918
def register_query_handler(
917-
self, query_type: str, handler: Callable[[dict], defer.Deferred]
919+
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
918920
):
919921
"""Sets the handler callable that will be used to handle an incoming
920922
federation query of the given type.
@@ -987,7 +989,7 @@ async def on_edu(self, edu_type: str, origin: str, content: dict):
987989
# Oh well, let's just log and move on.
988990
logger.warning("No handler registered for EDU type %s", edu_type)
989991

990-
async def on_query(self, query_type: str, args: dict):
992+
async def on_query(self, query_type: str, args: dict) -> JsonDict:
991993
handler = self.query_handlers.get(query_type)
992994
if handler:
993995
return await handler(args)

synapse/handlers/oidc_handler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from typing_extensions import TypedDict
3535

3636
from twisted.web.client import readBody
37+
from twisted.web.http_headers import Headers
3738

3839
from synapse.config import ConfigError
3940
from synapse.config.oidc_config import (
@@ -538,7 +539,7 @@ async def _exchange_code(self, code: str) -> Token:
538539
"""
539540
metadata = await self.load_metadata()
540541
token_endpoint = metadata.get("token_endpoint")
541-
headers = {
542+
raw_headers = {
542543
"Content-Type": "application/x-www-form-urlencoded",
543544
"User-Agent": self._http_client.user_agent,
544545
"Accept": "application/json",
@@ -552,10 +553,10 @@ async def _exchange_code(self, code: str) -> Token:
552553
body = urlencode(args, True)
553554

554555
# Fill the body/headers with credentials
555-
uri, headers, body = self._client_auth.prepare(
556-
method="POST", uri=token_endpoint, headers=headers, body=body
556+
uri, raw_headers, body = self._client_auth.prepare(
557+
method="POST", uri=token_endpoint, headers=raw_headers, body=body
557558
)
558-
headers = {k: [v] for (k, v) in headers.items()}
559+
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
559560

560561
# Do the actual request
561562
# We're not using the SimpleHttpClient util methods as we don't want to

synapse/http/client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@
5757
)
5858
from twisted.web.http import PotentialDataLoss
5959
from twisted.web.http_headers import Headers
60-
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
60+
from twisted.web.iweb import (
61+
UNKNOWN_LENGTH,
62+
IAgent,
63+
IBodyProducer,
64+
IPolicyForHTTPS,
65+
IResponse,
66+
)
6167

6268
from synapse.api.errors import Codes, HttpResponseException, SynapseError
6369
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
870876
return query_str.encode("utf8")
871877

872878

879+
@implementer(IPolicyForHTTPS)
873880
class InsecureInterceptableContextFactory(ssl.ContextFactory):
874881
"""
875882
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.

synapse/logging/_remote.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
TCP4ClientEndpoint,
3333
TCP6ClientEndpoint,
3434
)
35-
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
35+
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
3636
from twisted.internet.protocol import Factory, Protocol
37+
from twisted.internet.tcp import Connection
3738
from twisted.python.failure import Failure
3839

3940
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
5253
format: A callable to format the log record to a string.
5354
"""
5455

55-
transport = attr.ib(type=ITransport)
56+
# This is essentially ITCPTransport, but that is missing certain fields
57+
# (connected and registerProducer) which are part of the implementation.
58+
transport = attr.ib(type=Connection)
5659
_format = attr.ib(type=Callable[[logging.LogRecord], str])
5760
_buffer = attr.ib(type=deque)
5861
_paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ def _connect(self) -> None:
149152
if self._connection_waiter:
150153
return
151154

152-
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
153-
154155
def fail(failure: Failure) -> None:
155156
# If the Deferred was cancelled (e.g. during shutdown) do not try to
156157
# reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ def fail(failure: Failure) -> None:
163164
self._connect()
164165

165166
def writer(result: Protocol) -> None:
167+
# Force recognising transport as a Connection and not the more
168+
# generic ITransport.
169+
transport = result.transport # type: Connection # type: ignore
170+
166171
# We have a connection. If we already have a producer, and its
167172
# transport is the same, just trigger a resumeProducing.
168-
if self._producer and result.transport is self._producer.transport:
173+
if self._producer and transport is self._producer.transport:
169174
self._producer.resumeProducing()
170175
self._connection_waiter = None
171176
return
@@ -177,14 +182,16 @@ def writer(result: Protocol) -> None:
177182
# Make a new producer and start it.
178183
self._producer = LogProducer(
179184
buffer=self._buffer,
180-
transport=result.transport,
185+
transport=transport,
181186
format=self.format,
182187
)
183-
result.transport.registerProducer(self._producer, True)
188+
transport.registerProducer(self._producer, True)
184189
self._producer.resumeProducing()
185190
self._connection_waiter = None
186191

187-
self._connection_waiter.addCallbacks(writer, fail)
192+
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
193+
deferred.addCallbacks(writer, fail)
194+
self._connection_waiter = deferred
188195

189196
def _handle_pressure(self) -> None:
190197
"""

synapse/push/emailpusher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import logging
1717
from typing import TYPE_CHECKING, Dict, List, Optional
1818

19-
from twisted.internet.base import DelayedCall
2019
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
20+
from twisted.internet.interfaces import IDelayedCall
2121

2222
from synapse.metrics.background_process_metrics import run_as_background_process
2323
from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer
6666

6767
self.store = self.hs.get_datastore()
6868
self.email = pusher_config.pushkey
69-
self.timed_call = None # type: Optional[DelayedCall]
69+
self.timed_call = None # type: Optional[IDelayedCall]
7070
self.throttle_params = {} # type: Dict[str, ThrottleParams]
7171
self._inited = False
7272

synapse/replication/tcp/handler.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
UserIpCommand,
4949
UserSyncCommand,
5050
)
51-
from synapse.replication.tcp.protocol import AbstractConnection
51+
from synapse.replication.tcp.protocol import IReplicationConnection
5252
from synapse.replication.tcp.streams import (
5353
STREAMS_MAP,
5454
AccountDataStream,
@@ -82,7 +82,7 @@
8282

8383
# the type of the entries in _command_queues_by_stream
8484
_StreamCommandQueue = Deque[
85-
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
85+
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
8686
]
8787

8888

@@ -174,7 +174,7 @@ def __init__(self, hs: "HomeServer"):
174174

175175
# The currently connected connections. (The list of places we need to send
176176
# outgoing replication commands to.)
177-
self._connections = [] # type: List[AbstractConnection]
177+
self._connections = [] # type: List[IReplicationConnection]
178178

179179
LaterGauge(
180180
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ def __init__(self, hs: "HomeServer"):
197197

198198
# For each connection, the incoming stream names that have received a POSITION
199199
# from that connection.
200-
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
200+
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
201201

202202
LaterGauge(
203203
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ def __init__(self, hs: "HomeServer"):
220220
self._server_notices_sender = hs.get_server_notices_sender()
221221

222222
def _add_command_to_stream_queue(
223-
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
223+
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
224224
) -> None:
225225
"""Queue the given received command for processing
226226
@@ -267,7 +267,7 @@ async def _unsafe_process_queue(self, stream_name: str):
267267
async def _process_command(
268268
self,
269269
cmd: Union[PositionCommand, RdataCommand],
270-
conn: AbstractConnection,
270+
conn: IReplicationConnection,
271271
stream_name: str,
272272
) -> None:
273273
if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ def get_streams_to_replicate(self) -> List[Stream]:
321321
"""Get a list of streams that this instances replicates."""
322322
return self._streams_to_replicate
323323

324-
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
324+
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
325325
self.send_positions_to_connection(conn)
326326

327-
def send_positions_to_connection(self, conn: AbstractConnection):
327+
def send_positions_to_connection(self, conn: IReplicationConnection):
328328
"""Send current position of all streams this process is source of to
329329
the connection.
330330
"""
@@ -347,7 +347,7 @@ def send_positions_to_connection(self, conn: AbstractConnection):
347347
)
348348

349349
def on_USER_SYNC(
350-
self, conn: AbstractConnection, cmd: UserSyncCommand
350+
self, conn: IReplicationConnection, cmd: UserSyncCommand
351351
) -> Optional[Awaitable[None]]:
352352
user_sync_counter.inc()
353353

@@ -359,21 +359,23 @@ def on_USER_SYNC(
359359
return None
360360

361361
def on_CLEAR_USER_SYNC(
362-
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
362+
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
363363
) -> Optional[Awaitable[None]]:
364364
if self._is_master:
365365
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
366366
else:
367367
return None
368368

369-
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
369+
def on_FEDERATION_ACK(
370+
self, conn: IReplicationConnection, cmd: FederationAckCommand
371+
):
370372
federation_ack_counter.inc()
371373

372374
if self._federation_sender:
373375
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
374376

375377
def on_USER_IP(
376-
self, conn: AbstractConnection, cmd: UserIpCommand
378+
self, conn: IReplicationConnection, cmd: UserIpCommand
377379
) -> Optional[Awaitable[None]]:
378380
user_ip_cache_counter.inc()
379381

@@ -395,7 +397,7 @@ async def _handle_user_ip(self, cmd: UserIpCommand):
395397
assert self._server_notices_sender is not None
396398
await self._server_notices_sender.on_user_ip(cmd.user_id)
397399

398-
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
400+
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
399401
if cmd.instance_name == self._instance_name:
400402
# Ignore RDATA that are just our own echoes
401403
return
@@ -412,7 +414,7 @@ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
412414
self._add_command_to_stream_queue(conn, cmd)
413415

414416
async def _process_rdata(
415-
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
417+
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
416418
) -> None:
417419
"""Process an RDATA command
418420
@@ -486,7 +488,7 @@ async def on_rdata(
486488
stream_name, instance_name, token, rows
487489
)
488490

489-
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
491+
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
490492
if cmd.instance_name == self._instance_name:
491493
# Ignore POSITION that are just our own echoes
492494
return
@@ -496,7 +498,7 @@ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
496498
self._add_command_to_stream_queue(conn, cmd)
497499

498500
async def _process_position(
499-
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
501+
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
500502
) -> None:
501503
"""Process a POSITION command
502504
@@ -553,7 +555,9 @@ async def _process_position(
553555

554556
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
555557

556-
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
558+
def on_REMOTE_SERVER_UP(
559+
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
560+
):
557561
""""Called when get a new REMOTE_SERVER_UP command."""
558562
self._replication_data_handler.on_remote_server_up(cmd.data)
559563

@@ -576,7 +580,7 @@ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpComma
576580
# between two instances, but that is not currently supported).
577581
self.send_command(cmd, ignore_conn=conn)
578582

579-
def new_connection(self, connection: AbstractConnection):
583+
def new_connection(self, connection: IReplicationConnection):
580584
"""Called when we have a new connection."""
581585
self._connections.append(connection)
582586

@@ -603,7 +607,7 @@ def new_connection(self, connection: AbstractConnection):
603607
UserSyncCommand(self._instance_id, user_id, True, now)
604608
)
605609

606-
def lost_connection(self, connection: AbstractConnection):
610+
def lost_connection(self, connection: IReplicationConnection):
607611
"""Called when a connection is closed/lost."""
608612
# we no longer need _streams_by_connection for this connection.
609613
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ def connected(self) -> bool:
624628
return bool(self._connections)
625629

626630
def send_command(
627-
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
631+
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
628632
):
629633
"""Send a command to all connected connections.
630634

0 commit comments

Comments
 (0)