|
1 | 1 | import asyncio |
| 2 | +from contextlib import contextmanager |
2 | 3 | import sys |
3 | 4 | import unittest |
4 | 5 |
|
@@ -26,6 +27,19 @@ async def mock_coro(*args, **kwargs): |
26 | 27 | return mock_coro |
27 | 28 |
|
28 | 29 |
|
| 30 | +@contextmanager |
| 31 | +def mock_wait_for(): |
| 32 | + async def fake_wait_for(coro, timeout): |
| 33 | + await coro |
| 34 | + await fake_wait_for._mock(timeout) |
| 35 | + |
| 36 | + original_wait_for = asyncio.wait_for |
| 37 | + asyncio.wait_for = fake_wait_for |
| 38 | + fake_wait_for._mock = AsyncMock() |
| 39 | + yield |
| 40 | + asyncio.wait_for = original_wait_for |
| 41 | + |
| 42 | + |
29 | 43 | def _run(coro): |
30 | 44 | """Run the given coroutine.""" |
31 | 45 | return asyncio.get_event_loop().run_until_complete(coro) |
@@ -542,51 +556,64 @@ def on_foo(self, a, b): |
542 | 556 | _run(c._trigger_event('foo', '/', 1, '2')) |
543 | 557 | self.assertEqual(result, [1, '2']) |
544 | 558 |
|
| 559 | + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, |
| 560 | + side_effect=asyncio.TimeoutError) |
545 | 561 | @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) |
546 | | - def test_handle_reconnect(self, random): |
| 562 | + def test_handle_reconnect(self, random, wait_for): |
547 | 563 | c = asyncio_client.AsyncClient() |
548 | 564 | c._reconnect_task = 'foo' |
549 | | - c.sleep = AsyncMock() |
550 | 565 | c.connect = AsyncMock( |
551 | 566 | side_effect=[ValueError, exceptions.ConnectionError, None]) |
552 | 567 | _run(c._handle_reconnect()) |
553 | | - self.assertEqual(c.sleep.mock.call_count, 3) |
554 | | - self.assertEqual(c.sleep.mock.call_args_list, [ |
555 | | - mock.call(1.5), |
556 | | - mock.call(1.5), |
557 | | - mock.call(4.0) |
558 | | - ]) |
| 568 | + self.assertEqual(wait_for.mock.call_count, 3) |
| 569 | + self.assertEqual( |
| 570 | + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], |
| 571 | + [1.5, 1.5, 4.0]) |
559 | 572 | self.assertEqual(c._reconnect_task, None) |
560 | 573 |
|
| 574 | + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, |
| 575 | + side_effect=asyncio.TimeoutError) |
561 | 576 | @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) |
562 | | - def test_handle_reconnect_max_delay(self, random): |
| 577 | + def test_handle_reconnect_max_delay(self, random, wait_for): |
563 | 578 | c = asyncio_client.AsyncClient(reconnection_delay_max=3) |
564 | 579 | c._reconnect_task = 'foo' |
565 | | - c.sleep = AsyncMock() |
566 | 580 | c.connect = AsyncMock( |
567 | 581 | side_effect=[ValueError, exceptions.ConnectionError, None]) |
568 | 582 | _run(c._handle_reconnect()) |
569 | | - self.assertEqual(c.sleep.mock.call_count, 3) |
570 | | - self.assertEqual(c.sleep.mock.call_args_list, [ |
571 | | - mock.call(1.5), |
572 | | - mock.call(1.5), |
573 | | - mock.call(3.0) |
574 | | - ]) |
| 583 | + self.assertEqual(wait_for.mock.call_count, 3) |
| 584 | + self.assertEqual( |
| 585 | + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], |
| 586 | + [1.5, 1.5, 3.0]) |
575 | 587 | self.assertEqual(c._reconnect_task, None) |
576 | 588 |
|
| 589 | + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, |
| 590 | + side_effect=asyncio.TimeoutError) |
577 | 591 | @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) |
578 | | - def test_handle_reconnect_max_attempts(self, random): |
| 592 | + def test_handle_reconnect_max_attempts(self, random, wait_for): |
579 | 593 | c = asyncio_client.AsyncClient(reconnection_attempts=2) |
580 | 594 | c._reconnect_task = 'foo' |
581 | | - c.sleep = AsyncMock() |
582 | 595 | c.connect = AsyncMock( |
583 | 596 | side_effect=[ValueError, exceptions.ConnectionError, None]) |
584 | 597 | _run(c._handle_reconnect()) |
585 | | - self.assertEqual(c.sleep.mock.call_count, 2) |
586 | | - self.assertEqual(c.sleep.mock.call_args_list, [ |
587 | | - mock.call(1.5), |
588 | | - mock.call(1.5) |
589 | | - ]) |
| 598 | + self.assertEqual(wait_for.mock.call_count, 2) |
| 599 | + self.assertEqual( |
| 600 | + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], |
| 601 | + [1.5, 1.5]) |
| 602 | + self.assertEqual(c._reconnect_task, 'foo') |
| 603 | + |
| 604 | + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, |
| 605 | + side_effect=[asyncio.TimeoutError, None]) |
| 606 | + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) |
| 607 | + def test_handle_reconnect_aborted(self, random, wait_for): |
| 608 | + c = asyncio_client.AsyncClient() |
| 609 | + c._reconnect_task = 'foo' |
| 610 | + c.connect = AsyncMock( |
| 611 | + side_effect=[ValueError, exceptions.ConnectionError, None]) |
| 612 | + _run(c._handle_reconnect()) |
| 613 | + self.assertEqual(wait_for.mock.call_count, 2) |
| 614 | + self.assertEqual( |
| 615 | + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], |
| 616 | + [1.5, 1.5]) |
590 | 617 | self.assertEqual(c._reconnect_task, 'foo') |
591 | 618 |
|
592 | 619 | def test_eio_connect(self): |
|
0 commit comments