Skip to content

Commit 8b03f94

Browse files
authored
bpo-38093: Correctly returns AsyncMock for async subclasses. (GH-15947)
1 parent 2702638 commit 8b03f94

5 files changed

Lines changed: 180 additions & 69 deletions

File tree

Doc/library/unittest.mock-examples.rst

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import asyncio
1616
import unittest
17-
from unittest.mock import Mock, MagicMock, patch, call, sentinel
17+
from unittest.mock import Mock, MagicMock, AsyncMock, patch, call, sentinel
1818

1919
class SomeClass:
2020
attribute = 'this is a doctest'
@@ -280,39 +280,42 @@ function returns is what the call returns:
280280
Mocking asynchronous iterators
281281
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
282282

283-
Since Python 3.8, ``MagicMock`` has support to mock :ref:`async-iterators`
284-
through ``__aiter__``. The :attr:`~Mock.return_value` attribute of ``__aiter__``
285-
can be used to set the return values to be used for iteration.
283+
Since Python 3.8, ``AsyncMock`` and ``MagicMock`` have support to mock
284+
:ref:`async-iterators` through ``__aiter__``. The :attr:`~Mock.return_value`
285+
attribute of ``__aiter__`` can be used to set the return values to be used for
286+
iteration.
286287

287-
>>> mock = MagicMock()
288+
>>> mock = MagicMock() # AsyncMock also works here
288289
>>> mock.__aiter__.return_value = [1, 2, 3]
289290
>>> async def main():
290291
... return [i async for i in mock]
292+
...
291293
>>> asyncio.run(main())
292294
[1, 2, 3]
293295

294296

295297
Mocking asynchronous context manager
296298
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
297299

298-
Since Python 3.8, ``MagicMock`` has support to mock
299-
:ref:`async-context-managers` through ``__aenter__`` and ``__aexit__``. The
300-
return value of ``__aenter__`` is an :class:`AsyncMock`.
300+
Since Python 3.8, ``AsyncMock`` and ``MagicMock`` have support to mock
301+
:ref:`async-context-managers` through ``__aenter__`` and ``__aexit__``.
302+
By default, ``__aenter__`` and ``__aexit__`` are ``AsyncMock`` instances that
303+
return an async function.
301304

302305
>>> class AsyncContextManager:
303-
...
304306
... async def __aenter__(self):
305307
... return self
306-
...
307-
... async def __aexit__(self):
308+
... async def __aexit__(self, exc_type, exc, tb):
308309
... pass
309-
>>> mock_instance = MagicMock(AsyncContextManager())
310+
...
311+
>>> mock_instance = MagicMock(AsyncContextManager()) # AsyncMock also works here
310312
>>> async def main():
311313
... async with mock_instance as result:
312314
... pass
315+
...
313316
>>> asyncio.run(main())
314-
>>> mock_instance.__aenter__.assert_called_once()
315-
>>> mock_instance.__aexit__.assert_called_once()
317+
>>> mock_instance.__aenter__.assert_awaited_once()
318+
>>> mock_instance.__aexit__.assert_awaited_once()
316319

317320

318321
Creating a Mock from an Existing Object

Lib/unittest/mock.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -983,9 +983,13 @@ def _get_child_mock(self, /, **kw):
983983
_type = type(self)
984984
if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
985985
klass = AsyncMock
986-
if issubclass(_type, AsyncMockMixin):
986+
elif _new_name in _sync_async_magics:
987+
# Special case these ones b/c users will assume they are async,
988+
# but they are actually sync (ie. __aiter__)
987989
klass = MagicMock
988-
if not issubclass(_type, CallableMixin):
990+
elif issubclass(_type, AsyncMockMixin):
991+
klass = AsyncMock
992+
elif not issubclass(_type, CallableMixin):
989993
if issubclass(_type, NonCallableMagicMock):
990994
klass = MagicMock
991995
elif issubclass(_type, NonCallableMock) :
@@ -1881,7 +1885,7 @@ def _patch_stopall():
18811885
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
18821886
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
18831887
'__repr__', '__dir__', '__subclasses__', '__format__',
1884-
'__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__',
1888+
'__getnewargs_ex__',
18851889
}
18861890

