to_interpreter.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from __future__ import annotations
  2. __all__ = (
  3. "run_sync",
  4. "current_default_interpreter_limiter",
  5. )
  6. import atexit
  7. import os
  8. import sys
  9. from collections import deque
  10. from collections.abc import Callable
  11. from typing import Any, Final, TypeVar
  12. from . import current_time, to_thread
  13. from ._core._exceptions import BrokenWorkerInterpreter
  14. from ._core._synchronization import CapacityLimiter
  15. from .lowlevel import RunVar
  16. if sys.version_info >= (3, 11):
  17. from typing import TypeVarTuple, Unpack
  18. else:
  19. from typing_extensions import TypeVarTuple, Unpack
  20. if sys.version_info >= (3, 14):
  21. from concurrent.interpreters import ExecutionFailed, create
  22. def _interp_call(
  23. func: Callable[..., Any], args: tuple[Any, ...]
  24. ) -> tuple[Any, bool]:
  25. try:
  26. retval = func(*args)
  27. except BaseException as exc:
  28. return exc, True
  29. else:
  30. return retval, False
  31. class _Worker:
  32. last_used: float = 0
  33. def __init__(self) -> None:
  34. self._interpreter = create()
  35. def destroy(self) -> None:
  36. self._interpreter.close()
  37. def call(
  38. self,
  39. func: Callable[..., T_Retval],
  40. args: tuple[Any, ...],
  41. ) -> T_Retval:
  42. try:
  43. res, is_exception = self._interpreter.call(_interp_call, func, args)
  44. except ExecutionFailed as exc:
  45. raise BrokenWorkerInterpreter(exc.excinfo) from exc
  46. if is_exception:
  47. raise res
  48. return res
  49. elif sys.version_info >= (3, 13):
  50. import _interpqueues
  51. import _interpreters
  52. UNBOUND: Final = 2 # I have no clue how this works, but it was used in the stdlib
  53. FMT_UNPICKLED: Final = 0
  54. FMT_PICKLED: Final = 1
  55. QUEUE_PICKLE_ARGS: Final = (FMT_PICKLED, UNBOUND)
  56. QUEUE_UNPICKLE_ARGS: Final = (FMT_UNPICKLED, UNBOUND)
  57. _run_func = compile(
  58. """
  59. import _interpqueues
  60. from _interpreters import NotShareableError
  61. from pickle import loads, dumps, HIGHEST_PROTOCOL
  62. QUEUE_PICKLE_ARGS = (1, 2)
  63. QUEUE_UNPICKLE_ARGS = (0, 2)
  64. item = _interpqueues.get(queue_id)[0]
  65. try:
  66. func, args = loads(item)
  67. retval = func(*args)
  68. except BaseException as exc:
  69. is_exception = True
  70. retval = exc
  71. else:
  72. is_exception = False
  73. try:
  74. _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_UNPICKLE_ARGS)
  75. except NotShareableError:
  76. retval = dumps(retval, HIGHEST_PROTOCOL)
  77. _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_PICKLE_ARGS)
  78. """,
  79. "<string>",
  80. "exec",
  81. )
  82. class _Worker:
  83. last_used: float = 0
  84. def __init__(self) -> None:
  85. self._interpreter_id = _interpreters.create()
  86. self._queue_id = _interpqueues.create(1, *QUEUE_UNPICKLE_ARGS)
  87. _interpreters.set___main___attrs(
  88. self._interpreter_id, {"queue_id": self._queue_id}
  89. )
  90. def destroy(self) -> None:
  91. _interpqueues.destroy(self._queue_id)
  92. _interpreters.destroy(self._interpreter_id)
  93. def call(
  94. self,
  95. func: Callable[..., T_Retval],
  96. args: tuple[Any, ...],
  97. ) -> T_Retval:
  98. import pickle
  99. item = pickle.dumps((func, args), pickle.HIGHEST_PROTOCOL)
  100. _interpqueues.put(self._queue_id, item, *QUEUE_PICKLE_ARGS)
  101. exc_info = _interpreters.exec(self._interpreter_id, _run_func)
  102. if exc_info:
  103. raise BrokenWorkerInterpreter(exc_info)
  104. res = _interpqueues.get(self._queue_id)
  105. (res, is_exception), fmt = res[:2]
  106. if fmt == FMT_PICKLED:
  107. res = pickle.loads(res)
  108. if is_exception:
  109. raise res
  110. return res
  111. else:
  112. class _Worker:
  113. last_used: float = 0
  114. def __init__(self) -> None:
  115. raise RuntimeError("subinterpreters require at least Python 3.13")
  116. def call(
  117. self,
  118. func: Callable[..., T_Retval],
  119. args: tuple[Any, ...],
  120. ) -> T_Retval:
  121. raise NotImplementedError
  122. def destroy(self) -> None:
  123. pass
  124. DEFAULT_CPU_COUNT: Final = 8 # this is just an arbitrarily selected value
  125. MAX_WORKER_IDLE_TIME = (
  126. 30 # seconds a subinterpreter can be idle before becoming eligible for pruning
  127. )
  128. T_Retval = TypeVar("T_Retval")
  129. PosArgsT = TypeVarTuple("PosArgsT")
  130. _idle_workers = RunVar[deque[_Worker]]("_available_workers")
  131. _default_interpreter_limiter = RunVar[CapacityLimiter]("_default_interpreter_limiter")
  132. def _stop_workers(workers: deque[_Worker]) -> None:
  133. for worker in workers:
  134. worker.destroy()
  135. workers.clear()
  136. async def run_sync(
  137. func: Callable[[Unpack[PosArgsT]], T_Retval],
  138. *args: Unpack[PosArgsT],
  139. limiter: CapacityLimiter | None = None,
  140. ) -> T_Retval:
  141. """
  142. Call the given function with the given arguments in a subinterpreter.
  143. .. warning:: On Python 3.13, the :mod:`concurrent.interpreters` module was not yet
  144. available, so the code path for that Python version relies on an undocumented,
  145. private API. As such, it is recommended to not rely on this function for anything
  146. mission-critical on Python 3.13.
  147. :param func: a callable
  148. :param args: the positional arguments for the callable
  149. :param limiter: capacity limiter to use to limit the total number of subinterpreters
  150. running (if omitted, the default limiter is used)
  151. :return: the result of the call
  152. :raises BrokenWorkerInterpreter: if there's an internal error in a subinterpreter
  153. """
  154. if limiter is None:
  155. limiter = current_default_interpreter_limiter()
  156. try:
  157. idle_workers = _idle_workers.get()
  158. except LookupError:
  159. idle_workers = deque()
  160. _idle_workers.set(idle_workers)
  161. atexit.register(_stop_workers, idle_workers)
  162. async with limiter:
  163. try:
  164. worker = idle_workers.pop()
  165. except IndexError:
  166. worker = _Worker()
  167. try:
  168. return await to_thread.run_sync(
  169. worker.call,
  170. func,
  171. args,
  172. limiter=limiter,
  173. )
  174. finally:
  175. # Prune workers that have been idle for too long
  176. now = current_time()
  177. while idle_workers:
  178. if now - idle_workers[0].last_used <= MAX_WORKER_IDLE_TIME:
  179. break
  180. await to_thread.run_sync(idle_workers.popleft().destroy, limiter=limiter)
  181. worker.last_used = current_time()
  182. idle_workers.append(worker)
  183. def current_default_interpreter_limiter() -> CapacityLimiter:
  184. """
  185. Return the capacity limiter used by default to limit the number of concurrently
  186. running subinterpreters.
  187. Defaults to the number of CPU cores.
  188. :return: a capacity limiter object
  189. """
  190. try:
  191. return _default_interpreter_limiter.get()
  192. except LookupError:
  193. limiter = CapacityLimiter(os.cpu_count() or DEFAULT_CPU_COUNT)
  194. _default_interpreter_limiter.set(limiter)
  195. return limiter