Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 9372971

Browse files
Use inline type hints in tests/ (#10350)
This PR is tantamount to running: python3.8 -m com2ann -v 6 tests/ (com2ann requires python 3.8 to run)
1 parent 89cfc3d commit 9372971

18 files changed

+62
-63
lines changed

changelog.d/10350.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Convert internal type variable syntax to reflect wider ecosystem use.

tests/events/test_presence_router.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_receiving_all_presence(self):
152152
)
153153
self.assertEqual(len(presence_updates), 1)
154154

155-
presence_update = presence_updates[0] # type: UserPresenceState
155+
presence_update: UserPresenceState = presence_updates[0]
156156
self.assertEqual(presence_update.user_id, self.other_user_one_id)
157157
self.assertEqual(presence_update.state, "online")
158158
self.assertEqual(presence_update.status_msg, "boop")
@@ -274,7 +274,7 @@ def test_send_local_online_presence_to_with_module(self):
274274
presence_updates, _ = sync_presence(self, self.other_user_id)
275275
self.assertEqual(len(presence_updates), 1)
276276

277-
presence_update = presence_updates[0] # type: UserPresenceState
277+
presence_update: UserPresenceState = presence_updates[0]
278278
self.assertEqual(presence_update.user_id, self.other_user_id)
279279
self.assertEqual(presence_update.state, "online")
280280
self.assertEqual(presence_update.status_msg, "I'm online!")
@@ -320,7 +320,7 @@ def test_send_local_online_presence_to_with_module(self):
320320
)
321321
for call in calls:
322322
call_args = call[0]
323-
federation_transaction = call_args[0] # type: Transaction
323+
federation_transaction: Transaction = call_args[0]
324324

325325
# Get the sent EDUs in this transaction
326326
edus = federation_transaction.get_dict()["edus"]

tests/module_api/test_api.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def test_sending_events_into_room(self):
100100
"content": content,
101101
"sender": user_id,
102102
}
103-
event = self.get_success(
103+
event: EventBase = self.get_success(
104104
self.module_api.create_and_send_event_into_room(event_dict)
105-
) # type: EventBase
105+
)
106106
self.assertEqual(event.sender, user_id)
107107
self.assertEqual(event.type, "m.room.message")
108108
self.assertEqual(event.room_id, room_id)
@@ -136,9 +136,9 @@ def test_sending_events_into_room(self):
136136
"sender": user_id,
137137
"state_key": "",
138138
}
139-
event = self.get_success(
139+
event: EventBase = self.get_success(
140140
self.module_api.create_and_send_event_into_room(event_dict)
141-
) # type: EventBase
141+
)
142142
self.assertEqual(event.sender, user_id)
143143
self.assertEqual(event.type, "m.room.power_levels")
144144
self.assertEqual(event.room_id, room_id)
@@ -281,7 +281,7 @@ def test_send_local_online_presence_to_federation(self):
281281
)
282282
for call in calls:
283283
call_args = call[0]
284-
federation_transaction = call_args[0] # type: Transaction
284+
federation_transaction: Transaction = call_args[0]
285285

286286
# Get the sent EDUs in this transaction
287287
edus = federation_transaction.get_dict()["edus"]
@@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
390390
)
391391
test_case.assertEqual(len(presence_updates), 1)
392392

393-
presence_update = presence_updates[0] # type: UserPresenceState
393+
presence_update: UserPresenceState = presence_updates[0]
394394
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
395395
test_case.assertEqual(presence_update.state, "online")
396396

@@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
443443
)
444444
test_case.assertEqual(len(presence_updates), 1)
445445

446-
presence_update = presence_updates[0] # type: UserPresenceState
446+
presence_update: UserPresenceState = presence_updates[0]
447447
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
448448
test_case.assertEqual(presence_update.state, "online")
449449

@@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
454454
)
455455
test_case.assertEqual(len(presence_updates), 1)
456456

457-
presence_update = presence_updates[0] # type: UserPresenceState
457+
presence_update: UserPresenceState = presence_updates[0]
458458
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
459459
test_case.assertEqual(presence_update.state, "online")
460460

tests/replication/_base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def prepare(self, reactor, clock, hs):
5353
# build a replication server
5454
server_factory = ReplicationStreamProtocolFactory(hs)
5555
self.streamer = hs.get_replication_streamer()
56-
self.server = server_factory.buildProtocol(
56+
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
5757
None
58-
) # type: ServerReplicationStreamProtocol
58+
)
5959

