concurrency.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from contextlib import AsyncExitStack as AsyncExitStack # noqa
  2. from contextlib import asynccontextmanager as asynccontextmanager
  3. from typing import AsyncGenerator, ContextManager, TypeVar
  4. import anyio
  5. from anyio import CapacityLimiter
  6. from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
  7. from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
  8. from starlette.concurrency import ( # noqa
  9. run_until_first_complete as run_until_first_complete,
  10. )
  11. _T = TypeVar("_T")
  12. @asynccontextmanager
  13. async def contextmanager_in_threadpool(
  14. cm: ContextManager[_T],
  15. ) -> AsyncGenerator[_T, None]:
  16. # blocking __exit__ from running waiting on a free thread
  17. # can create race conditions/deadlocks if the context manager itself
  18. # has its own internal pool (e.g. a database connection pool)
  19. # to avoid this we let __exit__ run without a capacity limit
  20. # since we're creating a new limiter for each call, any non-zero limit
  21. # works (1 is arbitrary)
  22. exit_limiter = CapacityLimiter(1)
  23. try:
  24. yield await run_in_threadpool(cm.__enter__)
  25. except Exception as e:
  26. ok = bool(
  27. await anyio.to_thread.run_sync(
  28. cm.__exit__, type(e), e, None, limiter=exit_limiter
  29. )
  30. )
  31. if not ok:
  32. raise e
  33. else:
  34. await anyio.to_thread.run_sync(
  35. cm.__exit__, None, None, None, limiter=exit_limiter
  36. )