Skip to content

Commit 54f743e

Browse files
tirkarthicjw296
authored andcommitted
Improve test coverage for AsyncMock. (GH-17906)
* Add test for nested async decorator patch. * Add test for side_effect and wraps with a function. * Add test for side_effect with an exception in the iterable.
1 parent 45cf5db commit 54f743e

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

Lib/unittest/test/testmock/testasync.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,17 @@ def test_async(mock_method):
7272
test_async()
7373

7474
def test_async_def_patch(self):
75-
@patch(f"{__name__}.async_func", AsyncMock())
76-
async def test_async():
75+
@patch(f"{__name__}.async_func", return_value=1)
76+
@patch(f"{__name__}.async_func_args", return_value=2)
77+
async def test_async(func_args_mock, func_mock):
78+
self.assertEqual(func_args_mock._mock_name, "async_func_args")
79+
self.assertEqual(func_mock._mock_name, "async_func")
80+
7781
self.assertIsInstance(async_func, AsyncMock)
82+
self.assertIsInstance(async_func_args, AsyncMock)
83+
84+
self.assertEqual(await async_func(), 1)
85+
self.assertEqual(await async_func_args(1, 2, c=3), 2)
7886

7987
asyncio.run(test_async())
8088
self.assertTrue(inspect.iscoroutinefunction(async_func))
@@ -375,22 +383,40 @@ async def addition(var):
375383
with self.assertRaises(Exception):
376384
await mock(5)
377385

378-
async def test_add_side_effect_function(self):
386+
async def test_add_side_effect_coroutine(self):
379387
async def addition(var):
380388
return var + 1
381389
mock = AsyncMock(side_effect=addition)
382390
result = await mock(5)
383391
self.assertEqual(result, 6)
384392

393+
async def test_add_side_effect_normal_function(self):
394+
def addition(var):
395+
return var + 1
396+
mock = AsyncMock(side_effect=addition)
397+
result = await mock(5)
398+
self.assertEqual(result, 6)
399+
385400
async def test_add_side_effect_iterable(self):
386401
vals = [1, 2, 3]
387402
mock = AsyncMock(side_effect=vals)
388403
for item in vals:
389-
self.assertEqual(item, await mock())
404+
self.assertEqual(await mock(), item)
390405

391406
with self.assertRaises(StopAsyncIteration) as e:
392407
await mock()
393408

409+
async def test_add_side_effect_exception_iterable(self):
410+
class SampleException(Exception):
411+
pass
412+
413+
vals = [1, SampleException("foo")]
414+
mock = AsyncMock(side_effect=vals)
415+
self.assertEqual(await mock(), 1)
416+
417+
with self.assertRaises(SampleException) as e:
418+
await mock()
419+
394420
async def test_return_value_AsyncMock(self):
395421
value = AsyncMock(return_value=10)
396422
mock = AsyncMock(return_value=value)
@@ -437,6 +463,21 @@ async def inner():
437463
mock.assert_awaited()
438464
self.assertTrue(ran)
439465

466+
async def test_wraps_normal_function(self):
467+
value = 1
468+
469+
ran = False
470+
def inner():
471+
nonlocal ran
472+
ran = True
473+
return value
474+
475+
mock = AsyncMock(wraps=inner)
476+
result = await mock()
477+
self.assertEqual(result, value)
478+
mock.assert_awaited()
479+
self.assertTrue(ran)
480+
440481
class AsyncMagicMethods(unittest.TestCase):
441482
def test_async_magic_methods_return_async_mocks(self):
442483
m_mock = MagicMock()
@@ -860,6 +901,10 @@ def test_assert_awaited_once(self):
860901
self.mock.assert_awaited_once()
861902

862903
def test_assert_awaited_with(self):
904+
msg = 'Not awaited'
905+
with self.assertRaisesRegex(AssertionError, msg):
906+
self.mock.assert_awaited_with('foo')
907+
863908
asyncio.run(self._runnable_test())
864909
msg = 'expected await not found'
865910
with self.assertRaisesRegex(AssertionError, msg):

0 commit comments

Comments
 (0)