functools.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from __future__ import annotations
  2. __all__ = (
  3. "AsyncCacheInfo",
  4. "AsyncCacheParameters",
  5. "AsyncLRUCacheWrapper",
  6. "cache",
  7. "lru_cache",
  8. "reduce",
  9. )
  10. import functools
  11. import sys
  12. from collections import OrderedDict
  13. from collections.abc import (
  14. AsyncIterable,
  15. Awaitable,
  16. Callable,
  17. Coroutine,
  18. Hashable,
  19. Iterable,
  20. )
  21. from functools import update_wrapper
  22. from inspect import iscoroutinefunction
  23. from typing import (
  24. Any,
  25. Generic,
  26. NamedTuple,
  27. TypedDict,
  28. TypeVar,
  29. cast,
  30. final,
  31. overload,
  32. )
  33. from weakref import WeakKeyDictionary
  34. from ._core._synchronization import Lock
  35. from .lowlevel import RunVar, checkpoint
  36. if sys.version_info >= (3, 11):
  37. from typing import ParamSpec
  38. else:
  39. from typing_extensions import ParamSpec
  40. T = TypeVar("T")
  41. S = TypeVar("S")
  42. P = ParamSpec("P")
  43. lru_cache_items: RunVar[
  44. WeakKeyDictionary[
  45. AsyncLRUCacheWrapper[Any, Any],
  46. OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
  47. ]
  48. ] = RunVar("lru_cache_items")
  49. class _InitialMissingType:
  50. pass
  51. initial_missing: _InitialMissingType = _InitialMissingType()
  52. class AsyncCacheInfo(NamedTuple):
  53. hits: int
  54. misses: int
  55. maxsize: int | None
  56. currsize: int
  57. class AsyncCacheParameters(TypedDict):
  58. maxsize: int | None
  59. typed: bool
  60. always_checkpoint: bool
  61. @final
  62. class AsyncLRUCacheWrapper(Generic[P, T]):
  63. def __init__(
  64. self,
  65. func: Callable[..., Awaitable[T]],
  66. maxsize: int | None,
  67. typed: bool,
  68. always_checkpoint: bool,
  69. ):
  70. self.__wrapped__ = func
  71. self._hits: int = 0
  72. self._misses: int = 0
  73. self._maxsize = max(maxsize, 0) if maxsize is not None else None
  74. self._currsize: int = 0
  75. self._typed = typed
  76. self._always_checkpoint = always_checkpoint
  77. update_wrapper(self, func)
  78. def cache_info(self) -> AsyncCacheInfo:
  79. return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
  80. def cache_parameters(self) -> AsyncCacheParameters:
  81. return {
  82. "maxsize": self._maxsize,
  83. "typed": self._typed,
  84. "always_checkpoint": self._always_checkpoint,
  85. }
  86. def cache_clear(self) -> None:
  87. if cache := lru_cache_items.get(None):
  88. cache.pop(self, None)
  89. self._hits = self._misses = self._currsize = 0
  90. async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  91. # Easy case first: if maxsize == 0, no caching is done
  92. if self._maxsize == 0:
  93. value = await self.__wrapped__(*args, **kwargs)
  94. self._misses += 1
  95. return value
  96. # The key is constructed as a flat tuple to avoid memory overhead
  97. key: tuple[Any, ...] = args
  98. if kwargs:
  99. # initial_missing is used as a separator
  100. key += (initial_missing,) + sum(kwargs.items(), ())
  101. if self._typed:
  102. key += tuple(type(arg) for arg in args)
  103. if kwargs:
  104. key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
  105. try:
  106. cache = lru_cache_items.get()
  107. except LookupError:
  108. cache = WeakKeyDictionary()
  109. lru_cache_items.set(cache)
  110. try:
  111. cache_entry = cache[self]
  112. except KeyError:
  113. cache_entry = cache[self] = OrderedDict()
  114. cached_value: T | _InitialMissingType
  115. try:
  116. cached_value, lock = cache_entry[key]
  117. except KeyError:
  118. # We're the first task to call this function
  119. cached_value, lock = (
  120. initial_missing,
  121. Lock(fast_acquire=not self._always_checkpoint),
  122. )
  123. cache_entry[key] = cached_value, lock
  124. if lock is None:
  125. # The value was already cached
  126. self._hits += 1
  127. cache_entry.move_to_end(key)
  128. if self._always_checkpoint:
  129. await checkpoint()
  130. return cast(T, cached_value)
  131. async with lock:
  132. # Check if another task filled the cache while we acquired the lock
  133. if (cached_value := cache_entry[key][0]) is initial_missing:
  134. self._misses += 1
  135. if self._maxsize is not None and self._currsize >= self._maxsize:
  136. cache_entry.popitem(last=False)
  137. else:
  138. self._currsize += 1
  139. value = await self.__wrapped__(*args, **kwargs)
  140. cache_entry[key] = value, None
  141. else:
  142. # Another task filled the cache while we were waiting for the lock
  143. self._hits += 1
  144. cache_entry.move_to_end(key)
  145. value = cast(T, cached_value)
  146. return value
  147. class _LRUCacheWrapper(Generic[T]):
  148. def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
  149. self._maxsize = maxsize
  150. self._typed = typed
  151. self._always_checkpoint = always_checkpoint
  152. @overload
  153. def __call__( # type: ignore[overload-overlap]
  154. self, func: Callable[P, Coroutine[Any, Any, T]], /
  155. ) -> AsyncLRUCacheWrapper[P, T]: ...
  156. @overload
  157. def __call__(
  158. self, func: Callable[..., T], /
  159. ) -> functools._lru_cache_wrapper[T]: ...
  160. def __call__(
  161. self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
  162. ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
  163. if iscoroutinefunction(f):
  164. return AsyncLRUCacheWrapper(
  165. f, self._maxsize, self._typed, self._always_checkpoint
  166. )
  167. return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
  168. @overload
  169. def cache( # type: ignore[overload-overlap]
  170. func: Callable[P, Coroutine[Any, Any, T]], /
  171. ) -> AsyncLRUCacheWrapper[P, T]: ...
  172. @overload
  173. def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
  174. def cache(
  175. func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
  176. ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
  177. """
  178. A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
  179. This is the asynchronous equivalent to :func:`functools.cache`.
  180. """
  181. return lru_cache(maxsize=None)(func)
  182. @overload
  183. def lru_cache(
  184. *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
  185. ) -> _LRUCacheWrapper[Any]: ...
  186. @overload
  187. def lru_cache( # type: ignore[overload-overlap]
  188. func: Callable[P, Coroutine[Any, Any, T]], /
  189. ) -> AsyncLRUCacheWrapper[P, T]: ...
  190. @overload
  191. def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
  192. def lru_cache(
  193. func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
  194. /,
  195. *,
  196. maxsize: int | None = 128,
  197. typed: bool = False,
  198. always_checkpoint: bool = False,
  199. ) -> (
  200. AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
  201. ):
  202. """
  203. An asynchronous version of :func:`functools.lru_cache`.
  204. If a synchronous function is passed, the standard library
  205. :func:`functools.lru_cache` is applied instead.
  206. :param always_checkpoint: if ``True``, every call to the cached function will be
  207. guaranteed to yield control to the event loop at least once
  208. .. note:: Caches and locks are managed on a per-event loop basis.
  209. """
  210. if func is None:
  211. return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
  212. if not callable(func):
  213. raise TypeError("the first argument must be callable")
  214. return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
  215. @overload
  216. async def reduce(
  217. function: Callable[[T, S], Awaitable[T]],
  218. iterable: Iterable[S] | AsyncIterable[S],
  219. /,
  220. initial: T,
  221. ) -> T: ...
  222. @overload
  223. async def reduce(
  224. function: Callable[[T, T], Awaitable[T]],
  225. iterable: Iterable[T] | AsyncIterable[T],
  226. /,
  227. ) -> T: ...
  228. async def reduce( # type: ignore[misc]
  229. function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
  230. iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
  231. /,
  232. initial: T | _InitialMissingType = initial_missing,
  233. ) -> T:
  234. """
  235. Asynchronous version of :func:`functools.reduce`.
  236. :param function: a coroutine function that takes two arguments: the accumulated
  237. value and the next element from the iterable
  238. :param iterable: an iterable or async iterable
  239. :param initial: the initial value (if missing, the first element of the iterable is
  240. used as the initial value)
  241. """
  242. element: Any
  243. function_called = False
  244. if isinstance(iterable, AsyncIterable):
  245. async_it = iterable.__aiter__()
  246. if initial is initial_missing:
  247. try:
  248. value = cast(T, await async_it.__anext__())
  249. except StopAsyncIteration:
  250. raise TypeError(
  251. "reduce() of empty sequence with no initial value"
  252. ) from None
  253. else:
  254. value = cast(T, initial)
  255. async for element in async_it:
  256. value = await function(value, element)
  257. function_called = True
  258. elif isinstance(iterable, Iterable):
  259. it = iter(iterable)
  260. if initial is initial_missing:
  261. try:
  262. value = cast(T, next(it))
  263. except StopIteration:
  264. raise TypeError(
  265. "reduce() of empty sequence with no initial value"
  266. ) from None
  267. else:
  268. value = cast(T, initial)
  269. for element in it:
  270. value = await function(value, element)
  271. function_called = True
  272. else:
  273. raise TypeError("reduce() argument 2 must be an iterable or async iterable")
  274. # Make sure there is at least one checkpoint, even if an empty iterable and an
  275. # initial value were given
  276. if not function_called:
  277. await checkpoint()
  278. return value