You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
9.2 KiB
Python
347 lines
9.2 KiB
Python
import asyncio
|
|
import dataclasses
|
|
import sys
|
|
from asyncio.coroutines import _is_coroutine # type: ignore[attr-defined]
|
|
from functools import _CacheInfo, _make_key, partial, partialmethod
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Coroutine,
|
|
Generic,
|
|
Hashable,
|
|
Optional,
|
|
OrderedDict,
|
|
Set,
|
|
Type,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
final,
|
|
overload,
|
|
)
|
|
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import Self
|
|
else:
|
|
from typing_extensions import Self
|
|
|
|
|
|
__version__ = "2.0.5"
|
|
|
|
__all__ = ("alru_cache",)
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_R = TypeVar("_R")
|
|
_Coro = Coroutine[Any, Any, _R]
|
|
_CB = Callable[..., _Coro[_R]]
|
|
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
|
|
|
|
|
|
@final
|
|
class _CacheParameters(TypedDict):
|
|
typed: bool
|
|
maxsize: Optional[int]
|
|
tasks: int
|
|
closed: bool
|
|
|
|
|
|
@final
|
|
@dataclasses.dataclass
|
|
class _CacheItem(Generic[_R]):
|
|
fut: "asyncio.Future[_R]"
|
|
later_call: Optional[asyncio.Handle]
|
|
|
|
def cancel(self) -> None:
|
|
if self.later_call is not None:
|
|
self.later_call.cancel()
|
|
self.later_call = None
|
|
|
|
|
|
@final
|
|
class _LRUCacheWrapper(Generic[_R]):
|
|
def __init__(
|
|
self,
|
|
fn: _CB[_R],
|
|
maxsize: Optional[int],
|
|
typed: bool,
|
|
ttl: Optional[float],
|
|
) -> None:
|
|
try:
|
|
self.__module__ = fn.__module__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__name__ = fn.__name__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__qualname__ = fn.__qualname__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__doc__ = fn.__doc__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__annotations__ = fn.__annotations__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__dict__.update(fn.__dict__)
|
|
except AttributeError:
|
|
pass
|
|
# set __wrapped__ last so we don't inadvertently copy it
|
|
# from the wrapped function when updating __dict__
|
|
self._is_coroutine = _is_coroutine
|
|
self.__wrapped__ = fn
|
|
self.__maxsize = maxsize
|
|
self.__typed = typed
|
|
self.__ttl = ttl
|
|
self.__cache: OrderedDict[Hashable, _CacheItem[_R]] = OrderedDict()
|
|
self.__closed = False
|
|
self.__hits = 0
|
|
self.__misses = 0
|
|
self.__tasks: Set["asyncio.Task[_R]"] = set()
|
|
|
|
def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
|
|
key = _make_key(args, kwargs, self.__typed)
|
|
|
|
cache_item = self.__cache.pop(key, None)
|
|
if cache_item is None:
|
|
return False
|
|
else:
|
|
cache_item.cancel()
|
|
return True
|
|
|
|
def cache_clear(self) -> None:
|
|
self.__hits = 0
|
|
self.__misses = 0
|
|
|
|
for c in self.__cache.values():
|
|
if c.later_call:
|
|
c.later_call.cancel()
|
|
self.__cache.clear()
|
|
self.__tasks.clear()
|
|
|
|
async def cache_close(self, *, wait: bool = False) -> None:
|
|
self.__closed = True
|
|
|
|
tasks = list(self.__tasks)
|
|
if not tasks:
|
|
return
|
|
|
|
if not wait:
|
|
for task in tasks:
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
def cache_info(self) -> _CacheInfo:
|
|
return _CacheInfo(
|
|
self.__hits,
|
|
self.__misses,
|
|
self.__maxsize,
|
|
len(self.__cache),
|
|
)
|
|
|
|
def cache_parameters(self) -> _CacheParameters:
|
|
return _CacheParameters(
|
|
maxsize=self.__maxsize,
|
|
typed=self.__typed,
|
|
tasks=len(self.__tasks),
|
|
closed=self.__closed,
|
|
)
|
|
|
|
def _cache_hit(self, key: Hashable) -> None:
|
|
self.__hits += 1
|
|
self.__cache.move_to_end(key)
|
|
|
|
def _cache_miss(self, key: Hashable) -> None:
|
|
self.__misses += 1
|
|
|
|
def _task_done_callback(
|
|
self, fut: "asyncio.Future[_R]", key: Hashable, task: "asyncio.Task[_R]"
|
|
) -> None:
|
|
self.__tasks.discard(task)
|
|
|
|
if task.cancelled():
|
|
fut.cancel()
|
|
self.__cache.pop(key, None)
|
|
return
|
|
|
|
exc = task.exception()
|
|
if exc is not None:
|
|
fut.set_exception(exc)
|
|
self.__cache.pop(key, None)
|
|
return
|
|
|
|
cache_item = self.__cache.get(key)
|
|
if self.__ttl is not None and cache_item is not None:
|
|
loop = asyncio.get_running_loop()
|
|
cache_item.later_call = loop.call_later(
|
|
self.__ttl, self.__cache.pop, key, None
|
|
)
|
|
|
|
fut.set_result(task.result())
|
|
|
|
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
|
|
if self.__closed:
|
|
raise RuntimeError(f"alru_cache is closed for {self}")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
key = _make_key(fn_args, fn_kwargs, self.__typed)
|
|
|
|
cache_item = self.__cache.get(key)
|
|
|
|
if cache_item is not None:
|
|
self._cache_hit(key)
|
|
if not cache_item.fut.done():
|
|
return await asyncio.shield(cache_item.fut)
|
|
|
|
return cache_item.fut.result()
|
|
|
|
fut = loop.create_future()
|
|
coro = self.__wrapped__(*fn_args, **fn_kwargs)
|
|
task: asyncio.Task[_R] = loop.create_task(coro)
|
|
self.__tasks.add(task)
|
|
task.add_done_callback(partial(self._task_done_callback, fut, key))
|
|
|
|
self.__cache[key] = _CacheItem(fut, None)
|
|
|
|
if self.__maxsize is not None and len(self.__cache) > self.__maxsize:
|
|
dropped_key, cache_item = self.__cache.popitem(last=False)
|
|
cache_item.cancel()
|
|
|
|
self._cache_miss(key)
|
|
return await asyncio.shield(fut)
|
|
|
|
def __get__(
|
|
self, instance: _T, owner: Optional[Type[_T]]
|
|
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
|
|
if owner is None:
|
|
return self
|
|
else:
|
|
return _LRUCacheWrapperInstanceMethod(self, instance)
|
|
|
|
|
|
@final
|
|
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
|
|
def __init__(
|
|
self,
|
|
wrapper: _LRUCacheWrapper[_R],
|
|
instance: _T,
|
|
) -> None:
|
|
try:
|
|
self.__module__ = wrapper.__module__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__name__ = wrapper.__name__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__qualname__ = wrapper.__qualname__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__doc__ = wrapper.__doc__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__annotations__ = wrapper.__annotations__
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
self.__dict__.update(wrapper.__dict__)
|
|
except AttributeError:
|
|
pass
|
|
# set __wrapped__ last so we don't inadvertently copy it
|
|
# from the wrapped function when updating __dict__
|
|
self._is_coroutine = _is_coroutine
|
|
self.__wrapped__ = wrapper.__wrapped__
|
|
self.__instance = instance
|
|
self.__wrapper = wrapper
|
|
|
|
def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
|
|
return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)
|
|
|
|
def cache_clear(self) -> None:
|
|
self.__wrapper.cache_clear()
|
|
|
|
async def cache_close(
|
|
self, *, cancel: bool = False, return_exceptions: bool = True
|
|
) -> None:
|
|
await self.__wrapper.cache_close()
|
|
|
|
def cache_info(self) -> _CacheInfo:
|
|
return self.__wrapper.cache_info()
|
|
|
|
def cache_parameters(self) -> _CacheParameters:
|
|
return self.__wrapper.cache_parameters()
|
|
|
|
async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
|
|
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
|
|
|
|
|
|
def _make_wrapper(
|
|
maxsize: Optional[int],
|
|
typed: bool,
|
|
ttl: Optional[float] = None,
|
|
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
|
|
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
|
|
origin = fn
|
|
|
|
while isinstance(origin, (partial, partialmethod)):
|
|
origin = origin.func
|
|
|
|
if not asyncio.iscoroutinefunction(origin):
|
|
raise RuntimeError(f"Coroutine function is required, got {fn!r}")
|
|
|
|
# functools.partialmethod support
|
|
if hasattr(fn, "_make_unbound_method"):
|
|
fn = fn._make_unbound_method()
|
|
|
|
return _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
|
|
|
|
return wrapper
|
|
|
|
|
|
@overload
|
|
def alru_cache(
|
|
maxsize: Optional[int] = 128,
|
|
typed: bool = False,
|
|
*,
|
|
ttl: Optional[float] = None,
|
|
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def alru_cache(
|
|
maxsize: _CBP[_R],
|
|
/,
|
|
) -> _LRUCacheWrapper[_R]:
|
|
...
|
|
|
|
|
|
def alru_cache(
|
|
maxsize: Union[Optional[int], _CBP[_R]] = 128,
|
|
typed: bool = False,
|
|
*,
|
|
ttl: Optional[float] = None,
|
|
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
|
|
if maxsize is None or isinstance(maxsize, int):
|
|
return _make_wrapper(maxsize, typed, ttl)
|
|
else:
|
|
fn = cast(_CB[_R], maxsize)
|
|
|
|
if callable(fn) or hasattr(fn, "_make_unbound_method"):
|
|
return _make_wrapper(128, False, None)(fn)
|
|
|
|
raise NotImplementedError(f"{fn!r} decorating is not supported")
|