18871891

@@ -1900,10 +1904,12 @@ def method(self, /, *args, **kw):
19001904

19011905
# Magic methods used for async `with` statements
19021906
_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
1903-
# `__aiter__` is a plain function but used with async calls
1904-
_async_magics = _async_method_magics | {"__aiter__"}
1907+
# Magic methods that are only used with async calls but are synchronous functions themselves
1908+
_sync_async_magics = {"__aiter__"}
1909+
_async_magics = _async_method_magics | _sync_async_magics
19051910

1906-
_all_magics = _magics | _non_defaults
1911+
_all_sync_magics = _magics | _non_defaults
1912+
_all_magics = _all_sync_magics | _async_magics
19071913

19081914
_unsupported_magics = {
19091915
'__getattr__', '__setattr__',

Lib/unittest/test/testmock/testasync.py

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -382,35 +382,88 @@ def test_add_side_effect_iterable(self):
382382
class AsyncContextManagerTest(unittest.TestCase):
383383

384384
class WithAsyncContextManager:
385-
386385
async def __aenter__(self, *args, **kwargs):
387386
return self
388387

389388
async def __aexit__(self, *args, **kwargs):
390389
pass
391390

392-
def test_magic_methods_are_async_mocks(self):
393-
mock = MagicMock(self.WithAsyncContextManager())
394-
self.assertIsInstance(mock.__aenter__, AsyncMock)
395-
self.assertIsInstance(mock.__aexit__, AsyncMock)
391+
class WithSyncContextManager:
392+
def __enter__(self, *args, **kwargs):
393+
return self
394+
395+
def __exit__(self, *args, **kwargs):
396+
pass
397+
398+
class ProductionCode:
399+
# Example real-world(ish) code
400+
def __init__(self):
401+
self.session = None
402+
403+
async def main(self):
404+
async with self.session.post('https://python.org') as response:
405+
val = await response.json()
406+
return val
407+
408+
def test_async_magic_methods_are_async_mocks_with_magicmock(self):
409+
cm_mock = MagicMock(self.WithAsyncContextManager())
410+
self.assertIsInstance(cm_mock.__aenter__, AsyncMock)
411+
self.assertIsInstance(cm_mock.__aexit__, AsyncMock)
412+
413+
def test_magicmock_has_async_magic_methods(self):
414+
cm = MagicMock(name='magic_cm')
415+
self.assertTrue(hasattr(cm, "__aenter__"))
416+
self.assertTrue(hasattr(cm, "__aexit__"))
417+
418+
def test_magic_methods_are_async_functions(self):
419+
cm = MagicMock(name='magic_cm')
420+
self.assertIsInstance(cm.__aenter__, AsyncMock)
421+
self.assertIsInstance(cm.__aexit__, AsyncMock)
422+
# AsyncMocks are also coroutine functions
423+
self.assertTrue(asyncio.iscoroutinefunction(cm.__aenter__))
424+
self.assertTrue(asyncio.iscoroutinefunction(cm.__aexit__))
425+
426+
def test_set_return_value_of_aenter(self):
427+
def inner_test(mock_type):
428+
pc = self.ProductionCode()
429+
pc.session = MagicMock(name='sessionmock')
430+
cm = mock_type(name='magic_cm')
431+
response = AsyncMock(name='response')
432+
response.json = AsyncMock(return_value={'json': 123})
433+
cm.__aenter__.return_value = response
434+
pc.session.post.return_value = cm
435+
result = asyncio.run(pc.main())
436+
self.assertEqual(result, {'json': 123})
437+
438+
for mock_type in [AsyncMock, MagicMock]:
439+
with self.subTest(f"test set return value of aenter with {mock_type}"):
440+
inner_test(mock_type)
396441

397442
def test_mock_supports_async_context_manager(self):
398-
called = False
399-
instance = self.WithAsyncContextManager()
400-
mock_instance = MagicMock(instance)
443+
def inner_test(mock_type):
444+
called = False
445+
cm = self.WithAsyncContextManager()
446+
cm_mock = mock_type(cm)
447+
448+
async def use_context_manager():
449+
nonlocal called
450+
async with cm_mock as result:
451+
called = True
452+
return result
401453

402-
async def use_context_manager():
403-
nonlocal called
404-
async with mock_instance as result:
405-
called = True
406-
return result
454+
cm_result = asyncio.run(use_context_manager())
455+
self.assertTrue(called)
456+
self.assertTrue(cm_mock.__aenter__.called)
457+
self.assertTrue(cm_mock.__aexit__.called)
458+
cm_mock.__aenter__.assert_awaited()
459+
cm_mock.__aexit__.assert_awaited()
460+
# We mock __aenter__ so it does not return self
461+
self.assertIsNot(cm_mock, cm_result)
462+
463+
for mock_type in [AsyncMock, MagicMock]:
464+
with self.subTest(f"test context manager magics with {mock_type}"):
465+
inner_test(mock_type)
407466

408-
result = asyncio.run(use_context_manager())
409-
self.assertTrue(called)
410-
self.assertTrue(mock_instance.__aenter__.called)
411-
self.assertTrue(mock_instance.__aexit__.called)
412-
self.assertIsNot(mock_instance, result)
413-
self.assertIsInstance(result, AsyncMock)
414467

415468
def test_mock_customize_async_context_manager(self):
416469
instance = self.WithAsyncContextManager()
@@ -478,27 +531,30 @@ async def __anext__(self):
478531

479532
raise StopAsyncIteration
480533

481-
def test_mock_aiter_and_anext(self):
482-
instance = self.WithAsyncIterator()
483-
mock_instance = MagicMock(instance)
484-
485-
self.assertEqual(asyncio.iscoroutine(instance.__aiter__),
486-
asyncio.iscoroutine(mock_instance.__aiter__))
487-
self.assertEqual(asyncio.iscoroutine(instance.__anext__),
488-
asyncio.iscoroutine(mock_instance.__anext__))
489-
490-
iterator = instance.__aiter__()
491-
if asyncio.iscoroutine(iterator):
492-
iterator = asyncio.run(iterator)
493-
494-
mock_iterator = mock_instance.__aiter__()
495-
if asyncio.iscoroutine(mock_iterator):
496-
mock_iterator = asyncio.run(mock_iterator)
534+
def test_aiter_set_return_value(self):
535+
mock_iter = AsyncMock(name="tester")
536+
mock_iter.__aiter__.return_value = [1, 2, 3]
537+
async def main():
538+
return [i async for i in mock_iter]
539+
result = asyncio.run(main())
540+
self.assertEqual(result, [1, 2, 3])
541+
542+
def test_mock_aiter_and_anext_asyncmock(self):
543+
def inner_test(mock_type):
544+
instance = self.WithAsyncIterator()
545+
mock_instance = mock_type(instance)
546+
# Check that the mock and the real thing bahave the same
547+
# __aiter__ is not actually async, so not a coroutinefunction
548+
self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__))
549+
self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__))
550+
# __anext__ is async
551+
self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__))
552+
self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__))
553+
554+
for mock_type in [AsyncMock, MagicMock]:
555+
with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
556+
inner_test(mock_type)
497557

