1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-08 21:18:00 +03:00

Improved decorators signatures & removed transaction.atomic from cached call

This commit is contained in:
Adolfo Gómez García 2024-03-10 16:22:10 +01:00
parent 6ab0307bdd
commit a3f50e739a
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 109 additions and 94 deletions

View File

@ -1,6 +1,7 @@
[mypy] [mypy]
#plugins = #plugins =
# mypy_django_plugin.main # mypy_django_plugin.main
python_version = 3.11
# Exclude all .*/transports/.*/scripts/.* directories # Exclude all .*/transports/.*/scripts/.* directories
exclude = .*/transports/.*/scripts/.* exclude = .*/transports/.*/scripts/.*

View File

@ -37,8 +37,6 @@ import time
import typing import typing
import collections.abc import collections.abc
from django.db import transaction
from uds.core import consts, types, exceptions from uds.core import consts, types, exceptions
from uds.core.util import singleton from uds.core.util import singleton
@ -46,7 +44,9 @@ import uds.core.exceptions.rest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any]) # FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
T = typing.TypeVar('T')
P = typing.ParamSpec('P')
# Caching statistics # Caching statistics
@ -130,13 +130,13 @@ def classproperty(func: collections.abc.Callable[..., typing.Any]) -> ClassPrope
return ClassPropertyDescriptor(func) return ClassPropertyDescriptor(func)
def deprecated(func: FT) -> FT: def deprecated(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
"""This is a decorator which can be used to mark functions """This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted as deprecated. It will result in a warning being emitted
when the function is used.""" when the function is used."""
@functools.wraps(func) @functools.wraps(func)
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
try: try:
caller = inspect.stack()[1] caller = inspect.stack()[1]
logger.warning( logger.warning(
@ -150,7 +150,7 @@ def deprecated(func: FT) -> FT:
return func(*args, **kwargs) return func(*args, **kwargs)
return typing.cast(FT, new_func) return new_func
def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., typing.Any]: def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., typing.Any]:
@ -189,14 +189,15 @@ def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., t
return functools.partial(innerDeprecated, newVarName=new_var_name) return functools.partial(innerDeprecated, newVarName=new_var_name)
def ensure_connected(func: FT) -> FT: def ensure_connected(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
"""This decorator calls "connect" method of the class of the wrapped object""" """This decorator calls "connect" method of the class of the wrapped object"""
@functools.wraps(func) @functools.wraps(func)
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
args[0].connect() args[0].connect() # type: ignore
return func(*args, **kwargs) return func(*args, **kwargs)
return typing.cast(FT, new_func) return new_func
# Decorator for caching # Decorator for caching
@ -207,21 +208,22 @@ def cached(
args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None, args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None,
kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None, kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None,
key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None, key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None,
) -> collections.abc.Callable[[FT], FT]: ) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
"""Decorator that give us a "quick& clean" caching feature on db. """
The "cached" element must provide a "cache" variable, which is a cache object Decorator that gives us a "quick & clean" caching feature on the database.
Parameters: Parameters:
prefix: Prefix to use for cache key prefix (str): Prefix to use for the cache key.
timeout: Timeout for cache timeout (Union[Callable[[], int], int], optional): Timeout for the cache in seconds. If -1, it will use the default timeout. Defaults to -1.
args: List of arguments to use for cache key (i.e. [0, 1] will use first and second argument for cache key, 0 will use "self" if a method, and 1 will use first argument) args (Optional[Union[Iterable[int], int]], optional): List of arguments to use for the cache key. If an integer is provided, it will be treated as a single argument. Defaults to None.
kwargs: List of keyword arguments to use for cache key (i.e. ['a', 'b'] will use "a" and "b" keyword arguments for cache key) kwargs (Optional[Union[Iterable[str], str]], optional): List of keyword arguments to use for the cache key. If a string is provided, it will be treated as a single keyword argument. Defaults to None.
key_fnc: Function to use for cache key. If provided, this function will be called with the same arguments as the wrapped function, and must return a string to use as cache key key_helper (Optional[Callable[[Any], str]], optional): Function to use for improving the calculated cache key. Defaults to None.
Note: Note:
If args and kwargs are not provided, all parameters (except *args and **kwargs) will be used for building cache key If `args` and `kwargs` are not provided, all parameters (except `*args` and `**kwargs`) will be used for building the cache key.
Note:
The `key_helper` function will receive the first argument of the function (`self`) and must return a string that will be appended to the cache key.
""" """
from uds.core.util.cache import Cache # To avoid circular references from uds.core.util.cache import Cache # To avoid circular references
@ -229,13 +231,14 @@ def cached(
args_list: list[int] = [args] if isinstance(args, int) else list(args or []) args_list: list[int] = [args] if isinstance(args, int) else list(args or [])
kwargs_list = [kwargs] if isinstance(kwargs, str) else list(kwargs or []) kwargs_list = [kwargs] if isinstance(kwargs, str) else list(kwargs or [])
# Lock for stats concurrency
lock = threading.Lock() lock = threading.Lock()
hits = misses = exec_time = 0 hits = misses = exec_time = 0
def allow_cache_decorator(fnc: FT) -> FT: def allow_cache_decorator(fnc: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
# If no caching args and no caching kwargs, we will cache the whole call # If no caching args and no caching kwargs, we will cache the whole call
# If no parameters provider, try to infer them from function signature # If no parameters provided, try to infer them from function signature
try: try:
if not args_list and not kwargs_list: if not args_list and not kwargs_list:
for pos, (param_name, param) in enumerate(inspect.signature(fnc).parameters.items()): for pos, (param_name, param) in enumerate(inspect.signature(fnc).parameters.items()):
@ -254,15 +257,16 @@ def cached(
kwargs_list.append(param_name) kwargs_list.append(param_name)
# *args and **kwargs are not supported as cache parameters # *args and **kwargs are not supported as cache parameters
except Exception: except Exception:
logger.debug('Function %s is not inspectable, no caching possible', fnc.__name__)
# Not inspectable, no caching possible, return original function # Not inspectable, no caching possible, return original function
return fnc return fnc
lkey_fnc: collections.abc.Callable[[str], str] = key_helper or (lambda x: fnc.__name__) key_helper_fnc: collections.abc.Callable[[typing.Any], str] = key_helper or (lambda x: fnc.__name__)
@functools.wraps(fnc) @functools.wraps(fnc)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
nonlocal hits, misses, exec_time nonlocal hits, misses, exec_time
with transaction.atomic(): # On its own transaction (for cache operations, that are on DB)
cache_key: str = prefix cache_key: str = prefix
for i in args_list: for i in args_list:
if i < len(args): if i < len(args):
@ -270,7 +274,7 @@ def cached(
for s in kwargs_list: for s in kwargs_list:
cache_key += str(kwargs.get(s, '')) cache_key += str(kwargs.get(s, ''))
# Append key data # Append key data
cache_key += lkey_fnc(args[0] if len(args) > 0 else fnc.__name__) cache_key += key_helper_fnc(args[0] if len(args) > 0 else fnc.__name__)
# Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long # Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long
@ -279,11 +283,11 @@ def cached(
cache: 'Cache' = inner_cache or Cache('functionCache') cache: 'Cache' = inner_cache or Cache('functionCache')
# if timeout is a function, call it # if timeout is a function, call it
ltimeout = timeout() if callable(timeout) else timeout effective_timeout = timeout() if callable(timeout) else timeout
data: typing.Any = None data: typing.Any = None
# If misses is 0, we are starting, so we will not try to get from cache # If misses is 0, we are starting, so we will not try to get from cache
if not kwargs.get('force', False) and ltimeout > 0 and misses > 0: if not kwargs.get('force', False) and effective_timeout > 0 and misses > 0:
data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND) data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND)
if data is not consts.cache.CACHE_NOT_FOUND: if data is not consts.cache.CACHE_NOT_FOUND:
with lock: with lock:
@ -299,13 +303,14 @@ def cached(
# Remove force key # Remove force key
del kwargs['force'] del kwargs['force']
# Execute the function outside the DB transaction
t = time.thread_time_ns() t = time.thread_time_ns()
data = fnc(*args, **kwargs) data = fnc(*args, **kwargs)
exec_time += time.thread_time_ns() - t exec_time += time.thread_time_ns() - t
try: try:
# Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this # Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this
cache.put(cache_key, data, ltimeout) cache.put(cache_key, data, effective_timeout)
except Exception as e: except Exception as e:
logger.debug( logger.debug(
'Data for %s is not serializable on call to %s, not cached. %s (%s)', 'Data for %s is not serializable on call to %s, not cached. %s (%s)',
@ -316,6 +321,7 @@ def cached(
) )
return data return data
# Add a couple of methods to the wrapper to allow cache statistics access and cache clearing
def cache_info() -> CacheInfo: def cache_info() -> CacheInfo:
"""Report cache statistics""" """Report cache statistics"""
with lock: with lock:
@ -331,28 +337,28 @@ def cached(
wrapper.cache_info = cache_info # type: ignore wrapper.cache_info = cache_info # type: ignore
wrapper.cache_clear = cache_clear # type: ignore wrapper.cache_clear = cache_clear # type: ignore
return typing.cast(FT, wrapper) return wrapper
return allow_cache_decorator return allow_cache_decorator
# Decorator to execute method in a thread # Decorator to execute method in a thread
def threaded(func: FT) -> FT: def threaded(func: collections.abc.Callable[P, None]) -> collections.abc.Callable[P, None]:
"""Decorator to execute method in a thread""" """Decorator to execute method in a thread"""
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> None: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
thread = threading.Thread(target=func, args=args, kwargs=kwargs) thread = threading.Thread(target=func, args=args, kwargs=kwargs)
thread.start() thread.start()
return typing.cast(FT, wrapper) return wrapper
def blocker( def blocker(
request_attr: typing.Optional[str] = None, request_attr: typing.Optional[str] = None,
max_failures: typing.Optional[int] = None, max_failures: typing.Optional[int] = None,
ignore_block_config: bool = False, ignore_block_config: bool = False,
) -> collections.abc.Callable[[FT], FT]: ) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
""" """
Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds
GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures
@ -371,13 +377,13 @@ def blocker(
from uds.core.util.cache import Cache # To avoid circular references from uds.core.util.cache import Cache # To avoid circular references
from uds.core.util.config import GlobalConfig from uds.core.util.config import GlobalConfig
blockCache = Cache('uds:blocker') # Cache for blocked ips mycache = Cache('uds:blocker') # Cache for blocked ips
max_failures = max_failures or consts.system.ALLOWED_FAILS max_failures = max_failures or consts.system.ALLOWED_FAILS
def decorator(f: FT) -> FT: def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config: if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config:
return f(*args, **kwargs) return f(*args, **kwargs)
@ -392,33 +398,33 @@ def blocker(
ip = request.ip ip = request.ip
# if ip is blocked, raise exception # if ip is blocked, raise exception
failuresCount = blockCache.get(ip, 0) failures_count = mycache.get(ip, 0)
if failuresCount >= max_failures: if failures_count >= max_failures:
raise exceptions.rest.AccessDenied raise exceptions.rest.AccessDenied
try: try:
result = f(*args, **kwargs) result = f(*args, **kwargs)
except uds.core.exceptions.rest.BlockAccess: except uds.core.exceptions.rest.BlockAccess:
# Increment # Increment
blockCache.put(ip, failuresCount + 1, GlobalConfig.LOGIN_BLOCK.as_int()) mycache.put(ip, failures_count + 1, GlobalConfig.LOGIN_BLOCK.as_int())
raise exceptions.rest.AccessDenied raise exceptions.rest.AccessDenied
# Any other exception will be raised # Any other exception will be raised
except Exception: except Exception:
raise raise
# If we are here, it means that the call was successfull, so we reset the counter # If we are here, it means that the call was successfull, so we reset the counter
blockCache.delete(ip) mycache.delete(ip)
return result return result
return typing.cast(FT, wrapper) return wrapper
return decorator return decorator
def profile( def profile(
log_file: typing.Optional[str] = None, log_file: typing.Optional[str] = None,
) -> collections.abc.Callable[[FT], FT]: ) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
""" """
Decorator that will profile the wrapped function and log the results to the provided file Decorator that will profile the wrapped function and log the results to the provided file
@ -429,9 +435,11 @@ def profile(
Decorator Decorator
""" """
def decorator(f: FT) -> FT: def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
nonlocal log_file @functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
nonlocal log_file # use outer log_file
import cProfile import cProfile
import pstats import pstats
import tempfile import tempfile
@ -446,6 +454,6 @@ def profile(
stats.dump_stats(log_file) stats.dump_stats(log_file)
return result return result
return typing.cast(FT, wrapper) return wrapper
return decorator return decorator

View File

@ -44,7 +44,7 @@ from .linux_osmanager import LinuxOsManager
from .linux_randompass_osmanager import LinuxRandomPassManager from .linux_randompass_osmanager import LinuxRandomPassManager
from .linux_ad_osmanager import LinuxOsADManager from .linux_ad_osmanager import LinuxOsADManager
_mypath = os.path.dirname(__spec__.origin) _mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
# Old version, using spec is better, but we can use __package__ as well # Old version, using spec is better, but we can use __package__ as well
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore #_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore

View File

@ -42,7 +42,7 @@ from .windows import WindowsOsManager
from .windows_domain import WinDomainOsManager from .windows_domain import WinDomainOsManager
from .windows_random import WinRandomPassManager from .windows_random import WinRandomPassManager
_mypath = os.path.dirname(__spec__.origin) _mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
# Old version, using spec is better, but we can use __package__ as well # Old version, using spec is better, but we can use __package__ as well
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore #_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore

View File

@ -64,20 +64,26 @@ VOLUMES_ENDPOINT_TYPES = [
COMPUTE_ENDPOINT_TYPES = ['compute', 'compute_legacy'] COMPUTE_ENDPOINT_TYPES = ['compute', 'compute_legacy']
NETWORKS_ENDPOINT_TYPES = ['network'] NETWORKS_ENDPOINT_TYPES = ['network']
T = typing.TypeVar('T')
P = typing.ParamSpec('P')
# Decorators # Decorators
def auth_required(for_project: bool = False) -> collections.abc.Callable[[decorators.FT], decorators.FT]: def auth_required(
for_project: bool = False,
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
def decorator(func: decorators.FT) -> decorators.FT: def decorator(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
@functools.wraps(func) @functools.wraps(func)
def wrapper(obj: 'OpenstackClient', *args: typing.Any, **kwargs: typing.Any) -> typing.Any: def wrapper(*args: P.args, **kwargs: P.kwargs) -> typing.Any:
obj = typing.cast('OpenstackClient', args[0])
if for_project is True: if for_project is True:
if obj._projectid is None: if obj._projectid is None:
raise Exception('Need a project for method {}'.format(func)) raise Exception('Need a project for method {}'.format(func))
obj.ensure_authenticated() obj.ensure_authenticated()
return func(obj, *args, **kwargs) return func(*args, **kwargs)
return typing.cast(decorators.FT, wrapper) return wrapper
return decorator return decorator