1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Some improvements to type checking on decorators

This commit is contained in:
Adolfo Gómez García 2024-10-12 14:52:36 +02:00
parent 203a46a804
commit 25aa09309b
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23

View File

@ -44,9 +44,8 @@ import uds.core.exceptions.rest
logger = logging.getLogger(__name__)
# FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
T = typing.TypeVar('T')
P = typing.ParamSpec('P')
R = typing.TypeVar('R')
@dataclasses.dataclass
class CacheInfo:
@ -79,13 +78,13 @@ def classproperty(func: collections.abc.Callable[..., typing.Any]) -> ClassPrope
return ClassPropertyDescriptor(func)
def deprecated(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
def deprecated(func: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]:
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used."""
@functools.wraps(func)
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
def new_func(*args: P.args, **kwargs: P.kwargs) -> R:
try:
caller = inspect.stack()[1]
logger.warning(
@ -143,7 +142,7 @@ class _HasConnect(typing.Protocol):
def connect(self) -> None: ...
# def ensure_connected(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
# def ensure_connected(func: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]:
# Keep this, but mypy does not likes it... it's perfect with pyright
# We use pyright for type checking, so we will use this
@ -151,18 +150,29 @@ HasConnect = typing.TypeVar('HasConnect', bound=_HasConnect)
def ensure_connected(
func: collections.abc.Callable[typing.Concatenate[HasConnect, P], T]
) -> collections.abc.Callable[typing.Concatenate[HasConnect, P], T]:
func: collections.abc.Callable[typing.Concatenate[HasConnect, P], R]
) -> collections.abc.Callable[typing.Concatenate[HasConnect, P], R]:
"""This decorator calls "connect" method of the class of the wrapped object"""
@functools.wraps(func)
def new_func(obj: HasConnect, /, *args: P.args, **kwargs: P.kwargs) -> T:
def new_func(obj: HasConnect, /, *args: P.args, **kwargs: P.kwargs) -> R:
# self = typing.cast(_HasConnect, args[0])
obj.connect()
return func(obj, *args, **kwargs)
return new_func
# To be used in a future, for type checking only
# currently the problem is that the signature of a function is diferent
# thant the signature of a class method, so we can't use the same decorator
# Also, if we change the return type, abstract methods will not be able to be implemented
# because derieved will not have the same signature
# Also, R must be covariant for proper type checking
# class CacheMethods(typing.Protocol[P, R]):
# def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
# def cache_clear(self) -> None: ...
# def cache_info(self) -> CacheInfo: ...
# Decorator for caching
# This decorator will cache the result of the function for a given time, and given parameters
@ -172,7 +182,7 @@ def cached(
args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None,
kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None,
key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None,
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
Decorator that gives us a "quick & clean" caching feature on the database.
@ -200,7 +210,17 @@ def cached(
hits = misses = exec_time = 0
def allow_cache_decorator(fnc: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
# Add a couple of methods to the wrapper to allow cache statistics access and cache clearing
def cache_info() -> CacheInfo:
"""Report cache statistics"""
return CacheInfo(hits, misses, hits + misses, exec_time)
def cache_clear() -> None:
"""Clear the cache and cache statistics"""
nonlocal hits, misses, exec_time
hits = misses = exec_time = 0
def allow_cache_decorator(fnc: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]:
# If no caching args and no caching kwargs, we will cache the whole call
# If no parameters provided, try to infer them from function signature
try:
@ -223,12 +243,16 @@ def cached(
except Exception:
logger.debug('Function %s is not inspectable, no caching possible', fnc.__name__)
# Not inspectable, no caching possible, return original function
# Ensure compat with methods of cached functions
setattr(fnc, 'cache_info', cache_info)
setattr(fnc, 'cache_clear', cache_clear)
return fnc
key_helper_fnc: collections.abc.Callable[[typing.Any], str] = key_helper or (lambda x: fnc.__name__)
@functools.wraps(fnc)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
nonlocal hits, misses, exec_time
cache_key: str = prefix or fnc.__name__
@ -290,19 +314,9 @@ def cached(
)
return data
# Add a couple of methods to the wrapper to allow cache statistics access and cache clearing
def cache_info() -> CacheInfo:
"""Report cache statistics"""
return CacheInfo(hits, misses, hits + misses, exec_time)
def cache_clear() -> None:
"""Clear the cache and cache statistics"""
nonlocal hits, misses, exec_time
hits = misses = exec_time = 0
# Same as lru_cache
wrapper.cache_info = cache_info # type: ignore
wrapper.cache_clear = cache_clear # type: ignore
setattr(wrapper, 'cache_info', cache_info)
setattr(wrapper, 'cache_clear', cache_clear)
return wrapper
@ -325,7 +339,7 @@ def blocker(
request_attr: typing.Optional[str] = None,
max_failures: typing.Optional[int] = None,
ignore_block_config: bool = False,
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds
GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures
@ -348,9 +362,9 @@ def blocker(
max_failures = max_failures or consts.system.ALLOWED_FAILS
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
def decorator(f: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config:
return f(*args, **kwargs)
@ -390,7 +404,7 @@ def blocker(
def profiler(
log_file: typing.Optional[str] = None,
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
Decorator that will profile the wrapped function and log the results to the provided file
@ -401,10 +415,10 @@ def profiler(
Decorator
"""
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
def decorator(f: collections.abc.Callable[P, R]) -> collections.abc.Callable[P, R]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
nonlocal log_file # use outer log_file
import cProfile
import pstats