testasync.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090
  1. import asyncio
  2. import gc
  3. import inspect
  4. import re
  5. import unittest
  6. from contextlib import contextmanager
  7. from test import support
  8. support.requires_working_socket(module=True)
  9. from asyncio import run, iscoroutinefunction
  10. from unittest import IsolatedAsyncioTestCase
  11. from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
  12. create_autospec, sentinel, _CallList, seal)
  13. def tearDownModule():
  14. asyncio.set_event_loop_policy(None)
  15. class AsyncClass:
  16. def __init__(self): pass
  17. async def async_method(self): pass
  18. def normal_method(self): pass
  19. @classmethod
  20. async def async_class_method(cls): pass
  21. @staticmethod
  22. async def async_static_method(): pass
  23. class AwaitableClass:
  24. def __await__(self): yield
  25. async def async_func(): pass
  26. async def async_func_args(a, b, *, c): pass
  27. def normal_func(): pass
  28. class NormalClass(object):
  29. def a(self): pass
  30. async_foo_name = f'{__name__}.AsyncClass'
  31. normal_foo_name = f'{__name__}.NormalClass'
  32. @contextmanager
  33. def assertNeverAwaited(test):
  34. with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"):
  35. yield
  36. # In non-CPython implementations of Python, this is needed because timely
  37. # deallocation is not guaranteed by the garbage collector.
  38. gc.collect()
  39. class AsyncPatchDecoratorTest(unittest.TestCase):
  40. def test_is_coroutine_function_patch(self):
  41. @patch.object(AsyncClass, 'async_method')
  42. def test_async(mock_method):
  43. self.assertTrue(iscoroutinefunction(mock_method))
  44. test_async()
  45. def test_is_async_patch(self):
  46. @patch.object(AsyncClass, 'async_method')
  47. def test_async(mock_method):
  48. m = mock_method()
  49. self.assertTrue(inspect.isawaitable(m))
  50. run(m)
  51. @patch(f'{async_foo_name}.async_method')
  52. def test_no_parent_attribute(mock_method):
  53. m = mock_method()
  54. self.assertTrue(inspect.isawaitable(m))
  55. run(m)
  56. test_async()
  57. test_no_parent_attribute()
  58. def test_is_AsyncMock_patch(self):
  59. @patch.object(AsyncClass, 'async_method')
  60. def test_async(mock_method):
  61. self.assertIsInstance(mock_method, AsyncMock)
  62. test_async()
  63. def test_is_AsyncMock_patch_staticmethod(self):
  64. @patch.object(AsyncClass, 'async_static_method')
  65. def test_async(mock_method):
  66. self.assertIsInstance(mock_method, AsyncMock)
  67. test_async()
  68. def test_is_AsyncMock_patch_classmethod(self):
  69. @patch.object(AsyncClass, 'async_class_method')
  70. def test_async(mock_method):
  71. self.assertIsInstance(mock_method, AsyncMock)
  72. test_async()
  73. def test_async_def_patch(self):
  74. @patch(f"{__name__}.async_func", return_value=1)
  75. @patch(f"{__name__}.async_func_args", return_value=2)
  76. async def test_async(func_args_mock, func_mock):
  77. self.assertEqual(func_args_mock._mock_name, "async_func_args")
  78. self.assertEqual(func_mock._mock_name, "async_func")
  79. self.assertIsInstance(async_func, AsyncMock)
  80. self.assertIsInstance(async_func_args, AsyncMock)
  81. self.assertEqual(await async_func(), 1)
  82. self.assertEqual(await async_func_args(1, 2, c=3), 2)
  83. run(test_async())
  84. self.assertTrue(inspect.iscoroutinefunction(async_func))
  85. class AsyncPatchCMTest(unittest.TestCase):
  86. def test_is_async_function_cm(self):
  87. def test_async():
  88. with patch.object(AsyncClass, 'async_method') as mock_method:
  89. self.assertTrue(iscoroutinefunction(mock_method))
  90. test_async()
  91. def test_is_async_cm(self):
  92. def test_async():
  93. with patch.object(AsyncClass, 'async_method') as mock_method:
  94. m = mock_method()
  95. self.assertTrue(inspect.isawaitable(m))
  96. run(m)
  97. test_async()
  98. def test_is_AsyncMock_cm(self):
  99. def test_async():
  100. with patch.object(AsyncClass, 'async_method') as mock_method:
  101. self.assertIsInstance(mock_method, AsyncMock)
  102. test_async()
  103. def test_async_def_cm(self):
  104. async def test_async():
  105. with patch(f"{__name__}.async_func", AsyncMock()):
  106. self.assertIsInstance(async_func, AsyncMock)
  107. self.assertTrue(inspect.iscoroutinefunction(async_func))
  108. run(test_async())
  109. def test_patch_dict_async_def(self):
  110. foo = {'a': 'a'}
  111. @patch.dict(foo, {'a': 'b'})
  112. async def test_async():
  113. self.assertEqual(foo['a'], 'b')
  114. self.assertTrue(iscoroutinefunction(test_async))
  115. run(test_async())
  116. def test_patch_dict_async_def_context(self):
  117. foo = {'a': 'a'}
  118. async def test_async():
  119. with patch.dict(foo, {'a': 'b'}):
  120. self.assertEqual(foo['a'], 'b')
  121. run(test_async())
  122. class AsyncMockTest(unittest.TestCase):
  123. def test_iscoroutinefunction_default(self):
  124. mock = AsyncMock()
  125. self.assertTrue(iscoroutinefunction(mock))
  126. def test_iscoroutinefunction_function(self):
  127. async def foo(): pass
  128. mock = AsyncMock(foo)
  129. self.assertTrue(iscoroutinefunction(mock))
  130. self.assertTrue(inspect.iscoroutinefunction(mock))
  131. def test_isawaitable(self):
  132. mock = AsyncMock()
  133. m = mock()
  134. self.assertTrue(inspect.isawaitable(m))
  135. run(m)
  136. self.assertIn('assert_awaited', dir(mock))
  137. def test_iscoroutinefunction_normal_function(self):
  138. def foo(): pass
  139. mock = AsyncMock(foo)
  140. self.assertTrue(iscoroutinefunction(mock))
  141. self.assertTrue(inspect.iscoroutinefunction(mock))
  142. def test_future_isfuture(self):
  143. loop = asyncio.new_event_loop()
  144. fut = loop.create_future()
  145. loop.stop()
  146. loop.close()
  147. mock = AsyncMock(fut)
  148. self.assertIsInstance(mock, asyncio.Future)
  149. class AsyncAutospecTest(unittest.TestCase):
  150. def test_is_AsyncMock_patch(self):
  151. @patch(async_foo_name, autospec=True)
  152. def test_async(mock_method):
  153. self.assertIsInstance(mock_method.async_method, AsyncMock)
  154. self.assertIsInstance(mock_method, MagicMock)
  155. @patch(async_foo_name, autospec=True)
  156. def test_normal_method(mock_method):
  157. self.assertIsInstance(mock_method.normal_method, MagicMock)
  158. test_async()
  159. test_normal_method()
  160. def test_create_autospec_instance(self):
  161. with self.assertRaises(RuntimeError):
  162. create_autospec(async_func, instance=True)
  163. @unittest.skip('Broken test from https://bugs.python.org/issue37251')
  164. def test_create_autospec_awaitable_class(self):
  165. self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock)
  166. def test_create_autospec(self):
  167. spec = create_autospec(async_func_args)
  168. awaitable = spec(1, 2, c=3)
  169. async def main():
  170. await awaitable
  171. self.assertEqual(spec.await_count, 0)
  172. self.assertIsNone(spec.await_args)
  173. self.assertEqual(spec.await_args_list, [])
  174. spec.assert_not_awaited()
  175. run(main())
  176. self.assertTrue(iscoroutinefunction(spec))
  177. self.assertTrue(asyncio.iscoroutine(awaitable))
  178. self.assertEqual(spec.await_count, 1)
  179. self.assertEqual(spec.await_args, call(1, 2, c=3))
  180. self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
  181. spec.assert_awaited_once()
  182. spec.assert_awaited_once_with(1, 2, c=3)
  183. spec.assert_awaited_with(1, 2, c=3)
  184. spec.assert_awaited()
  185. with self.assertRaises(AssertionError):
  186. spec.assert_any_await(e=1)
  187. def test_patch_with_autospec(self):
  188. async def test_async():
  189. with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
  190. awaitable = mock_method(1, 2, c=3)
  191. self.assertIsInstance(mock_method.mock, AsyncMock)
  192. self.assertTrue(iscoroutinefunction(mock_method))
  193. self.assertTrue(asyncio.iscoroutine(awaitable))
  194. self.assertTrue(inspect.isawaitable(awaitable))
  195. # Verify the default values during mock setup
  196. self.assertEqual(mock_method.await_count, 0)
  197. self.assertEqual(mock_method.await_args_list, [])
  198. self.assertIsNone(mock_method.await_args)
  199. mock_method.assert_not_awaited()
  200. await awaitable
  201. self.assertEqual(mock_method.await_count, 1)
  202. self.assertEqual(mock_method.await_args, call(1, 2, c=3))
  203. self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
  204. mock_method.assert_awaited_once()
  205. mock_method.assert_awaited_once_with(1, 2, c=3)
  206. mock_method.assert_awaited_with(1, 2, c=3)
  207. mock_method.assert_awaited()
  208. mock_method.reset_mock()
  209. self.assertEqual(mock_method.await_count, 0)
  210. self.assertIsNone(mock_method.await_args)
  211. self.assertEqual(mock_method.await_args_list, [])
  212. run(test_async())
  213. class AsyncSpecTest(unittest.TestCase):
  214. def test_spec_normal_methods_on_class(self):
  215. def inner_test(mock_type):
  216. mock = mock_type(AsyncClass)
  217. self.assertIsInstance(mock.async_method, AsyncMock)
  218. self.assertIsInstance(mock.normal_method, MagicMock)
  219. for mock_type in [AsyncMock, MagicMock]:
  220. with self.subTest(f"test method types with {mock_type}"):
  221. inner_test(mock_type)
  222. def test_spec_normal_methods_on_class_with_mock(self):
  223. mock = Mock(AsyncClass)
  224. self.assertIsInstance(mock.async_method, AsyncMock)
  225. self.assertIsInstance(mock.normal_method, Mock)
  226. def test_spec_normal_methods_on_class_with_mock_seal(self):
  227. mock = Mock(AsyncClass)
  228. seal(mock)
  229. with self.assertRaises(AttributeError):
  230. mock.normal_method
  231. with self.assertRaises(AttributeError):
  232. mock.async_method
  233. def test_spec_mock_type_kw(self):
  234. def inner_test(mock_type):
  235. async_mock = mock_type(spec=async_func)
  236. self.assertIsInstance(async_mock, mock_type)
  237. with assertNeverAwaited(self):
  238. self.assertTrue(inspect.isawaitable(async_mock()))
  239. sync_mock = mock_type(spec=normal_func)
  240. self.assertIsInstance(sync_mock, mock_type)
  241. for mock_type in [AsyncMock, MagicMock, Mock]:
  242. with self.subTest(f"test spec kwarg with {mock_type}"):
  243. inner_test(mock_type)
  244. def test_spec_mock_type_positional(self):
  245. def inner_test(mock_type):
  246. async_mock = mock_type(async_func)
  247. self.assertIsInstance(async_mock, mock_type)
  248. with assertNeverAwaited(self):
  249. self.assertTrue(inspect.isawaitable(async_mock()))
  250. sync_mock = mock_type(normal_func)
  251. self.assertIsInstance(sync_mock, mock_type)
  252. for mock_type in [AsyncMock, MagicMock, Mock]:
  253. with self.subTest(f"test spec positional with {mock_type}"):
  254. inner_test(mock_type)
  255. def test_spec_as_normal_kw_AsyncMock(self):
  256. mock = AsyncMock(spec=normal_func)
  257. self.assertIsInstance(mock, AsyncMock)
  258. m = mock()
  259. self.assertTrue(inspect.isawaitable(m))
  260. run(m)
  261. def test_spec_as_normal_positional_AsyncMock(self):
  262. mock = AsyncMock(normal_func)
  263. self.assertIsInstance(mock, AsyncMock)
  264. m = mock()
  265. self.assertTrue(inspect.isawaitable(m))
  266. run(m)
  267. def test_spec_async_mock(self):
  268. @patch.object(AsyncClass, 'async_method', spec=True)
  269. def test_async(mock_method):
  270. self.assertIsInstance(mock_method, AsyncMock)
  271. test_async()
  272. def test_spec_parent_not_async_attribute_is(self):
  273. @patch(async_foo_name, spec=True)
  274. def test_async(mock_method):
  275. self.assertIsInstance(mock_method, MagicMock)
  276. self.assertIsInstance(mock_method.async_method, AsyncMock)
  277. test_async()
  278. def test_target_async_spec_not(self):
  279. @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
  280. def test_async_attribute(mock_method):
  281. self.assertIsInstance(mock_method, MagicMock)
  282. self.assertFalse(inspect.iscoroutine(mock_method))
  283. self.assertFalse(inspect.isawaitable(mock_method))
  284. test_async_attribute()
  285. def test_target_not_async_spec_is(self):
  286. @patch.object(NormalClass, 'a', spec=async_func)
  287. def test_attribute_not_async_spec_is(mock_async_func):
  288. self.assertIsInstance(mock_async_func, AsyncMock)
  289. test_attribute_not_async_spec_is()
  290. def test_spec_async_attributes(self):
  291. @patch(normal_foo_name, spec=AsyncClass)
  292. def test_async_attributes_coroutines(MockNormalClass):
  293. self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
  294. self.assertIsInstance(MockNormalClass, MagicMock)
  295. test_async_attributes_coroutines()
  296. class AsyncSpecSetTest(unittest.TestCase):
  297. def test_is_AsyncMock_patch(self):
  298. @patch.object(AsyncClass, 'async_method', spec_set=True)
  299. def test_async(async_method):
  300. self.assertIsInstance(async_method, AsyncMock)
  301. test_async()
  302. def test_is_async_AsyncMock(self):
  303. mock = AsyncMock(spec_set=AsyncClass.async_method)
  304. self.assertTrue(iscoroutinefunction(mock))
  305. self.assertIsInstance(mock, AsyncMock)
  306. def test_is_child_AsyncMock(self):
  307. mock = MagicMock(spec_set=AsyncClass)
  308. self.assertTrue(iscoroutinefunction(mock.async_method))
  309. self.assertFalse(iscoroutinefunction(mock.normal_method))
  310. self.assertIsInstance(mock.async_method, AsyncMock)
  311. self.assertIsInstance(mock.normal_method, MagicMock)
  312. self.assertIsInstance(mock, MagicMock)
  313. def test_magicmock_lambda_spec(self):
  314. mock_obj = MagicMock()
  315. mock_obj.mock_func = MagicMock(spec=lambda x: x)
  316. with patch.object(mock_obj, "mock_func") as cm:
  317. self.assertIsInstance(cm, MagicMock)
  318. class AsyncArguments(IsolatedAsyncioTestCase):
  319. async def test_add_return_value(self):
  320. async def addition(self, var): pass
  321. mock = AsyncMock(addition, return_value=10)
  322. output = await mock(5)
  323. self.assertEqual(output, 10)
  324. async def test_add_side_effect_exception(self):
  325. async def addition(var): pass
  326. mock = AsyncMock(addition, side_effect=Exception('err'))
  327. with self.assertRaises(Exception):
  328. await mock(5)
  329. async def test_add_side_effect_coroutine(self):
  330. async def addition(var):
  331. return var + 1
  332. mock = AsyncMock(side_effect=addition)
  333. result = await mock(5)
  334. self.assertEqual(result, 6)
  335. async def test_add_side_effect_normal_function(self):
  336. def addition(var):
  337. return var + 1
  338. mock = AsyncMock(side_effect=addition)
  339. result = await mock(5)
  340. self.assertEqual(result, 6)
  341. async def test_add_side_effect_iterable(self):
  342. vals = [1, 2, 3]
  343. mock = AsyncMock(side_effect=vals)
  344. for item in vals:
  345. self.assertEqual(await mock(), item)
  346. with self.assertRaises(StopAsyncIteration) as e:
  347. await mock()
  348. async def test_add_side_effect_exception_iterable(self):
  349. class SampleException(Exception):
  350. pass
  351. vals = [1, SampleException("foo")]
  352. mock = AsyncMock(side_effect=vals)
  353. self.assertEqual(await mock(), 1)
  354. with self.assertRaises(SampleException) as e:
  355. await mock()
  356. async def test_return_value_AsyncMock(self):
  357. value = AsyncMock(return_value=10)
  358. mock = AsyncMock(return_value=value)
  359. result = await mock()
  360. self.assertIs(result, value)
  361. async def test_return_value_awaitable(self):
  362. fut = asyncio.Future()
  363. fut.set_result(None)
  364. mock = AsyncMock(return_value=fut)
  365. result = await mock()
  366. self.assertIsInstance(result, asyncio.Future)
  367. async def test_side_effect_awaitable_values(self):
  368. fut = asyncio.Future()
  369. fut.set_result(None)
  370. mock = AsyncMock(side_effect=[fut])
  371. result = await mock()
  372. self.assertIsInstance(result, asyncio.Future)
  373. with self.assertRaises(StopAsyncIteration):
  374. await mock()
  375. async def test_side_effect_is_AsyncMock(self):
  376. effect = AsyncMock(return_value=10)
  377. mock = AsyncMock(side_effect=effect)
  378. result = await mock()
  379. self.assertEqual(result, 10)
  380. async def test_wraps_coroutine(self):
  381. value = asyncio.Future()
  382. ran = False
  383. async def inner():
  384. nonlocal ran
  385. ran = True
  386. return value
  387. mock = AsyncMock(wraps=inner)
  388. result = await mock()
  389. self.assertEqual(result, value)
  390. mock.assert_awaited()
  391. self.assertTrue(ran)
  392. async def test_wraps_normal_function(self):
  393. value = 1
  394. ran = False
  395. def inner():
  396. nonlocal ran
  397. ran = True
  398. return value
  399. mock = AsyncMock(wraps=inner)
  400. result = await mock()
  401. self.assertEqual(result, value)
  402. mock.assert_awaited()
  403. self.assertTrue(ran)
  404. async def test_await_args_list_order(self):
  405. async_mock = AsyncMock()
  406. mock2 = async_mock(2)
  407. mock1 = async_mock(1)
  408. await mock1
  409. await mock2
  410. async_mock.assert_has_awaits([call(1), call(2)])
  411. self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
  412. self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
  413. class AsyncMagicMethods(unittest.TestCase):
  414. def test_async_magic_methods_return_async_mocks(self):
  415. m_mock = MagicMock()
  416. self.assertIsInstance(m_mock.__aenter__, AsyncMock)
  417. self.assertIsInstance(m_mock.__aexit__, AsyncMock)
  418. self.assertIsInstance(m_mock.__anext__, AsyncMock)
  419. # __aiter__ is actually a synchronous object
  420. # so should return a MagicMock
  421. self.assertIsInstance(m_mock.__aiter__, MagicMock)
  422. def test_sync_magic_methods_return_magic_mocks(self):
  423. a_mock = AsyncMock()
  424. self.assertIsInstance(a_mock.__enter__, MagicMock)
  425. self.assertIsInstance(a_mock.__exit__, MagicMock)
  426. self.assertIsInstance(a_mock.__next__, MagicMock)
  427. self.assertIsInstance(a_mock.__len__, MagicMock)
  428. def test_magicmock_has_async_magic_methods(self):
  429. m_mock = MagicMock()
  430. self.assertTrue(hasattr(m_mock, "__aenter__"))
  431. self.assertTrue(hasattr(m_mock, "__aexit__"))
  432. self.assertTrue(hasattr(m_mock, "__anext__"))
  433. def test_asyncmock_has_sync_magic_methods(self):
  434. a_mock = AsyncMock()
  435. self.assertTrue(hasattr(a_mock, "__enter__"))
  436. self.assertTrue(hasattr(a_mock, "__exit__"))
  437. self.assertTrue(hasattr(a_mock, "__next__"))
  438. self.assertTrue(hasattr(a_mock, "__len__"))
  439. def test_magic_methods_are_async_functions(self):
  440. m_mock = MagicMock()
  441. self.assertIsInstance(m_mock.__aenter__, AsyncMock)
  442. self.assertIsInstance(m_mock.__aexit__, AsyncMock)
  443. # AsyncMocks are also coroutine functions
  444. self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
  445. self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
  446. class AsyncContextManagerTest(unittest.TestCase):
  447. class WithAsyncContextManager:
  448. async def __aenter__(self, *args, **kwargs): pass
  449. async def __aexit__(self, *args, **kwargs): pass
  450. class WithSyncContextManager:
  451. def __enter__(self, *args, **kwargs): pass
  452. def __exit__(self, *args, **kwargs): pass
  453. class ProductionCode:
  454. # Example real-world(ish) code
  455. def __init__(self):
  456. self.session = None
  457. async def main(self):
  458. async with self.session.post('https://python.org') as response:
  459. val = await response.json()
  460. return val
  461. def test_set_return_value_of_aenter(self):
  462. def inner_test(mock_type):
  463. pc = self.ProductionCode()
  464. pc.session = MagicMock(name='sessionmock')
  465. cm = mock_type(name='magic_cm')
  466. response = AsyncMock(name='response')
  467. response.json = AsyncMock(return_value={'json': 123})
  468. cm.__aenter__.return_value = response
  469. pc.session.post.return_value = cm
  470. result = run(pc.main())
  471. self.assertEqual(result, {'json': 123})
  472. for mock_type in [AsyncMock, MagicMock]:
  473. with self.subTest(f"test set return value of aenter with {mock_type}"):
  474. inner_test(mock_type)
  475. def test_mock_supports_async_context_manager(self):
  476. def inner_test(mock_type):
  477. called = False
  478. cm = self.WithAsyncContextManager()
  479. cm_mock = mock_type(cm)
  480. async def use_context_manager():
  481. nonlocal called
  482. async with cm_mock as result:
  483. called = True
  484. return result
  485. cm_result = run(use_context_manager())
  486. self.assertTrue(called)
  487. self.assertTrue(cm_mock.__aenter__.called)
  488. self.assertTrue(cm_mock.__aexit__.called)
  489. cm_mock.__aenter__.assert_awaited()
  490. cm_mock.__aexit__.assert_awaited()
  491. # We mock __aenter__ so it does not return self
  492. self.assertIsNot(cm_mock, cm_result)
  493. for mock_type in [AsyncMock, MagicMock]:
  494. with self.subTest(f"test context manager magics with {mock_type}"):
  495. inner_test(mock_type)
  496. def test_mock_customize_async_context_manager(self):
  497. instance = self.WithAsyncContextManager()
  498. mock_instance = MagicMock(instance)
  499. expected_result = object()
  500. mock_instance.__aenter__.return_value = expected_result
  501. async def use_context_manager():
  502. async with mock_instance as result:
  503. return result
  504. self.assertIs(run(use_context_manager()), expected_result)
  505. def test_mock_customize_async_context_manager_with_coroutine(self):
  506. enter_called = False
  507. exit_called = False
  508. async def enter_coroutine(*args):
  509. nonlocal enter_called
  510. enter_called = True
  511. async def exit_coroutine(*args):
  512. nonlocal exit_called
  513. exit_called = True
  514. instance = self.WithAsyncContextManager()
  515. mock_instance = MagicMock(instance)
  516. mock_instance.__aenter__ = enter_coroutine
  517. mock_instance.__aexit__ = exit_coroutine
  518. async def use_context_manager():
  519. async with mock_instance:
  520. pass
  521. run(use_context_manager())
  522. self.assertTrue(enter_called)
  523. self.assertTrue(exit_called)
  524. def test_context_manager_raise_exception_by_default(self):
  525. async def raise_in(context_manager):
  526. async with context_manager:
  527. raise TypeError()
  528. instance = self.WithAsyncContextManager()
  529. mock_instance = MagicMock(instance)
  530. with self.assertRaises(TypeError):
  531. run(raise_in(mock_instance))
  532. class AsyncIteratorTest(unittest.TestCase):
  533. class WithAsyncIterator(object):
  534. def __init__(self):
  535. self.items = ["foo", "NormalFoo", "baz"]
  536. def __aiter__(self): pass
  537. async def __anext__(self): pass
  538. def test_aiter_set_return_value(self):
  539. mock_iter = AsyncMock(name="tester")
  540. mock_iter.__aiter__.return_value = [1, 2, 3]
  541. async def main():
  542. return [i async for i in mock_iter]
  543. result = run(main())
  544. self.assertEqual(result, [1, 2, 3])
  545. def test_mock_aiter_and_anext_asyncmock(self):
  546. def inner_test(mock_type):
  547. instance = self.WithAsyncIterator()
  548. mock_instance = mock_type(instance)
  549. # Check that the mock and the real thing bahave the same
  550. # __aiter__ is not actually async, so not a coroutinefunction
  551. self.assertFalse(iscoroutinefunction(instance.__aiter__))
  552. self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
  553. # __anext__ is async
  554. self.assertTrue(iscoroutinefunction(instance.__anext__))
  555. self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
  556. for mock_type in [AsyncMock, MagicMock]:
  557. with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
  558. inner_test(mock_type)
  559. def test_mock_async_for(self):
  560. async def iterate(iterator):
  561. accumulator = []
  562. async for item in iterator:
  563. accumulator.append(item)
  564. return accumulator
  565. expected = ["FOO", "BAR", "BAZ"]
  566. def test_default(mock_type):
  567. mock_instance = mock_type(self.WithAsyncIterator())
  568. self.assertEqual(run(iterate(mock_instance)), [])
  569. def test_set_return_value(mock_type):
  570. mock_instance = mock_type(self.WithAsyncIterator())
  571. mock_instance.__aiter__.return_value = expected[:]
  572. self.assertEqual(run(iterate(mock_instance)), expected)
  573. def test_set_return_value_iter(mock_type):
  574. mock_instance = mock_type(self.WithAsyncIterator())
  575. mock_instance.__aiter__.return_value = iter(expected[:])
  576. self.assertEqual(run(iterate(mock_instance)), expected)
  577. for mock_type in [AsyncMock, MagicMock]:
  578. with self.subTest(f"default value with {mock_type}"):
  579. test_default(mock_type)
  580. with self.subTest(f"set return_value with {mock_type}"):
  581. test_set_return_value(mock_type)
  582. with self.subTest(f"set return_value iterator with {mock_type}"):
  583. test_set_return_value_iter(mock_type)
  584. class AsyncMockAssert(unittest.TestCase):
  585. def setUp(self):
  586. self.mock = AsyncMock()
  587. async def _runnable_test(self, *args, **kwargs):
  588. await self.mock(*args, **kwargs)
  589. async def _await_coroutine(self, coroutine):
  590. return await coroutine
  591. def test_assert_called_but_not_awaited(self):
  592. mock = AsyncMock(AsyncClass)
  593. with assertNeverAwaited(self):
  594. mock.async_method()
  595. self.assertTrue(iscoroutinefunction(mock.async_method))
  596. mock.async_method.assert_called()
  597. mock.async_method.assert_called_once()
  598. mock.async_method.assert_called_once_with()
  599. with self.assertRaises(AssertionError):
  600. mock.assert_awaited()
  601. with self.assertRaises(AssertionError):
  602. mock.async_method.assert_awaited()
  603. def test_assert_called_then_awaited(self):
  604. mock = AsyncMock(AsyncClass)
  605. mock_coroutine = mock.async_method()
  606. mock.async_method.assert_called()
  607. mock.async_method.assert_called_once()
  608. mock.async_method.assert_called_once_with()
  609. with self.assertRaises(AssertionError):
  610. mock.async_method.assert_awaited()
  611. run(self._await_coroutine(mock_coroutine))
  612. # Assert we haven't re-called the function
  613. mock.async_method.assert_called_once()
  614. mock.async_method.assert_awaited()
  615. mock.async_method.assert_awaited_once()
  616. mock.async_method.assert_awaited_once_with()
  617. def test_assert_called_and_awaited_at_same_time(self):
  618. with self.assertRaises(AssertionError):
  619. self.mock.assert_awaited()
  620. with self.assertRaises(AssertionError):
  621. self.mock.assert_called()
  622. run(self._runnable_test())
  623. self.mock.assert_called_once()
  624. self.mock.assert_awaited_once()
  625. def test_assert_called_twice_and_awaited_once(self):
  626. mock = AsyncMock(AsyncClass)
  627. coroutine = mock.async_method()
  628. # The first call will be awaited so no warning there
  629. # But this call will never get awaited, so it will warn here
  630. with assertNeverAwaited(self):
  631. mock.async_method()
  632. with self.assertRaises(AssertionError):
  633. mock.async_method.assert_awaited()
  634. mock.async_method.assert_called()
  635. run(self._await_coroutine(coroutine))
  636. mock.async_method.assert_awaited()
  637. mock.async_method.assert_awaited_once()
  638. def test_assert_called_once_and_awaited_twice(self):
  639. mock = AsyncMock(AsyncClass)
  640. coroutine = mock.async_method()
  641. mock.async_method.assert_called_once()
  642. run(self._await_coroutine(coroutine))
  643. with self.assertRaises(RuntimeError):
  644. # Cannot reuse already awaited coroutine
  645. run(self._await_coroutine(coroutine))
  646. mock.async_method.assert_awaited()
  647. def test_assert_awaited_but_not_called(self):
  648. with self.assertRaises(AssertionError):
  649. self.mock.assert_awaited()
  650. with self.assertRaises(AssertionError):
  651. self.mock.assert_called()
  652. with self.assertRaises(TypeError):
  653. # You cannot await an AsyncMock, it must be a coroutine
  654. run(self._await_coroutine(self.mock))
  655. with self.assertRaises(AssertionError):
  656. self.mock.assert_awaited()
  657. with self.assertRaises(AssertionError):
  658. self.mock.assert_called()
  659. def test_assert_has_calls_not_awaits(self):
  660. kalls = [call('foo')]
  661. with assertNeverAwaited(self):
  662. self.mock('foo')
  663. self.mock.assert_has_calls(kalls)
  664. with self.assertRaises(AssertionError):
  665. self.mock.assert_has_awaits(kalls)
  666. def test_assert_has_mock_calls_on_async_mock_no_spec(self):
  667. with assertNeverAwaited(self):
  668. self.mock()
  669. kalls_empty = [('', (), {})]
  670. self.assertEqual(self.mock.mock_calls, kalls_empty)
  671. with assertNeverAwaited(self):
  672. self.mock('foo')
  673. with assertNeverAwaited(self):
  674. self.mock('baz')
  675. mock_kalls = ([call(), call('foo'), call('baz')])
  676. self.assertEqual(self.mock.mock_calls, mock_kalls)
  677. def test_assert_has_mock_calls_on_async_mock_with_spec(self):
  678. a_class_mock = AsyncMock(AsyncClass)
  679. with assertNeverAwaited(self):
  680. a_class_mock.async_method()
  681. kalls_empty = [('', (), {})]
  682. self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
  683. self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
  684. with assertNeverAwaited(self):
  685. a_class_mock.async_method(1, 2, 3, a=4, b=5)
  686. method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
  687. mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
  688. self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
  689. self.assertEqual(a_class_mock.mock_calls, mock_kalls)
  690. def test_async_method_calls_recorded(self):
  691. with assertNeverAwaited(self):
  692. self.mock.something(3, fish=None)
  693. with assertNeverAwaited(self):
  694. self.mock.something_else.something(6, cake=sentinel.Cake)
  695. self.assertEqual(self.mock.method_calls, [
  696. ("something", (3,), {'fish': None}),
  697. ("something_else.something", (6,), {'cake': sentinel.Cake})
  698. ],
  699. "method calls not recorded correctly")
  700. self.assertEqual(self.mock.something_else.method_calls,
  701. [("something", (6,), {'cake': sentinel.Cake})],
  702. "method calls not recorded correctly")
  703. def test_async_arg_lists(self):
  704. def assert_attrs(mock):
  705. names = ('call_args_list', 'method_calls', 'mock_calls')
  706. for name in names:
  707. attr = getattr(mock, name)
  708. self.assertIsInstance(attr, _CallList)
  709. self.assertIsInstance(attr, list)
  710. self.assertEqual(attr, [])
  711. assert_attrs(self.mock)
  712. with assertNeverAwaited(self):
  713. self.mock()
  714. with assertNeverAwaited(self):
  715. self.mock(1, 2)
  716. with assertNeverAwaited(self):
  717. self.mock(a=3)
  718. self.mock.reset_mock()
  719. assert_attrs(self.mock)
  720. a_mock = AsyncMock(AsyncClass)
  721. with assertNeverAwaited(self):
  722. a_mock.async_method()
  723. with assertNeverAwaited(self):
  724. a_mock.async_method(1, a=3)
  725. a_mock.reset_mock()
  726. assert_attrs(a_mock)
  727. def test_assert_awaited(self):
  728. with self.assertRaises(AssertionError):
  729. self.mock.assert_awaited()
  730. run(self._runnable_test())
  731. self.mock.assert_awaited()
  732. def test_assert_awaited_once(self):
  733. with self.assertRaises(AssertionError):
  734. self.mock.assert_awaited_once()
  735. run(self._runnable_test())
  736. self.mock.assert_awaited_once()
  737. run(self._runnable_test())
  738. with self.assertRaises(AssertionError):
  739. self.mock.assert_awaited_once()
  740. def test_assert_awaited_with(self):
  741. msg = 'Not awaited'
  742. with self.assertRaisesRegex(AssertionError, msg):
  743. self.mock.assert_awaited_with('foo')
  744. run(self._runnable_test())
  745. msg = 'expected await not found'
  746. with self.assertRaisesRegex(AssertionError, msg):
  747. self.mock.assert_awaited_with('foo')
  748. run(self._runnable_test('foo'))
  749. self.mock.assert_awaited_with('foo')
  750. run(self._runnable_test('SomethingElse'))
  751. with self.assertRaises(AssertionError):
  752. self.mock.assert_awaited_with('foo')
  753. def test_assert_awaited_once_with(self):
  754. with self.assertRaises(AssertionError):
  755. self.mock.assert_awaited_once_with('foo')
  756. run(self._runnable_test('foo'))
  757. self.mock.assert_awaited_once_with('foo')
  758. run(self._runnable_test('foo'))
  759. with self.assertRaises(AssertionError):
  760. self.mock.assert_awaited_once_with('foo')
  761. def test_assert_any_wait(self):
  762. with self.assertRaises(AssertionError):
  763. self.mock.assert_any_await('foo')
  764. run(self._runnable_test('baz'))
  765. with self.assertRaises(AssertionError):
  766. self.mock.assert_any_await('foo')
  767. run(self._runnable_test('foo'))
  768. self.mock.assert_any_await('foo')
  769. run(self._runnable_test('SomethingElse'))
  770. self.mock.assert_any_await('foo')
  771. def test_assert_has_awaits_no_order(self):
  772. calls = [call('foo'), call('baz')]
  773. with self.assertRaises(AssertionError) as cm:
  774. self.mock.assert_has_awaits(calls)
  775. self.assertEqual(len(cm.exception.args), 1)
  776. run(self._runnable_test('foo'))
  777. with self.assertRaises(AssertionError):
  778. self.mock.assert_has_awaits(calls)
  779. run(self._runnable_test('foo'))
  780. with self.assertRaises(AssertionError):
  781. self.mock.assert_has_awaits(calls)
  782. run(self._runnable_test('baz'))
  783. self.mock.assert_has_awaits(calls)
  784. run(self._runnable_test('SomethingElse'))
  785. self.mock.assert_has_awaits(calls)
  786. def test_awaits_asserts_with_any(self):
  787. class Foo:
  788. def __eq__(self, other): pass
  789. run(self._runnable_test(Foo(), 1))
  790. self.mock.assert_has_awaits([call(ANY, 1)])
  791. self.mock.assert_awaited_with(ANY, 1)
  792. self.mock.assert_any_await(ANY, 1)
  793. def test_awaits_asserts_with_spec_and_any(self):
  794. class Foo:
  795. def __eq__(self, other): pass
  796. mock_with_spec = AsyncMock(spec=Foo)
  797. async def _custom_mock_runnable_test(*args):
  798. await mock_with_spec(*args)
  799. run(_custom_mock_runnable_test(Foo(), 1))
  800. mock_with_spec.assert_has_awaits([call(ANY, 1)])
  801. mock_with_spec.assert_awaited_with(ANY, 1)
  802. mock_with_spec.assert_any_await(ANY, 1)
  803. def test_assert_has_awaits_ordered(self):
  804. calls = [call('foo'), call('baz')]
  805. with self.assertRaises(AssertionError):
  806. self.mock.assert_has_awaits(calls, any_order=True)
  807. run(self._runnable_test('baz'))
  808. with self.assertRaises(AssertionError):
  809. self.mock.assert_has_awaits(calls, any_order=True)
  810. run(self._runnable_test('bamf'))
  811. with self.assertRaises(AssertionError):
  812. self.mock.assert_has_awaits(calls, any_order=True)
  813. run(self._runnable_test('foo'))
  814. self.mock.assert_has_awaits(calls, any_order=True)
  815. run(self._runnable_test('qux'))
  816. self.mock.assert_has_awaits(calls, any_order=True)
  817. def test_assert_not_awaited(self):
  818. self.mock.assert_not_awaited()
  819. run(self._runnable_test())
  820. with self.assertRaises(AssertionError):
  821. self.mock.assert_not_awaited()
  822. def test_assert_has_awaits_not_matching_spec_error(self):
  823. async def f(x=None): pass
  824. self.mock = AsyncMock(spec=f)
  825. run(self._runnable_test(1))
  826. with self.assertRaisesRegex(
  827. AssertionError,
  828. '^{}$'.format(
  829. re.escape('Awaits not found.\n'
  830. 'Expected: [call()]\n'
  831. 'Actual: [call(1)]'))) as cm:
  832. self.mock.assert_has_awaits([call()])
  833. self.assertIsNone(cm.exception.__cause__)
  834. with self.assertRaisesRegex(
  835. AssertionError,
  836. '^{}$'.format(
  837. re.escape(
  838. 'Error processing expected awaits.\n'
  839. "Errors: [None, TypeError('too many positional "
  840. "arguments')]\n"
  841. 'Expected: [call(), call(1, 2)]\n'
  842. 'Actual: [call(1)]'))) as cm:
  843. self.mock.assert_has_awaits([call(), call(1, 2)])
  844. self.assertIsInstance(cm.exception.__cause__, TypeError)
  845. if __name__ == '__main__':
  846. unittest.main()