| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770 |
- import asyncio
- from contextlib import (
- asynccontextmanager, AbstractAsyncContextManager,
- AsyncExitStack, nullcontext, aclosing, contextmanager)
- import functools
- from test import support
- import unittest
- import traceback
- from test.test_contextlib import TestBaseExitStack
- support.requires_working_socket(module=True)
- def _async_test(func):
- """Decorator to turn an async function into a test case."""
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- coro = func(*args, **kwargs)
- asyncio.run(coro)
- return wrapper
- def tearDownModule():
- asyncio.set_event_loop_policy(None)
- class TestAbstractAsyncContextManager(unittest.TestCase):
- @_async_test
- async def test_enter(self):
- class DefaultEnter(AbstractAsyncContextManager):
- async def __aexit__(self, *args):
- await super().__aexit__(*args)
- manager = DefaultEnter()
- self.assertIs(await manager.__aenter__(), manager)
- async with manager as context:
- self.assertIs(manager, context)
- @_async_test
- async def test_async_gen_propagates_generator_exit(self):
- # A regression test for https://bugs.python.org/issue33786.
- @asynccontextmanager
- async def ctx():
- yield
- async def gen():
- async with ctx():
- yield 11
- ret = []
- exc = ValueError(22)
- with self.assertRaises(ValueError):
- async with ctx():
- async for val in gen():
- ret.append(val)
- raise exc
- self.assertEqual(ret, [11])
- def test_exit_is_abstract(self):
- class MissingAexit(AbstractAsyncContextManager):
- pass
- with self.assertRaises(TypeError):
- MissingAexit()
- def test_structural_subclassing(self):
- class ManagerFromScratch:
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc_value, traceback):
- return None
- self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
- class DefaultEnter(AbstractAsyncContextManager):
- async def __aexit__(self, *args):
- await super().__aexit__(*args)
- self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
- class NoneAenter(ManagerFromScratch):
- __aenter__ = None
- self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
- class NoneAexit(ManagerFromScratch):
- __aexit__ = None
- self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
- class AsyncContextManagerTestCase(unittest.TestCase):
- @_async_test
- async def test_contextmanager_plain(self):
- state = []
- @asynccontextmanager
- async def woohoo():
- state.append(1)
- yield 42
- state.append(999)
- async with woohoo() as x:
- self.assertEqual(state, [1])
- self.assertEqual(x, 42)
- state.append(x)
- self.assertEqual(state, [1, 42, 999])
- @_async_test
- async def test_contextmanager_finally(self):
- state = []
- @asynccontextmanager
- async def woohoo():
- state.append(1)
- try:
- yield 42
- finally:
- state.append(999)
- with self.assertRaises(ZeroDivisionError):
- async with woohoo() as x:
- self.assertEqual(state, [1])
- self.assertEqual(x, 42)
- state.append(x)
- raise ZeroDivisionError()
- self.assertEqual(state, [1, 42, 999])
- @_async_test
- async def test_contextmanager_traceback(self):
- @asynccontextmanager
- async def f():
- yield
- try:
- async with f():
- 1/0
- except ZeroDivisionError as e:
- frames = traceback.extract_tb(e.__traceback__)
- self.assertEqual(len(frames), 1)
- self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
- self.assertEqual(frames[0].line, '1/0')
- # Repeat with RuntimeError (which goes through a different code path)
- class RuntimeErrorSubclass(RuntimeError):
- pass
- try:
- async with f():
- raise RuntimeErrorSubclass(42)
- except RuntimeErrorSubclass as e:
- frames = traceback.extract_tb(e.__traceback__)
- self.assertEqual(len(frames), 1)
- self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
- self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
- class StopIterationSubclass(StopIteration):
- pass
- class StopAsyncIterationSubclass(StopAsyncIteration):
- pass
- for stop_exc in (
- StopIteration('spam'),
- StopAsyncIteration('ham'),
- StopIterationSubclass('spam'),
- StopAsyncIterationSubclass('spam')
- ):
- with self.subTest(type=type(stop_exc)):
- try:
- async with f():
- raise stop_exc
- except type(stop_exc) as e:
- self.assertIs(e, stop_exc)
- frames = traceback.extract_tb(e.__traceback__)
- else:
- self.fail(f'{stop_exc} was suppressed')
- self.assertEqual(len(frames), 1)
- self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
- self.assertEqual(frames[0].line, 'raise stop_exc')
- @_async_test
- async def test_contextmanager_no_reraise(self):
- @asynccontextmanager
- async def whee():
- yield
- ctx = whee()
- await ctx.__aenter__()
- # Calling __aexit__ should not result in an exception
- self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
- @_async_test
- async def test_contextmanager_trap_yield_after_throw(self):
- @asynccontextmanager
- async def whoo():
- try:
- yield
- except:
- yield
- ctx = whoo()
- await ctx.__aenter__()
- with self.assertRaises(RuntimeError):
- await ctx.__aexit__(TypeError, TypeError('foo'), None)
- @_async_test
- async def test_contextmanager_trap_no_yield(self):
- @asynccontextmanager
- async def whoo():
- if False:
- yield
- ctx = whoo()
- with self.assertRaises(RuntimeError):
- await ctx.__aenter__()
- @_async_test
- async def test_contextmanager_trap_second_yield(self):
- @asynccontextmanager
- async def whoo():
- yield
- yield
- ctx = whoo()
- await ctx.__aenter__()
- with self.assertRaises(RuntimeError):
- await ctx.__aexit__(None, None, None)
- @_async_test
- async def test_contextmanager_non_normalised(self):
- @asynccontextmanager
- async def whoo():
- try:
- yield
- except RuntimeError:
- raise SyntaxError
- ctx = whoo()
- await ctx.__aenter__()
- with self.assertRaises(SyntaxError):
- await ctx.__aexit__(RuntimeError, None, None)
- @_async_test
- async def test_contextmanager_except(self):
- state = []
- @asynccontextmanager
- async def woohoo():
- state.append(1)
- try:
- yield 42
- except ZeroDivisionError as e:
- state.append(e.args[0])
- self.assertEqual(state, [1, 42, 999])
- async with woohoo() as x:
- self.assertEqual(state, [1])
- self.assertEqual(x, 42)
- state.append(x)
- raise ZeroDivisionError(999)
- self.assertEqual(state, [1, 42, 999])
- @_async_test
- async def test_contextmanager_except_stopiter(self):
- @asynccontextmanager
- async def woohoo():
- yield
- class StopIterationSubclass(StopIteration):
- pass
- class StopAsyncIterationSubclass(StopAsyncIteration):
- pass
- for stop_exc in (
- StopIteration('spam'),
- StopAsyncIteration('ham'),
- StopIterationSubclass('spam'),
- StopAsyncIterationSubclass('spam')
- ):
- with self.subTest(type=type(stop_exc)):
- try:
- async with woohoo():
- raise stop_exc
- except Exception as ex:
- self.assertIs(ex, stop_exc)
- else:
- self.fail(f'{stop_exc} was suppressed')
- @_async_test
- async def test_contextmanager_wrap_runtimeerror(self):
- @asynccontextmanager
- async def woohoo():
- try:
- yield
- except Exception as exc:
- raise RuntimeError(f'caught {exc}') from exc
- with self.assertRaises(RuntimeError):
- async with woohoo():
- 1 / 0
- # If the context manager wrapped StopAsyncIteration in a RuntimeError,
- # we also unwrap it, because we can't tell whether the wrapping was
- # done by the generator machinery or by the generator itself.
- with self.assertRaises(StopAsyncIteration):
- async with woohoo():
- raise StopAsyncIteration
- def _create_contextmanager_attribs(self):
- def attribs(**kw):
- def decorate(func):
- for k,v in kw.items():
- setattr(func,k,v)
- return func
- return decorate
- @asynccontextmanager
- @attribs(foo='bar')
- async def baz(spam):
- """Whee!"""
- yield
- return baz
- def test_contextmanager_attribs(self):
- baz = self._create_contextmanager_attribs()
- self.assertEqual(baz.__name__,'baz')
- self.assertEqual(baz.foo, 'bar')
- @support.requires_docstrings
- def test_contextmanager_doc_attrib(self):
- baz = self._create_contextmanager_attribs()
- self.assertEqual(baz.__doc__, "Whee!")
- @support.requires_docstrings
- @_async_test
- async def test_instance_docstring_given_cm_docstring(self):
- baz = self._create_contextmanager_attribs()(None)
- self.assertEqual(baz.__doc__, "Whee!")
- async with baz:
- pass # suppress warning
- @_async_test
- async def test_keywords(self):
- # Ensure no keyword arguments are inhibited
- @asynccontextmanager
- async def woohoo(self, func, args, kwds):
- yield (self, func, args, kwds)
- async with woohoo(self=11, func=22, args=33, kwds=44) as target:
- self.assertEqual(target, (11, 22, 33, 44))
- @_async_test
- async def test_recursive(self):
- depth = 0
- ncols = 0
- @asynccontextmanager
- async def woohoo():
- nonlocal ncols
- ncols += 1
- nonlocal depth
- before = depth
- depth += 1
- yield
- depth -= 1
- self.assertEqual(depth, before)
- @woohoo()
- async def recursive():
- if depth < 10:
- await recursive()
- await recursive()
- self.assertEqual(ncols, 10)
- self.assertEqual(depth, 0)
- @_async_test
- async def test_decorator(self):
- entered = False
- @asynccontextmanager
- async def context():
- nonlocal entered
- entered = True
- yield
- entered = False
- @context()
- async def test():
- self.assertTrue(entered)
- self.assertFalse(entered)
- await test()
- self.assertFalse(entered)
- @_async_test
- async def test_decorator_with_exception(self):
- entered = False
- @asynccontextmanager
- async def context():
- nonlocal entered
- try:
- entered = True
- yield
- finally:
- entered = False
- @context()
- async def test():
- self.assertTrue(entered)
- raise NameError('foo')
- self.assertFalse(entered)
- with self.assertRaisesRegex(NameError, 'foo'):
- await test()
- self.assertFalse(entered)
- @_async_test
- async def test_decorating_method(self):
- @asynccontextmanager
- async def context():
- yield
- class Test(object):
- @context()
- async def method(self, a, b, c=None):
- self.a = a
- self.b = b
- self.c = c
- # these tests are for argument passing when used as a decorator
- test = Test()
- await test.method(1, 2)
- self.assertEqual(test.a, 1)
- self.assertEqual(test.b, 2)
- self.assertEqual(test.c, None)
- test = Test()
- await test.method('a', 'b', 'c')
- self.assertEqual(test.a, 'a')
- self.assertEqual(test.b, 'b')
- self.assertEqual(test.c, 'c')
- test = Test()
- await test.method(a=1, b=2)
- self.assertEqual(test.a, 1)
- self.assertEqual(test.b, 2)
- class AclosingTestCase(unittest.TestCase):
- @support.requires_docstrings
- def test_instance_docs(self):
- cm_docstring = aclosing.__doc__
- obj = aclosing(None)
- self.assertEqual(obj.__doc__, cm_docstring)
- @_async_test
- async def test_aclosing(self):
- state = []
- class C:
- async def aclose(self):
- state.append(1)
- x = C()
- self.assertEqual(state, [])
- async with aclosing(x) as y:
- self.assertEqual(x, y)
- self.assertEqual(state, [1])
- @_async_test
- async def test_aclosing_error(self):
- state = []
- class C:
- async def aclose(self):
- state.append(1)
- x = C()
- self.assertEqual(state, [])
- with self.assertRaises(ZeroDivisionError):
- async with aclosing(x) as y:
- self.assertEqual(x, y)
- 1 / 0
- self.assertEqual(state, [1])
- @_async_test
- async def test_aclosing_bpo41229(self):
- state = []
- @contextmanager
- def sync_resource():
- try:
- yield
- finally:
- state.append(1)
- async def agenfunc():
- with sync_resource():
- yield -1
- yield -2
- x = agenfunc()
- self.assertEqual(state, [])
- with self.assertRaises(ZeroDivisionError):
- async with aclosing(x) as y:
- self.assertEqual(x, y)
- self.assertEqual(-1, await x.__anext__())
- 1 / 0
- self.assertEqual(state, [1])
- class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
- class SyncAsyncExitStack(AsyncExitStack):
- @staticmethod
- def run_coroutine(coro):
- loop = asyncio.get_event_loop_policy().get_event_loop()
- t = loop.create_task(coro)
- t.add_done_callback(lambda f: loop.stop())
- loop.run_forever()
- exc = t.exception()
- if not exc:
- return t.result()
- else:
- context = exc.__context__
- try:
- raise exc
- except:
- exc.__context__ = context
- raise exc
- def close(self):
- return self.run_coroutine(self.aclose())
- def __enter__(self):
- return self.run_coroutine(self.__aenter__())
- def __exit__(self, *exc_details):
- return self.run_coroutine(self.__aexit__(*exc_details))
- exit_stack = SyncAsyncExitStack
- callback_error_internal_frames = [
- ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'),
- ('run_coroutine', 'raise exc'),
- ('run_coroutine', 'raise exc'),
- ('__aexit__', 'raise exc_details[1]'),
- ('__aexit__', 'cb_suppress = cb(*exc_details)'),
- ]
- def setUp(self):
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
- self.addCleanup(self.loop.close)
- self.addCleanup(asyncio.set_event_loop_policy, None)
- @_async_test
- async def test_async_callback(self):
- expected = [
- ((), {}),
- ((1,), {}),
- ((1,2), {}),
- ((), dict(example=1)),
- ((1,), dict(example=1)),
- ((1,2), dict(example=1)),
- ]
- result = []
- async def _exit(*args, **kwds):
- """Test metadata propagation"""
- result.append((args, kwds))
- async with AsyncExitStack() as stack:
- for args, kwds in reversed(expected):
- if args and kwds:
- f = stack.push_async_callback(_exit, *args, **kwds)
- elif args:
- f = stack.push_async_callback(_exit, *args)
- elif kwds:
- f = stack.push_async_callback(_exit, **kwds)
- else:
- f = stack.push_async_callback(_exit)
- self.assertIs(f, _exit)
- for wrapper in stack._exit_callbacks:
- self.assertIs(wrapper[1].__wrapped__, _exit)
- self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
- self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
- self.assertEqual(result, expected)
- result = []
- async with AsyncExitStack() as stack:
- with self.assertRaises(TypeError):
- stack.push_async_callback(arg=1)
- with self.assertRaises(TypeError):
- self.exit_stack.push_async_callback(arg=2)
- with self.assertRaises(TypeError):
- stack.push_async_callback(callback=_exit, arg=3)
- self.assertEqual(result, [])
- @_async_test
- async def test_async_push(self):
- exc_raised = ZeroDivisionError
- async def _expect_exc(exc_type, exc, exc_tb):
- self.assertIs(exc_type, exc_raised)
- async def _suppress_exc(*exc_details):
- return True
- async def _expect_ok(exc_type, exc, exc_tb):
- self.assertIsNone(exc_type)
- self.assertIsNone(exc)
- self.assertIsNone(exc_tb)
- class ExitCM(object):
- def __init__(self, check_exc):
- self.check_exc = check_exc
- async def __aenter__(self):
- self.fail("Should not be called!")
- async def __aexit__(self, *exc_details):
- await self.check_exc(*exc_details)
- async with self.exit_stack() as stack:
- stack.push_async_exit(_expect_ok)
- self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
- cm = ExitCM(_expect_ok)
- stack.push_async_exit(cm)
- self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
- stack.push_async_exit(_suppress_exc)
- self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
- cm = ExitCM(_expect_exc)
- stack.push_async_exit(cm)
- self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
- stack.push_async_exit(_expect_exc)
- self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
- stack.push_async_exit(_expect_exc)
- self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
- 1/0
- @_async_test
- async def test_enter_async_context(self):
- class TestCM(object):
- async def __aenter__(self):
- result.append(1)
- async def __aexit__(self, *exc_details):
- result.append(3)
- result = []
- cm = TestCM()
- async with AsyncExitStack() as stack:
- @stack.push_async_callback # Registered first => cleaned up last
- async def _exit():
- result.append(4)
- self.assertIsNotNone(_exit)
- await stack.enter_async_context(cm)
- self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
- result.append(2)
- self.assertEqual(result, [1, 2, 3, 4])
- @_async_test
- async def test_enter_async_context_errors(self):
- class LacksEnterAndExit:
- pass
- class LacksEnter:
- async def __aexit__(self, *exc_info):
- pass
- class LacksExit:
- async def __aenter__(self):
- pass
- async with self.exit_stack() as stack:
- with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
- await stack.enter_async_context(LacksEnterAndExit())
- with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
- await stack.enter_async_context(LacksEnter())
- with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
- await stack.enter_async_context(LacksExit())
- self.assertFalse(stack._exit_callbacks)
- @_async_test
- async def test_async_exit_exception_chaining(self):
- # Ensure exception chaining matches the reference behaviour
- async def raise_exc(exc):
- raise exc
- saved_details = None
- async def suppress_exc(*exc_details):
- nonlocal saved_details
- saved_details = exc_details
- return True
- try:
- async with self.exit_stack() as stack:
- stack.push_async_callback(raise_exc, IndexError)
- stack.push_async_callback(raise_exc, KeyError)
- stack.push_async_callback(raise_exc, AttributeError)
- stack.push_async_exit(suppress_exc)
- stack.push_async_callback(raise_exc, ValueError)
- 1 / 0
- except IndexError as exc:
- self.assertIsInstance(exc.__context__, KeyError)
- self.assertIsInstance(exc.__context__.__context__, AttributeError)
- # Inner exceptions were suppressed
- self.assertIsNone(exc.__context__.__context__.__context__)
- else:
- self.fail("Expected IndexError, but no exception was raised")
- # Check the inner exceptions
- inner_exc = saved_details[1]
- self.assertIsInstance(inner_exc, ValueError)
- self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
- @_async_test
- async def test_async_exit_exception_explicit_none_context(self):
- # Ensure AsyncExitStack chaining matches actual nested `with` statements
- # regarding explicit __context__ = None.
- class MyException(Exception):
- pass
- @asynccontextmanager
- async def my_cm():
- try:
- yield
- except BaseException:
- exc = MyException()
- try:
- raise exc
- finally:
- exc.__context__ = None
- @asynccontextmanager
- async def my_cm_with_exit_stack():
- async with self.exit_stack() as stack:
- await stack.enter_async_context(my_cm())
- yield stack
- for cm in (my_cm, my_cm_with_exit_stack):
- with self.subTest():
- try:
- async with cm():
- raise IndexError()
- except MyException as exc:
- self.assertIsNone(exc.__context__)
- else:
- self.fail("Expected IndexError, but no exception was raised")
- @_async_test
- async def test_instance_bypass_async(self):
- class Example(object): pass
- cm = Example()
- cm.__aenter__ = object()
- cm.__aexit__ = object()
- stack = self.exit_stack()
- with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
- await stack.enter_async_context(cm)
- stack.push_async_exit(cm)
- self.assertIs(stack._exit_callbacks[-1][1], cm)
- class TestAsyncNullcontext(unittest.TestCase):
- @_async_test
- async def test_async_nullcontext(self):
- class C:
- pass
- c = C()
- async with nullcontext(c) as c_in:
- self.assertIs(c_in, c)
- if __name__ == '__main__':
- unittest.main()
|