1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Improved decorators signatures & removed transaction.atomic from cached call

This commit is contained in:
Adolfo Gómez García 2024-03-10 16:22:10 +01:00
parent 6ab0307bdd
commit a3f50e739a
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 109 additions and 94 deletions

View File

@ -1,6 +1,7 @@
[mypy]
#plugins =
# mypy_django_plugin.main
python_version = 3.11
# Exclude all .*/transports/.*/scripts/.* directories
exclude = .*/transports/.*/scripts/.*

View File

@ -37,8 +37,6 @@ import time
import typing
import collections.abc
from django.db import transaction
from uds.core import consts, types, exceptions
from uds.core.util import singleton
@ -46,7 +44,9 @@ import uds.core.exceptions.rest
logger = logging.getLogger(__name__)
FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
# FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
T = typing.TypeVar('T')
P = typing.ParamSpec('P')
# Caching statistics
@ -130,13 +130,13 @@ def classproperty(func: collections.abc.Callable[..., typing.Any]) -> ClassPrope
return ClassPropertyDescriptor(func)
def deprecated(func: FT) -> FT:
def deprecated(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used."""
@functools.wraps(func)
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
try:
caller = inspect.stack()[1]
logger.warning(
@ -150,7 +150,7 @@ def deprecated(func: FT) -> FT:
return func(*args, **kwargs)
return typing.cast(FT, new_func)
return new_func
def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., typing.Any]:
@ -189,14 +189,15 @@ def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., t
return functools.partial(innerDeprecated, newVarName=new_var_name)
def ensure_connected(func: FT) -> FT:
def ensure_connected(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
"""This decorator calls "connect" method of the class of the wrapped object"""
@functools.wraps(func)
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
args[0].connect()
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
args[0].connect() # type: ignore
return func(*args, **kwargs)
return typing.cast(FT, new_func)
return new_func
# Decorator for caching
@ -207,21 +208,22 @@ def cached(
args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None,
kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None,
key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None,
) -> collections.abc.Callable[[FT], FT]:
"""Decorator that give us a "quick& clean" caching feature on db.
The "cached" element must provide a "cache" variable, which is a cache object
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
"""
Decorator that gives us a "quick & clean" caching feature on the database.
Parameters:
prefix: Prefix to use for cache key
timeout: Timeout for cache
args: List of arguments to use for cache key (i.e. [0, 1] will use first and second argument for cache key, 0 will use "self" if a method, and 1 will use first argument)
kwargs: List of keyword arguments to use for cache key (i.e. ['a', 'b'] will use "a" and "b" keyword arguments for cache key)
key_fnc: Function to use for cache key. If provided, this function will be called with the same arguments as the wrapped function, and must return a string to use as cache key
prefix (str): Prefix to use for the cache key.
timeout (Union[Callable[[], int], int], optional): Timeout for the cache in seconds. If -1, it will use the default timeout. Defaults to -1.
args (Optional[Union[Iterable[int], int]], optional): List of arguments to use for the cache key. If an integer is provided, it will be treated as a single argument. Defaults to None.
kwargs (Optional[Union[Iterable[str], str]], optional): List of keyword arguments to use for the cache key. If a string is provided, it will be treated as a single keyword argument. Defaults to None.
key_helper (Optional[Callable[[Any], str]], optional): Function to use for improving the calculated cache key. Defaults to None.
Note:
If args and kwargs are not provided, all parameters (except *args and **kwargs) will be used for building cache key
If `args` and `kwargs` are not provided, all parameters (except `*args` and `**kwargs`) will be used for building the cache key.
Note:
The `key_helper` function will receive the first argument of the function (`self`) and must return a string that will be appended to the cache key.
"""
from uds.core.util.cache import Cache # To avoid circular references
@ -229,13 +231,14 @@ def cached(
args_list: list[int] = [args] if isinstance(args, int) else list(args or [])
kwargs_list = [kwargs] if isinstance(kwargs, str) else list(kwargs or [])
# Lock for stats concurrency
lock = threading.Lock()
hits = misses = exec_time = 0
def allow_cache_decorator(fnc: FT) -> FT:
def allow_cache_decorator(fnc: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
# If no caching args and no caching kwargs, we will cache the whole call
# If no parameters provider, try to infer them from function signature
# If no parameters provided, try to infer them from function signature
try:
if not args_list and not kwargs_list:
for pos, (param_name, param) in enumerate(inspect.signature(fnc).parameters.items()):
@ -254,68 +257,71 @@ def cached(
kwargs_list.append(param_name)
# *args and **kwargs are not supported as cache parameters
except Exception:
logger.debug('Function %s is not inspectable, no caching possible', fnc.__name__)
# Not inspectable, no caching possible, return original function
return fnc
lkey_fnc: collections.abc.Callable[[str], str] = key_helper or (lambda x: fnc.__name__)
key_helper_fnc: collections.abc.Callable[[typing.Any], str] = key_helper or (lambda x: fnc.__name__)
@functools.wraps(fnc)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
nonlocal hits, misses, exec_time
with transaction.atomic(): # On its own transaction (for cache operations, that are on DB)
cache_key: str = prefix
for i in args_list:
if i < len(args):
cache_key += str(args[i])
for s in kwargs_list:
cache_key += str(kwargs.get(s, ''))
# Append key data
cache_key += lkey_fnc(args[0] if len(args) > 0 else fnc.__name__)
cache_key: str = prefix
for i in args_list:
if i < len(args):
cache_key += str(args[i])
for s in kwargs_list:
cache_key += str(kwargs.get(s, ''))
# Append key data
cache_key += key_helper_fnc(args[0] if len(args) > 0 else fnc.__name__)
# Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long
# Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long
# Get cache from object if present, or use the global 'functionCache' (generic, common to all objects)
inner_cache: 'Cache|None' = getattr(args[0], 'cache', None) if len(args) > 0 else None
cache: 'Cache' = inner_cache or Cache('functionCache')
# Get cache from object if present, or use the global 'functionCache' (generic, common to all objects)
inner_cache: 'Cache|None' = getattr(args[0], 'cache', None) if len(args) > 0 else None
cache: 'Cache' = inner_cache or Cache('functionCache')
# if timeout is a function, call it
ltimeout = timeout() if callable(timeout) else timeout
# if timeout is a function, call it
effective_timeout = timeout() if callable(timeout) else timeout
data: typing.Any = None
# If misses is 0, we are starting, so we will not try to get from cache
if not kwargs.get('force', False) and ltimeout > 0 and misses > 0:
data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND)
if data is not consts.cache.CACHE_NOT_FOUND:
with lock:
hits += 1
CacheStats.manager().add_hit(exec_time // hits) # Use mean execution time
return data
data: typing.Any = None
# If misses is 0, we are starting, so we will not try to get from cache
if not kwargs.get('force', False) and effective_timeout > 0 and misses > 0:
data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND)
if data is not consts.cache.CACHE_NOT_FOUND:
with lock:
hits += 1
CacheStats.manager().add_hit(exec_time // hits) # Use mean execution time
return data
with lock:
misses += 1
CacheStats.manager().add_miss()
with lock:
misses += 1
CacheStats.manager().add_miss()
if 'force' in kwargs:
# Remove force key
del kwargs['force']
if 'force' in kwargs:
# Remove force key
del kwargs['force']
t = time.thread_time_ns()
data = fnc(*args, **kwargs)
exec_time += time.thread_time_ns() - t
# Execute the function outside the DB transaction
t = time.thread_time_ns()
data = fnc(*args, **kwargs)
exec_time += time.thread_time_ns() - t
try:
# Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this
cache.put(cache_key, data, ltimeout)
except Exception as e:
logger.debug(
'Data for %s is not serializable on call to %s, not cached. %s (%s)',
cache_key,
fnc.__name__,
data,
e,
)
return data
try:
# Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this
cache.put(cache_key, data, effective_timeout)
except Exception as e:
logger.debug(
'Data for %s is not serializable on call to %s, not cached. %s (%s)',
cache_key,
fnc.__name__,
data,
e,
)
return data
# Add a couple of methods to the wrapper to allow cache statistics access and cache clearing
def cache_info() -> CacheInfo:
"""Report cache statistics"""
with lock:
@ -331,28 +337,28 @@ def cached(
wrapper.cache_info = cache_info # type: ignore
wrapper.cache_clear = cache_clear # type: ignore
return typing.cast(FT, wrapper)
return wrapper
return allow_cache_decorator
# Decorator to execute method in a thread
def threaded(func: FT) -> FT:
def threaded(func: collections.abc.Callable[P, None]) -> collections.abc.Callable[P, None]:
"""Decorator to execute method in a thread"""
@functools.wraps(func)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> None:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
thread.start()
return typing.cast(FT, wrapper)
return wrapper
def blocker(
request_attr: typing.Optional[str] = None,
max_failures: typing.Optional[int] = None,
ignore_block_config: bool = False,
) -> collections.abc.Callable[[FT], FT]:
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
"""
Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds
GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures
@ -371,13 +377,13 @@ def blocker(
from uds.core.util.cache import Cache # To avoid circular references
from uds.core.util.config import GlobalConfig
blockCache = Cache('uds:blocker') # Cache for blocked ips
mycache = Cache('uds:blocker') # Cache for blocked ips
max_failures = max_failures or consts.system.ALLOWED_FAILS
def decorator(f: FT) -> FT:
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
@functools.wraps(f)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config:
return f(*args, **kwargs)
@ -392,33 +398,33 @@ def blocker(
ip = request.ip
# if ip is blocked, raise exception
failuresCount = blockCache.get(ip, 0)
if failuresCount >= max_failures:
failures_count = mycache.get(ip, 0)
if failures_count >= max_failures:
raise exceptions.rest.AccessDenied
try:
result = f(*args, **kwargs)
except uds.core.exceptions.rest.BlockAccess:
# Increment
blockCache.put(ip, failuresCount + 1, GlobalConfig.LOGIN_BLOCK.as_int())
mycache.put(ip, failures_count + 1, GlobalConfig.LOGIN_BLOCK.as_int())
raise exceptions.rest.AccessDenied
# Any other exception will be raised
except Exception:
raise
# If we are here, it means that the call was successfull, so we reset the counter
blockCache.delete(ip)
mycache.delete(ip)
return result
return typing.cast(FT, wrapper)
return wrapper
return decorator
def profile(
log_file: typing.Optional[str] = None,
) -> collections.abc.Callable[[FT], FT]:
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
"""
Decorator that will profile the wrapped function and log the results to the provided file
@ -429,9 +435,11 @@ def profile(
Decorator
"""
def decorator(f: FT) -> FT:
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
nonlocal log_file
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
nonlocal log_file # use outer log_file
import cProfile
import pstats
import tempfile
@ -446,6 +454,6 @@ def profile(
stats.dump_stats(log_file)
return result
return typing.cast(FT, wrapper)
return wrapper
return decorator

View File

@ -44,7 +44,7 @@ from .linux_osmanager import LinuxOsManager
from .linux_randompass_osmanager import LinuxRandomPassManager
from .linux_ad_osmanager import LinuxOsADManager
_mypath = os.path.dirname(__spec__.origin)
_mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
# Old version, using spec is better, but we can use __package__ as well
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore

View File

@ -42,7 +42,7 @@ from .windows import WindowsOsManager
from .windows_domain import WinDomainOsManager
from .windows_random import WinRandomPassManager
_mypath = os.path.dirname(__spec__.origin)
_mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
# Old version, using spec is better, but we can use __package__ as well
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore

View File

@ -64,20 +64,26 @@ VOLUMES_ENDPOINT_TYPES = [
COMPUTE_ENDPOINT_TYPES = ['compute', 'compute_legacy']
NETWORKS_ENDPOINT_TYPES = ['network']
T = typing.TypeVar('T')
P = typing.ParamSpec('P')
# Decorators
def auth_required(for_project: bool = False) -> collections.abc.Callable[[decorators.FT], decorators.FT]:
def auth_required(
for_project: bool = False,
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
def decorator(func: decorators.FT) -> decorators.FT:
def decorator(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
@functools.wraps(func)
def wrapper(obj: 'OpenstackClient', *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> typing.Any:
obj = typing.cast('OpenstackClient', args[0])
if for_project is True:
if obj._projectid is None:
raise Exception('Need a project for method {}'.format(func))
obj.ensure_authenticated()
return func(obj, *args, **kwargs)
return func(*args, **kwargs)
return typing.cast(decorators.FT, wrapper)
return wrapper
return decorator