mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Improved decorators signatures & removed transaction.atomic from cached call
This commit is contained in:
parent
6ab0307bdd
commit
a3f50e739a
@ -1,6 +1,7 @@
|
||||
[mypy]
|
||||
#plugins =
|
||||
# mypy_django_plugin.main
|
||||
python_version = 3.11
|
||||
|
||||
# Exclude all .*/transports/.*/scripts/.* directories
|
||||
exclude = .*/transports/.*/scripts/.*
|
||||
|
@ -37,8 +37,6 @@ import time
|
||||
import typing
|
||||
import collections.abc
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from uds.core import consts, types, exceptions
|
||||
from uds.core.util import singleton
|
||||
|
||||
@ -46,7 +44,9 @@ import uds.core.exceptions.rest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
|
||||
# FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
|
||||
T = typing.TypeVar('T')
|
||||
P = typing.ParamSpec('P')
|
||||
|
||||
|
||||
# Caching statistics
|
||||
@ -130,13 +130,13 @@ def classproperty(func: collections.abc.Callable[..., typing.Any]) -> ClassPrope
|
||||
return ClassPropertyDescriptor(func)
|
||||
|
||||
|
||||
def deprecated(func: FT) -> FT:
|
||||
def deprecated(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
"""This is a decorator which can be used to mark functions
|
||||
as deprecated. It will result in a warning being emitted
|
||||
when the function is used."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
caller = inspect.stack()[1]
|
||||
logger.warning(
|
||||
@ -150,7 +150,7 @@ def deprecated(func: FT) -> FT:
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return typing.cast(FT, new_func)
|
||||
return new_func
|
||||
|
||||
|
||||
def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., typing.Any]:
|
||||
@ -189,14 +189,15 @@ def deprecated_class_value(new_var_name: str) -> collections.abc.Callable[..., t
|
||||
return functools.partial(innerDeprecated, newVarName=new_var_name)
|
||||
|
||||
|
||||
def ensure_connected(func: FT) -> FT:
|
||||
def ensure_connected(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
"""This decorator calls "connect" method of the class of the wrapped object"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
args[0].connect()
|
||||
def new_func(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
args[0].connect() # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return typing.cast(FT, new_func)
|
||||
return new_func
|
||||
|
||||
|
||||
# Decorator for caching
|
||||
@ -207,21 +208,22 @@ def cached(
|
||||
args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None,
|
||||
kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None,
|
||||
key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None,
|
||||
) -> collections.abc.Callable[[FT], FT]:
|
||||
"""Decorator that give us a "quick& clean" caching feature on db.
|
||||
The "cached" element must provide a "cache" variable, which is a cache object
|
||||
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
|
||||
"""
|
||||
Decorator that gives us a "quick & clean" caching feature on the database.
|
||||
|
||||
Parameters:
|
||||
prefix: Prefix to use for cache key
|
||||
timeout: Timeout for cache
|
||||
args: List of arguments to use for cache key (i.e. [0, 1] will use first and second argument for cache key, 0 will use "self" if a method, and 1 will use first argument)
|
||||
kwargs: List of keyword arguments to use for cache key (i.e. ['a', 'b'] will use "a" and "b" keyword arguments for cache key)
|
||||
key_fnc: Function to use for cache key. If provided, this function will be called with the same arguments as the wrapped function, and must return a string to use as cache key
|
||||
prefix (str): Prefix to use for the cache key.
|
||||
timeout (Union[Callable[[], int], int], optional): Timeout for the cache in seconds. If -1, it will use the default timeout. Defaults to -1.
|
||||
args (Optional[Union[Iterable[int], int]], optional): List of arguments to use for the cache key. If an integer is provided, it will be treated as a single argument. Defaults to None.
|
||||
kwargs (Optional[Union[Iterable[str], str]], optional): List of keyword arguments to use for the cache key. If a string is provided, it will be treated as a single keyword argument. Defaults to None.
|
||||
key_helper (Optional[Callable[[Any], str]], optional): Function to use for improving the calculated cache key. Defaults to None.
|
||||
|
||||
Note:
|
||||
If args and kwargs are not provided, all parameters (except *args and **kwargs) will be used for building cache key
|
||||
|
||||
If `args` and `kwargs` are not provided, all parameters (except `*args` and `**kwargs`) will be used for building the cache key.
|
||||
|
||||
Note:
|
||||
The `key_helper` function will receive the first argument of the function (`self`) and must return a string that will be appended to the cache key.
|
||||
"""
|
||||
from uds.core.util.cache import Cache # To avoid circular references
|
||||
|
||||
@ -229,13 +231,14 @@ def cached(
|
||||
args_list: list[int] = [args] if isinstance(args, int) else list(args or [])
|
||||
kwargs_list = [kwargs] if isinstance(kwargs, str) else list(kwargs or [])
|
||||
|
||||
# Lock for stats concurrency
|
||||
lock = threading.Lock()
|
||||
|
||||
hits = misses = exec_time = 0
|
||||
|
||||
def allow_cache_decorator(fnc: FT) -> FT:
|
||||
def allow_cache_decorator(fnc: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
# If no caching args and no caching kwargs, we will cache the whole call
|
||||
# If no parameters provider, try to infer them from function signature
|
||||
# If no parameters provided, try to infer them from function signature
|
||||
try:
|
||||
if not args_list and not kwargs_list:
|
||||
for pos, (param_name, param) in enumerate(inspect.signature(fnc).parameters.items()):
|
||||
@ -254,68 +257,71 @@ def cached(
|
||||
kwargs_list.append(param_name)
|
||||
# *args and **kwargs are not supported as cache parameters
|
||||
except Exception:
|
||||
logger.debug('Function %s is not inspectable, no caching possible', fnc.__name__)
|
||||
# Not inspectable, no caching possible, return original function
|
||||
return fnc
|
||||
|
||||
lkey_fnc: collections.abc.Callable[[str], str] = key_helper or (lambda x: fnc.__name__)
|
||||
key_helper_fnc: collections.abc.Callable[[typing.Any], str] = key_helper or (lambda x: fnc.__name__)
|
||||
|
||||
@functools.wraps(fnc)
|
||||
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
nonlocal hits, misses, exec_time
|
||||
with transaction.atomic(): # On its own transaction (for cache operations, that are on DB)
|
||||
cache_key: str = prefix
|
||||
for i in args_list:
|
||||
if i < len(args):
|
||||
cache_key += str(args[i])
|
||||
for s in kwargs_list:
|
||||
cache_key += str(kwargs.get(s, ''))
|
||||
# Append key data
|
||||
cache_key += lkey_fnc(args[0] if len(args) > 0 else fnc.__name__)
|
||||
|
||||
cache_key: str = prefix
|
||||
for i in args_list:
|
||||
if i < len(args):
|
||||
cache_key += str(args[i])
|
||||
for s in kwargs_list:
|
||||
cache_key += str(kwargs.get(s, ''))
|
||||
# Append key data
|
||||
cache_key += key_helper_fnc(args[0] if len(args) > 0 else fnc.__name__)
|
||||
|
||||
# Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long
|
||||
# Note tha this value (cache_key) will be hashed by cache, so it's not a problem if it's too long
|
||||
|
||||
# Get cache from object if present, or use the global 'functionCache' (generic, common to all objects)
|
||||
inner_cache: 'Cache|None' = getattr(args[0], 'cache', None) if len(args) > 0 else None
|
||||
cache: 'Cache' = inner_cache or Cache('functionCache')
|
||||
# Get cache from object if present, or use the global 'functionCache' (generic, common to all objects)
|
||||
inner_cache: 'Cache|None' = getattr(args[0], 'cache', None) if len(args) > 0 else None
|
||||
cache: 'Cache' = inner_cache or Cache('functionCache')
|
||||
|
||||
# if timeout is a function, call it
|
||||
ltimeout = timeout() if callable(timeout) else timeout
|
||||
# if timeout is a function, call it
|
||||
effective_timeout = timeout() if callable(timeout) else timeout
|
||||
|
||||
data: typing.Any = None
|
||||
# If misses is 0, we are starting, so we will not try to get from cache
|
||||
if not kwargs.get('force', False) and ltimeout > 0 and misses > 0:
|
||||
data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND)
|
||||
if data is not consts.cache.CACHE_NOT_FOUND:
|
||||
with lock:
|
||||
hits += 1
|
||||
CacheStats.manager().add_hit(exec_time // hits) # Use mean execution time
|
||||
return data
|
||||
data: typing.Any = None
|
||||
# If misses is 0, we are starting, so we will not try to get from cache
|
||||
if not kwargs.get('force', False) and effective_timeout > 0 and misses > 0:
|
||||
data = cache.get(cache_key, default=consts.cache.CACHE_NOT_FOUND)
|
||||
if data is not consts.cache.CACHE_NOT_FOUND:
|
||||
with lock:
|
||||
hits += 1
|
||||
CacheStats.manager().add_hit(exec_time // hits) # Use mean execution time
|
||||
return data
|
||||
|
||||
with lock:
|
||||
misses += 1
|
||||
CacheStats.manager().add_miss()
|
||||
with lock:
|
||||
misses += 1
|
||||
CacheStats.manager().add_miss()
|
||||
|
||||
if 'force' in kwargs:
|
||||
# Remove force key
|
||||
del kwargs['force']
|
||||
if 'force' in kwargs:
|
||||
# Remove force key
|
||||
del kwargs['force']
|
||||
|
||||
t = time.thread_time_ns()
|
||||
data = fnc(*args, **kwargs)
|
||||
exec_time += time.thread_time_ns() - t
|
||||
# Execute the function outside the DB transaction
|
||||
t = time.thread_time_ns()
|
||||
data = fnc(*args, **kwargs)
|
||||
exec_time += time.thread_time_ns() - t
|
||||
|
||||
try:
|
||||
# Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this
|
||||
cache.put(cache_key, data, ltimeout)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
'Data for %s is not serializable on call to %s, not cached. %s (%s)',
|
||||
cache_key,
|
||||
fnc.__name__,
|
||||
data,
|
||||
e,
|
||||
)
|
||||
return data
|
||||
try:
|
||||
# Maybe returned data is not serializable. In that case, cache will fail but no harm is done with this
|
||||
cache.put(cache_key, data, effective_timeout)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
'Data for %s is not serializable on call to %s, not cached. %s (%s)',
|
||||
cache_key,
|
||||
fnc.__name__,
|
||||
data,
|
||||
e,
|
||||
)
|
||||
return data
|
||||
|
||||
# Add a couple of methods to the wrapper to allow cache statistics access and cache clearing
|
||||
def cache_info() -> CacheInfo:
|
||||
"""Report cache statistics"""
|
||||
with lock:
|
||||
@ -331,28 +337,28 @@ def cached(
|
||||
wrapper.cache_info = cache_info # type: ignore
|
||||
wrapper.cache_clear = cache_clear # type: ignore
|
||||
|
||||
return typing.cast(FT, wrapper)
|
||||
return wrapper
|
||||
|
||||
return allow_cache_decorator
|
||||
|
||||
|
||||
# Decorator to execute method in a thread
|
||||
def threaded(func: FT) -> FT:
|
||||
def threaded(func: collections.abc.Callable[P, None]) -> collections.abc.Callable[P, None]:
|
||||
"""Decorator to execute method in a thread"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
|
||||
thread.start()
|
||||
|
||||
return typing.cast(FT, wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def blocker(
|
||||
request_attr: typing.Optional[str] = None,
|
||||
max_failures: typing.Optional[int] = None,
|
||||
ignore_block_config: bool = False,
|
||||
) -> collections.abc.Callable[[FT], FT]:
|
||||
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
|
||||
"""
|
||||
Decorator that will block the actor if it has more than ALLOWED_FAILS failures in BLOCK_ACTOR_TIME seconds
|
||||
GlobalConfig.BLOCK_ACTOR_FAILURES.getBool() --> If true, block actor after ALLOWED_FAILS failures
|
||||
@ -371,13 +377,13 @@ def blocker(
|
||||
from uds.core.util.cache import Cache # To avoid circular references
|
||||
from uds.core.util.config import GlobalConfig
|
||||
|
||||
blockCache = Cache('uds:blocker') # Cache for blocked ips
|
||||
mycache = Cache('uds:blocker') # Cache for blocked ips
|
||||
|
||||
max_failures = max_failures or consts.system.ALLOWED_FAILS
|
||||
|
||||
def decorator(f: FT) -> FT:
|
||||
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if not GlobalConfig.BLOCK_ACTOR_FAILURES.as_bool(True) and not ignore_block_config:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -392,33 +398,33 @@ def blocker(
|
||||
ip = request.ip
|
||||
|
||||
# if ip is blocked, raise exception
|
||||
failuresCount = blockCache.get(ip, 0)
|
||||
if failuresCount >= max_failures:
|
||||
failures_count = mycache.get(ip, 0)
|
||||
if failures_count >= max_failures:
|
||||
raise exceptions.rest.AccessDenied
|
||||
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
except uds.core.exceptions.rest.BlockAccess:
|
||||
# Increment
|
||||
blockCache.put(ip, failuresCount + 1, GlobalConfig.LOGIN_BLOCK.as_int())
|
||||
mycache.put(ip, failures_count + 1, GlobalConfig.LOGIN_BLOCK.as_int())
|
||||
raise exceptions.rest.AccessDenied
|
||||
# Any other exception will be raised
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# If we are here, it means that the call was successfull, so we reset the counter
|
||||
blockCache.delete(ip)
|
||||
mycache.delete(ip)
|
||||
|
||||
return result
|
||||
|
||||
return typing.cast(FT, wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def profile(
|
||||
log_file: typing.Optional[str] = None,
|
||||
) -> collections.abc.Callable[[FT], FT]:
|
||||
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
|
||||
"""
|
||||
Decorator that will profile the wrapped function and log the results to the provided file
|
||||
|
||||
@ -429,9 +435,11 @@ def profile(
|
||||
Decorator
|
||||
"""
|
||||
|
||||
def decorator(f: FT) -> FT:
|
||||
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
nonlocal log_file
|
||||
def decorator(f: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
nonlocal log_file # use outer log_file
|
||||
import cProfile
|
||||
import pstats
|
||||
import tempfile
|
||||
@ -446,6 +454,6 @@ def profile(
|
||||
stats.dump_stats(log_file)
|
||||
return result
|
||||
|
||||
return typing.cast(FT, wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
@ -44,7 +44,7 @@ from .linux_osmanager import LinuxOsManager
|
||||
from .linux_randompass_osmanager import LinuxRandomPassManager
|
||||
from .linux_ad_osmanager import LinuxOsADManager
|
||||
|
||||
_mypath = os.path.dirname(__spec__.origin)
|
||||
_mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
|
||||
# Old version, using spec is better, but we can use __package__ as well
|
||||
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore
|
||||
|
||||
|
@ -42,7 +42,7 @@ from .windows import WindowsOsManager
|
||||
from .windows_domain import WinDomainOsManager
|
||||
from .windows_random import WinRandomPassManager
|
||||
|
||||
_mypath = os.path.dirname(__spec__.origin)
|
||||
_mypath = os.path.dirname(__spec__.origin) # type: ignore[name-defined] # mypy incorrectly report __spec__ as not beind defined
|
||||
# Old version, using spec is better, but we can use __package__ as well
|
||||
#_mypath = os.path.dirname(typing.cast(str, sys.modules[__package__].__file__)) # pyright: ignore
|
||||
|
||||
|
@ -64,20 +64,26 @@ VOLUMES_ENDPOINT_TYPES = [
|
||||
COMPUTE_ENDPOINT_TYPES = ['compute', 'compute_legacy']
|
||||
NETWORKS_ENDPOINT_TYPES = ['network']
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
P = typing.ParamSpec('P')
|
||||
|
||||
|
||||
# Decorators
|
||||
def auth_required(for_project: bool = False) -> collections.abc.Callable[[decorators.FT], decorators.FT]:
|
||||
def auth_required(
|
||||
for_project: bool = False,
|
||||
) -> collections.abc.Callable[[collections.abc.Callable[P, T]], collections.abc.Callable[P, T]]:
|
||||
|
||||
def decorator(func: decorators.FT) -> decorators.FT:
|
||||
def decorator(func: collections.abc.Callable[P, T]) -> collections.abc.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(obj: 'OpenstackClient', *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> typing.Any:
|
||||
obj = typing.cast('OpenstackClient', args[0])
|
||||
if for_project is True:
|
||||
if obj._projectid is None:
|
||||
raise Exception('Need a project for method {}'.format(func))
|
||||
obj.ensure_authenticated()
|
||||
return func(obj, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return typing.cast(decorators.FT, wrapper)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user