_utils.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import asyncio
  2. import functools
  3. import sys
  4. import typing
  5. from types import TracebackType
  6. if sys.version_info < (3, 8): # pragma: no cover
  7. from typing_extensions import Protocol
  8. else: # pragma: no cover
  9. from typing import Protocol
  10. def is_async_callable(obj: typing.Any) -> bool:
  11. while isinstance(obj, functools.partial):
  12. obj = obj.func
  13. return asyncio.iscoroutinefunction(obj) or (
  14. callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
  15. )
  16. T_co = typing.TypeVar("T_co", covariant=True)
  17. # TODO: once 3.8 is the minimum supported version (27 Jun 2023)
  18. # this can just become
  19. # class AwaitableOrContextManager(
  20. # typing.Awaitable[T_co],
  21. # typing.AsyncContextManager[T_co],
  22. # typing.Protocol[T_co],
  23. # ):
  24. # pass
  25. class AwaitableOrContextManager(Protocol[T_co]):
  26. def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
  27. ... # pragma: no cover
  28. async def __aenter__(self) -> T_co:
  29. ... # pragma: no cover
  30. async def __aexit__(
  31. self,
  32. __exc_type: typing.Optional[typing.Type[BaseException]],
  33. __exc_value: typing.Optional[BaseException],
  34. __traceback: typing.Optional[TracebackType],
  35. ) -> typing.Union[bool, None]:
  36. ... # pragma: no cover
  37. class SupportsAsyncClose(Protocol):
  38. async def close(self) -> None:
  39. ... # pragma: no cover
  40. SupportsAsyncCloseType = typing.TypeVar(
  41. "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
  42. )
  43. class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
  44. __slots__ = ("aw", "entered")
  45. def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
  46. self.aw = aw
  47. def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
  48. return self.aw.__await__()
  49. async def __aenter__(self) -> SupportsAsyncCloseType:
  50. self.entered = await self.aw
  51. return self.entered
  52. async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
  53. await self.entered.close()
  54. return None