taskgroups.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Adapted with permission from the EdgeDB project;
  2. # license: PSFL.
  3. __all__ = ["TaskGroup"]
  4. from . import events
  5. from . import exceptions
  6. from . import tasks
  7. class TaskGroup:
  8. def __init__(self):
  9. self._entered = False
  10. self._exiting = False
  11. self._aborting = False
  12. self._loop = None
  13. self._parent_task = None
  14. self._parent_cancel_requested = False
  15. self._tasks = set()
  16. self._errors = []
  17. self._base_error = None
  18. self._on_completed_fut = None
  19. def __repr__(self):
  20. info = ['']
  21. if self._tasks:
  22. info.append(f'tasks={len(self._tasks)}')
  23. if self._errors:
  24. info.append(f'errors={len(self._errors)}')
  25. if self._aborting:
  26. info.append('cancelling')
  27. elif self._entered:
  28. info.append('entered')
  29. info_str = ' '.join(info)
  30. return f'<TaskGroup{info_str}>'
  31. async def __aenter__(self):
  32. if self._entered:
  33. raise RuntimeError(
  34. f"TaskGroup {self!r} has been already entered")
  35. self._entered = True
  36. if self._loop is None:
  37. self._loop = events.get_running_loop()
  38. self._parent_task = tasks.current_task(self._loop)
  39. if self._parent_task is None:
  40. raise RuntimeError(
  41. f'TaskGroup {self!r} cannot determine the parent task')
  42. return self
  43. async def __aexit__(self, et, exc, tb):
  44. self._exiting = True
  45. if (exc is not None and
  46. self._is_base_error(exc) and
  47. self._base_error is None):
  48. self._base_error = exc
  49. propagate_cancellation_error = \
  50. exc if et is exceptions.CancelledError else None
  51. if self._parent_cancel_requested:
  52. # If this flag is set we *must* call uncancel().
  53. if self._parent_task.uncancel() == 0:
  54. # If there are no pending cancellations left,
  55. # don't propagate CancelledError.
  56. propagate_cancellation_error = None
  57. if et is not None:
  58. if not self._aborting:
  59. # Our parent task is being cancelled:
  60. #
  61. # async with TaskGroup() as g:
  62. # g.create_task(...)
  63. # await ... # <- CancelledError
  64. #
  65. # or there's an exception in "async with":
  66. #
  67. # async with TaskGroup() as g:
  68. # g.create_task(...)
  69. # 1 / 0
  70. #
  71. self._abort()
  72. # We use while-loop here because "self._on_completed_fut"
  73. # can be cancelled multiple times if our parent task
  74. # is being cancelled repeatedly (or even once, when
  75. # our own cancellation is already in progress)
  76. while self._tasks:
  77. if self._on_completed_fut is None:
  78. self._on_completed_fut = self._loop.create_future()
  79. try:
  80. await self._on_completed_fut
  81. except exceptions.CancelledError as ex:
  82. if not self._aborting:
  83. # Our parent task is being cancelled:
  84. #
  85. # async def wrapper():
  86. # async with TaskGroup() as g:
  87. # g.create_task(foo)
  88. #
  89. # "wrapper" is being cancelled while "foo" is
  90. # still running.
  91. propagate_cancellation_error = ex
  92. self._abort()
  93. self._on_completed_fut = None
  94. assert not self._tasks
  95. if self._base_error is not None:
  96. raise self._base_error
  97. # Propagate CancelledError if there is one, except if there
  98. # are other errors -- those have priority.
  99. if propagate_cancellation_error and not self._errors:
  100. raise propagate_cancellation_error
  101. if et is not None and et is not exceptions.CancelledError:
  102. self._errors.append(exc)
  103. if self._errors:
  104. # Exceptions are heavy objects that can have object
  105. # cycles (bad for GC); let's not keep a reference to
  106. # a bunch of them.
  107. try:
  108. me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
  109. raise me from None
  110. finally:
  111. self._errors = None
  112. def create_task(self, coro, *, name=None, context=None):
  113. if not self._entered:
  114. raise RuntimeError(f"TaskGroup {self!r} has not been entered")
  115. if self._exiting and not self._tasks:
  116. raise RuntimeError(f"TaskGroup {self!r} is finished")
  117. if self._aborting:
  118. raise RuntimeError(f"TaskGroup {self!r} is shutting down")
  119. if context is None:
  120. task = self._loop.create_task(coro)
  121. else:
  122. task = self._loop.create_task(coro, context=context)
  123. tasks._set_task_name(task, name)
  124. task.add_done_callback(self._on_task_done)
  125. self._tasks.add(task)
  126. return task
  127. # Since Python 3.8 Tasks propagate all exceptions correctly,
  128. # except for KeyboardInterrupt and SystemExit which are
  129. # still considered special.
  130. def _is_base_error(self, exc: BaseException) -> bool:
  131. assert isinstance(exc, BaseException)
  132. return isinstance(exc, (SystemExit, KeyboardInterrupt))
  133. def _abort(self):
  134. self._aborting = True
  135. for t in self._tasks:
  136. if not t.done():
  137. t.cancel()
  138. def _on_task_done(self, task):
  139. self._tasks.discard(task)
  140. if self._on_completed_fut is not None and not self._tasks:
  141. if not self._on_completed_fut.done():
  142. self._on_completed_fut.set_result(True)
  143. if task.cancelled():
  144. return
  145. exc = task.exception()
  146. if exc is None:
  147. return
  148. self._errors.append(exc)
  149. if self._is_base_error(exc) and self._base_error is None:
  150. self._base_error = exc
  151. if self._parent_task.done():
  152. # Not sure if this case is possible, but we want to handle
  153. # it anyways.
  154. self._loop.call_exception_handler({
  155. 'message': f'Task {task!r} has errored out but its parent '
  156. f'task {self._parent_task} is already completed',
  157. 'exception': exc,
  158. 'task': task,
  159. })
  160. return
  161. if not self._aborting and not self._parent_cancel_requested:
  162. # If parent task *is not* being cancelled, it means that we want
  163. # to manually cancel it to abort whatever is being run right now
  164. # in the TaskGroup. But we want to mark parent task as
  165. # "not cancelled" later in __aexit__. Example situation that
  166. # we need to handle:
  167. #
  168. # async def foo():
  169. # try:
  170. # async with TaskGroup() as g:
  171. # g.create_task(crash_soon())
  172. # await something # <- this needs to be canceled
  173. # # by the TaskGroup, e.g.
  174. # # foo() needs to be cancelled
  175. # except Exception:
  176. # # Ignore any exceptions raised in the TaskGroup
  177. # pass
  178. # await something_else # this line has to be called
  179. # # after TaskGroup is finished.
  180. self._abort()
  181. self._parent_cancel_requested = True
  182. self._parent_task.cancel()