@@ -72,9 +72,17 @@ def test_async(mock_method):
7272 test_async ()
7373
7474 def test_async_def_patch (self ):
75- @patch (f"{ __name__ } .async_func" , AsyncMock ())
76- async def test_async ():
75+ @patch (f"{ __name__ } .async_func" , return_value = 1 )
76+ @patch (f"{ __name__ } .async_func_args" , return_value = 2 )
77+ async def test_async (func_args_mock , func_mock ):
78+ self .assertEqual (func_args_mock ._mock_name , "async_func_args" )
79+ self .assertEqual (func_mock ._mock_name , "async_func" )
80+
7781 self .assertIsInstance (async_func , AsyncMock )
82+ self .assertIsInstance (async_func_args , AsyncMock )
83+
84+ self .assertEqual (await async_func (), 1 )
85+ self .assertEqual (await async_func_args (1 , 2 , c = 3 ), 2 )
7886
7987 asyncio .run (test_async ())
8088 self .assertTrue (inspect .iscoroutinefunction (async_func ))
@@ -375,22 +383,40 @@ async def addition(var):
375383 with self .assertRaises (Exception ):
376384 await mock (5 )
377385
378- async def test_add_side_effect_function (self ):
386+ async def test_add_side_effect_coroutine (self ):
379387 async def addition (var ):
380388 return var + 1
381389 mock = AsyncMock (side_effect = addition )
382390 result = await mock (5 )
383391 self .assertEqual (result , 6 )
384392
393+ async def test_add_side_effect_normal_function (self ):
394+ def addition (var ):
395+ return var + 1
396+ mock = AsyncMock (side_effect = addition )
397+ result = await mock (5 )
398+ self .assertEqual (result , 6 )
399+
385400 async def test_add_side_effect_iterable (self ):
386401 vals = [1 , 2 , 3 ]
387402 mock = AsyncMock (side_effect = vals )
388403 for item in vals :
389- self .assertEqual (item , await mock ())
404+ self .assertEqual (await mock (), item )
390405
391406 with self .assertRaises (StopAsyncIteration ) as e :
392407 await mock ()
393408
409+ async def test_add_side_effect_exception_iterable (self ):
410+ class SampleException (Exception ):
411+ pass
412+
413+ vals = [1 , SampleException ("foo" )]
414+ mock = AsyncMock (side_effect = vals )
415+ self .assertEqual (await mock (), 1 )
416+
417+ with self .assertRaises (SampleException ) as e :
418+ await mock ()
419+
394420 async def test_return_value_AsyncMock (self ):
395421 value = AsyncMock (return_value = 10 )
396422 mock = AsyncMock (return_value = value )
@@ -437,6 +463,21 @@ async def inner():
437463 mock .assert_awaited ()
438464 self .assertTrue (ran )
439465
466+ async def test_wraps_normal_function (self ):
467+ value = 1
468+
469+ ran = False
470+ def inner ():
471+ nonlocal ran
472+ ran = True
473+ return value
474+
475+ mock = AsyncMock (wraps = inner )
476+ result = await mock ()
477+ self .assertEqual (result , value )
478+ mock .assert_awaited ()
479+ self .assertTrue (ran )
480+
440481class AsyncMagicMethods (unittest .TestCase ):
441482 def test_async_magic_methods_return_async_mocks (self ):
442483 m_mock = MagicMock ()
@@ -860,6 +901,10 @@ def test_assert_awaited_once(self):
860901 self .mock .assert_awaited_once ()
861902
862903 def test_assert_awaited_with (self ):
904+ msg = 'Not awaited'
905+ with self .assertRaisesRegex (AssertionError , msg ):
906+ self .mock .assert_awaited_with ('foo' )
907+
863908 asyncio .run (self ._runnable_test ())
864909 msg = 'expected await not found'
865910 with self .assertRaisesRegex (AssertionError , msg ):
0 commit comments