timeouts.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import enum
  2. from types import TracebackType
  3. from typing import final, Optional, Type
  4. from . import events
  5. from . import exceptions
  6. from . import tasks
  7. __all__ = (
  8. "Timeout",
  9. "timeout",
  10. "timeout_at",
  11. )
  12. class _State(enum.Enum):
  13. CREATED = "created"
  14. ENTERED = "active"
  15. EXPIRING = "expiring"
  16. EXPIRED = "expired"
  17. EXITED = "finished"
  18. @final
  19. class Timeout:
  20. def __init__(self, when: Optional[float]) -> None:
  21. self._state = _State.CREATED
  22. self._timeout_handler: Optional[events.TimerHandle] = None
  23. self._task: Optional[tasks.Task] = None
  24. self._when = when
  25. def when(self) -> Optional[float]:
  26. return self._when
  27. def reschedule(self, when: Optional[float]) -> None:
  28. assert self._state is not _State.CREATED
  29. if self._state is not _State.ENTERED:
  30. raise RuntimeError(
  31. f"Cannot change state of {self._state.value} Timeout",
  32. )
  33. self._when = when
  34. if self._timeout_handler is not None:
  35. self._timeout_handler.cancel()
  36. if when is None:
  37. self._timeout_handler = None
  38. else:
  39. loop = events.get_running_loop()
  40. if when <= loop.time():
  41. self._timeout_handler = loop.call_soon(self._on_timeout)
  42. else:
  43. self._timeout_handler = loop.call_at(when, self._on_timeout)
  44. def expired(self) -> bool:
  45. """Is timeout expired during execution?"""
  46. return self._state in (_State.EXPIRING, _State.EXPIRED)
  47. def __repr__(self) -> str:
  48. info = ['']
  49. if self._state is _State.ENTERED:
  50. when = round(self._when, 3) if self._when is not None else None
  51. info.append(f"when={when}")
  52. info_str = ' '.join(info)
  53. return f"<Timeout [{self._state.value}]{info_str}>"
  54. async def __aenter__(self) -> "Timeout":
  55. self._state = _State.ENTERED
  56. self._task = tasks.current_task()
  57. if self._task is None:
  58. raise RuntimeError("Timeout should be used inside a task")
  59. self.reschedule(self._when)
  60. return self
  61. async def __aexit__(
  62. self,
  63. exc_type: Optional[Type[BaseException]],
  64. exc_val: Optional[BaseException],
  65. exc_tb: Optional[TracebackType],
  66. ) -> Optional[bool]:
  67. assert self._state in (_State.ENTERED, _State.EXPIRING)
  68. if self._timeout_handler is not None:
  69. self._timeout_handler.cancel()
  70. self._timeout_handler = None
  71. if self._state is _State.EXPIRING:
  72. self._state = _State.EXPIRED
  73. if self._task.uncancel() == 0 and exc_type is exceptions.CancelledError:
  74. # Since there are no outstanding cancel requests, we're
  75. # handling this.
  76. raise TimeoutError
  77. elif self._state is _State.ENTERED:
  78. self._state = _State.EXITED
  79. return None
  80. def _on_timeout(self) -> None:
  81. assert self._state is _State.ENTERED
  82. self._task.cancel()
  83. self._state = _State.EXPIRING
  84. # drop the reference early
  85. self._timeout_handler = None
  86. def timeout(delay: Optional[float]) -> Timeout:
  87. """Timeout async context manager.
  88. Useful in cases when you want to apply timeout logic around block
  89. of code or in cases when asyncio.wait_for is not suitable. For example:
  90. >>> async with asyncio.timeout(10): # 10 seconds timeout
  91. ... await long_running_task()
  92. delay - value in seconds or None to disable timeout logic
  93. long_running_task() is interrupted by raising asyncio.CancelledError,
  94. the top-most affected timeout() context manager converts CancelledError
  95. into TimeoutError.
  96. """
  97. loop = events.get_running_loop()
  98. return Timeout(loop.time() + delay if delay is not None else None)
  99. def timeout_at(when: Optional[float]) -> Timeout:
  100. """Schedule the timeout at absolute time.
  101. Like timeout() but argument gives absolute time in the same clock system
  102. as loop.time().
  103. Please note: it is not POSIX time but a time with
  104. undefined starting base, e.g. the time of the system power on.
  105. >>> async with asyncio.timeout_at(loop.time() + 10):
  106. ... await long_running_task()
  107. when - a deadline when timeout occurs or None to disable timeout logic
  108. long_running_task() is interrupted by raising asyncio.CancelledError,
  109. the top-most affected timeout() context manager converts CancelledError
  110. into TimeoutError.
  111. """
  112. return Timeout(when)