from_thread.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from __future__ import annotations
  2. import threading
  3. from asyncio import iscoroutine
  4. from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
  5. from contextlib import AbstractContextManager, contextmanager
  6. from types import TracebackType
  7. from typing import (
  8. Any,
  9. AsyncContextManager,
  10. Awaitable,
  11. Callable,
  12. ContextManager,
  13. Generator,
  14. Generic,
  15. Iterable,
  16. TypeVar,
  17. cast,
  18. overload,
  19. )
  20. from warnings import warn
  21. from ._core import _eventloop
  22. from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals
  23. from ._core._synchronization import Event
  24. from ._core._tasks import CancelScope, create_task_group
  25. from .abc._tasks import TaskStatus
  26. T_Retval = TypeVar("T_Retval")
  27. T_co = TypeVar("T_co")
  28. def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
  29. """
  30. Call a coroutine function from a worker thread.
  31. :param func: a coroutine function
  32. :param args: positional arguments for the callable
  33. :return: the return value of the coroutine function
  34. """
  35. try:
  36. asynclib = threadlocals.current_async_module
  37. except AttributeError:
  38. raise RuntimeError("This function can only be run from an AnyIO worker thread")
  39. return asynclib.run_async_from_thread(func, *args)
  40. def run_async_from_thread(
  41. func: Callable[..., Awaitable[T_Retval]], *args: object
  42. ) -> T_Retval:
  43. warn(
  44. "run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead",
  45. DeprecationWarning,
  46. )
  47. return run(func, *args)
  48. def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
  49. """
  50. Call a function in the event loop thread from a worker thread.
  51. :param func: a callable
  52. :param args: positional arguments for the callable
  53. :return: the return value of the callable
  54. """
  55. try:
  56. asynclib = threadlocals.current_async_module
  57. except AttributeError:
  58. raise RuntimeError("This function can only be run from an AnyIO worker thread")
  59. return asynclib.run_sync_from_thread(func, *args)
  60. def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval:
  61. warn(
  62. "run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead",
  63. DeprecationWarning,
  64. )
  65. return run_sync(func, *args)
  66. class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
  67. _enter_future: Future
  68. _exit_future: Future
  69. _exit_event: Event
  70. _exit_exc_info: tuple[
  71. type[BaseException] | None, BaseException | None, TracebackType | None
  72. ] = (None, None, None)
  73. def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal):
  74. self._async_cm = async_cm
  75. self._portal = portal
  76. async def run_async_cm(self) -> bool | None:
  77. try:
  78. self._exit_event = Event()
  79. value = await self._async_cm.__aenter__()
  80. except BaseException as exc:
  81. self._enter_future.set_exception(exc)
  82. raise
  83. else:
  84. self._enter_future.set_result(value)
  85. try:
  86. # Wait for the sync context manager to exit.
  87. # This next statement can raise `get_cancelled_exc_class()` if
  88. # something went wrong in a task group in this async context
  89. # manager.
  90. await self._exit_event.wait()
  91. finally:
  92. # In case of cancellation, it could be that we end up here before
  93. # `_BlockingAsyncContextManager.__exit__` is called, and an
  94. # `_exit_exc_info` has been set.
  95. result = await self._async_cm.__aexit__(*self._exit_exc_info)
  96. return result
  97. def __enter__(self) -> T_co:
  98. self._enter_future = Future()
  99. self._exit_future = self._portal.start_task_soon(self.run_async_cm)
  100. cm = self._enter_future.result()
  101. return cast(T_co, cm)
  102. def __exit__(
  103. self,
  104. __exc_type: type[BaseException] | None,
  105. __exc_value: BaseException | None,
  106. __traceback: TracebackType | None,
  107. ) -> bool | None:
  108. self._exit_exc_info = __exc_type, __exc_value, __traceback
  109. self._portal.call(self._exit_event.set)
  110. return self._exit_future.result()
  111. class _BlockingPortalTaskStatus(TaskStatus):
  112. def __init__(self, future: Future):
  113. self._future = future
  114. def started(self, value: object = None) -> None:
  115. self._future.set_result(value)
  116. class BlockingPortal:
  117. """An object that lets external threads run code in an asynchronous event loop."""
  118. def __new__(cls) -> BlockingPortal:
  119. return get_asynclib().BlockingPortal()
  120. def __init__(self) -> None:
  121. self._event_loop_thread_id: int | None = threading.get_ident()
  122. self._stop_event = Event()
  123. self._task_group = create_task_group()
  124. self._cancelled_exc_class = get_cancelled_exc_class()
  125. async def __aenter__(self) -> BlockingPortal:
  126. await self._task_group.__aenter__()
  127. return self
  128. async def __aexit__(
  129. self,
  130. exc_type: type[BaseException] | None,
  131. exc_val: BaseException | None,
  132. exc_tb: TracebackType | None,
  133. ) -> bool | None:
  134. await self.stop()
  135. return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
  136. def _check_running(self) -> None:
  137. if self._event_loop_thread_id is None:
  138. raise RuntimeError("This portal is not running")
  139. if self._event_loop_thread_id == threading.get_ident():
  140. raise RuntimeError(
  141. "This method cannot be called from the event loop thread"
  142. )
  143. async def sleep_until_stopped(self) -> None:
  144. """Sleep until :meth:`stop` is called."""
  145. await self._stop_event.wait()
  146. async def stop(self, cancel_remaining: bool = False) -> None:
  147. """
  148. Signal the portal to shut down.
  149. This marks the portal as no longer accepting new calls and exits from
  150. :meth:`sleep_until_stopped`.
  151. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them
  152. finish before returning
  153. """
  154. self._event_loop_thread_id = None
  155. self._stop_event.set()
  156. if cancel_remaining:
  157. self._task_group.cancel_scope.cancel()
  158. async def _call_func(
  159. self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future
  160. ) -> None:
  161. def callback(f: Future) -> None:
  162. if f.cancelled() and self._event_loop_thread_id not in (
  163. None,
  164. threading.get_ident(),
  165. ):
  166. self.call(scope.cancel)
  167. try:
  168. retval = func(*args, **kwargs)
  169. if iscoroutine(retval):
  170. with CancelScope() as scope:
  171. if future.cancelled():
  172. scope.cancel()
  173. else:
  174. future.add_done_callback(callback)
  175. retval = await retval
  176. except self._cancelled_exc_class:
  177. future.cancel()
  178. except BaseException as exc:
  179. if not future.cancelled():
  180. future.set_exception(exc)
  181. # Let base exceptions fall through
  182. if not isinstance(exc, Exception):
  183. raise
  184. else:
  185. if not future.cancelled():
  186. future.set_result(retval)
  187. finally:
  188. scope = None # type: ignore[assignment]
  189. def _spawn_task_from_thread(
  190. self,
  191. func: Callable,
  192. args: tuple,
  193. kwargs: dict[str, Any],
  194. name: object,
  195. future: Future,
  196. ) -> None:
  197. """
  198. Spawn a new task using the given callable.
  199. Implementors must ensure that the future is resolved when the task finishes.
  200. :param func: a callable
  201. :param args: positional arguments to be passed to the callable
  202. :param kwargs: keyword arguments to be passed to the callable
  203. :param name: name of the task (will be coerced to a string if not ``None``)
  204. :param future: a future that will resolve to the return value of the callable, or the
  205. exception raised during its execution
  206. """
  207. raise NotImplementedError
  208. @overload
  209. def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval:
  210. ...
  211. @overload
  212. def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval:
  213. ...
  214. def call(
  215. self, func: Callable[..., Awaitable[T_Retval] | T_Retval], *args: object
  216. ) -> T_Retval:
  217. """
  218. Call the given function in the event loop thread.
  219. If the callable returns a coroutine object, it is awaited on.
  220. :param func: any callable
  221. :raises RuntimeError: if the portal is not running or if this method is called from within
  222. the event loop thread
  223. """
  224. return cast(T_Retval, self.start_task_soon(func, *args).result())
  225. @overload
  226. def spawn_task(
  227. self,
  228. func: Callable[..., Awaitable[T_Retval]],
  229. *args: object,
  230. name: object = None,
  231. ) -> Future[T_Retval]:
  232. ...
  233. @overload
  234. def spawn_task(
  235. self, func: Callable[..., T_Retval], *args: object, name: object = None
  236. ) -> Future[T_Retval]:
  237. ...
  238. def spawn_task(
  239. self,
  240. func: Callable[..., Awaitable[T_Retval] | T_Retval],
  241. *args: object,
  242. name: object = None,
  243. ) -> Future[T_Retval]:
  244. """
  245. Start a task in the portal's task group.
  246. :param func: the target coroutine function
  247. :param args: positional arguments passed to ``func``
  248. :param name: name of the task (will be coerced to a string if not ``None``)
  249. :return: a future that resolves with the return value of the callable if the task completes
  250. successfully, or with the exception raised in the task
  251. :raises RuntimeError: if the portal is not running or if this method is called from within
  252. the event loop thread
  253. .. versionadded:: 2.1
  254. .. deprecated:: 3.0
  255. Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you
  256. can keep using this until AnyIO 4.
  257. """
  258. warn(
  259. "spawn_task() is deprecated -- use start_task_soon() instead",
  260. DeprecationWarning,
  261. )
  262. return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type]
  263. @overload
  264. def start_task_soon(
  265. self,
  266. func: Callable[..., Awaitable[T_Retval]],
  267. *args: object,
  268. name: object = None,
  269. ) -> Future[T_Retval]:
  270. ...
  271. @overload
  272. def start_task_soon(
  273. self, func: Callable[..., T_Retval], *args: object, name: object = None
  274. ) -> Future[T_Retval]:
  275. ...
  276. def start_task_soon(
  277. self,
  278. func: Callable[..., Awaitable[T_Retval] | T_Retval],
  279. *args: object,
  280. name: object = None,
  281. ) -> Future[T_Retval]:
  282. """
  283. Start a task in the portal's task group.
  284. The task will be run inside a cancel scope which can be cancelled by cancelling the
  285. returned future.
  286. :param func: the target function
  287. :param args: positional arguments passed to ``func``
  288. :param name: name of the task (will be coerced to a string if not ``None``)
  289. :return: a future that resolves with the return value of the callable if the
  290. task completes successfully, or with the exception raised in the task
  291. :raises RuntimeError: if the portal is not running or if this method is called
  292. from within the event loop thread
  293. :rtype: concurrent.futures.Future[T_Retval]
  294. .. versionadded:: 3.0
  295. """
  296. self._check_running()
  297. f: Future = Future()
  298. self._spawn_task_from_thread(func, args, {}, name, f)
  299. return f
  300. def start_task(
  301. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  302. ) -> tuple[Future[Any], Any]:
  303. """
  304. Start a task in the portal's task group and wait until it signals for readiness.
  305. This method works the same way as :meth:`.abc.TaskGroup.start`.
  306. :param func: the target function
  307. :param args: positional arguments passed to ``func``
  308. :param name: name of the task (will be coerced to a string if not ``None``)
  309. :return: a tuple of (future, task_status_value) where the ``task_status_value``
  310. is the value passed to ``task_status.started()`` from within the target
  311. function
  312. :rtype: tuple[concurrent.futures.Future[Any], Any]
  313. .. versionadded:: 3.0
  314. """
  315. def task_done(future: Future) -> None:
  316. if not task_status_future.done():
  317. if future.cancelled():
  318. task_status_future.cancel()
  319. elif future.exception():
  320. task_status_future.set_exception(future.exception())
  321. else:
  322. exc = RuntimeError(
  323. "Task exited without calling task_status.started()"
  324. )
  325. task_status_future.set_exception(exc)
  326. self._check_running()
  327. task_status_future: Future = Future()
  328. task_status = _BlockingPortalTaskStatus(task_status_future)
  329. f: Future = Future()
  330. f.add_done_callback(task_done)
  331. self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
  332. return f, task_status_future.result()
  333. def wrap_async_context_manager(
  334. self, cm: AsyncContextManager[T_co]
  335. ) -> ContextManager[T_co]:
  336. """
  337. Wrap an async context manager as a synchronous context manager via this portal.
  338. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the
  339. middle until the synchronous context manager exits.
  340. :param cm: an asynchronous context manager
  341. :return: a synchronous context manager
  342. .. versionadded:: 2.1
  343. """
  344. return _BlockingAsyncContextManager(cm, self)
  345. def create_blocking_portal() -> BlockingPortal:
  346. """
  347. Create a portal for running functions in the event loop thread from external threads.
  348. Use this function in asynchronous code when you need to allow external threads access to the
  349. event loop where your asynchronous code is currently running.
  350. .. deprecated:: 3.0
  351. Use :class:`.BlockingPortal` directly.
  352. """
  353. warn(
  354. "create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() "
  355. "directly",
  356. DeprecationWarning,
  357. )
  358. return BlockingPortal()
  359. @contextmanager
  360. def start_blocking_portal(
  361. backend: str = "asyncio", backend_options: dict[str, Any] | None = None
  362. ) -> Generator[BlockingPortal, Any, None]:
  363. """
  364. Start a new event loop in a new thread and run a blocking portal in its main task.
  365. The parameters are the same as for :func:`~anyio.run`.
  366. :param backend: name of the backend
  367. :param backend_options: backend options
  368. :return: a context manager that yields a blocking portal
  369. .. versionchanged:: 3.0
  370. Usage as a context manager is now required.
  371. """
  372. async def run_portal() -> None:
  373. async with BlockingPortal() as portal_:
  374. if future.set_running_or_notify_cancel():
  375. future.set_result(portal_)
  376. await portal_.sleep_until_stopped()
  377. future: Future[BlockingPortal] = Future()
  378. with ThreadPoolExecutor(1) as executor:
  379. run_future = executor.submit(
  380. _eventloop.run,
  381. run_portal, # type: ignore[arg-type]
  382. backend=backend,
  383. backend_options=backend_options,
  384. )
  385. try:
  386. wait(
  387. cast(Iterable[Future], [run_future, future]),
  388. return_when=FIRST_COMPLETED,
  389. )
  390. except BaseException:
  391. future.cancel()
  392. run_future.cancel()
  393. raise
  394. if future.done():
  395. portal = future.result()
  396. cancel_remaining_tasks = False
  397. try:
  398. yield portal
  399. except BaseException:
  400. cancel_remaining_tasks = True
  401. raise
  402. finally:
  403. try:
  404. portal.call(portal.stop, cancel_remaining_tasks)
  405. except RuntimeError:
  406. pass
  407. run_future.result()