Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add final type hint to tests.unittest. #15072

Merged
merged 13 commits into from
Feb 14, 2023
2 changes: 1 addition & 1 deletion tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock(
self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)

Expand Down
8 changes: 4 additions & 4 deletions tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_map_cas_user_to_user(self) -> None:

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]

cas_response = CasResponse("test_user", {})
request = _mock_request()
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_map_cas_user_to_existing_user(self) -> None:

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]

# Map a user via SSO.
cas_response = CasResponse("test_user", {})
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_map_cas_user_to_invalid_localpart(self) -> None:

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]

cas_response = CasResponse("föö", {})
request = _mock_request()
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_required_attributes(self) -> None:

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]

# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
Expand Down
55 changes: 29 additions & 26 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
Expand Down Expand Up @@ -187,37 +188,37 @@ def test_fallback_key(self) -> None:
)

# we should now have an unused alg1 key
res = self.get_success(
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
self.assertEqual(fallback_res, ["alg1"])

# claiming an OTK when no OTKs are available should return the fallback
# key
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# we shouldn't have any unused fallback keys again
res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
self.assertEqual(unused_res, [])

# claiming an OTK again should return the same fallback key
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

Expand All @@ -231,10 +232,10 @@ def test_fallback_key(self) -> None:
)
)

res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
self.assertEqual(unused_res, [])

# uploading a new fallback key should result in an unused fallback key
self.get_success(
Expand All @@ -245,10 +246,10 @@ def test_fallback_key(self) -> None:
)
)

res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
self.assertEqual(unused_res, ["alg1"])

# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
Expand All @@ -258,23 +259,23 @@ def test_fallback_key(self) -> None:
)
)

res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)

res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)

Expand All @@ -287,13 +288,13 @@ def test_fallback_key(self) -> None:
)
)

res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)

Expand Down Expand Up @@ -366,7 +367,7 @@ def test_reupload_signatures(self) -> None:
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))

# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
Expand All @@ -379,7 +380,7 @@ def test_reupload_signatures(self) -> None:
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
device_key_2 = {
device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
Expand Down Expand Up @@ -451,8 +452,10 @@ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))

device_handler = self.hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to distinguish from DeviceWorkerHandler, I'm guessing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much, yes. 👍

e = self.get_failure(
self.hs.get_device_handler().check_device_registered(
device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
Expand All @@ -475,7 +478,7 @@ def test_upload_signatures(self) -> None:
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
device_key = {
device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
Expand All @@ -497,7 +500,7 @@ def test_upload_signatures(self) -> None:

# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
master_key = {
master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
Expand Down Expand Up @@ -540,7 +543,7 @@ def test_upload_signatures(self) -> None:
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
other_master_key = {
other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
Expand Down Expand Up @@ -702,7 +705,7 @@ def test_query_devices_remote_no_sync(self) -> None:
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_client_keys = mock.Mock(
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
Expand Down Expand Up @@ -782,7 +785,7 @@ def test_query_devices_remote_sync(self) -> None:
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_user_devices = mock.Mock(
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
Expand Down
18 changes: 9 additions & 9 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,14 @@ def test_backfill_ignores_known_events(self) -> None:
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
self.hs.get_federation_client().backfill = federation_client_backfill_mock
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]

# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
self.hs.get_federation_event_handler().persist_events_and_notify = (
self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)

Expand Down Expand Up @@ -712,12 +712,12 @@ async def sync_partial_state_room(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)

# Try to start another partial state sync.
# Nothing should happen.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)

# End the partial state sync
Expand All @@ -729,7 +729,7 @@ async def sync_partial_state_room(

# The next attempt to start the partial state sync should work.
is_partial_state = True
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)

def test_partial_state_room_sync_restart(self) -> None:
Expand Down Expand Up @@ -764,7 +764,7 @@ async def sync_partial_state_room(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)

# Fail the partial state sync.
Expand All @@ -773,11 +773,11 @@ async def sync_partial_state_room(
self.assertEqual(mock_sync_partial_state_room.call_count, 1)

# Start the partial state sync again.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)

# Deduplicate another partial state sync.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)

# Fail the partial state sync.
Expand All @@ -786,6 +786,6 @@ async def sync_partial_state_room(
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
other_destinations=["hs2"],
other_destinations={"hs2"},
room_id="room_id",
)
6 changes: 4 additions & 2 deletions tests/handlers/test_federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
Expand Down Expand Up @@ -161,6 +162,7 @@ def _test_process_pulled_event_with_missing_state(
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
Expand Down Expand Up @@ -861,7 +863,7 @@ def test_process_pulled_event_with_rejected_missing_state(self) -> None:
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=main_store,
state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
Expand Down Expand Up @@ -906,7 +908,7 @@ def test_process_pulled_event_with_rejected_missing_state(self) -> None:
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=main_store,
state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
Expand Down
11 changes: 6 additions & 5 deletions tests/handlers/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persist_event_storage_controller = persistence

self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)

self.info = self.get_success(
info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
self.token_id = self.info.token_id
assert info is not None
self.token_id = info.token_id

self.requester = create_requester(self.user_id, access_token_id=self.token_id)

Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def _test_3pid_allowed(self, username: str, registration: bool) -> None:
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
self.hs.get_identity_handler().send_threepid_validation = Mock(
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)

Expand Down
Loading