concurrency.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import functools
  2. import sys
  3. import typing
  4. import warnings
  5. import anyio
  6. if sys.version_info >= (3, 10): # pragma: no cover
  7. from typing import ParamSpec
  8. else: # pragma: no cover
  9. from typing_extensions import ParamSpec
  10. T = typing.TypeVar("T")
  11. P = ParamSpec("P")
  12. async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
  13. warnings.warn(
  14. "run_until_first_complete is deprecated "
  15. "and will be removed in a future version.",
  16. DeprecationWarning,
  17. )
  18. async with anyio.create_task_group() as task_group:
  19. async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
  20. await func()
  21. task_group.cancel_scope.cancel()
  22. for func, kwargs in args:
  23. task_group.start_soon(run, functools.partial(func, **kwargs))
  24. async def run_in_threadpool(
  25. func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
  26. ) -> T:
  27. if kwargs: # pragma: no cover
  28. # run_sync doesn't accept 'kwargs', so bind them in here
  29. func = functools.partial(func, **kwargs)
  30. return await anyio.to_thread.run_sync(func, *args)
  31. class _StopIteration(Exception):
  32. pass
  33. def _next(iterator: typing.Iterator[T]) -> T:
  34. # We can't raise `StopIteration` from within the threadpool iterator
  35. # and catch it outside that context, so we coerce them into a different
  36. # exception type.
  37. try:
  38. return next(iterator)
  39. except StopIteration:
  40. raise _StopIteration
  41. async def iterate_in_threadpool(
  42. iterator: typing.Iterator[T],
  43. ) -> typing.AsyncIterator[T]:
  44. while True:
  45. try:
  46. yield await anyio.to_thread.run_sync(_next, iterator)
  47. except _StopIteration:
  48. break