| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- from __future__ import annotations
- __all__ = (
- "AsyncCacheInfo",
- "AsyncCacheParameters",
- "AsyncLRUCacheWrapper",
- "cache",
- "lru_cache",
- "reduce",
- )
- import functools
- import sys
- from collections import OrderedDict
- from collections.abc import (
- AsyncIterable,
- Awaitable,
- Callable,
- Coroutine,
- Hashable,
- Iterable,
- )
- from functools import update_wrapper
- from inspect import iscoroutinefunction
- from typing import (
- Any,
- Generic,
- NamedTuple,
- TypedDict,
- TypeVar,
- cast,
- final,
- overload,
- )
- from weakref import WeakKeyDictionary
- from ._core._synchronization import Lock
- from .lowlevel import RunVar, checkpoint
- if sys.version_info >= (3, 11):
- from typing import ParamSpec
- else:
- from typing_extensions import ParamSpec
- T = TypeVar("T")
- S = TypeVar("S")
- P = ParamSpec("P")
- lru_cache_items: RunVar[
- WeakKeyDictionary[
- AsyncLRUCacheWrapper[Any, Any],
- OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
- ]
- ] = RunVar("lru_cache_items")
- class _InitialMissingType:
- pass
- initial_missing: _InitialMissingType = _InitialMissingType()
- class AsyncCacheInfo(NamedTuple):
- hits: int
- misses: int
- maxsize: int | None
- currsize: int
- class AsyncCacheParameters(TypedDict):
- maxsize: int | None
- typed: bool
- always_checkpoint: bool
- @final
- class AsyncLRUCacheWrapper(Generic[P, T]):
- def __init__(
- self,
- func: Callable[..., Awaitable[T]],
- maxsize: int | None,
- typed: bool,
- always_checkpoint: bool,
- ):
- self.__wrapped__ = func
- self._hits: int = 0
- self._misses: int = 0
- self._maxsize = max(maxsize, 0) if maxsize is not None else None
- self._currsize: int = 0
- self._typed = typed
- self._always_checkpoint = always_checkpoint
- update_wrapper(self, func)
- def cache_info(self) -> AsyncCacheInfo:
- return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
- def cache_parameters(self) -> AsyncCacheParameters:
- return {
- "maxsize": self._maxsize,
- "typed": self._typed,
- "always_checkpoint": self._always_checkpoint,
- }
- def cache_clear(self) -> None:
- if cache := lru_cache_items.get(None):
- cache.pop(self, None)
- self._hits = self._misses = self._currsize = 0
- async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
- # Easy case first: if maxsize == 0, no caching is done
- if self._maxsize == 0:
- value = await self.__wrapped__(*args, **kwargs)
- self._misses += 1
- return value
- # The key is constructed as a flat tuple to avoid memory overhead
- key: tuple[Any, ...] = args
- if kwargs:
- # initial_missing is used as a separator
- key += (initial_missing,) + sum(kwargs.items(), ())
- if self._typed:
- key += tuple(type(arg) for arg in args)
- if kwargs:
- key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
- try:
- cache = lru_cache_items.get()
- except LookupError:
- cache = WeakKeyDictionary()
- lru_cache_items.set(cache)
- try:
- cache_entry = cache[self]
- except KeyError:
- cache_entry = cache[self] = OrderedDict()
- cached_value: T | _InitialMissingType
- try:
- cached_value, lock = cache_entry[key]
- except KeyError:
- # We're the first task to call this function
- cached_value, lock = (
- initial_missing,
- Lock(fast_acquire=not self._always_checkpoint),
- )
- cache_entry[key] = cached_value, lock
- if lock is None:
- # The value was already cached
- self._hits += 1
- cache_entry.move_to_end(key)
- if self._always_checkpoint:
- await checkpoint()
- return cast(T, cached_value)
- async with lock:
- # Check if another task filled the cache while we acquired the lock
- if (cached_value := cache_entry[key][0]) is initial_missing:
- self._misses += 1
- if self._maxsize is not None and self._currsize >= self._maxsize:
- cache_entry.popitem(last=False)
- else:
- self._currsize += 1
- value = await self.__wrapped__(*args, **kwargs)
- cache_entry[key] = value, None
- else:
- # Another task filled the cache while we were waiting for the lock
- self._hits += 1
- cache_entry.move_to_end(key)
- value = cast(T, cached_value)
- return value
- class _LRUCacheWrapper(Generic[T]):
- def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
- self._maxsize = maxsize
- self._typed = typed
- self._always_checkpoint = always_checkpoint
- @overload
- def __call__( # type: ignore[overload-overlap]
- self, func: Callable[P, Coroutine[Any, Any, T]], /
- ) -> AsyncLRUCacheWrapper[P, T]: ...
- @overload
- def __call__(
- self, func: Callable[..., T], /
- ) -> functools._lru_cache_wrapper[T]: ...
- def __call__(
- self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
- ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
- if iscoroutinefunction(f):
- return AsyncLRUCacheWrapper(
- f, self._maxsize, self._typed, self._always_checkpoint
- )
- return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
- @overload
- def cache( # type: ignore[overload-overlap]
- func: Callable[P, Coroutine[Any, Any, T]], /
- ) -> AsyncLRUCacheWrapper[P, T]: ...
- @overload
- def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
- def cache(
- func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
- ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
- """
- A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
- This is the asynchronous equivalent to :func:`functools.cache`.
- """
- return lru_cache(maxsize=None)(func)
- @overload
- def lru_cache(
- *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
- ) -> _LRUCacheWrapper[Any]: ...
- @overload
- def lru_cache( # type: ignore[overload-overlap]
- func: Callable[P, Coroutine[Any, Any, T]], /
- ) -> AsyncLRUCacheWrapper[P, T]: ...
- @overload
- def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
- def lru_cache(
- func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
- /,
- *,
- maxsize: int | None = 128,
- typed: bool = False,
- always_checkpoint: bool = False,
- ) -> (
- AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
- ):
- """
- An asynchronous version of :func:`functools.lru_cache`.
- If a synchronous function is passed, the standard library
- :func:`functools.lru_cache` is applied instead.
- :param always_checkpoint: if ``True``, every call to the cached function will be
- guaranteed to yield control to the event loop at least once
- .. note:: Caches and locks are managed on a per-event loop basis.
- """
- if func is None:
- return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
- if not callable(func):
- raise TypeError("the first argument must be callable")
- return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
- @overload
- async def reduce(
- function: Callable[[T, S], Awaitable[T]],
- iterable: Iterable[S] | AsyncIterable[S],
- /,
- initial: T,
- ) -> T: ...
- @overload
- async def reduce(
- function: Callable[[T, T], Awaitable[T]],
- iterable: Iterable[T] | AsyncIterable[T],
- /,
- ) -> T: ...
- async def reduce( # type: ignore[misc]
- function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
- iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
- /,
- initial: T | _InitialMissingType = initial_missing,
- ) -> T:
- """
- Asynchronous version of :func:`functools.reduce`.
- :param function: a coroutine function that takes two arguments: the accumulated
- value and the next element from the iterable
- :param iterable: an iterable or async iterable
- :param initial: the initial value (if missing, the first element of the iterable is
- used as the initial value)
- """
- element: Any
- function_called = False
- if isinstance(iterable, AsyncIterable):
- async_it = iterable.__aiter__()
- if initial is initial_missing:
- try:
- value = cast(T, await async_it.__anext__())
- except StopAsyncIteration:
- raise TypeError(
- "reduce() of empty sequence with no initial value"
- ) from None
- else:
- value = cast(T, initial)
- async for element in async_it:
- value = await function(value, element)
- function_called = True
- elif isinstance(iterable, Iterable):
- it = iter(iterable)
- if initial is initial_missing:
- try:
- value = cast(T, next(it))
- except StopIteration:
- raise TypeError(
- "reduce() of empty sequence with no initial value"
- ) from None
- else:
- value = cast(T, initial)
- for element in it:
- value = await function(value, element)
- function_called = True
- else:
- raise TypeError("reduce() argument 2 must be an iterable or async iterable")
- # Make sure there is at least one checkpoint, even if an empty iterable and an
- # initial value were given
- if not function_called:
- await checkpoint()
- return value
|