@@ -382,35 +382,88 @@ def test_add_side_effect_iterable(self):
382382class AsyncContextManagerTest (unittest .TestCase ):
383383
384384 class WithAsyncContextManager :
385-
386385 async def __aenter__ (self , * args , ** kwargs ):
387386 return self
388387
389388 async def __aexit__ (self , * args , ** kwargs ):
390389 pass
391390
392- def test_magic_methods_are_async_mocks (self ):
393- mock = MagicMock (self .WithAsyncContextManager ())
394- self .assertIsInstance (mock .__aenter__ , AsyncMock )
395- self .assertIsInstance (mock .__aexit__ , AsyncMock )
391+ class WithSyncContextManager :
392+ def __enter__ (self , * args , ** kwargs ):
393+ return self
394+
395+ def __exit__ (self , * args , ** kwargs ):
396+ pass
397+
398+ class ProductionCode :
399+ # Example real-world(ish) code
400+ def __init__ (self ):
401+ self .session = None
402+
403+ async def main (self ):
404+ async with self .session .post ('https://python.org' ) as response :
405+ val = await response .json ()
406+ return val
407+
408+ def test_async_magic_methods_are_async_mocks_with_magicmock (self ):
409+ cm_mock = MagicMock (self .WithAsyncContextManager ())
410+ self .assertIsInstance (cm_mock .__aenter__ , AsyncMock )
411+ self .assertIsInstance (cm_mock .__aexit__ , AsyncMock )
412+
413+ def test_magicmock_has_async_magic_methods (self ):
414+ cm = MagicMock (name = 'magic_cm' )
415+ self .assertTrue (hasattr (cm , "__aenter__" ))
416+ self .assertTrue (hasattr (cm , "__aexit__" ))
417+
418+ def test_magic_methods_are_async_functions (self ):
419+ cm = MagicMock (name = 'magic_cm' )
420+ self .assertIsInstance (cm .__aenter__ , AsyncMock )
421+ self .assertIsInstance (cm .__aexit__ , AsyncMock )
422+ # AsyncMocks are also coroutine functions
423+ self .assertTrue (asyncio .iscoroutinefunction (cm .__aenter__ ))
424+ self .assertTrue (asyncio .iscoroutinefunction (cm .__aexit__ ))
425+
426+ def test_set_return_value_of_aenter (self ):
427+ def inner_test (mock_type ):
428+ pc = self .ProductionCode ()
429+ pc .session = MagicMock (name = 'sessionmock' )
430+ cm = mock_type (name = 'magic_cm' )
431+ response = AsyncMock (name = 'response' )
432+ response .json = AsyncMock (return_value = {'json' : 123 })
433+ cm .__aenter__ .return_value = response
434+ pc .session .post .return_value = cm
435+ result = asyncio .run (pc .main ())
436+ self .assertEqual (result , {'json' : 123 })
437+
438+ for mock_type in [AsyncMock , MagicMock ]:
439+ with self .subTest (f"test set return value of aenter with { mock_type } " ):
440+ inner_test (mock_type )
396441
397442 def test_mock_supports_async_context_manager (self ):
398- called = False
399- instance = self .WithAsyncContextManager ()
400- mock_instance = MagicMock (instance )
443+ def inner_test (mock_type ):
444+ called = False
445+ cm = self .WithAsyncContextManager ()
446+ cm_mock = mock_type (cm )
447+
448+ async def use_context_manager ():
449+ nonlocal called
450+ async with cm_mock as result :
451+ called = True
452+ return result
401453
402- async def use_context_manager ():
403- nonlocal called
404- async with mock_instance as result :
405- called = True
406- return result
454+ cm_result = asyncio .run (use_context_manager ())
455+ self .assertTrue (called )
456+ self .assertTrue (cm_mock .__aenter__ .called )
457+ self .assertTrue (cm_mock .__aexit__ .called )
458+ cm_mock .__aenter__ .assert_awaited ()
459+ cm_mock .__aexit__ .assert_awaited ()
460+ # We mock __aenter__ so it does not return self
461+ self .assertIsNot (cm_mock , cm_result )
462+
463+ for mock_type in [AsyncMock , MagicMock ]:
464+ with self .subTest (f"test context manager magics with { mock_type } " ):
465+ inner_test (mock_type )
407466
408- result = asyncio .run (use_context_manager ())
409- self .assertTrue (called )
410- self .assertTrue (mock_instance .__aenter__ .called )
411- self .assertTrue (mock_instance .__aexit__ .called )
412- self .assertIsNot (mock_instance , result )
413- self .assertIsInstance (result , AsyncMock )
414467
415468 def test_mock_customize_async_context_manager (self ):
416469 instance = self .WithAsyncContextManager ()
@@ -478,27 +531,30 @@ async def __anext__(self):
478531
479532 raise StopAsyncIteration
480533
481- def test_mock_aiter_and_anext (self ):
482- instance = self .WithAsyncIterator ()
483- mock_instance = MagicMock (instance )
484-
485- self .assertEqual (asyncio .iscoroutine (instance .__aiter__ ),
486- asyncio .iscoroutine (mock_instance .__aiter__ ))
487- self .assertEqual (asyncio .iscoroutine (instance .__anext__ ),
488- asyncio .iscoroutine (mock_instance .__anext__ ))
489-
490- iterator = instance .__aiter__ ()
491- if asyncio .iscoroutine (iterator ):
492- iterator = asyncio .run (iterator )
493-
494- mock_iterator = mock_instance .__aiter__ ()
495- if asyncio .iscoroutine (mock_iterator ):
496- mock_iterator = asyncio .run (mock_iterator )
534+ def test_aiter_set_return_value (self ):
535+ mock_iter = AsyncMock (name = "tester" )
536+ mock_iter .__aiter__ .return_value = [1 , 2 , 3 ]
537+ async def main ():
538+ return [i async for i in mock_iter ]
539+ result = asyncio .run (main ())
540+ self .assertEqual (result , [1 , 2 , 3 ])
541+
542+ def test_mock_aiter_and_anext_asyncmock (self ):
543+ def inner_test (mock_type ):
544+ instance = self .WithAsyncIterator ()
545+ mock_instance = mock_type (instance )
546+ # Check that the mock and the real thing bahave the same
547+ # __aiter__ is not actually async, so not a coroutinefunction
548+ self .assertFalse (asyncio .iscoroutinefunction (instance .__aiter__ ))
549+ self .assertFalse (asyncio .iscoroutinefunction (mock_instance .__aiter__ ))
550+ # __anext__ is async
551+ self .assertTrue (asyncio .iscoroutinefunction (instance .__anext__ ))
552+ self .assertTrue (asyncio .iscoroutinefunction (mock_instance .__anext__ ))
553+
554+ for mock_type in [AsyncMock , MagicMock ]:
555+ with self .subTest (f"test aiter and anext corourtine with { mock_type } " ):
556+ inner_test (mock_type )
497557
498- self .assertEqual (asyncio .iscoroutine (iterator .__aiter__ ),
499- asyncio .iscoroutine (mock_iterator .__aiter__ ))
500- self .assertEqual (asyncio .iscoroutine (iterator .__anext__ ),
501- asyncio .iscoroutine (mock_iterator .__anext__ ))
502558
503559 def test_mock_async_for (self ):
504560 async def iterate (iterator ):
@@ -509,19 +565,30 @@ async def iterate(iterator):
509565 return accumulator
510566
511567 expected = ["FOO" , "BAR" , "BAZ" ]
512- with self .subTest ("iterate through default value" ):
513- mock_instance = MagicMock (self .WithAsyncIterator ())
514- self .assertEqual ([], asyncio .run (iterate (mock_instance )))
568+ def test_default (mock_type ):
569+ mock_instance = mock_type (self .WithAsyncIterator ())
570+ self .assertEqual (asyncio .run (iterate (mock_instance )), [])
571+
515572
516- with self . subTest ( "iterate through set return_value" ):
517- mock_instance = MagicMock (self .WithAsyncIterator ())
573+ def test_set_return_value ( mock_type ):
574+ mock_instance = mock_type (self .WithAsyncIterator ())
518575 mock_instance .__aiter__ .return_value = expected [:]
519- self .assertEqual (expected , asyncio .run (iterate (mock_instance )))
576+ self .assertEqual (asyncio .run (iterate (mock_instance )), expected )
520577
521- with self . subTest ( "iterate through set return_value iterator" ):
522- mock_instance = MagicMock (self .WithAsyncIterator ())
578+ def test_set_return_value_iter ( mock_type ):
579+ mock_instance = mock_type (self .WithAsyncIterator ())
523580 mock_instance .__aiter__ .return_value = iter (expected [:])
524- self .assertEqual (expected , asyncio .run (iterate (mock_instance )))
581+ self .assertEqual (asyncio .run (iterate (mock_instance )), expected )
582+
583+ for mock_type in [AsyncMock , MagicMock ]:
584+ with self .subTest (f"default value with { mock_type } " ):
585+ test_default (mock_type )
586+
587+ with self .subTest (f"set return_value with { mock_type } " ):
588+ test_set_return_value (mock_type )
589+
590+ with self .subTest (f"set return_value iterator with { mock_type } " ):
591+ test_set_return_value_iter (mock_type )
525592
526593
527594class AsyncMockAssert (unittest .TestCase ):
0 commit comments