From 25aa09309b53d70d8ef8ededdcc13e33312e8822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sat, 12 Oct 2024 14:52:36 +0200 Subject: [PATCH] Some improvements to type checking on decorators --- server/src/uds/core/util/decorators.py | 72 +++++++++++++++----------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/server/src/uds/core/util/decorators.py b/server/src/uds/core/util/decorators.py index 588851fd0..39f3f48f5 100644 --- a/server/src/uds/core/util/decorators.py +++ b/server/src/uds/core/util/decorators.py @@ -44,9 +44,8 @@ import uds.core.exceptions.rest logger = logging.getLogger(__name__) # FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any]) -T = typing.TypeVar('T') P = typing.ParamSpec('P') - +R = typing.TypeVar('R') @dataclasses.dataclass class CacheInfo: @@ -79,13 +78,13 @@ def classproperty(func: collections.abc.Callable[..., typing.Any]) -> ClassPrope return ClassPropertyDescriptor(func) -def deprecated(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]: +def deprecated(func: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]: """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used.""" @functools.wraps(func) - def new_func(*args: P.args, **kwargs: P.kwargs) -> T: + def new_func(*args: P.args, **kwargs: P.kwargs) -> R: try: caller = inspect.stack()[1] logger.warning( @@ -143,7 +142,7 @@ class _HasConnect(typing.Protocol): def connect(self) -> None: ... -# def ensure_connected(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]: +# def ensure_connected(func: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]: # Keep this, but mypy does not likes it... it's perfect with pyright # We use pyright for type checking, so we will use this @@ -151,18 +150,29 @@ HasConnect = typing.TypeVar('HasConnect', bound=_HasConnect) def ensure_connected( - func: collections.abc.Callable[typing.Concatenate[HasConnect, P], T] -) -> collections.abc.Callable[typing.Concatenate[HasConnect, P], T]: + func: collections.abc.Callable[typing.Concatenate[HasConnect, P], R] +) -> collections.abc.Callable[typing.Concatenate[HasConnect, P], R]: """This decorator calls "connect" method of the class of the wrapped object""" @functools.wraps(func) - def new_func(obj: HasConnect, /, *args: P.args, **kwargs: P.kwargs) -> T: + def new_func(obj: HasConnect, /, *args: P.args, **kwargs: P.kwargs) -> R: # self = typing.cast(_HasConnect, args[0]) obj.connect() return func(obj, *args, **kwargs) return new_func +# To be used in a future, for type checking only +# currently the problem is that the signature of a function is diferent +# thant the signature of a class method, so we can't use the same decorator +# Also, if we change the return type, abstract methods will not be able to be implemented +# because derieved will not have the same signature +# Also, R must be covariant for proper type checking +# class CacheMethods(typing.Protocol[P, R]): +# def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... +# def cache_clear(self) -> None: ... +# def cache_info(self) -> CacheInfo: ... + # Decorator for caching # This decorator will cache the result of the function for a given time, and given parameters @@ -172,7 +182,7 @@ def cached( args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None, kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None, key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None, -) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]: +) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]: """ Decorator that gives us a "quick & clean" caching feature on the database. @@ -200,7 +210,17 @@ def cached( hits = misses = exec_time = 0 - def allow_cache_decorator(fnc: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]: + # Add a couple of methods to the wrapper to allow cache statistics access and cache clearing + def cache_info() -> CacheInfo: + """Report cache statistics""" + return CacheInfo(hits, misses, hits + misses, exec_time) + + def cache_clear() -> None: + """Clear the cache and cache statistics""" + nonlocal hits, misses, exec_time + hits = misses = exec_time = 0 + + def allow_cache_decorator(fnc: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]: # If no caching args and no caching kwargs, we will cache the whole call # If no parameters provided, try to infer them from function signature try: @@ -223,12 +243,16 @@ def cached( except Exception: logger.debug('Function %s is not inspectable, no caching possible', fnc.__name__) # Not inspectable, no caching possible, return original function + + # Ensure compat with methods of cached functions + setattr(fnc, 'cache_info', cache_info) + setattr(fnc, 'cache_clear', cache_clear) return fnc key_helper_fnc: collections.abc.Callable[[typing.Any], str] = key_helper or (lambda x: fnc.__name__) @functools.wraps(fnc) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: nonlocal hits, misses, exec_time cache_key: str = prefix or fnc.__name__ @@ -290,19 +314,9 @@ def cached( ) return data - # Add a couple of methods to the wrapper to allow cache statistics access and cache clearing - def cache_info() -> CacheInfo: - """Report cache statistics""" - return CacheInfo(hits, misses, hits + misses, exec_time) - - def cache_clear() -> None: - """Clear the cache and cache statistics""" - nonlocal hits, misses, exec_time - hits = misses = exec_time = 0 - # Same as lru_cache - wrapper.cache_info = cache_info # type: ignore - wrapper.cache_clear = cache_clear # type: ignore + setattr(wrapper, 'cache_info', cache_info) + setattr(wrapper, 'cache_clear', cache_clear) return wrapper @@ -325,7 +339,7 @@ def blocker( request_attr: typing.Optional[str] = None, max_failures: typing.Optional[int] = None, ignore_block_config: bool = False, -) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]: +) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]: """ Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures @@ -348,9 +362,9 @@ def blocker( max_failures = max_failures or consts.system.ALLOWED_FAILS - def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]: + def decorator(f: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]: @functools.wraps(f) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config: return f(*args, **kwargs) @@ -390,7 +404,7 @@ def blocker( def profiler( log_file: typing.Optional[str] = None, -) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]: +) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]: """ Decorator that will profile the wrapped function and log the results to the provided file @@ -401,10 +415,10 @@ def profiler( Decorator """ - def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]: + def decorator(f: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]: @functools.wraps(f) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: nonlocal log_file # use outer log_file import cProfile import pstats