1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from typing import Callable , FrozenSet , List , Optional , Set
15- from unittest .mock import Mock
15+ from unittest .mock import AsyncMock , Mock
1616
1717from signedjson import key , sign
1818from signedjson .types import BaseKey , SigningKey
2929from synapse .types import JsonDict , ReadReceipt
3030from synapse .util import Clock
3131
32- from tests .test_utils import make_awaitable
3332from tests .unittest import HomeserverTestCase
3433
3534
@@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
4342
4443 def make_homeserver (self , reactor : MemoryReactor , clock : Clock ) -> HomeServer :
4544 self .federation_transport_client = Mock (spec = ["send_transaction" ])
45+ self .federation_transport_client .send_transaction = AsyncMock ()
4646 hs = self .setup_test_homeserver (
4747 federation_transport_client = self .federation_transport_client ,
4848 )
4949
50- hs .get_storage_controllers ().state .get_current_hosts_in_room = Mock ( # type: ignore[assignment]
51- return_value = make_awaitable ( {"test" , "host2" })
50+ hs .get_storage_controllers ().state .get_current_hosts_in_room = AsyncMock ( # type: ignore[assignment]
51+ return_value = {"test" , "host2" }
5252 )
5353
5454 hs .get_storage_controllers ().state .get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
@@ -64,7 +64,7 @@ def default_config(self) -> JsonDict:
6464
6565 def test_send_receipts (self ) -> None :
6666 mock_send_transaction = self .federation_transport_client .send_transaction
67- mock_send_transaction .return_value = make_awaitable ({})
67+ mock_send_transaction .return_value = {}
6868
6969 sender = self .hs .get_federation_sender ()
7070 receipt = ReadReceipt (
@@ -104,7 +104,7 @@ def test_send_receipts(self) -> None:
104104
105105 def test_send_receipts_thread (self ) -> None :
106106 mock_send_transaction = self .federation_transport_client .send_transaction
107- mock_send_transaction .return_value = make_awaitable ({})
107+ mock_send_transaction .return_value = {}
108108
109109 # Create receipts for:
110110 #
@@ -180,7 +180,7 @@ def test_send_receipts_with_backoff(self) -> None:
180180 """Send two receipts in quick succession; the second should be flushed, but
181181 only after 20ms"""
182182 mock_send_transaction = self .federation_transport_client .send_transaction
183- mock_send_transaction .return_value = make_awaitable ({})
183+ mock_send_transaction .return_value = {}
184184
185185 sender = self .hs .get_federation_sender ()
186186 receipt = ReadReceipt (
@@ -276,6 +276,8 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
276276 self .federation_transport_client = Mock (
277277 spec = ["send_transaction" , "query_user_devices" ]
278278 )
279+ self .federation_transport_client .send_transaction = AsyncMock ()
280+ self .federation_transport_client .query_user_devices = AsyncMock ()
279281 return self .setup_test_homeserver (
280282 federation_transport_client = self .federation_transport_client ,
281283 )
@@ -317,13 +319,13 @@ async def get_current_hosts_in_room(room_id: str) -> Set[str]:
317319 self .record_transaction
318320 )
319321
320- def record_transaction (
322+ async def record_transaction (
321323 self , txn : Transaction , json_cb : Optional [Callable [[], JsonDict ]] = None
322- ) -> "defer.Deferred[ JsonDict]" :
324+ ) -> JsonDict :
323325 assert json_cb is not None
324326 data = json_cb ()
325327 self .edus .extend (data ["edus" ])
326- return defer . succeed ({})
328+ return {}
327329
328330 def test_send_device_updates (self ) -> None :
329331 """Basic case: each device update should result in an EDU"""
@@ -354,15 +356,11 @@ def test_dont_send_device_updates_for_remote_users(self) -> None:
354356
355357 # Send the server a device list EDU for the other user, this will cause
356358 # it to try and resync the device lists.
357- self .federation_transport_client .query_user_devices .return_value = (
358- make_awaitable (
359- {
360- "stream_id" : "1" ,
361- "user_id" : "@user2:host2" ,
362- "devices" : [{"device_id" : "D1" }],
363- }
364- )
365- )
359+ self .federation_transport_client .query_user_devices .return_value = {
360+ "stream_id" : "1" ,
361+ "user_id" : "@user2:host2" ,
362+ "devices" : [{"device_id" : "D1" }],
363+ }
366364
367365 self .get_success (
368366 self .device_handler .device_list_updater .incoming_device_list_update (
@@ -533,7 +531,7 @@ def test_unreachable_server(self) -> None:
533531 recovery
534532 """
535533 mock_send_txn = self .federation_transport_client .send_transaction
536- mock_send_txn .side_effect = lambda t , cb : defer . fail ( AssertionError ("fail" ) )
534+ mock_send_txn .side_effect = AssertionError ("fail" )
537535
538536 # create devices
539537 u1 = self .register_user ("user" , "pass" )
@@ -578,7 +576,7 @@ def test_prune_outbound_device_pokes1(self) -> None:
578576 This case tests the behaviour when the server has never been reachable.
579577 """
580578 mock_send_txn = self .federation_transport_client .send_transaction
581- mock_send_txn .side_effect = lambda t , cb : defer . fail ( AssertionError ("fail" ) )
579+ mock_send_txn .side_effect = AssertionError ("fail" )
582580
583581 # create devices
584582 u1 = self .register_user ("user" , "pass" )
@@ -636,7 +634,7 @@ def test_prune_outbound_device_pokes2(self) -> None:
636634
637635 # now the server goes offline
638636 mock_send_txn = self .federation_transport_client .send_transaction
639- mock_send_txn .side_effect = lambda t , cb : defer . fail ( AssertionError ("fail" ) )
637+ mock_send_txn .side_effect = AssertionError ("fail" )
640638
641639 self .login ("user" , "pass" , device_id = "D2" )
642640 self .login ("user" , "pass" , device_id = "D3" )
0 commit comments