test_contextlib.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254
  1. """Unit tests for contextlib.py, and other context managers."""
  2. import io
  3. import os
  4. import sys
  5. import tempfile
  6. import threading
  7. import traceback
  8. import unittest
  9. from contextlib import * # Tests __all__
  10. from test import support
  11. from test.support import os_helper
  12. import weakref
  13. class TestAbstractContextManager(unittest.TestCase):
  14. def test_enter(self):
  15. class DefaultEnter(AbstractContextManager):
  16. def __exit__(self, *args):
  17. super().__exit__(*args)
  18. manager = DefaultEnter()
  19. self.assertIs(manager.__enter__(), manager)
  20. def test_exit_is_abstract(self):
  21. class MissingExit(AbstractContextManager):
  22. pass
  23. with self.assertRaises(TypeError):
  24. MissingExit()
  25. def test_structural_subclassing(self):
  26. class ManagerFromScratch:
  27. def __enter__(self):
  28. return self
  29. def __exit__(self, exc_type, exc_value, traceback):
  30. return None
  31. self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
  32. class DefaultEnter(AbstractContextManager):
  33. def __exit__(self, *args):
  34. super().__exit__(*args)
  35. self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
  36. class NoEnter(ManagerFromScratch):
  37. __enter__ = None
  38. self.assertFalse(issubclass(NoEnter, AbstractContextManager))
  39. class NoExit(ManagerFromScratch):
  40. __exit__ = None
  41. self.assertFalse(issubclass(NoExit, AbstractContextManager))
  42. class ContextManagerTestCase(unittest.TestCase):
  43. def test_contextmanager_plain(self):
  44. state = []
  45. @contextmanager
  46. def woohoo():
  47. state.append(1)
  48. yield 42
  49. state.append(999)
  50. with woohoo() as x:
  51. self.assertEqual(state, [1])
  52. self.assertEqual(x, 42)
  53. state.append(x)
  54. self.assertEqual(state, [1, 42, 999])
  55. def test_contextmanager_finally(self):
  56. state = []
  57. @contextmanager
  58. def woohoo():
  59. state.append(1)
  60. try:
  61. yield 42
  62. finally:
  63. state.append(999)
  64. with self.assertRaises(ZeroDivisionError):
  65. with woohoo() as x:
  66. self.assertEqual(state, [1])
  67. self.assertEqual(x, 42)
  68. state.append(x)
  69. raise ZeroDivisionError()
  70. self.assertEqual(state, [1, 42, 999])
  71. def test_contextmanager_traceback(self):
  72. @contextmanager
  73. def f():
  74. yield
  75. try:
  76. with f():
  77. 1/0
  78. except ZeroDivisionError as e:
  79. frames = traceback.extract_tb(e.__traceback__)
  80. self.assertEqual(len(frames), 1)
  81. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  82. self.assertEqual(frames[0].line, '1/0')
  83. # Repeat with RuntimeError (which goes through a different code path)
  84. class RuntimeErrorSubclass(RuntimeError):
  85. pass
  86. try:
  87. with f():
  88. raise RuntimeErrorSubclass(42)
  89. except RuntimeErrorSubclass as e:
  90. frames = traceback.extract_tb(e.__traceback__)
  91. self.assertEqual(len(frames), 1)
  92. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  93. self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
  94. class StopIterationSubclass(StopIteration):
  95. pass
  96. for stop_exc in (
  97. StopIteration('spam'),
  98. StopIterationSubclass('spam'),
  99. ):
  100. with self.subTest(type=type(stop_exc)):
  101. try:
  102. with f():
  103. raise stop_exc
  104. except type(stop_exc) as e:
  105. self.assertIs(e, stop_exc)
  106. frames = traceback.extract_tb(e.__traceback__)
  107. else:
  108. self.fail(f'{stop_exc} was suppressed')
  109. self.assertEqual(len(frames), 1)
  110. self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
  111. self.assertEqual(frames[0].line, 'raise stop_exc')
  112. def test_contextmanager_no_reraise(self):
  113. @contextmanager
  114. def whee():
  115. yield
  116. ctx = whee()
  117. ctx.__enter__()
  118. # Calling __exit__ should not result in an exception
  119. self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
  120. def test_contextmanager_trap_yield_after_throw(self):
  121. @contextmanager
  122. def whoo():
  123. try:
  124. yield
  125. except:
  126. yield
  127. ctx = whoo()
  128. ctx.__enter__()
  129. self.assertRaises(
  130. RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
  131. )
  132. def test_contextmanager_except(self):
  133. state = []
  134. @contextmanager
  135. def woohoo():
  136. state.append(1)
  137. try:
  138. yield 42
  139. except ZeroDivisionError as e:
  140. state.append(e.args[0])
  141. self.assertEqual(state, [1, 42, 999])
  142. with woohoo() as x:
  143. self.assertEqual(state, [1])
  144. self.assertEqual(x, 42)
  145. state.append(x)
  146. raise ZeroDivisionError(999)
  147. self.assertEqual(state, [1, 42, 999])
  148. def test_contextmanager_except_stopiter(self):
  149. @contextmanager
  150. def woohoo():
  151. yield
  152. class StopIterationSubclass(StopIteration):
  153. pass
  154. for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
  155. with self.subTest(type=type(stop_exc)):
  156. try:
  157. with woohoo():
  158. raise stop_exc
  159. except Exception as ex:
  160. self.assertIs(ex, stop_exc)
  161. else:
  162. self.fail(f'{stop_exc} was suppressed')
  163. def test_contextmanager_except_pep479(self):
  164. code = """\
  165. from __future__ import generator_stop
  166. from contextlib import contextmanager
  167. @contextmanager
  168. def woohoo():
  169. yield
  170. """
  171. locals = {}
  172. exec(code, locals, locals)
  173. woohoo = locals['woohoo']
  174. stop_exc = StopIteration('spam')
  175. try:
  176. with woohoo():
  177. raise stop_exc
  178. except Exception as ex:
  179. self.assertIs(ex, stop_exc)
  180. else:
  181. self.fail('StopIteration was suppressed')
  182. def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
  183. @contextmanager
  184. def test_issue29692():
  185. try:
  186. yield
  187. except Exception as exc:
  188. raise RuntimeError('issue29692:Chained') from exc
  189. try:
  190. with test_issue29692():
  191. raise ZeroDivisionError
  192. except Exception as ex:
  193. self.assertIs(type(ex), RuntimeError)
  194. self.assertEqual(ex.args[0], 'issue29692:Chained')
  195. self.assertIsInstance(ex.__cause__, ZeroDivisionError)
  196. try:
  197. with test_issue29692():
  198. raise StopIteration('issue29692:Unchained')
  199. except Exception as ex:
  200. self.assertIs(type(ex), StopIteration)
  201. self.assertEqual(ex.args[0], 'issue29692:Unchained')
  202. self.assertIsNone(ex.__cause__)
  203. def _create_contextmanager_attribs(self):
  204. def attribs(**kw):
  205. def decorate(func):
  206. for k,v in kw.items():
  207. setattr(func,k,v)
  208. return func
  209. return decorate
  210. @contextmanager
  211. @attribs(foo='bar')
  212. def baz(spam):
  213. """Whee!"""
  214. return baz
  215. def test_contextmanager_attribs(self):
  216. baz = self._create_contextmanager_attribs()
  217. self.assertEqual(baz.__name__,'baz')
  218. self.assertEqual(baz.foo, 'bar')
  219. @support.requires_docstrings
  220. def test_contextmanager_doc_attrib(self):
  221. baz = self._create_contextmanager_attribs()
  222. self.assertEqual(baz.__doc__, "Whee!")
  223. @support.requires_docstrings
  224. def test_instance_docstring_given_cm_docstring(self):
  225. baz = self._create_contextmanager_attribs()(None)
  226. self.assertEqual(baz.__doc__, "Whee!")
  227. def test_keywords(self):
  228. # Ensure no keyword arguments are inhibited
  229. @contextmanager
  230. def woohoo(self, func, args, kwds):
  231. yield (self, func, args, kwds)
  232. with woohoo(self=11, func=22, args=33, kwds=44) as target:
  233. self.assertEqual(target, (11, 22, 33, 44))
  234. def test_nokeepref(self):
  235. class A:
  236. pass
  237. @contextmanager
  238. def woohoo(a, b):
  239. a = weakref.ref(a)
  240. b = weakref.ref(b)
  241. # Allow test to work with a non-refcounted GC
  242. support.gc_collect()
  243. self.assertIsNone(a())
  244. self.assertIsNone(b())
  245. yield
  246. with woohoo(A(), b=A()):
  247. pass
  248. def test_param_errors(self):
  249. @contextmanager
  250. def woohoo(a, *, b):
  251. yield
  252. with self.assertRaises(TypeError):
  253. woohoo()
  254. with self.assertRaises(TypeError):
  255. woohoo(3, 5)
  256. with self.assertRaises(TypeError):
  257. woohoo(b=3)
  258. def test_recursive(self):
  259. depth = 0
  260. @contextmanager
  261. def woohoo():
  262. nonlocal depth
  263. before = depth
  264. depth += 1
  265. yield
  266. depth -= 1
  267. self.assertEqual(depth, before)
  268. @woohoo()
  269. def recursive():
  270. if depth < 10:
  271. recursive()
  272. recursive()
  273. self.assertEqual(depth, 0)
  274. class ClosingTestCase(unittest.TestCase):
  275. @support.requires_docstrings
  276. def test_instance_docs(self):
  277. # Issue 19330: ensure context manager instances have good docstrings
  278. cm_docstring = closing.__doc__
  279. obj = closing(None)
  280. self.assertEqual(obj.__doc__, cm_docstring)
  281. def test_closing(self):
  282. state = []
  283. class C:
  284. def close(self):
  285. state.append(1)
  286. x = C()
  287. self.assertEqual(state, [])
  288. with closing(x) as y:
  289. self.assertEqual(x, y)
  290. self.assertEqual(state, [1])
  291. def test_closing_error(self):
  292. state = []
  293. class C:
  294. def close(self):
  295. state.append(1)
  296. x = C()
  297. self.assertEqual(state, [])
  298. with self.assertRaises(ZeroDivisionError):
  299. with closing(x) as y:
  300. self.assertEqual(x, y)
  301. 1 / 0
  302. self.assertEqual(state, [1])
  303. class NullcontextTestCase(unittest.TestCase):
  304. def test_nullcontext(self):
  305. class C:
  306. pass
  307. c = C()
  308. with nullcontext(c) as c_in:
  309. self.assertIs(c_in, c)
  310. class FileContextTestCase(unittest.TestCase):
  311. def testWithOpen(self):
  312. tfn = tempfile.mktemp()
  313. try:
  314. f = None
  315. with open(tfn, "w", encoding="utf-8") as f:
  316. self.assertFalse(f.closed)
  317. f.write("Booh\n")
  318. self.assertTrue(f.closed)
  319. f = None
  320. with self.assertRaises(ZeroDivisionError):
  321. with open(tfn, "r", encoding="utf-8") as f:
  322. self.assertFalse(f.closed)
  323. self.assertEqual(f.read(), "Booh\n")
  324. 1 / 0
  325. self.assertTrue(f.closed)
  326. finally:
  327. os_helper.unlink(tfn)
  328. class LockContextTestCase(unittest.TestCase):
  329. def boilerPlate(self, lock, locked):
  330. self.assertFalse(locked())
  331. with lock:
  332. self.assertTrue(locked())
  333. self.assertFalse(locked())
  334. with self.assertRaises(ZeroDivisionError):
  335. with lock:
  336. self.assertTrue(locked())
  337. 1 / 0
  338. self.assertFalse(locked())
  339. def testWithLock(self):
  340. lock = threading.Lock()
  341. self.boilerPlate(lock, lock.locked)
  342. def testWithRLock(self):
  343. lock = threading.RLock()
  344. self.boilerPlate(lock, lock._is_owned)
  345. def testWithCondition(self):
  346. lock = threading.Condition()
  347. def locked():
  348. return lock._is_owned()
  349. self.boilerPlate(lock, locked)
  350. def testWithSemaphore(self):
  351. lock = threading.Semaphore()
  352. def locked():
  353. if lock.acquire(False):
  354. lock.release()
  355. return False
  356. else:
  357. return True
  358. self.boilerPlate(lock, locked)
  359. def testWithBoundedSemaphore(self):
  360. lock = threading.BoundedSemaphore()
  361. def locked():
  362. if lock.acquire(False):
  363. lock.release()
  364. return False
  365. else:
  366. return True
  367. self.boilerPlate(lock, locked)
  368. class mycontext(ContextDecorator):
  369. """Example decoration-compatible context manager for testing"""
  370. started = False
  371. exc = None
  372. catch = False
  373. def __enter__(self):
  374. self.started = True
  375. return self
  376. def __exit__(self, *exc):
  377. self.exc = exc
  378. return self.catch
  379. class TestContextDecorator(unittest.TestCase):
  380. @support.requires_docstrings
  381. def test_instance_docs(self):
  382. # Issue 19330: ensure context manager instances have good docstrings
  383. cm_docstring = mycontext.__doc__
  384. obj = mycontext()
  385. self.assertEqual(obj.__doc__, cm_docstring)
  386. def test_contextdecorator(self):
  387. context = mycontext()
  388. with context as result:
  389. self.assertIs(result, context)
  390. self.assertTrue(context.started)
  391. self.assertEqual(context.exc, (None, None, None))
  392. def test_contextdecorator_with_exception(self):
  393. context = mycontext()
  394. with self.assertRaisesRegex(NameError, 'foo'):
  395. with context:
  396. raise NameError('foo')
  397. self.assertIsNotNone(context.exc)
  398. self.assertIs(context.exc[0], NameError)
  399. context = mycontext()
  400. context.catch = True
  401. with context:
  402. raise NameError('foo')
  403. self.assertIsNotNone(context.exc)
  404. self.assertIs(context.exc[0], NameError)
  405. def test_decorator(self):
  406. context = mycontext()
  407. @context
  408. def test():
  409. self.assertIsNone(context.exc)
  410. self.assertTrue(context.started)
  411. test()
  412. self.assertEqual(context.exc, (None, None, None))
  413. def test_decorator_with_exception(self):
  414. context = mycontext()
  415. @context
  416. def test():
  417. self.assertIsNone(context.exc)
  418. self.assertTrue(context.started)
  419. raise NameError('foo')
  420. with self.assertRaisesRegex(NameError, 'foo'):
  421. test()
  422. self.assertIsNotNone(context.exc)
  423. self.assertIs(context.exc[0], NameError)
  424. def test_decorating_method(self):
  425. context = mycontext()
  426. class Test(object):
  427. @context
  428. def method(self, a, b, c=None):
  429. self.a = a
  430. self.b = b
  431. self.c = c
  432. # these tests are for argument passing when used as a decorator
  433. test = Test()
  434. test.method(1, 2)
  435. self.assertEqual(test.a, 1)
  436. self.assertEqual(test.b, 2)
  437. self.assertEqual(test.c, None)
  438. test = Test()
  439. test.method('a', 'b', 'c')
  440. self.assertEqual(test.a, 'a')
  441. self.assertEqual(test.b, 'b')
  442. self.assertEqual(test.c, 'c')
  443. test = Test()
  444. test.method(a=1, b=2)
  445. self.assertEqual(test.a, 1)
  446. self.assertEqual(test.b, 2)
  447. def test_typo_enter(self):
  448. class mycontext(ContextDecorator):
  449. def __unter__(self):
  450. pass
  451. def __exit__(self, *exc):
  452. pass
  453. with self.assertRaisesRegex(TypeError, 'the context manager'):
  454. with mycontext():
  455. pass
  456. def test_typo_exit(self):
  457. class mycontext(ContextDecorator):
  458. def __enter__(self):
  459. pass
  460. def __uxit__(self, *exc):
  461. pass
  462. with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
  463. with mycontext():
  464. pass
  465. def test_contextdecorator_as_mixin(self):
  466. class somecontext(object):
  467. started = False
  468. exc = None
  469. def __enter__(self):
  470. self.started = True
  471. return self
  472. def __exit__(self, *exc):
  473. self.exc = exc
  474. class mycontext(somecontext, ContextDecorator):
  475. pass
  476. context = mycontext()
  477. @context
  478. def test():
  479. self.assertIsNone(context.exc)
  480. self.assertTrue(context.started)
  481. test()
  482. self.assertEqual(context.exc, (None, None, None))
  483. def test_contextmanager_as_decorator(self):
  484. @contextmanager
  485. def woohoo(y):
  486. state.append(y)
  487. yield
  488. state.append(999)
  489. state = []
  490. @woohoo(1)
  491. def test(x):
  492. self.assertEqual(state, [1])
  493. state.append(x)
  494. test('something')
  495. self.assertEqual(state, [1, 'something', 999])
  496. # Issue #11647: Ensure the decorated function is 'reusable'
  497. state = []
  498. test('something else')
  499. self.assertEqual(state, [1, 'something else', 999])
  500. class TestBaseExitStack:
  501. exit_stack = None
  502. @support.requires_docstrings
  503. def test_instance_docs(self):
  504. # Issue 19330: ensure context manager instances have good docstrings
  505. cm_docstring = self.exit_stack.__doc__
  506. obj = self.exit_stack()
  507. self.assertEqual(obj.__doc__, cm_docstring)
  508. def test_no_resources(self):
  509. with self.exit_stack():
  510. pass
  511. def test_callback(self):
  512. expected = [
  513. ((), {}),
  514. ((1,), {}),
  515. ((1,2), {}),
  516. ((), dict(example=1)),
  517. ((1,), dict(example=1)),
  518. ((1,2), dict(example=1)),
  519. ((1,2), dict(self=3, callback=4)),
  520. ]
  521. result = []
  522. def _exit(*args, **kwds):
  523. """Test metadata propagation"""
  524. result.append((args, kwds))
  525. with self.exit_stack() as stack:
  526. for args, kwds in reversed(expected):
  527. if args and kwds:
  528. f = stack.callback(_exit, *args, **kwds)
  529. elif args:
  530. f = stack.callback(_exit, *args)
  531. elif kwds:
  532. f = stack.callback(_exit, **kwds)
  533. else:
  534. f = stack.callback(_exit)
  535. self.assertIs(f, _exit)
  536. for wrapper in stack._exit_callbacks:
  537. self.assertIs(wrapper[1].__wrapped__, _exit)
  538. self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
  539. self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
  540. self.assertEqual(result, expected)
  541. result = []
  542. with self.exit_stack() as stack:
  543. with self.assertRaises(TypeError):
  544. stack.callback(arg=1)
  545. with self.assertRaises(TypeError):
  546. self.exit_stack.callback(arg=2)
  547. with self.assertRaises(TypeError):
  548. stack.callback(callback=_exit, arg=3)
  549. self.assertEqual(result, [])
  550. def test_push(self):
  551. exc_raised = ZeroDivisionError
  552. def _expect_exc(exc_type, exc, exc_tb):
  553. self.assertIs(exc_type, exc_raised)
  554. def _suppress_exc(*exc_details):
  555. return True
  556. def _expect_ok(exc_type, exc, exc_tb):
  557. self.assertIsNone(exc_type)
  558. self.assertIsNone(exc)
  559. self.assertIsNone(exc_tb)
  560. class ExitCM(object):
  561. def __init__(self, check_exc):
  562. self.check_exc = check_exc
  563. def __enter__(self):
  564. self.fail("Should not be called!")
  565. def __exit__(self, *exc_details):
  566. self.check_exc(*exc_details)
  567. with self.exit_stack() as stack:
  568. stack.push(_expect_ok)
  569. self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
  570. cm = ExitCM(_expect_ok)
  571. stack.push(cm)
  572. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  573. stack.push(_suppress_exc)
  574. self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
  575. cm = ExitCM(_expect_exc)
  576. stack.push(cm)
  577. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  578. stack.push(_expect_exc)
  579. self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
  580. stack.push(_expect_exc)
  581. self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
  582. 1/0
  583. def test_enter_context(self):
  584. class TestCM(object):
  585. def __enter__(self):
  586. result.append(1)
  587. def __exit__(self, *exc_details):
  588. result.append(3)
  589. result = []
  590. cm = TestCM()
  591. with self.exit_stack() as stack:
  592. @stack.callback # Registered first => cleaned up last
  593. def _exit():
  594. result.append(4)
  595. self.assertIsNotNone(_exit)
  596. stack.enter_context(cm)
  597. self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
  598. result.append(2)
  599. self.assertEqual(result, [1, 2, 3, 4])
  600. def test_enter_context_errors(self):
  601. class LacksEnterAndExit:
  602. pass
  603. class LacksEnter:
  604. def __exit__(self, *exc_info):
  605. pass
  606. class LacksExit:
  607. def __enter__(self):
  608. pass
  609. with self.exit_stack() as stack:
  610. with self.assertRaisesRegex(TypeError, 'the context manager'):
  611. stack.enter_context(LacksEnterAndExit())
  612. with self.assertRaisesRegex(TypeError, 'the context manager'):
  613. stack.enter_context(LacksEnter())
  614. with self.assertRaisesRegex(TypeError, 'the context manager'):
  615. stack.enter_context(LacksExit())
  616. self.assertFalse(stack._exit_callbacks)
  617. def test_close(self):
  618. result = []
  619. with self.exit_stack() as stack:
  620. @stack.callback
  621. def _exit():
  622. result.append(1)
  623. self.assertIsNotNone(_exit)
  624. stack.close()
  625. result.append(2)
  626. self.assertEqual(result, [1, 2])
  627. def test_pop_all(self):
  628. result = []
  629. with self.exit_stack() as stack:
  630. @stack.callback
  631. def _exit():
  632. result.append(3)
  633. self.assertIsNotNone(_exit)
  634. new_stack = stack.pop_all()
  635. result.append(1)
  636. result.append(2)
  637. new_stack.close()
  638. self.assertEqual(result, [1, 2, 3])
  639. def test_exit_raise(self):
  640. with self.assertRaises(ZeroDivisionError):
  641. with self.exit_stack() as stack:
  642. stack.push(lambda *exc: False)
  643. 1/0
  644. def test_exit_suppress(self):
  645. with self.exit_stack() as stack:
  646. stack.push(lambda *exc: True)
  647. 1/0
  648. def test_exit_exception_traceback(self):
  649. # This test captures the current behavior of ExitStack so that we know
  650. # if we ever unintendedly change it. It is not a statement of what the
  651. # desired behavior is (for instance, we may want to remove some of the
  652. # internal contextlib frames).
  653. def raise_exc(exc):
  654. raise exc
  655. try:
  656. with self.exit_stack() as stack:
  657. stack.callback(raise_exc, ValueError)
  658. 1/0
  659. except ValueError as e:
  660. exc = e
  661. self.assertIsInstance(exc, ValueError)
  662. ve_frames = traceback.extract_tb(exc.__traceback__)
  663. expected = \
  664. [('test_exit_exception_traceback', 'with self.exit_stack() as stack:')] + \
  665. self.callback_error_internal_frames + \
  666. [('_exit_wrapper', 'callback(*args, **kwds)'),
  667. ('raise_exc', 'raise exc')]
  668. self.assertEqual(
  669. [(f.name, f.line) for f in ve_frames], expected)
  670. self.assertIsInstance(exc.__context__, ZeroDivisionError)
  671. zde_frames = traceback.extract_tb(exc.__context__.__traceback__)
  672. self.assertEqual([(f.name, f.line) for f in zde_frames],
  673. [('test_exit_exception_traceback', '1/0')])
  674. def test_exit_exception_chaining_reference(self):
  675. # Sanity check to make sure that ExitStack chaining matches
  676. # actual nested with statements
  677. class RaiseExc:
  678. def __init__(self, exc):
  679. self.exc = exc
  680. def __enter__(self):
  681. return self
  682. def __exit__(self, *exc_details):
  683. raise self.exc
  684. class RaiseExcWithContext:
  685. def __init__(self, outer, inner):
  686. self.outer = outer
  687. self.inner = inner
  688. def __enter__(self):
  689. return self
  690. def __exit__(self, *exc_details):
  691. try:
  692. raise self.inner
  693. except:
  694. raise self.outer
  695. class SuppressExc:
  696. def __enter__(self):
  697. return self
  698. def __exit__(self, *exc_details):
  699. type(self).saved_details = exc_details
  700. return True
  701. try:
  702. with RaiseExc(IndexError):
  703. with RaiseExcWithContext(KeyError, AttributeError):
  704. with SuppressExc():
  705. with RaiseExc(ValueError):
  706. 1 / 0
  707. except IndexError as exc:
  708. self.assertIsInstance(exc.__context__, KeyError)
  709. self.assertIsInstance(exc.__context__.__context__, AttributeError)
  710. # Inner exceptions were suppressed
  711. self.assertIsNone(exc.__context__.__context__.__context__)
  712. else:
  713. self.fail("Expected IndexError, but no exception was raised")
  714. # Check the inner exceptions
  715. inner_exc = SuppressExc.saved_details[1]
  716. self.assertIsInstance(inner_exc, ValueError)
  717. self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
  718. def test_exit_exception_chaining(self):
  719. # Ensure exception chaining matches the reference behaviour
  720. def raise_exc(exc):
  721. raise exc
  722. saved_details = None
  723. def suppress_exc(*exc_details):
  724. nonlocal saved_details
  725. saved_details = exc_details
  726. return True
  727. try:
  728. with self.exit_stack() as stack:
  729. stack.callback(raise_exc, IndexError)
  730. stack.callback(raise_exc, KeyError)
  731. stack.callback(raise_exc, AttributeError)
  732. stack.push(suppress_exc)
  733. stack.callback(raise_exc, ValueError)
  734. 1 / 0
  735. except IndexError as exc:
  736. self.assertIsInstance(exc.__context__, KeyError)
  737. self.assertIsInstance(exc.__context__.__context__, AttributeError)
  738. # Inner exceptions were suppressed
  739. self.assertIsNone(exc.__context__.__context__.__context__)
  740. else:
  741. self.fail("Expected IndexError, but no exception was raised")
  742. # Check the inner exceptions
  743. inner_exc = saved_details[1]
  744. self.assertIsInstance(inner_exc, ValueError)
  745. self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
  746. def test_exit_exception_explicit_none_context(self):
  747. # Ensure ExitStack chaining matches actual nested `with` statements
  748. # regarding explicit __context__ = None.
  749. class MyException(Exception):
  750. pass
  751. @contextmanager
  752. def my_cm():
  753. try:
  754. yield
  755. except BaseException:
  756. exc = MyException()
  757. try:
  758. raise exc
  759. finally:
  760. exc.__context__ = None
  761. @contextmanager
  762. def my_cm_with_exit_stack():
  763. with self.exit_stack() as stack:
  764. stack.enter_context(my_cm())
  765. yield stack
  766. for cm in (my_cm, my_cm_with_exit_stack):
  767. with self.subTest():
  768. try:
  769. with cm():
  770. raise IndexError()
  771. except MyException as exc:
  772. self.assertIsNone(exc.__context__)
  773. else:
  774. self.fail("Expected IndexError, but no exception was raised")
  775. def test_exit_exception_non_suppressing(self):
  776. # http://bugs.python.org/issue19092
  777. def raise_exc(exc):
  778. raise exc
  779. def suppress_exc(*exc_details):
  780. return True
  781. try:
  782. with self.exit_stack() as stack:
  783. stack.callback(lambda: None)
  784. stack.callback(raise_exc, IndexError)
  785. except Exception as exc:
  786. self.assertIsInstance(exc, IndexError)
  787. else:
  788. self.fail("Expected IndexError, but no exception was raised")
  789. try:
  790. with self.exit_stack() as stack:
  791. stack.callback(raise_exc, KeyError)
  792. stack.push(suppress_exc)
  793. stack.callback(raise_exc, IndexError)
  794. except Exception as exc:
  795. self.assertIsInstance(exc, KeyError)
  796. else:
  797. self.fail("Expected KeyError, but no exception was raised")
  798. def test_exit_exception_with_correct_context(self):
  799. # http://bugs.python.org/issue20317
  800. @contextmanager
  801. def gets_the_context_right(exc):
  802. try:
  803. yield
  804. finally:
  805. raise exc
  806. exc1 = Exception(1)
  807. exc2 = Exception(2)
  808. exc3 = Exception(3)
  809. exc4 = Exception(4)
  810. # The contextmanager already fixes the context, so prior to the
  811. # fix, ExitStack would try to fix it *again* and get into an
  812. # infinite self-referential loop
  813. try:
  814. with self.exit_stack() as stack:
  815. stack.enter_context(gets_the_context_right(exc4))
  816. stack.enter_context(gets_the_context_right(exc3))
  817. stack.enter_context(gets_the_context_right(exc2))
  818. raise exc1
  819. except Exception as exc:
  820. self.assertIs(exc, exc4)
  821. self.assertIs(exc.__context__, exc3)
  822. self.assertIs(exc.__context__.__context__, exc2)
  823. self.assertIs(exc.__context__.__context__.__context__, exc1)
  824. self.assertIsNone(
  825. exc.__context__.__context__.__context__.__context__)
  826. def test_exit_exception_with_existing_context(self):
  827. # Addresses a lack of test coverage discovered after checking in a
  828. # fix for issue 20317 that still contained debugging code.
  829. def raise_nested(inner_exc, outer_exc):
  830. try:
  831. raise inner_exc
  832. finally:
  833. raise outer_exc
  834. exc1 = Exception(1)
  835. exc2 = Exception(2)
  836. exc3 = Exception(3)
  837. exc4 = Exception(4)
  838. exc5 = Exception(5)
  839. try:
  840. with self.exit_stack() as stack:
  841. stack.callback(raise_nested, exc4, exc5)
  842. stack.callback(raise_nested, exc2, exc3)
  843. raise exc1
  844. except Exception as exc:
  845. self.assertIs(exc, exc5)
  846. self.assertIs(exc.__context__, exc4)
  847. self.assertIs(exc.__context__.__context__, exc3)
  848. self.assertIs(exc.__context__.__context__.__context__, exc2)
  849. self.assertIs(
  850. exc.__context__.__context__.__context__.__context__, exc1)
  851. self.assertIsNone(
  852. exc.__context__.__context__.__context__.__context__.__context__)
  853. def test_body_exception_suppress(self):
  854. def suppress_exc(*exc_details):
  855. return True
  856. try:
  857. with self.exit_stack() as stack:
  858. stack.push(suppress_exc)
  859. 1/0
  860. except IndexError as exc:
  861. self.fail("Expected no exception, got IndexError")
  862. def test_exit_exception_chaining_suppress(self):
  863. with self.exit_stack() as stack:
  864. stack.push(lambda *exc: True)
  865. stack.push(lambda *exc: 1/0)
  866. stack.push(lambda *exc: {}[1])
  867. def test_excessive_nesting(self):
  868. # The original implementation would die with RecursionError here
  869. with self.exit_stack() as stack:
  870. for i in range(10000):
  871. stack.callback(int)
  872. def test_instance_bypass(self):
  873. class Example(object): pass
  874. cm = Example()
  875. cm.__enter__ = object()
  876. cm.__exit__ = object()
  877. stack = self.exit_stack()
  878. with self.assertRaisesRegex(TypeError, 'the context manager'):
  879. stack.enter_context(cm)
  880. stack.push(cm)
  881. self.assertIs(stack._exit_callbacks[-1][1], cm)
  882. def test_dont_reraise_RuntimeError(self):
  883. # https://bugs.python.org/issue27122
  884. class UniqueException(Exception): pass
  885. class UniqueRuntimeError(RuntimeError): pass
  886. @contextmanager
  887. def second():
  888. try:
  889. yield 1
  890. except Exception as exc:
  891. raise UniqueException("new exception") from exc
  892. @contextmanager
  893. def first():
  894. try:
  895. yield 1
  896. except Exception as exc:
  897. raise exc
  898. # The UniqueRuntimeError should be caught by second()'s exception
  899. # handler which chain raised a new UniqueException.
  900. with self.assertRaises(UniqueException) as err_ctx:
  901. with self.exit_stack() as es_ctx:
  902. es_ctx.enter_context(second())
  903. es_ctx.enter_context(first())
  904. raise UniqueRuntimeError("please no infinite loop.")
  905. exc = err_ctx.exception
  906. self.assertIsInstance(exc, UniqueException)
  907. self.assertIsInstance(exc.__context__, UniqueRuntimeError)
  908. self.assertIsNone(exc.__context__.__context__)
  909. self.assertIsNone(exc.__context__.__cause__)
  910. self.assertIs(exc.__cause__, exc.__context__)
  911. class TestExitStack(TestBaseExitStack, unittest.TestCase):
  912. exit_stack = ExitStack
  913. callback_error_internal_frames = [
  914. ('__exit__', 'raise exc_details[1]'),
  915. ('__exit__', 'if cb(*exc_details):'),
  916. ]
  917. class TestRedirectStream:
  918. redirect_stream = None
  919. orig_stream = None
  920. @support.requires_docstrings
  921. def test_instance_docs(self):
  922. # Issue 19330: ensure context manager instances have good docstrings
  923. cm_docstring = self.redirect_stream.__doc__
  924. obj = self.redirect_stream(None)
  925. self.assertEqual(obj.__doc__, cm_docstring)
  926. def test_no_redirect_in_init(self):
  927. orig_stdout = getattr(sys, self.orig_stream)
  928. self.redirect_stream(None)
  929. self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
  930. def test_redirect_to_string_io(self):
  931. f = io.StringIO()
  932. msg = "Consider an API like help(), which prints directly to stdout"
  933. orig_stdout = getattr(sys, self.orig_stream)
  934. with self.redirect_stream(f):
  935. print(msg, file=getattr(sys, self.orig_stream))
  936. self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
  937. s = f.getvalue().strip()
  938. self.assertEqual(s, msg)
  939. def test_enter_result_is_target(self):
  940. f = io.StringIO()
  941. with self.redirect_stream(f) as enter_result:
  942. self.assertIs(enter_result, f)
  943. def test_cm_is_reusable(self):
  944. f = io.StringIO()
  945. write_to_f = self.redirect_stream(f)
  946. orig_stdout = getattr(sys, self.orig_stream)
  947. with write_to_f:
  948. print("Hello", end=" ", file=getattr(sys, self.orig_stream))
  949. with write_to_f:
  950. print("World!", file=getattr(sys, self.orig_stream))
  951. self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
  952. s = f.getvalue()
  953. self.assertEqual(s, "Hello World!\n")
  954. def test_cm_is_reentrant(self):
  955. f = io.StringIO()
  956. write_to_f = self.redirect_stream(f)
  957. orig_stdout = getattr(sys, self.orig_stream)
  958. with write_to_f:
  959. print("Hello", end=" ", file=getattr(sys, self.orig_stream))
  960. with write_to_f:
  961. print("World!", file=getattr(sys, self.orig_stream))
  962. self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
  963. s = f.getvalue()
  964. self.assertEqual(s, "Hello World!\n")
  965. class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
  966. redirect_stream = redirect_stdout
  967. orig_stream = "stdout"
  968. class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
  969. redirect_stream = redirect_stderr
  970. orig_stream = "stderr"
  971. class TestSuppress(unittest.TestCase):
  972. @support.requires_docstrings
  973. def test_instance_docs(self):
  974. # Issue 19330: ensure context manager instances have good docstrings
  975. cm_docstring = suppress.__doc__
  976. obj = suppress()
  977. self.assertEqual(obj.__doc__, cm_docstring)
  978. def test_no_result_from_enter(self):
  979. with suppress(ValueError) as enter_result:
  980. self.assertIsNone(enter_result)
  981. def test_no_exception(self):
  982. with suppress(ValueError):
  983. self.assertEqual(pow(2, 5), 32)
  984. def test_exact_exception(self):
  985. with suppress(TypeError):
  986. len(5)
  987. def test_exception_hierarchy(self):
  988. with suppress(LookupError):
  989. 'Hello'[50]
  990. def test_other_exception(self):
  991. with self.assertRaises(ZeroDivisionError):
  992. with suppress(TypeError):
  993. 1/0
  994. def test_no_args(self):
  995. with self.assertRaises(ZeroDivisionError):
  996. with suppress():
  997. 1/0
  998. def test_multiple_exception_args(self):
  999. with suppress(ZeroDivisionError, TypeError):
  1000. 1/0
  1001. with suppress(ZeroDivisionError, TypeError):
  1002. len(5)
  1003. def test_cm_is_reentrant(self):
  1004. ignore_exceptions = suppress(Exception)
  1005. with ignore_exceptions:
  1006. pass
  1007. with ignore_exceptions:
  1008. len(5)
  1009. with ignore_exceptions:
  1010. with ignore_exceptions: # Check nested usage
  1011. len(5)
  1012. outer_continued = True
  1013. 1/0
  1014. self.assertTrue(outer_continued)
  1015. class TestChdir(unittest.TestCase):
  1016. def make_relative_path(self, *parts):
  1017. return os.path.join(
  1018. os.path.dirname(os.path.realpath(__file__)),
  1019. *parts,
  1020. )
  1021. def test_simple(self):
  1022. old_cwd = os.getcwd()
  1023. target = self.make_relative_path('data')
  1024. self.assertNotEqual(old_cwd, target)
  1025. with chdir(target):
  1026. self.assertEqual(os.getcwd(), target)
  1027. self.assertEqual(os.getcwd(), old_cwd)
  1028. def test_reentrant(self):
  1029. old_cwd = os.getcwd()
  1030. target1 = self.make_relative_path('data')
  1031. target2 = self.make_relative_path('ziptestdata')
  1032. self.assertNotIn(old_cwd, (target1, target2))
  1033. chdir1, chdir2 = chdir(target1), chdir(target2)
  1034. with chdir1:
  1035. self.assertEqual(os.getcwd(), target1)
  1036. with chdir2:
  1037. self.assertEqual(os.getcwd(), target2)
  1038. with chdir1:
  1039. self.assertEqual(os.getcwd(), target1)
  1040. self.assertEqual(os.getcwd(), target2)
  1041. self.assertEqual(os.getcwd(), target1)
  1042. self.assertEqual(os.getcwd(), old_cwd)
  1043. def test_exception(self):
  1044. old_cwd = os.getcwd()
  1045. target = self.make_relative_path('data')
  1046. self.assertNotEqual(old_cwd, target)
  1047. try:
  1048. with chdir(target):
  1049. self.assertEqual(os.getcwd(), target)
  1050. raise RuntimeError("boom")
  1051. except RuntimeError as re:
  1052. self.assertEqual(str(re), "boom")
  1053. self.assertEqual(os.getcwd(), old_cwd)
  1054. if __name__ == "__main__":
  1055. unittest.main()