498-
self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
499-
asyncio.iscoroutine(mock_iterator.__aiter__))
500-
self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
501-
asyncio.iscoroutine(mock_iterator.__anext__))
502558

503559
def test_mock_async_for(self):
504560
async def iterate(iterator):
@@ -509,19 +565,30 @@ async def iterate(iterator):
509565
return accumulator
510566

511567
expected = ["FOO", "BAR", "BAZ"]
512-
with self.subTest("iterate through default value"):
513-
mock_instance = MagicMock(self.WithAsyncIterator())
514-
self.assertEqual([], asyncio.run(iterate(mock_instance)))
568+
def test_default(mock_type):
569+
mock_instance = mock_type(self.WithAsyncIterator())
570+
self.assertEqual(asyncio.run(iterate(mock_instance)), [])
571+
515572

516-
with self.subTest("iterate through set return_value"):
517-
mock_instance = MagicMock(self.WithAsyncIterator())
573+
def test_set_return_value(mock_type):
574+
mock_instance = mock_type(self.WithAsyncIterator())
518575
mock_instance.__aiter__.return_value = expected[:]
519-
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
576+
self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
520577

521-
with self.subTest("iterate through set return_value iterator"):
522-
mock_instance = MagicMock(self.WithAsyncIterator())
578+
def test_set_return_value_iter(mock_type):
579+
mock_instance = mock_type(self.WithAsyncIterator())
523580
mock_instance.__aiter__.return_value = iter(expected[:])
524-
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
581+
self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
582+
583+
for mock_type in [AsyncMock, MagicMock]:
584+
with self.subTest(f"default value with {mock_type}"):
585+
test_default(mock_type)
586+
587+
with self.subTest(f"set return_value with {mock_type}"):
588+
test_set_return_value(mock_type)
589+
590+
with self.subTest(f"set return_value iterator with {mock_type}"):
591+
test_set_return_value_iter(mock_type)
525592