6060
# Make a new HomeServer object for the worker
6161
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ def assert_request_is_get_repl_stream_updates(
195195
fetching updates for given stream.
196196
"""
197197

198-
path = request.path # type: bytes # type: ignore
198+
path: bytes = request.path # type: ignore
199199
self.assertRegex(
200200
path,
201201
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
212212
unlike `BaseStreamTestCase`.
213213
"""
214214

215-
servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
215+
servlets: List[Callable[[HomeServer, JsonResource], None]] = []
216216

217217
def setUp(self):
218218
super().setUp()
@@ -448,7 +448,7 @@ def __init__(self, hs: HomeServer):
448448
super().__init__(hs)
449449

450450
# list of received (stream_name, token, row) tuples
451-
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
451+
self.received_rdata_rows: List[Tuple[str, int, Any]] = []
452452

453453
async def on_rdata(self, stream_name, instance_name, token, rows):
454454
await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ def buildProtocol(self, addr):
484484
class FakeRedisPubSubProtocol(Protocol):
485485
"""A connection from a client talking to the fake Redis server."""
486486

487-
transport = None # type: Optional[FakeTransport]
487+
transport: Optional[FakeTransport] = None
488488

489489
def __init__(self, server: FakeRedisPubSubServer):
490490
self._server = server

tests/replication/tcp/streams/test_events.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_update_function_huge_state_change(self):
135135
)
136136

137137
# this is the point in the DAG where we make a fork
138-
fork_point = self.get_success(
138+
fork_point: List[str] = self.get_success(
139139
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
140-
) # type: List[str]
140+
)
141141

142142
events = [
143143
self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ def test_update_function_huge_state_change(self):
238238
self.assertEqual(row.data.event_id, pl_event.event_id)
239239

240240
# the state rows are unsorted
241-
state_rows = [] # type: List[EventsStreamCurrentStateRow]
241+
state_rows: List[EventsStreamCurrentStateRow] = []
242242
for stream_name, _, row in received_rows:
243243
self.assertEqual("events", stream_name)
244244
self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ def test_update_function_state_row_limit(self):
290290
)
291291

292292
# this is the point in the DAG where we make a fork
293-
fork_point = self.get_success(
293+
fork_point: List[str] = self.get_success(
294294
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
295-
) # type: List[str]
295+
)
296296

297-
events = [] # type: List[EventBase]
297+
events: List[EventBase] = []
298298
for user in user_ids:
299299
events.extend(
300300
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ def test_update_function_state_row_limit(self):
355355
self.assertEqual(row.data.event_id, pl_events[i].event_id)
356356

357357
# the state rows are unsorted
358-
state_rows = [] # type: List[EventsStreamCurrentStateRow]
358+
state_rows: List[EventsStreamCurrentStateRow] = []
359359
for _ in range(STATES_PER_USER + 1):
360360
stream_name, token, row = received_rows.pop(0)
361361
self.assertEqual("events", stream_name)

tests/replication/tcp/streams/test_receipts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_receipt(self):
4343
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
4444
self.assertEqual(stream_name, "receipts")
4545
self.assertEqual(1, len(rdata_rows))
46-
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
46+
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
4747
self.assertEqual("!room:blue", row.room_id)
4848
self.assertEqual("m.read", row.receipt_type)
4949
self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ def test_receipt(self):
7575
self.assertEqual(token, 3)
7676
self.assertEqual(1, len(rdata_rows))
7777

78-
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
78+
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
7979
self.assertEqual("!room2:blue", row.room_id)
8080
self.assertEqual("m.read", row.receipt_type)
8181
self.assertEqual(USER_ID, row.user_id)

tests/replication/tcp/streams/test_typing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_typing(self):
4747
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
4848
self.assertEqual(stream_name, "typing")
4949
self.assertEqual(1, len(rdata_rows))
50-
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
50+
row: TypingStream.TypingStreamRow = rdata_rows[0]
5151
self.assertEqual(ROOM_ID, row.room_id)
5252
self.assertEqual([USER_ID], row.user_ids)
5353

@@ -102,7 +102,7 @@ def test_reset(self):
102102
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
103103
self.assertEqual(stream_name, "typing")
104104
self.assertEqual(1, len(rdata_rows))
105-
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
105+
row: TypingStream.TypingStreamRow = rdata_rows[0]
106106
self.assertEqual(ROOM_ID, row.room_id)
107107
self.assertEqual([USER_ID], row.user_ids)
108108

tests/replication/test_multi_media_repo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
logger = logging.getLogger(__name__)
3333

34-
test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
34+
test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
3535

3636

3737
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):

tests/rest/client/test_third_party_rules.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def test_send_event(self):
233233
"content": content,
234234
"sender": self.user_id,
235235
}
236-
event = self.get_success(
236+
event: EventBase = self.get_success(
237237
current_rules_module().module_api.create_and_send_event_into_room(
238238
event_dict
239239
)
240-
) # type: EventBase
240+
)
241241

242242
self.assertEquals(event.sender, self.user_id)
243243
self.assertEquals(event.room_id, self.room_id)

tests/rest/client/v1/test_login.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def test_get_msc2858_login_flows(self):
453453
self.assertEqual(channel.code, 200, channel.result)
454454

