lowlevel.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from __future__ import annotations
  2. import enum
  3. import sys
  4. from dataclasses import dataclass
  5. from typing import Any, Generic, TypeVar, overload
  6. from weakref import WeakKeyDictionary
  7. from ._core._eventloop import get_asynclib
  8. if sys.version_info >= (3, 8):
  9. from typing import Literal
  10. else:
  11. from typing_extensions import Literal
  12. T = TypeVar("T")
  13. D = TypeVar("D")
  14. async def checkpoint() -> None:
  15. """
  16. Check for cancellation and allow the scheduler to switch to another task.
  17. Equivalent to (but more efficient than)::
  18. await checkpoint_if_cancelled()
  19. await cancel_shielded_checkpoint()
  20. .. versionadded:: 3.0
  21. """
  22. await get_asynclib().checkpoint()
  23. async def checkpoint_if_cancelled() -> None:
  24. """
  25. Enter a checkpoint if the enclosing cancel scope has been cancelled.
  26. This does not allow the scheduler to switch to a different task.
  27. .. versionadded:: 3.0
  28. """
  29. await get_asynclib().checkpoint_if_cancelled()
  30. async def cancel_shielded_checkpoint() -> None:
  31. """
  32. Allow the scheduler to switch to another task but without checking for cancellation.
  33. Equivalent to (but potentially more efficient than)::
  34. with CancelScope(shield=True):
  35. await checkpoint()
  36. .. versionadded:: 3.0
  37. """
  38. await get_asynclib().cancel_shielded_checkpoint()
  39. def current_token() -> object:
  40. """Return a backend specific token object that can be used to get back to the event loop."""
  41. return get_asynclib().current_token()
  42. _run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary()
  43. _token_wrappers: dict[Any, _TokenWrapper] = {}
  44. @dataclass(frozen=True)
  45. class _TokenWrapper:
  46. __slots__ = "_token", "__weakref__"
  47. _token: object
  48. class _NoValueSet(enum.Enum):
  49. NO_VALUE_SET = enum.auto()
  50. class RunvarToken(Generic[T]):
  51. __slots__ = "_var", "_value", "_redeemed"
  52. def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
  53. self._var = var
  54. self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
  55. self._redeemed = False
  56. class RunVar(Generic[T]):
  57. """
  58. Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
  59. """
  60. __slots__ = "_name", "_default"
  61. NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
  62. _token_wrappers: set[_TokenWrapper] = set()
  63. def __init__(
  64. self,
  65. name: str,
  66. default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET,
  67. ):
  68. self._name = name
  69. self._default = default
  70. @property
  71. def _current_vars(self) -> dict[str, T]:
  72. token = current_token()
  73. while True:
  74. try:
  75. return _run_vars[token]
  76. except TypeError:
  77. # Happens when token isn't weak referable (TrioToken).
  78. # This workaround does mean that some memory will leak on Trio until the problem
  79. # is fixed on their end.
  80. token = _TokenWrapper(token)
  81. self._token_wrappers.add(token)
  82. except KeyError:
  83. run_vars = _run_vars[token] = {}
  84. return run_vars
  85. @overload
  86. def get(self, default: D) -> T | D:
  87. ...
  88. @overload
  89. def get(self) -> T:
  90. ...
  91. def get(
  92. self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  93. ) -> T | D:
  94. try:
  95. return self._current_vars[self._name]
  96. except KeyError:
  97. if default is not RunVar.NO_VALUE_SET:
  98. return default
  99. elif self._default is not RunVar.NO_VALUE_SET:
  100. return self._default
  101. raise LookupError(
  102. f'Run variable "{self._name}" has no value and no default set'
  103. )
  104. def set(self, value: T) -> RunvarToken[T]:
  105. current_vars = self._current_vars
  106. token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
  107. current_vars[self._name] = value
  108. return token
  109. def reset(self, token: RunvarToken[T]) -> None:
  110. if token._var is not self:
  111. raise ValueError("This token does not belong to this RunVar")
  112. if token._redeemed:
  113. raise ValueError("This token has already been used")
  114. if token._value is _NoValueSet.NO_VALUE_SET:
  115. try:
  116. del self._current_vars[self._name]
  117. except KeyError:
  118. pass
  119. else:
  120. self._current_vars[self._name] = token._value
  121. token._redeemed = True
  122. def __repr__(self) -> str:
  123. return f"<RunVar name={self._name!r}>"