test_contextlib_async.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. import asyncio
  2. from contextlib import (
  3. asynccontextmanager, AbstractAsyncContextManager,
  4. AsyncExitStack, nullcontext, aclosing, contextmanager)
  5. import functools
  6. from test import support
  7. import unittest
  8. import traceback
  9. from test.test_contextlib import TestBaseExitStack
  10. support.requires_working_socket(module=True)
  11. def _async_test(func):
  12. """Decorator to turn an async function into a test case."""
  13. @functools.wraps(func)
  14. def wrapper(*args, **kwargs):
  15. coro = func(*args, **kwargs)
  16. asyncio.run(coro)
  17. return wrapper
  18. def tearDownModule():
  19. asyncio.set_event_loop_policy(None)
  20. class TestAbstractAsyncContextManager(unittest.TestCase):
  21. @_async_test
  22. async def test_enter(self):
  23. class DefaultEnter(AbstractAsyncContextManager):
  24. async def __aexit__(self, *args):
  25. await super().__aexit__(*args)
  26. manager = DefaultEnter()
  27. self.assertIs(await manager.__aenter__(), manager)
  28. async with manager as context:
  29. self.assertIs(manager, context)
  30. @_async_test
  31. async def test_async_gen_propagates_generator_exit(self):
  32. # A regression test for https://bugs.python.org/issue33786.
  33. @asynccontextmanager
  34. async def ctx():
  35. yield
  36. async def gen():
  37. async with ctx():
  38. yield 11
  39. ret = []
  40. exc = ValueError(22)
  41. with self.assertRaises(ValueError):
  42. async with ctx():
  43. async for val in gen():
  44. ret.append(val)
  45. raise exc
  46. self.assertEqual(ret, [11])
  47. def test_exit_is_abstract(self):
  48. class MissingAexit(AbstractAsyncContextManager):
  49. pass
  50. with self.assertRaises(TypeError):
  51. MissingAexit()
  52. def test_structural_subclassing(self):
  53. class ManagerFromScratch:
  54. async def __aenter__(self):
  55. return self
  56. async def __aexit__(self, exc_type, exc_value, traceback):
  57. return None
  58. self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
  59. class DefaultEnter(AbstractAsyncContextManager):
  60. async def __aexit__(self, *args):
  61. await super().__aexit__(*args)
  62. self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
  63. class NoneAenter(ManagerFromScratch):
  64. __aenter__ = None
  65. self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
  66. class NoneAexit(ManagerFromScratch):
  67. __aexit__ = None
  68. self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
  69. class AsyncContextManagerTestCase(unittest.TestCase):
  70. @_async_test
  71. async def test_contextmanager_plain(self):
  72. state = []
  73. @asynccontextmanager
  74. async def woohoo():
  75. state.append(1)
  76. yield 42
  77. state.append(999)
  78. async with woohoo() as x:
  79. self.assertEqual(state, [1])
  80. self.assertEqual(x, 42)
  81. state.append(x)
  82. self.assertEqual(state, [1, 42, 999])
  83. @_async_test
  84. async def test_contextmanager_finally(self):
  85. state = []
  86. @asynccontextmanager
  87. async def woohoo():
  88. state.append(1)
  89. try:
  90. yield 42
  91. finally:
  92. state.append(999)
  93. with self.assertRaises(ZeroDivisionError):
  94. async with woohoo() as x:
  95. self.assertEqual(state, [1])
  96. self.assertEqual(x, 42)
  97. state.append(x)
  98. raise ZeroDivisionError()
  99. self.assertEqual(state, [1, 42, 999])
  100. @_async_test
  101. async def test_contextmanager_traceback(self):
  102. @asynccontextmanager
  103. async def f():
  104. yield
  105. try:
  106. async with f():
  107. 1/0
  108. except ZeroDivisionError as e:
  109. frames = traceback.extract_tb(e.__traceback__)
  110. self.assertEqual(len(frames), 1)
  111. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  112. self.assertEqual(frames[0].line, '1/0')
  113. # Repeat with RuntimeError (which goes through a different code path)
  114. class RuntimeErrorSubclass(RuntimeError):
  115. pass
  116. try:
  117. async with f():
  118. raise RuntimeErrorSubclass(42)
  119. except RuntimeErrorSubclass as e:
  120. frames = traceback.extract_tb(e.__traceback__)
  121. self.assertEqual(len(frames), 1)
  122. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  123. self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
  124. class StopIterationSubclass(StopIteration):
  125. pass
  126. class StopAsyncIterationSubclass(StopAsyncIteration):
  127. pass
  128. for stop_exc in (
  129. StopIteration('spam'),
  130. StopAsyncIteration('ham'),
  131. StopIterationSubclass('spam'),
  132. StopAsyncIterationSubclass('spam')
  133. ):
  134. with self.subTest(type=type(stop_exc)):
  135. try:
  136. async with f():
  137. raise stop_exc
  138. except type(stop_exc) as e:
  139. self.assertIs(e, stop_exc)
  140. frames = traceback.extract_tb(e.__traceback__)
  141. else:
  142. self.fail(f'{stop_exc} was suppressed')
  143. self.assertEqual(len(frames), 1)
  144. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  145. self.assertEqual(frames[0].line, 'raise stop_exc')
  146. @_async_test
  147. async def test_contextmanager_no_reraise(self):
  148. @asynccontextmanager
  149. async def whee():
  150. yield
  151. ctx = whee()
  152. await ctx.__aenter__()
  153. # Calling __aexit__ should not result in an exception
  154. self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
  155. @_async_test
  156. async def test_contextmanager_trap_yield_after_throw(self):
  157. @asynccontextmanager
  158. async def whoo():
  159. try:
  160. yield
  161. except:
  162. yield
  163. ctx = whoo()
  164. await ctx.__aenter__()
  165. with self.assertRaises(RuntimeError):
  166. await ctx.__aexit__(TypeError, TypeError('foo'), None)
  167. @_async_test
  168. async def test_contextmanager_trap_no_yield(self):
  169. @asynccontextmanager
  170. async def whoo():
  171. if False:
  172. yield
  173. ctx = whoo()
  174. with self.assertRaises(RuntimeError):
  175. await ctx.__aenter__()
  176. @_async_test
  177. async def test_contextmanager_trap_second_yield(self):
  178. @asynccontextmanager
  179. async def whoo():
  180. yield
  181. yield
  182. ctx = whoo()
  183. await ctx.__aenter__()
  184. with self.assertRaises(RuntimeError):
  185. await ctx.__aexit__(None, None, None)
  186. @_async_test
  187. async def test_contextmanager_non_normalised(self):
  188. @asynccontextmanager
  189. async def whoo():
  190. try:
  191. yield
  192. except RuntimeError:
  193. raise SyntaxError
  194. ctx = whoo()
  195. await ctx.__aenter__()
  196. with self.assertRaises(SyntaxError):
  197. await ctx.__aexit__(RuntimeError, None, None)
  198. @_async_test
  199. async def test_contextmanager_except(self):
  200. state = []
  201. @asynccontextmanager
  202. async def woohoo():
  203. state.append(1)
  204. try:
  205. yield 42
  206. except ZeroDivisionError as e:
  207. state.append(e.args[0])
  208. self.assertEqual(state, [1, 42, 999])
  209. async with woohoo() as x:
  210. self.assertEqual(state, [1])
  211. self.assertEqual(x, 42)
  212. state.append(x)
  213. raise ZeroDivisionError(999)
  214. self.assertEqual(state, [1, 42, 999])
  215. @_async_test
  216. async def test_contextmanager_except_stopiter(self):
  217. @asynccontextmanager
  218. async def woohoo():
  219. yield
  220. class StopIterationSubclass(StopIteration):
  221. pass
  222. class StopAsyncIterationSubclass(StopAsyncIteration):
  223. pass
  224. for stop_exc in (
  225. StopIteration('spam'),
  226. StopAsyncIteration('ham'),
  227. StopIterationSubclass('spam'),
  228. StopAsyncIterationSubclass('spam')
  229. ):
  230. with self.subTest(type=type(stop_exc)):
  231. try:
  232. async with woohoo():
  233. raise stop_exc
  234. except Exception as ex:
  235. self.assertIs(ex, stop_exc)
  236. else:
  237. self.fail(f'{stop_exc} was suppressed')
  238. @_async_test
  239. async def test_contextmanager_wrap_runtimeerror(self):
  240. @asynccontextmanager
  241. async def woohoo():
  242. try:
  243. yield
  244. except Exception as exc:
  245. raise RuntimeError(f'caught {exc}') from exc
  246. with self.assertRaises(RuntimeError):
  247. async with woohoo():
  248. 1 / 0
  249. # If the context manager wrapped StopAsyncIteration in a RuntimeError,
  250. # we also unwrap it, because we can't tell whether the wrapping was
  251. # done by the generator machinery or by the generator itself.
  252. with self.assertRaises(StopAsyncIteration):
  253. async with woohoo():
  254. raise StopAsyncIteration
  255. def _create_contextmanager_attribs(self):
  256. def attribs(**kw):
  257. def decorate(func):
  258. for k,v in kw.items():
  259. setattr(func,k,v)
  260. return func
  261. return decorate
  262. @asynccontextmanager
  263. @attribs(foo='bar')
  264. async def baz(spam):
  265. """Whee!"""
  266. yield
  267. return baz
  268. def test_contextmanager_attribs(self):
  269. baz = self._create_contextmanager_attribs()
  270. self.assertEqual(baz.__name__,'baz')
  271. self.assertEqual(baz.foo, 'bar')
  272. @support.requires_docstrings
  273. def test_contextmanager_doc_attrib(self):
  274. baz = self._create_contextmanager_attribs()
  275. self.assertEqual(baz.__doc__, "Whee!")
  276. @support.requires_docstrings
  277. @_async_test
  278. async def test_instance_docstring_given_cm_docstring(self):
  279. baz = self._create_contextmanager_attribs()(None)
  280. self.assertEqual(baz.__doc__, "Whee!")
  281. async with baz:
  282. pass # suppress warning
  283. @_async_test
  284. async def test_keywords(self):
  285. # Ensure no keyword arguments are inhibited
  286. @asynccontextmanager
  287. async def woohoo(self, func, args, kwds):
  288. yield (self, func, args, kwds)
  289. async with woohoo(self=11, func=22, args=33, kwds=44) as target:
  290. self.assertEqual(target, (11, 22, 33, 44))
  291. @_async_test
  292. async def test_recursive(self):
  293. depth = 0
  294. ncols = 0
  295. @asynccontextmanager
  296. async def woohoo():
  297. nonlocal ncols
  298. ncols += 1
  299. nonlocal depth
  300. before = depth
  301. depth += 1
  302. yield
  303. depth -= 1
  304. self.assertEqual(depth, before)
  305. @woohoo()
  306. async def recursive():
  307. if depth < 10:
  308. await recursive()
  309. await recursive()
  310. self.assertEqual(ncols, 10)
  311. self.assertEqual(depth, 0)
  312. @_async_test
  313. async def test_decorator(self):
  314. entered = False
  315. @asynccontextmanager
  316. async def context():
  317. nonlocal entered
  318. entered = True
  319. yield
  320. entered = False
  321. @context()
  322. async def test():
  323. self.assertTrue(entered)
  324. self.assertFalse(entered)
  325. await test()
  326. self.assertFalse(entered)
  327. @_async_test
  328. async def test_decorator_with_exception(self):
  329. entered = False
  330. @asynccontextmanager
  331. async def context():
  332. nonlocal entered
  333. try:
  334. entered = True
  335. yield
  336. finally:
  337. entered = False
  338. @context()
  339. async def test():
  340. self.assertTrue(entered)
  341. raise NameError('foo')
  342. self.assertFalse(entered)
  343. with self.assertRaisesRegex(NameError, 'foo'):
  344. await test()
  345. self.assertFalse(entered)
  346. @_async_test
  347. async def test_decorating_method(self):
  348. @asynccontextmanager
  349. async def context():
  350. yield
  351. class Test(object):
  352. @context()
  353. async def method(self, a, b, c=None):
  354. self.a = a
  355. self.b = b
  356. self.c = c
  357. # these tests are for argument passing when used as a decorator
  358. test = Test()
  359. await test.method(1, 2)
  360. self.assertEqual(test.a, 1)
  361. self.assertEqual(test.b, 2)
  362. self.assertEqual(test.c, None)
  363. test = Test()
  364. await test.method('a', 'b', 'c')
  365. self.assertEqual(test.a, 'a')
  366. self.assertEqual(test.b, 'b')
  367. self.assertEqual(test.c, 'c')
  368. test = Test()
  369. await test.method(a=1, b=2)
  370. self.assertEqual(test.a, 1)
  371. self.assertEqual(test.b, 2)
  372. class AclosingTestCase(unittest.TestCase):
  373. @support.requires_docstrings
  374. def test_instance_docs(self):
  375. cm_docstring = aclosing.__doc__
  376. obj = aclosing(None)
  377. self.assertEqual(obj.__doc__, cm_docstring)
  378. @_async_test
  379. async def test_aclosing(self):
  380. state = []
  381. class C:
  382. async def aclose(self):
  383. state.append(1)
  384. x = C()
  385. self.assertEqual(state, [])
  386. async with aclosing(x) as y:
  387. self.assertEqual(x, y)
  388. self.assertEqual(state, [1])
  389. @_async_test
  390. async def test_aclosing_error(self):
  391. state = []
  392. class C:
  393. async def aclose(self):
  394. state.append(1)
  395. x = C()
  396. self.assertEqual(state, [])
  397. with self.assertRaises(ZeroDivisionError):
  398. async with aclosing(x) as y:
  399. self.assertEqual(x, y)
  400. 1 / 0
  401. self.assertEqual(state, [1])
  402. @_async_test
  403. async def test_aclosing_bpo41229(self):
  404. state = []
  405. @contextmanager
  406. def sync_resource():
  407. try:
  408. yield
  409. finally:
  410. state.append(1)
  411. async def agenfunc():
  412. with sync_resource():
  413. yield -1
  414. yield -2
  415. x = agenfunc()
  416. self.assertEqual(state, [])
  417. with self.assertRaises(ZeroDivisionError):
  418. async with aclosing(x) as y:
  419. self.assertEqual(x, y)
  420. self.assertEqual(-1, await x.__anext__())
  421. 1 / 0
  422. self.assertEqual(state, [1])
  423. class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
  424. class SyncAsyncExitStack(AsyncExitStack):
  425. @staticmethod
  426. def run_coroutine(coro):
  427. loop = asyncio.get_event_loop_policy().get_event_loop()
  428. t = loop.create_task(coro)
  429. t.add_done_callback(lambda f: loop.stop())
  430. loop.run_forever()
  431. exc = t.exception()
  432. if not exc:
  433. return t.result()
  434. else:
  435. context = exc.__context__
  436. try:
  437. raise exc
  438. except:
  439. exc.__context__ = context
  440. raise exc
  441. def close(self):
  442. return self.run_coroutine(self.aclose())
  443. def __enter__(self):
  444. return self.run_coroutine(self.__aenter__())
  445. def __exit__(self, *exc_details):
  446. return self.run_coroutine(self.__aexit__(*exc_details))
  447. exit_stack = SyncAsyncExitStack
  448. callback_error_internal_frames = [
  449. ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'),
  450. ('run_coroutine', 'raise exc'),
  451. ('run_coroutine', 'raise exc'),
  452. ('__aexit__', 'raise exc_details[1]'),
  453. ('__aexit__', 'cb_suppress = cb(*exc_details)'),
  454. ]
  455. def setUp(self):
  456. self.loop = asyncio.new_event_loop()
  457. asyncio.set_event_loop(self.loop)
  458. self.addCleanup(self.loop.close)
  459. self.addCleanup(asyncio.set_event_loop_policy, None)
  460. @_async_test
  461. async def test_async_callback(self):
  462. expected = [
  463. ((), {}),
  464. ((1,), {}),
  465. ((1,2), {}),
  466. ((), dict(example=1)),
  467. ((1,), dict(example=1)),
  468. ((1,2), dict(example=1)),
  469. ]
  470. result = []
  471. async def _exit(*args, **kwds):
  472. """Test metadata propagation"""
  473. result.append((args, kwds))
  474. async with AsyncExitStack() as stack:
  475. for args, kwds in reversed(expected):
  476. if args and kwds:
  477. f = stack.push_async_callback(_exit, *args, **kwds)
  478. elif args:
  479. f = stack.push_async_callback(_exit, *args)
  480. elif kwds:
  481. f = stack.push_async_callback(_exit, **kwds)
  482. else:
  483. f = stack.push_async_callback(_exit)
  484. self.assertIs(f, _exit)
  485. for wrapper in stack._exit_callbacks:
  486. self.assertIs(wrapper[1].__wrapped__, _exit)
  487. self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
  488. self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
  489. self.assertEqual(result, expected)
  490. result = []
  491. async with AsyncExitStack() as stack:
  492. with self.assertRaises(TypeError):
  493. stack.push_async_callback(arg=1)
  494. with self.assertRaises(TypeError):
  495. self.exit_stack.push_async_callback(arg=2)
  496. with self.assertRaises(TypeError):
  497. stack.push_async_callback(callback=_exit, arg=3)
  498. self.assertEqual(result, [])
  499. @_async_test
  500. async def test_async_push(self):
  501. exc_raised = ZeroDivisionError
  502. async def _expect_exc(exc_type, exc, exc_tb):
  503. self.assertIs(exc_type, exc_raised)
  504. async def _suppress_exc(*exc_details):
  505. return True
  506. async def _expect_ok(exc_type, exc, exc_tb):
  507. self.assertIsNone(exc_type)
  508. self.assertIsNone(exc)
  509. self.assertIsNone(exc_tb)
  510. class ExitCM(object):
  511. def __init__(self, check_exc):
  512. self.check_exc = check_exc
  513. async def __aenter__(self):
  514. self.fail("Should not be called!")
  515. async def __aexit__(self, *exc_details):
  516. await self.check_exc(*exc_details)
  517. async with self.exit_stack() as stack:
  518. stack.push_async_exit(_expect_ok)
  519. self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
  520. cm = ExitCM(_expect_ok)
  521. stack.push_async_exit(cm)
  522. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  523. stack.push_async_exit(_suppress_exc)
  524. self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
  525. cm = ExitCM(_expect_exc)
  526. stack.push_async_exit(cm)
  527. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  528. stack.push_async_exit(_expect_exc)
  529. self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
  530. stack.push_async_exit(_expect_exc)
  531. self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
  532. 1/0
  533. @_async_test
  534. async def test_enter_async_context(self):
  535. class TestCM(object):
  536. async def __aenter__(self):
  537. result.append(1)
  538. async def __aexit__(self, *exc_details):
  539. result.append(3)
  540. result = []
  541. cm = TestCM()
  542. async with AsyncExitStack() as stack:
  543. @stack.push_async_callback # Registered first => cleaned up last
  544. async def _exit():
  545. result.append(4)
  546. self.assertIsNotNone(_exit)
  547. await stack.enter_async_context(cm)
  548. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  549. result.append(2)
  550. self.assertEqual(result, [1, 2, 3, 4])
  551. @_async_test
  552. async def test_enter_async_context_errors(self):
  553. class LacksEnterAndExit:
  554. pass
  555. class LacksEnter:
  556. async def __aexit__(self, *exc_info):
  557. pass
  558. class LacksExit:
  559. async def __aenter__(self):
  560. pass
  561. async with self.exit_stack() as stack:
  562. with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
  563. await stack.enter_async_context(LacksEnterAndExit())
  564. with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
  565. await stack.enter_async_context(LacksEnter())
  566. with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
  567. await stack.enter_async_context(LacksExit())
  568. self.assertFalse(stack._exit_callbacks)
  569. @_async_test
  570. async def test_async_exit_exception_chaining(self):
  571. # Ensure exception chaining matches the reference behaviour
  572. async def raise_exc(exc):
  573. raise exc
  574. saved_details = None
  575. async def suppress_exc(*exc_details):
  576. nonlocal saved_details
  577. saved_details = exc_details
  578. return True
  579. try:
  580. async with self.exit_stack() as stack:
  581. stack.push_async_callback(raise_exc, IndexError)
  582. stack.push_async_callback(raise_exc, KeyError)
  583. stack.push_async_callback(raise_exc, AttributeError)
  584. stack.push_async_exit(suppress_exc)
  585. stack.push_async_callback(raise_exc, ValueError)
  586. 1 / 0
  587. except IndexError as exc:
  588. self.assertIsInstance(exc.__context__, KeyError)
  589. self.assertIsInstance(exc.__context__.__context__, AttributeError)
  590. # Inner exceptions were suppressed
  591. self.assertIsNone(exc.__context__.__context__.__context__)
  592. else:
  593. self.fail("Expected IndexError, but no exception was raised")
  594. # Check the inner exceptions
  595. inner_exc = saved_details[1]
  596. self.assertIsInstance(inner_exc, ValueError)
  597. self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
  598. @_async_test
  599. async def test_async_exit_exception_explicit_none_context(self):
  600. # Ensure AsyncExitStack chaining matches actual nested `with` statements
  601. # regarding explicit __context__ = None.
  602. class MyException(Exception):
  603. pass
  604. @asynccontextmanager
  605. async def my_cm():
  606. try:
  607. yield
  608. except BaseException:
  609. exc = MyException()
  610. try:
  611. raise exc
  612. finally:
  613. exc.__context__ = None
  614. @asynccontextmanager
  615. async def my_cm_with_exit_stack():
  616. async with self.exit_stack() as stack:
  617. await stack.enter_async_context(my_cm())
  618. yield stack
  619. for cm in (my_cm, my_cm_with_exit_stack):
  620. with self.subTest():
  621. try:
  622. async with cm():
  623. raise IndexError()
  624. except MyException as exc:
  625. self.assertIsNone(exc.__context__)
  626. else:
  627. self.fail("Expected IndexError, but no exception was raised")
  628. @_async_test
  629. async def test_instance_bypass_async(self):
  630. class Example(object): pass
  631. cm = Example()
  632. cm.__aenter__ = object()
  633. cm.__aexit__ = object()
  634. stack = self.exit_stack()
  635. with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
  636. await stack.enter_async_context(cm)
  637. stack.push_async_exit(cm)
  638. self.assertIs(stack._exit_callbacks[-1][1], cm)
  639. class TestAsyncNullcontext(unittest.TestCase):
  640. @_async_test
  641. async def test_async_nullcontext(self):
  642. class C:
  643. pass
  644. c = C()
  645. async with nullcontext(c) as c_in:
  646. self.assertIs(c_in, c)
  647. if __name__ == '__main__':
  648. unittest.main()