lowlevel.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. __all__ = (
  3. "EventLoopToken",
  4. "RunvarToken",
  5. "RunVar",
  6. "checkpoint",
  7. "checkpoint_if_cancelled",
  8. "cancel_shielded_checkpoint",
  9. "current_token",
  10. )
  11. import enum
  12. from dataclasses import dataclass
  13. from types import TracebackType
  14. from typing import Any, Generic, Literal, TypeVar, final, overload
  15. from weakref import WeakKeyDictionary
  16. from ._core._eventloop import get_async_backend
  17. from .abc import AsyncBackend
  18. T = TypeVar("T")
  19. D = TypeVar("D")
  20. async def checkpoint() -> None:
  21. """
  22. Check for cancellation and allow the scheduler to switch to another task.
  23. Equivalent to (but more efficient than)::
  24. await checkpoint_if_cancelled()
  25. await cancel_shielded_checkpoint()
  26. .. versionadded:: 3.0
  27. """
  28. await get_async_backend().checkpoint()
  29. async def checkpoint_if_cancelled() -> None:
  30. """
  31. Enter a checkpoint if the enclosing cancel scope has been cancelled.
  32. This does not allow the scheduler to switch to a different task.
  33. .. versionadded:: 3.0
  34. """
  35. await get_async_backend().checkpoint_if_cancelled()
  36. async def cancel_shielded_checkpoint() -> None:
  37. """
  38. Allow the scheduler to switch to another task but without checking for cancellation.
  39. Equivalent to (but potentially more efficient than)::
  40. with CancelScope(shield=True):
  41. await checkpoint()
  42. .. versionadded:: 3.0
  43. """
  44. await get_async_backend().cancel_shielded_checkpoint()
  45. @final
  46. @dataclass(frozen=True, repr=False)
  47. class EventLoopToken:
  48. """
  49. An opaque object that holds a reference to an event loop.
  50. .. versionadded:: 4.11.0
  51. """
  52. backend_class: type[AsyncBackend]
  53. native_token: object
  54. def current_token() -> EventLoopToken:
  55. """
  56. Return a token object that can be used to call code in the current event loop from
  57. another thread.
  58. .. versionadded:: 4.11.0
  59. """
  60. backend_class = get_async_backend()
  61. raw_token = backend_class.current_token()
  62. return EventLoopToken(backend_class, raw_token)
  63. _run_vars: WeakKeyDictionary[object, dict[RunVar[Any], Any]] = WeakKeyDictionary()
  64. class _NoValueSet(enum.Enum):
  65. NO_VALUE_SET = enum.auto()
  66. class RunvarToken(Generic[T]):
  67. __slots__ = "_var", "_value", "_redeemed"
  68. def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
  69. self._var = var
  70. self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
  71. self._redeemed = False
  72. def __enter__(self) -> RunvarToken[T]:
  73. return self
  74. def __exit__(
  75. self,
  76. exc_type: type[BaseException] | None,
  77. exc_val: BaseException | None,
  78. exc_tb: TracebackType | None,
  79. ) -> None:
  80. self._var.reset(self)
  81. class RunVar(Generic[T]):
  82. """
  83. Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
  84. Can be used as a context manager, Just like :class:`~contextvars.ContextVar`, that
  85. will reset the variable to its previous value when the context block is exited.
  86. """
  87. __slots__ = "_name", "_default"
  88. NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
  89. def __init__(
  90. self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  91. ):
  92. self._name = name
  93. self._default = default
  94. @property
  95. def _current_vars(self) -> dict[RunVar[T], T]:
  96. native_token = current_token().native_token
  97. try:
  98. return _run_vars[native_token]
  99. except KeyError:
  100. run_vars = _run_vars[native_token] = {}
  101. return run_vars
  102. @overload
  103. def get(self, default: D) -> T | D: ...
  104. @overload
  105. def get(self) -> T: ...
  106. def get(
  107. self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  108. ) -> T | D:
  109. try:
  110. return self._current_vars[self]
  111. except KeyError:
  112. if default is not RunVar.NO_VALUE_SET:
  113. return default
  114. elif self._default is not RunVar.NO_VALUE_SET:
  115. return self._default
  116. raise LookupError(
  117. f'Run variable "{self._name}" has no value and no default set'
  118. )
  119. def set(self, value: T) -> RunvarToken[T]:
  120. current_vars = self._current_vars
  121. token = RunvarToken(self, current_vars.get(self, RunVar.NO_VALUE_SET))
  122. current_vars[self] = value
  123. return token
  124. def reset(self, token: RunvarToken[T]) -> None:
  125. if token._var is not self:
  126. raise ValueError("This token does not belong to this RunVar")
  127. if token._redeemed:
  128. raise ValueError("This token has already been used")
  129. if token._value is _NoValueSet.NO_VALUE_SET:
  130. try:
  131. del self._current_vars[self]
  132. except KeyError:
  133. pass
  134. else:
  135. self._current_vars[self] = token._value
  136. token._redeemed = True
  137. def __repr__(self) -> str:
  138. return f"<RunVar name={self._name!r}>"