455455
# stick the flows results in a dict by type
456-
flow_results = {} # type: Dict[str, Any]
456+
flow_results: Dict[str, Any] = {}
457457
for f in channel.json_body["flows"]:
458458
flow_type = f["type"]
459459
self.assertNotIn(
@@ -501,7 +501,7 @@ def test_multi_sso_redirect(self):
501501
p.close()
502502

503503
# there should be a link for each href
504-
returned_idps = [] # type: List[str]
504+
returned_idps: List[str] = []
505505
for link in p.links:
506506
path, query = link.split("?", 1)
507507
self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ def test_login_via_oidc(self):
582582
# ... and should have set a cookie including the redirect url
583583
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
584584
assert cookie_headers
585-
cookies = {} # type: Dict[str, str]
585+
cookies: Dict[str, str] = {}
586586
for h in cookie_headers:
587587
key, value = h.split(";")[0].split("=", maxsplit=1)
588588
cookies[key] = value
@@ -874,9 +874,7 @@ def make_homeserver(self, reactor, clock):
874874

875875
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
876876
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
877-
result = jwt.encode(
878-
payload, secret, self.jwt_algorithm
879-
) # type: Union[str, bytes]
877+
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
880878
if isinstance(result, bytes):
881879
return result.decode("ascii")
882880
return result
@@ -1084,7 +1082,7 @@ def make_homeserver(self, reactor, clock):
10841082

10851083
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
10861084
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
1087-
result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
1085+
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
10881086
if isinstance(result, bytes):
10891087
return result.decode("ascii")
10901088
return result
@@ -1272,7 +1270,7 @@ def test_username_picker(self):
12721270
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
12731271

12741272
# ... with a username_mapping_session cookie
1275-
cookies = {} # type: Dict[str,str]
1273+
cookies: Dict[str, str] = {}
12761274
channel.extract_cookies(cookies)
12771275
self.assertIn("username_mapping_session", cookies)
12781276
session_id = cookies["username_mapping_session"]

tests/server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class FakeChannel:
5252
_reactor = attr.ib()
5353
result = attr.ib(type=dict, default=attr.Factory(dict))
5454
_ip = attr.ib(type=str, default="127.0.0.1")
55-
_producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
55+
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
5656

5757
@property
5858
def json_body(self):
@@ -316,8 +316,10 @@ def __init__(self):
316316

317317
self._tcp_callbacks = {}
318318
self._udp = []
319-
lookups = self.lookups = {} # type: Dict[str, str]
320-
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
319+
self.lookups: Dict[str, str] = {}
320+
self._thread_callbacks: Deque[Callable[[], None]] = deque()
321+
322+
lookups = self.lookups
321323

322324
@implementer(IResolverSimple)
323325
class FakeResolver:

tests/storage/test_background_update.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
99
def prepare(self, reactor, clock, homeserver):
10-
self.updates = (
11-
self.hs.get_datastore().db_pool.updates
12-
) # type: BackgroundUpdater
10+
self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
1311
# the base test class should have run the real bg updates for us
1412
self.assertTrue(
1513
self.get_success(self.updates.has_completed_background_updates())

tests/storage/test_id_generators.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
2727

2828
def prepare(self, reactor, clock, hs):
2929
self.store = hs.get_datastore()
30-
self.db_pool = self.store.db_pool # type: DatabasePool
30+
self.db_pool: DatabasePool = self.store.db_pool
3131

3232
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
3333

@@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
460460

461461
def prepare(self, reactor, clock, hs):
462462
self.store = hs.get_datastore()
463-
self.db_pool = self.store.db_pool # type: DatabasePool
463+
self.db_pool: DatabasePool = self.store.db_pool
464464

465465
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
466466

@@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
586586

587587
def prepare(self, reactor, clock, hs):
588588
self.store = hs.get_datastore()
589-
self.db_pool = self.store.db_pool # type: DatabasePool
589+
self.db_pool: DatabasePool = self.store.db_pool
590590

591591
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
592592

tests/test_state.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_branch_no_conflict(self):
199199

200200
self.store.register_events(graph.walk())
201201

202-
context_store = {} # type: dict[str, EventContext]
202+
context_store: dict[str, EventContext] = {}
203203

204204
for event in graph.walk():
205205
context = yield defer.ensureDeferred(

tests/test_utils/html_parsers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def __init__(self):
2323
super().__init__()
2424

2525
# a list of links found in the doc
26-
self.links = [] # type: List[str]
26+
self.links: List[str] = []
2727

2828
# the values of any hidden <input>s: map from name to value
29-
self.hiddens = {} # type: Dict[str, Optional[str]]
29+
self.hiddens: Dict[str, Optional[str]] = {}
3030

3131
# the values of any radio buttons: map from name to list of values
32-
self.radios = {} # type: Dict[str, List[Optional[str]]]
32+
self.radios: Dict[str, List[Optional[str]]] = {}
3333

3434
def handle_starttag(
3535
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]

tests/unittest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def get_success_or_raise(self, d, by=0.0):
520520
if not isinstance(deferred, Deferred):
521521
return d
522522

523-
results = [] # type: list
523+
results: list = []
524524
deferred.addBoth(results.append)
525525

526526
self.pump(by=by)

0 commit comments

Comments
 (0)