from_thread.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. from __future__ import annotations
  2. __all__ = (
  3. "BlockingPortal",
  4. "BlockingPortalProvider",
  5. "check_cancelled",
  6. "run",
  7. "run_sync",
  8. "start_blocking_portal",
  9. )
  10. import sys
  11. from collections.abc import Awaitable, Callable, Generator
  12. from concurrent.futures import Future
  13. from contextlib import (
  14. AbstractAsyncContextManager,
  15. AbstractContextManager,
  16. contextmanager,
  17. )
  18. from dataclasses import dataclass, field
  19. from inspect import isawaitable
  20. from threading import Lock, Thread, current_thread, get_ident
  21. from types import TracebackType
  22. from typing import (
  23. Any,
  24. Generic,
  25. TypeVar,
  26. cast,
  27. overload,
  28. )
  29. from ._core._eventloop import (
  30. get_async_backend,
  31. get_cancelled_exc_class,
  32. threadlocals,
  33. )
  34. from ._core._eventloop import run as run_eventloop
  35. from ._core._exceptions import NoEventLoopError
  36. from ._core._synchronization import Event
  37. from ._core._tasks import CancelScope, create_task_group
  38. from .abc._tasks import TaskStatus
  39. from .lowlevel import EventLoopToken
  40. if sys.version_info >= (3, 11):
  41. from typing import TypeVarTuple, Unpack
  42. else:
  43. from typing_extensions import TypeVarTuple, Unpack
  44. T_Retval = TypeVar("T_Retval")
  45. T_co = TypeVar("T_co", covariant=True)
  46. PosArgsT = TypeVarTuple("PosArgsT")
  47. def _token_or_error(token: EventLoopToken | None) -> EventLoopToken:
  48. if token is not None:
  49. return token
  50. try:
  51. return threadlocals.current_token
  52. except AttributeError:
  53. raise NoEventLoopError(
  54. "Not running inside an AnyIO worker thread, and no event loop token was "
  55. "provided"
  56. ) from None
  57. def run(
  58. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  59. *args: Unpack[PosArgsT],
  60. token: EventLoopToken | None = None,
  61. ) -> T_Retval:
  62. """
  63. Call a coroutine function from a worker thread.
  64. :param func: a coroutine function
  65. :param args: positional arguments for the callable
  66. :param token: an event loop token to use to get back to the event loop thread
  67. (required if calling this function from outside an AnyIO worker thread)
  68. :return: the return value of the coroutine function
  69. :raises MissingTokenError: if no token was provided and called from outside an
  70. AnyIO worker thread
  71. :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
  72. .. versionchanged:: 4.11.0
  73. Added the ``token`` parameter.
  74. """
  75. explicit_token = token is not None
  76. token = _token_or_error(token)
  77. return token.backend_class.run_async_from_thread(
  78. func, args, token=token.native_token if explicit_token else None
  79. )
  80. def run_sync(
  81. func: Callable[[Unpack[PosArgsT]], T_Retval],
  82. *args: Unpack[PosArgsT],
  83. token: EventLoopToken | None = None,
  84. ) -> T_Retval:
  85. """
  86. Call a function in the event loop thread from a worker thread.
  87. :param func: a callable
  88. :param args: positional arguments for the callable
  89. :param token: an event loop token to use to get back to the event loop thread
  90. (required if calling this function from outside an AnyIO worker thread)
  91. :return: the return value of the callable
  92. :raises MissingTokenError: if no token was provided and called from outside an
  93. AnyIO worker thread
  94. :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
  95. .. versionchanged:: 4.11.0
  96. Added the ``token`` parameter.
  97. """
  98. explicit_token = token is not None
  99. token = _token_or_error(token)
  100. return token.backend_class.run_sync_from_thread(
  101. func, args, token=token.native_token if explicit_token else None
  102. )
  103. class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
  104. _enter_future: Future[T_co]
  105. _exit_future: Future[bool | None]
  106. _exit_event: Event
  107. _exit_exc_info: tuple[
  108. type[BaseException] | None, BaseException | None, TracebackType | None
  109. ] = (None, None, None)
  110. def __init__(
  111. self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
  112. ):
  113. self._async_cm = async_cm
  114. self._portal = portal
  115. async def run_async_cm(self) -> bool | None:
  116. try:
  117. self._exit_event = Event()
  118. value = await self._async_cm.__aenter__()
  119. except BaseException as exc:
  120. self._enter_future.set_exception(exc)
  121. raise
  122. else:
  123. self._enter_future.set_result(value)
  124. try:
  125. # Wait for the sync context manager to exit.
  126. # This next statement can raise `get_cancelled_exc_class()` if
  127. # something went wrong in a task group in this async context
  128. # manager.
  129. await self._exit_event.wait()
  130. finally:
  131. # In case of cancellation, it could be that we end up here before
  132. # `_BlockingAsyncContextManager.__exit__` is called, and an
  133. # `_exit_exc_info` has been set.
  134. result = await self._async_cm.__aexit__(*self._exit_exc_info)
  135. return result
  136. def __enter__(self) -> T_co:
  137. self._enter_future = Future()
  138. self._exit_future = self._portal.start_task_soon(self.run_async_cm)
  139. return self._enter_future.result()
  140. def __exit__(
  141. self,
  142. __exc_type: type[BaseException] | None,
  143. __exc_value: BaseException | None,
  144. __traceback: TracebackType | None,
  145. ) -> bool | None:
  146. self._exit_exc_info = __exc_type, __exc_value, __traceback
  147. self._portal.call(self._exit_event.set)
  148. return self._exit_future.result()
  149. class _BlockingPortalTaskStatus(TaskStatus):
  150. def __init__(self, future: Future):
  151. self._future = future
  152. def started(self, value: object = None) -> None:
  153. self._future.set_result(value)
  154. class BlockingPortal:
  155. """An object that lets external threads run code in an asynchronous event loop."""
  156. def __new__(cls) -> BlockingPortal:
  157. return get_async_backend().create_blocking_portal()
  158. def __init__(self) -> None:
  159. self._event_loop_thread_id: int | None = get_ident()
  160. self._stop_event = Event()
  161. self._task_group = create_task_group()
  162. self._cancelled_exc_class = get_cancelled_exc_class()
  163. async def __aenter__(self) -> BlockingPortal:
  164. await self._task_group.__aenter__()
  165. return self
  166. async def __aexit__(
  167. self,
  168. exc_type: type[BaseException] | None,
  169. exc_val: BaseException | None,
  170. exc_tb: TracebackType | None,
  171. ) -> bool:
  172. await self.stop()
  173. return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
  174. def _check_running(self) -> None:
  175. if self._event_loop_thread_id is None:
  176. raise RuntimeError("This portal is not running")
  177. if self._event_loop_thread_id == get_ident():
  178. raise RuntimeError(
  179. "This method cannot be called from the event loop thread"
  180. )
  181. async def sleep_until_stopped(self) -> None:
  182. """Sleep until :meth:`stop` is called."""
  183. await self._stop_event.wait()
  184. async def stop(self, cancel_remaining: bool = False) -> None:
  185. """
  186. Signal the portal to shut down.
  187. This marks the portal as no longer accepting new calls and exits from
  188. :meth:`sleep_until_stopped`.
  189. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
  190. to let them finish before returning
  191. """
  192. self._event_loop_thread_id = None
  193. self._stop_event.set()
  194. if cancel_remaining:
  195. self._task_group.cancel_scope.cancel("the blocking portal is shutting down")
  196. async def _call_func(
  197. self,
  198. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  199. args: tuple[Unpack[PosArgsT]],
  200. kwargs: dict[str, Any],
  201. future: Future[T_Retval],
  202. ) -> None:
  203. def callback(f: Future[T_Retval]) -> None:
  204. if f.cancelled():
  205. if self._event_loop_thread_id == get_ident():
  206. scope.cancel("the future was cancelled")
  207. elif self._event_loop_thread_id is not None:
  208. self.call(scope.cancel, "the future was cancelled")
  209. try:
  210. retval_or_awaitable = func(*args, **kwargs)
  211. if isawaitable(retval_or_awaitable):
  212. with CancelScope() as scope:
  213. future.add_done_callback(callback)
  214. retval = await retval_or_awaitable
  215. else:
  216. retval = retval_or_awaitable
  217. except self._cancelled_exc_class:
  218. future.cancel()
  219. future.set_running_or_notify_cancel()
  220. except BaseException as exc:
  221. if not future.cancelled():
  222. future.set_exception(exc)
  223. # Let base exceptions fall through
  224. if not isinstance(exc, Exception):
  225. raise
  226. else:
  227. if not future.cancelled():
  228. future.set_result(retval)
  229. finally:
  230. scope = None # type: ignore[assignment]
  231. def _spawn_task_from_thread(
  232. self,
  233. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  234. args: tuple[Unpack[PosArgsT]],
  235. kwargs: dict[str, Any],
  236. name: object,
  237. future: Future[T_Retval],
  238. ) -> None:
  239. """
  240. Spawn a new task using the given callable.
  241. Implementers must ensure that the future is resolved when the task finishes.
  242. :param func: a callable
  243. :param args: positional arguments to be passed to the callable
  244. :param kwargs: keyword arguments to be passed to the callable
  245. :param name: name of the task (will be coerced to a string if not ``None``)
  246. :param future: a future that will resolve to the return value of the callable,
  247. or the exception raised during its execution
  248. """
  249. raise NotImplementedError
  250. @overload
  251. def call(
  252. self,
  253. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  254. *args: Unpack[PosArgsT],
  255. ) -> T_Retval: ...
  256. @overload
  257. def call(
  258. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  259. ) -> T_Retval: ...
  260. def call(
  261. self,
  262. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  263. *args: Unpack[PosArgsT],
  264. ) -> T_Retval:
  265. """
  266. Call the given function in the event loop thread.
  267. If the callable returns a coroutine object, it is awaited on.
  268. :param func: any callable
  269. :raises RuntimeError: if the portal is not running or if this method is called
  270. from within the event loop thread
  271. """
  272. return cast(T_Retval, self.start_task_soon(func, *args).result())
  273. @overload
  274. def start_task_soon(
  275. self,
  276. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  277. *args: Unpack[PosArgsT],
  278. name: object = None,
  279. ) -> Future[T_Retval]: ...
  280. @overload
  281. def start_task_soon(
  282. self,
  283. func: Callable[[Unpack[PosArgsT]], T_Retval],
  284. *args: Unpack[PosArgsT],
  285. name: object = None,
  286. ) -> Future[T_Retval]: ...
  287. def start_task_soon(
  288. self,
  289. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  290. *args: Unpack[PosArgsT],
  291. name: object = None,
  292. ) -> Future[T_Retval]:
  293. """
  294. Start a task in the portal's task group.
  295. The task will be run inside a cancel scope which can be cancelled by cancelling
  296. the returned future.
  297. :param func: the target function
  298. :param args: positional arguments passed to ``func``
  299. :param name: name of the task (will be coerced to a string if not ``None``)
  300. :return: a future that resolves with the return value of the callable if the
  301. task completes successfully, or with the exception raised in the task
  302. :raises RuntimeError: if the portal is not running or if this method is called
  303. from within the event loop thread
  304. :rtype: concurrent.futures.Future[T_Retval]
  305. .. versionadded:: 3.0
  306. """
  307. self._check_running()
  308. f: Future[T_Retval] = Future()
  309. self._spawn_task_from_thread(func, args, {}, name, f)
  310. return f
  311. def start_task(
  312. self,
  313. func: Callable[..., Awaitable[T_Retval]],
  314. *args: object,
  315. name: object = None,
  316. ) -> tuple[Future[T_Retval], Any]:
  317. """
  318. Start a task in the portal's task group and wait until it signals for readiness.
  319. This method works the same way as :meth:`.abc.TaskGroup.start`.
  320. :param func: the target function
  321. :param args: positional arguments passed to ``func``
  322. :param name: name of the task (will be coerced to a string if not ``None``)
  323. :return: a tuple of (future, task_status_value) where the ``task_status_value``
  324. is the value passed to ``task_status.started()`` from within the target
  325. function
  326. :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
  327. .. versionadded:: 3.0
  328. """
  329. def task_done(future: Future[T_Retval]) -> None:
  330. if not task_status_future.done():
  331. if future.cancelled():
  332. task_status_future.cancel()
  333. elif future.exception():
  334. task_status_future.set_exception(future.exception())
  335. else:
  336. exc = RuntimeError(
  337. "Task exited without calling task_status.started()"
  338. )
  339. task_status_future.set_exception(exc)
  340. self._check_running()
  341. task_status_future: Future = Future()
  342. task_status = _BlockingPortalTaskStatus(task_status_future)
  343. f: Future = Future()
  344. f.add_done_callback(task_done)
  345. self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
  346. return f, task_status_future.result()
  347. def wrap_async_context_manager(
  348. self, cm: AbstractAsyncContextManager[T_co]
  349. ) -> AbstractContextManager[T_co]:
  350. """
  351. Wrap an async context manager as a synchronous context manager via this portal.
  352. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
  353. in the middle until the synchronous context manager exits.
  354. :param cm: an asynchronous context manager
  355. :return: a synchronous context manager
  356. .. versionadded:: 2.1
  357. """
  358. return _BlockingAsyncContextManager(cm, self)
  359. @dataclass
  360. class BlockingPortalProvider:
  361. """
  362. A manager for a blocking portal. Used as a context manager. The first thread to
  363. enter this context manager causes a blocking portal to be started with the specific
  364. parameters, and the last thread to exit causes the portal to be shut down. Thus,
  365. there will be exactly one blocking portal running in this context as long as at
  366. least one thread has entered this context manager.
  367. The parameters are the same as for :func:`~anyio.run`.
  368. :param backend: name of the backend
  369. :param backend_options: backend options
  370. .. versionadded:: 4.4
  371. """
  372. backend: str = "asyncio"
  373. backend_options: dict[str, Any] | None = None
  374. _lock: Lock = field(init=False, default_factory=Lock)
  375. _leases: int = field(init=False, default=0)
  376. _portal: BlockingPortal = field(init=False)
  377. _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
  378. init=False, default=None
  379. )
  380. def __enter__(self) -> BlockingPortal:
  381. with self._lock:
  382. if self._portal_cm is None:
  383. self._portal_cm = start_blocking_portal(
  384. self.backend, self.backend_options
  385. )
  386. self._portal = self._portal_cm.__enter__()
  387. self._leases += 1
  388. return self._portal
  389. def __exit__(
  390. self,
  391. exc_type: type[BaseException] | None,
  392. exc_val: BaseException | None,
  393. exc_tb: TracebackType | None,
  394. ) -> None:
  395. portal_cm: AbstractContextManager[BlockingPortal] | None = None
  396. with self._lock:
  397. assert self._portal_cm
  398. assert self._leases > 0
  399. self._leases -= 1
  400. if not self._leases:
  401. portal_cm = self._portal_cm
  402. self._portal_cm = None
  403. del self._portal
  404. if portal_cm:
  405. portal_cm.__exit__(None, None, None)
  406. @contextmanager
  407. def start_blocking_portal(
  408. backend: str = "asyncio",
  409. backend_options: dict[str, Any] | None = None,
  410. *,
  411. name: str | None = None,
  412. ) -> Generator[BlockingPortal, Any, None]:
  413. """
  414. Start a new event loop in a new thread and run a blocking portal in its main task.
  415. The parameters are the same as for :func:`~anyio.run`.
  416. :param backend: name of the backend
  417. :param backend_options: backend options
  418. :param name: name of the thread
  419. :return: a context manager that yields a blocking portal
  420. .. versionchanged:: 3.0
  421. Usage as a context manager is now required.
  422. """
  423. async def run_portal() -> None:
  424. async with BlockingPortal() as portal_:
  425. if name is None:
  426. current_thread().name = f"{backend}-portal-{id(portal_):x}"
  427. future.set_result(portal_)
  428. await portal_.sleep_until_stopped()
  429. def run_blocking_portal() -> None:
  430. if future.set_running_or_notify_cancel():
  431. try:
  432. run_eventloop(
  433. run_portal, backend=backend, backend_options=backend_options
  434. )
  435. except BaseException as exc:
  436. if not future.done():
  437. future.set_exception(exc)
  438. future: Future[BlockingPortal] = Future()
  439. thread = Thread(target=run_blocking_portal, daemon=True, name=name)
  440. thread.start()
  441. try:
  442. cancel_remaining_tasks = False
  443. portal = future.result()
  444. try:
  445. yield portal
  446. except BaseException:
  447. cancel_remaining_tasks = True
  448. raise
  449. finally:
  450. try:
  451. portal.call(portal.stop, cancel_remaining_tasks)
  452. except RuntimeError:
  453. pass
  454. finally:
  455. thread.join()
  456. def check_cancelled() -> None:
  457. """
  458. Check if the cancel scope of the host task's running the current worker thread has
  459. been cancelled.
  460. If the host task's current cancel scope has indeed been cancelled, the
  461. backend-specific cancellation exception will be raised.
  462. :raises RuntimeError: if the current thread was not spawned by
  463. :func:`.to_thread.run_sync`
  464. """
  465. try:
  466. token: EventLoopToken = threadlocals.current_token
  467. except AttributeError:
  468. raise NoEventLoopError(
  469. "This function can only be called inside an AnyIO worker thread"
  470. ) from None
  471. token.backend_class.check_cancelled()