526593

527594
class AsyncMockAssert(unittest.TestCase):

Lib/unittest/test/testmock/testmagicmethods.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import asyncio
12
import math
23
import unittest
34
import os
45
import sys
5-
from unittest.mock import Mock, MagicMock, _magics
6+
from unittest.mock import AsyncMock, Mock, MagicMock, _magics
67

78

89

@@ -271,6 +272,34 @@ def test_magic_mock_equality(self):
271272
self.assertEqual(mock != mock, False)
272273

273274

275+
# This should be fixed with issue38163
276+
@unittest.expectedFailure
277+
def test_asyncmock_defaults(self):
278+
mock = AsyncMock()
279+
self.assertEqual(int(mock), 1)
280+
self.assertEqual(complex(mock), 1j)
281+
self.assertEqual(float(mock), 1.0)
282+
self.assertNotIn(object(), mock)
283+
self.assertEqual(len(mock), 0)
284+
self.assertEqual(list(mock), [])
285+
self.assertEqual(hash(mock), object.__hash__(mock))
286+
self.assertEqual(str(mock), object.__str__(mock))
287+
self.assertTrue(bool(mock))
288+
self.assertEqual(round(mock), mock.__round__())
289+
self.assertEqual(math.trunc(mock), mock.__trunc__())
290+
self.assertEqual(math.floor(mock), mock.__floor__())
291+
self.assertEqual(math.ceil(mock), mock.__ceil__())
292+
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__))
293+
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__))
294+
self.assertIsInstance(mock.__aenter__, AsyncMock)
295+
self.assertIsInstance(mock.__aexit__, AsyncMock)
296+
297+
# in Python 3 oct and hex use __index__
298+
# so these tests are for __index__ in py3k
299+
self.assertEqual(oct(mock), '0o1')
300+
self.assertEqual(hex(mock), '0x1')
301+
# how to test __sizeof__ ?
302+
274303
def test_magicmock_defaults(self):
275304
mock = MagicMock()
276305
self.assertEqual(int(mock), 1)
@@ -286,6 +315,10 @@ def test_magicmock_defaults(self):
286315
self.assertEqual(math.trunc(mock), mock.__trunc__())
287316
self.assertEqual(math.floor(mock), mock.__floor__())
288317
self.assertEqual(math.ceil(mock), mock.__ceil__())
318+
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__))
319+
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__))
320+
self.assertIsInstance(mock.__aenter__, AsyncMock)
321+
self.assertIsInstance(mock.__aexit__, AsyncMock)
289322

290323
# in Python 3 oct and hex use __index__
291324
# so these tests are for __index__ in py3k
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixes AsyncMock so it doesn't crash when used with AsyncContextManagers
2+
or AsyncIterators.

0 commit comments

Comments
 (0)