From 485520f402d4f67ba3d0f0e796f17c71c69d7b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Fri, 26 Jan 2024 01:30:40 +0100 Subject: [PATCH] Remodeled cache updater, minor fixes and more refactoring --- server/src/uds/REST/methods/actor_v3.py | 34 +-- server/src/uds/REST/methods/services.py | 2 +- .../src/uds/auths/RegexLdap/authenticator.py | 28 +- .../src/uds/auths/SimpleLDAP/authenticator.py | 45 +-- .../uds/core/managers/servers_api/events.py | 12 +- server/src/uds/core/managers/user_service.py | 25 +- .../core/managers/userservice/opchecker.py | 64 ++-- server/src/uds/core/serializable.py | 4 +- server/src/uds/core/services/publication.py | 8 +- server/src/uds/core/services/service.py | 46 +-- server/src/uds/core/services/user_service.py | 2 +- server/src/uds/core/ui/user_interface.py | 1 + server/src/uds/core/util/auto_attributes.py | 4 +- server/src/uds/core/util/auto_serializable.py | 34 ++- server/src/uds/core/util/fields.py | 66 ++-- .../workers/servicepools_cache_updater.py | 285 ++++++++++-------- server/src/uds/management/commands/tree.py | 4 +- server/src/uds/models/__init__.py | 2 +- .../models/{os_manager.py => osmanager.py} | 0 server/src/uds/models/service_pool.py | 4 +- .../uds/models/service_pool_publication.py | 27 +- server/src/uds/models/user_service.py | 57 ++-- server/src/uds/services/OpenGnsys/service.py | 5 +- .../PhysicalMachines/service_multi.py | 2 +- server/src/uds/services/Test/provider.py | 2 +- .../auths/regex_ldap/test_serialization.py | 43 ++- server/tests/auths/simple_ldap/__init__.py | 0 .../auths/simple_ldap/test_serialization.py | 124 ++++++++ .../test_servicepools_cache_updater.py | 28 +- 29 files changed, 572 insertions(+), 386 deletions(-) rename server/src/uds/models/{os_manager.py => osmanager.py} (100%) create mode 100644 server/tests/auths/simple_ldap/__init__.py create mode 100644 server/tests/auths/simple_ldap/test_serialization.py diff --git a/server/src/uds/REST/methods/actor_v3.py b/server/src/uds/REST/methods/actor_v3.py index f2f4a4bfe..436e24cca 100644 --- a/server/src/uds/REST/methods/actor_v3.py +++ b/server/src/uds/REST/methods/actor_v3.py @@ -446,9 +446,9 @@ class Initialize(ActorV3Action): # Set last seen actor version userService.actor_version = self._params['version'] osData: collections.abc.MutableMapping[str, typing.Any] = {} - osManager = userService.get_osmanager_instance() - if osManager: - osData = osManager.actor_data(userService) + osmanager = userService.get_osmanager_instance() + if osmanager: + osData = osmanager.actor_data(userService) if service and not alias_token: # Is a service managed by UDS # Create a new alias for it, and save @@ -501,11 +501,11 @@ class BaseReadyChange(ActorV3Action): if userService.os_state != State.USABLE: userService.setOsState(State.USABLE) - # Notify osManager or readyness if has os manager - osManager = userService.get_osmanager_instance() + # Notify osmanager or readyness if has os manager + osmanager = userService.get_osmanager_instance() - if osManager: - osManager.to_ready(userService) + if osmanager: + osmanager.to_ready(userService) UserServiceManager().notify_ready_from_os_manager(userService, '') # Generates a certificate and send it to client. @@ -590,10 +590,10 @@ class Login(ActorV3Action): @staticmethod def process_login(userservice: UserService, username: str) -> typing.Optional[osmanagers.OSManager]: - osManager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance() + osmanager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance() if not userservice.in_use: # If already logged in, do not add a second login (windows does this i.e.) osmanagers.OSManager.logged_in(userservice, username) - return osManager + return osmanager def action(self) -> dict[str, typing.Any]: isManaged = self._params.get('type') != consts.actor.UNMANAGED @@ -605,17 +605,17 @@ class Login(ActorV3Action): try: userservice: UserService = self.get_userservice() - os_manager = Login.process_login(userservice, self._params.get('username') or '') + osmanager = Login.process_login(userservice, self._params.get('username') or '') - max_idle = os_manager.max_idle() if os_manager else None + max_idle = osmanager.max_idle() if osmanager else None logger.debug('Max idle: %s', max_idle) src = userservice.getConnectionSource() session_id = userservice.start_session() # creates a session for every login requested - if os_manager: # For os managed services, let's check if we honor deadline - if os_manager.ignore_deadline(): + if osmanager: # For os managed services, let's check if we honor deadline + if osmanager.ignore_deadline(): deadline = userservice.deployed_service.get_deadline() else: deadline = None @@ -653,7 +653,7 @@ class Logout(ActorV3Action): """ This method is static so can be invoked from elsewhere """ - osManager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance() + osmanager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance() # Close session # For compat, we have taken '' as "all sessions" @@ -661,9 +661,9 @@ class Logout(ActorV3Action): if userservice.in_use: # If already logged out, do not add a second logout (windows does this i.e.) osmanagers.OSManager.logged_out(userservice, username) - if osManager: - if osManager.is_removable_on_logout(userservice): - logger.debug('Removable on logout: %s', osManager) + if osmanager: + if osmanager.is_removable_on_logout(userservice): + logger.debug('Removable on logout: %s', osmanager) userservice.remove() else: userservice.remove() diff --git a/server/src/uds/REST/methods/services.py b/server/src/uds/REST/methods/services.py index 4f483a09e..92577e62e 100644 --- a/server/src/uds/REST/methods/services.py +++ b/server/src/uds/REST/methods/services.py @@ -72,7 +72,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods return { 'icon': info.icon64().replace('\n', ''), 'needs_publication': info.publication_type is not None, - 'max_deployed': info.max_user_services, + 'max_deployed': info.userservices_limit, 'uses_cache': info.uses_cache and info.overrided_fields is None, 'uses_cache_l2': info.uses_cache_l2, 'cache_tooltip': _(info.cache_tooltip), diff --git a/server/src/uds/auths/RegexLdap/authenticator.py b/server/src/uds/auths/RegexLdap/authenticator.py index bc016a42b..ba3aaa551 100644 --- a/server/src/uds/auths/RegexLdap/authenticator.py +++ b/server/src/uds/auths/RegexLdap/authenticator.py @@ -98,23 +98,10 @@ class RegexLdap(auths.Authenticator): required=True, tab=types.ui.Tab.CREDENTIALS, ) - timeout = gui.NumericField( - length=3, - label=_('Timeout'), - default=10, - order=6, - tooltip=_('Timeout in seconds of connection to LDAP'), - required=True, - min_value=1, - ) - verify_ssl = gui.CheckBoxField( - label=_('Verify SSL'), - default=True, - order=11, - tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'), - tab=types.ui.Tab.ADVANCED, - old_field_name='verifySsl', - ) + + timeout = fields.timeout_field(tab=False, default=10) # Use "main tab" + verify_ssl = fields.verify_ssl_field(order=11) + certificate = gui.TextField( length=8192, lines=4, @@ -123,7 +110,6 @@ class RegexLdap(auths.Authenticator): tooltip=_('Certificate to use for SSL verification'), required=False, tab=types.ui.Tab.ADVANCED, - old_field_name='certificate', ) ldap_base = gui.TextField( length=64, @@ -132,7 +118,6 @@ class RegexLdap(auths.Authenticator): tooltip=_('Common search base (used for "users" and "groups")'), required=True, tab=_('Ldap info'), - old_field_name='ldapBase', ) user_class = gui.TextField( length=64, @@ -142,7 +127,6 @@ class RegexLdap(auths.Authenticator): tooltip=_('Class for LDAP users (normally posixAccount)'), required=True, tab=_('Ldap info'), - old_field_name='userClass', ) userid_attr = gui.TextField( length=64, @@ -152,7 +136,6 @@ class RegexLdap(auths.Authenticator): tooltip=_('Attribute that contains the user id.'), required=True, tab=_('Ldap info'), - old_field_name='userIdAttr', ) username_attr = gui.TextField( length=640, @@ -165,7 +148,6 @@ class RegexLdap(auths.Authenticator): ), required=True, tab=_('Ldap info'), - old_field_name='userNameAttr', ) groupname_attr = gui.TextField( length=640, @@ -178,7 +160,6 @@ class RegexLdap(auths.Authenticator): ), required=True, tab=_('Ldap info'), - old_field_name='groupNameAttr', ) # regex = gui.TextField(length=64, label = _('Regular Exp. for groups'), defvalue = '^(.*)', order = 12, tooltip = _('Regular Expression to extract the group name'), required = True) @@ -190,7 +171,6 @@ class RegexLdap(auths.Authenticator): tooltip=_('Class for LDAP objects that will be also checked for groups retrieval (normally empty)'), required=False, tab=_('Advanced'), - old_field_name='altClass', ) mfa_attribute = fields.mfa_attr_field() diff --git a/server/src/uds/auths/SimpleLDAP/authenticator.py b/server/src/uds/auths/SimpleLDAP/authenticator.py index 584b94813..fcec362dd 100644 --- a/server/src/uds/auths/SimpleLDAP/authenticator.py +++ b/server/src/uds/auths/SimpleLDAP/authenticator.py @@ -1,6 +1,6 @@ # pylint: disable=no-member # ldap module gives errors to pylint # -# Copyright (c) 2012-2021 Virtual Cable S.L.U. +# Copyright (c) 2024 Virtual Cable S.L.U. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ''' -@author: Adolfo Gómez, dkmaster at dkmon dot com +Author: Adolfo Gómez, dkmaster at dkmon dot com ''' import logging import typing @@ -40,7 +40,7 @@ from django.utils.translation import gettext_noop as _ from uds.core import auths, types, consts, exceptions from uds.core.auths.auth import log_login from uds.core.ui import gui -from uds.core.util import ldaputil +from uds.core.util import fields, ldaputil # Not imported at runtime, just for type checking if typing.TYPE_CHECKING: @@ -68,7 +68,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): tooltip=_('Ldap port (usually 389 for non ssl and 636 for ssl)'), required=True, ) - ssl = gui.CheckBoxField( + use_ssl = gui.CheckBoxField( label=_('Use SSL'), order=3, tooltip=_('If checked, the connection will be ssl, using port 636 instead of 389'), @@ -89,23 +89,10 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=types.ui.Tab.CREDENTIALS, ) - timeout = gui.NumericField( - length=3, - label=_('Timeout'), - default=10, - order=10, - tooltip=_('Timeout in seconds of connection to LDAP'), - required=True, - min_value=1, - tab=types.ui.Tab.ADVANCED, - ) - verifySsl = gui.CheckBoxField( - label=_('Verify SSL'), - default=True, - order=11, - tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'), - tab=types.ui.Tab.ADVANCED, - ) + + timeout = fields.timeout_field(tab=False, default=10) # Use "main tab" + verify_ssl = fields.verify_ssl_field(order=11) + certificate = gui.TextField( length=8192, lines=4, @@ -116,7 +103,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): tab=types.ui.Tab.ADVANCED, ) - ldapBase = gui.TextField( + ldap_base = gui.TextField( length=64, label=_('Base'), order=30, @@ -124,7 +111,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - userClass = gui.TextField( + user_class = gui.TextField( length=64, label=_('User class'), default='posixAccount', @@ -133,7 +120,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - userIdAttr = gui.TextField( + user_id_attr = gui.TextField( length=64, label=_('User Id Attr'), default='uid', @@ -142,7 +129,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - userNameAttr = gui.TextField( + username_attr = gui.TextField( length=64, label=_('User Name Attr'), default='uid', @@ -151,7 +138,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - groupClass = gui.TextField( + group_class = gui.TextField( length=64, label=_('Group class'), default='posixGroup', @@ -160,7 +147,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - groupIdAttr = gui.TextField( + group_id_attr = gui.TextField( length=64, label=_('Group Id Attr'), default='cn', @@ -169,7 +156,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - memberAttr = gui.TextField( + member_attr = gui.TextField( length=64, label=_('Group membership attr'), default='memberUid', @@ -178,7 +165,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator): required=True, tab=_('Ldap info'), ) - mfaAttr = gui.TextField( + mfa_attribute = gui.TextField( length=2048, lines=2, label=_('MFA attribute'), diff --git a/server/src/uds/core/managers/servers_api/events.py b/server/src/uds/core/managers/servers_api/events.py index 59ee73011..d667aefb0 100644 --- a/server/src/uds/core/managers/servers_api/events.py +++ b/server/src/uds/core/managers/servers_api/events.py @@ -99,13 +99,13 @@ def process_login(server: 'models.Server', data: dict[str, typing.Any]) -> typin src = userService.getConnectionSource() session_id = userService.start_session() # creates a session for every login requested - osManager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance() - maxIdle = osManager.max_idle() if osManager else None + osmanager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance() + maxIdle = osmanager.max_idle() if osmanager else None logger.debug('Max idle: %s', maxIdle) deadLine = ( - userService.deployed_service.get_deadline() if not osManager or osManager.ignore_deadline() else None + userService.deployed_service.get_deadline() if not osmanager or osmanager.ignore_deadline() else None ) result = { 'ip': src.ip, @@ -138,9 +138,9 @@ def process_logout(server: 'models.Server', data: dict[str, typing.Any]) -> typi if userService.in_use: # If already logged out, do not add a second logout (windows does this i.e.) osmanagers.OSManager.logged_out(userService, data['username']) - osManager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance() - if not osManager or osManager.is_removable_on_logout(userService): - logger.debug('Removable on logout: %s', osManager) + osmanager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance() + if not osmanager or osmanager.is_removable_on_logout(userService): + logger.debug('Removable on logout: %s', osmanager) userService.remove() return rest_result(consts.OK) diff --git a/server/src/uds/core/managers/user_service.py b/server/src/uds/core/managers/user_service.py index 4fae1c833..b71f7b132 100644 --- a/server/src/uds/core/managers/user_service.py +++ b/server/src/uds/core/managers/user_service.py @@ -86,7 +86,7 @@ class UserServiceManager(metaclass=singleton.Singleton): states = [State.PREPARING, State.USABLE, State.REMOVING, State.REMOVABLE] return Q(state__in=states) - def _check_if_max_user_services_reached(self, service_pool: ServicePool) -> None: + def _check_user_services_limit_reached(self, service_pool: ServicePool) -> None: """ Checks if max_user_services for the service has been reached, and, if so, raises an exception that no more services of this kind can be reached @@ -110,10 +110,10 @@ class UserServiceManager(metaclass=singleton.Singleton): """ serviceInstance = service.get_instance() # Early return, so no database count is needed - if serviceInstance.max_user_services == consts.UNLIMITED: + if serviceInstance.userservices_limit == consts.UNLIMITED: return False - if self.get_existing_user_services(service) >= serviceInstance.max_user_services: + if self.get_existing_user_services(service) >= serviceInstance.userservices_limit: return True return False @@ -123,7 +123,7 @@ class UserServiceManager(metaclass=singleton.Singleton): Private method to instatiate a cache element at database with default states """ # Checks if max_user_services has been reached and if so, raises an exception - self._check_if_max_user_services_reached(publication.deployed_service) + self._check_user_services_limit_reached(publication.deployed_service) now = sql_datetime() return publication.userServices.create( cache_level=cacheLevel, @@ -141,7 +141,7 @@ class UserServiceManager(metaclass=singleton.Singleton): """ Private method to instatiate an assigned element at database with default state """ - self._check_if_max_user_services_reached(publication.deployed_service) + self._check_user_services_limit_reached(publication.deployed_service) now = sql_datetime() return publication.userServices.create( cache_level=0, @@ -161,7 +161,7 @@ class UserServiceManager(metaclass=singleton.Singleton): There is cases where deployed services do not have publications (do not need them), so we need this method to create an UserService with no publications, and create them from an ServicePool """ - self._check_if_max_user_services_reached(service_pool) + self._check_user_services_limit_reached(service_pool) now = sql_datetime() return service_pool.userServices.create( cache_level=0, @@ -296,7 +296,7 @@ class UserServiceManager(metaclass=singleton.Singleton): # Data will be serialized on makeUnique process UserServiceOpChecker.make_unique(cache, cacheInstance, state) - def cancel(self, user_service: UserService) -> UserService: + def cancel(self, user_service: UserService) -> None: """ Cancels an user service creation @return: the Uservice canceling @@ -325,9 +325,8 @@ class UserServiceManager(metaclass=singleton.Singleton): # opchecker will set state to "removable" UserServiceOpChecker.make_unique(user_service, user_service_instance, state) - return user_service - def remove(self, userservice: UserService) -> UserService: + def remove(self, userservice: UserService) -> None: """ Removes a uService element """ @@ -347,9 +346,7 @@ class UserServiceManager(metaclass=singleton.Singleton): # Data will be serialized on makeUnique process UserServiceOpChecker.make_unique(userservice, userServiceInstance, state) - return userservice - - def remove_or_cancel(self, user_service: UserService): + def remove_or_cancel(self, user_service: UserService) -> None: if user_service.is_usable() or State.from_str(user_service.state).is_removable(): return self.remove(user_service) @@ -629,9 +626,9 @@ class UserServiceManager(metaclass=singleton.Singleton): This method is used by UserService when a request for setInUse(False) is made This checks that the service can continue existing or not """ - osManager = user_service.deployed_service.osmanager + osmanager = user_service.deployed_service.osmanager # If os manager says "machine is persistent", do not try to delete "previous version" assigned machines - doPublicationCleanup = True if not osManager else not osManager.get_instance().is_persistent() + doPublicationCleanup = True if not osmanager else not osmanager.get_instance().is_persistent() if doPublicationCleanup: remove = False diff --git a/server/src/uds/core/managers/userservice/opchecker.py b/server/src/uds/core/managers/userservice/opchecker.py index 0b664a0c2..087cbf068 100644 --- a/server/src/uds/core/managers/userservice/opchecker.py +++ b/server/src/uds/core/managers/userservice/opchecker.py @@ -117,17 +117,17 @@ class StateUpdater: class UpdateFromPreparing(StateUpdater): def check_os_manager_related(self) -> str: - osManager = self.user_service_instance.osmanager() + osmanager = self.user_service_instance.osmanager() state = State.USABLE # and make this usable if os manager says that it is usable, else it pass to configuring state # This is an "early check" for os manager, so if we do not have os manager, or os manager # already notifies "ready" for this, we - if osManager is not None and State.from_str(self.user_service.os_state).is_preparing(): + if osmanager is not None and State.from_str(self.user_service.os_state).is_preparing(): logger.debug('Has valid osmanager for %s', self.user_service.friendly_name) - stateOs = osManager.check_state(self.user_service) + stateOs = osmanager.check_state(self.user_service) else: stateOs = State.FINISHED @@ -175,18 +175,18 @@ class UpdateFromPreparing(StateUpdater): class UpdateFromRemoving(StateUpdater): def finish(self): - osManager = self.user_service_instance.osmanager() - if osManager is not None: - osManager.release(self.user_service) + osmanager = self.user_service_instance.osmanager() + if osmanager is not None: + osmanager.release(self.user_service) self.save(State.REMOVED) class UpdateFromCanceling(StateUpdater): def finish(self): - osManager = self.user_service_instance.osmanager() - if osManager is not None: - osManager.release(self.user_service) + osmanager = self.user_service_instance.osmanager() + if osmanager is not None: + osmanager.release(self.user_service) self.save(State.CANCELED) @@ -203,33 +203,35 @@ class UserServiceOpChecker(DelayedTask): """ This is the delayed task responsible of executing the service tasks and the service state transitions """ + _svrId: int + _state: str - def __init__(self, service): + def __init__(self, service: 'UserService'): super().__init__() self._svrId = service.id self._state = service.state @staticmethod - def make_unique(userService: UserService, userServiceInstance: services.UserService, state: str): + def make_unique(userservice: UserService, userservice_instance: services.UserService, state: str): """ This method ensures that there will be only one delayedtask related to the userService indicated """ - DelayedTaskRunner.runner().remove(USERSERVICE_TAG + userService.uuid) - UserServiceOpChecker.state_updater(userService, userServiceInstance, state) + DelayedTaskRunner.runner().remove(USERSERVICE_TAG + userservice.uuid) + UserServiceOpChecker.state_updater(userservice, userservice_instance, state) @staticmethod - def state_updater(userService: UserService, userServiceInstance: services.UserService, state: str): + def state_updater(userservice: UserService, userservice_instance: services.UserService, state: str): """ Checks the value returned from invocation to publish or checkPublishingState, updating the servicePoolPub database object Return True if it has to continue checking, False if finished """ try: # Fills up basic data - userService.unique_id = userServiceInstance.get_unique_id() # Updates uniqueId - userService.friendly_name = ( - userServiceInstance.get_name() + userservice.unique_id = userservice_instance.get_unique_id() # Updates uniqueId + userservice.friendly_name = ( + userservice_instance.get_name() ) # And name, both methods can modify serviceInstance, so we save it later - userService.save(update_fields=['unique_id', 'friendly_name']) + userservice.save(update_fields=['unique_id', 'friendly_name']) updater = typing.cast( type[StateUpdater], @@ -237,39 +239,39 @@ class UserServiceOpChecker(DelayedTask): State.PREPARING: UpdateFromPreparing, State.REMOVING: UpdateFromRemoving, State.CANCELING: UpdateFromCanceling, - }.get(State.from_str(userService.state), UpdateFromOther), + }.get(State.from_str(userservice.state), UpdateFromOther), ) logger.debug( 'Updating %s from %s with updater %s and state %s', - userService.friendly_name, - State.from_str(userService.state).literal, + userservice.friendly_name, + State.from_str(userservice.state).literal, updater, state, ) - updater(userService, userServiceInstance).run(state) + updater(userservice, userservice_instance).run(state) except Exception as e: logger.exception('Checking service state') - log.log(userService, log.LogLevel.ERROR, f'Exception: {e}', log.LogSource.INTERNAL) - userService.set_state(State.ERROR) - userService.save(update_fields=['data']) + log.log(userservice, log.LogLevel.ERROR, f'Exception: {e}', log.LogSource.INTERNAL) + userservice.set_state(State.ERROR) + userservice.save(update_fields=['data']) @staticmethod - def check_later(userService, ci): + def check_later(userservice: 'UserService', instance: 'services.UserService'): """ - Inserts a task in the delayedTaskRunner so we can check the state of this publication + Inserts a task in the delayedTaskRunner so we can check the state of this service later @param dps: Database object for ServicePoolPublication @param pi: Instance of Publication manager for the object """ # Do not add task if already exists one that updates this service - if DelayedTaskRunner.runner().tag_exists(USERSERVICE_TAG + userService.uuid): + if DelayedTaskRunner.runner().tag_exists(USERSERVICE_TAG + userservice.uuid): return DelayedTaskRunner.runner().insert( - UserServiceOpChecker(userService), - ci.suggested_delay, - USERSERVICE_TAG + userService.uuid, + UserServiceOpChecker(userservice), + instance.suggested_delay, + USERSERVICE_TAG + userservice.uuid, ) def run(self) -> None: diff --git a/server/src/uds/core/serializable.py b/server/src/uds/core/serializable.py index f7fbf734e..c57c91674 100644 --- a/server/src/uds/core/serializable.py +++ b/server/src/uds/core/serializable.py @@ -38,7 +38,7 @@ class Serializable: """ This class represents the interface that all serializable objects must provide. - Every single serializable class must implement marshall & unmarshall methods. Also, the class must allow + Every single serializable class must implement marshal & unmarshal methods. Also, the class must allow to be initialized without parameters, so we can: - Initialize the object with default values - Read values from seralized data @@ -109,7 +109,7 @@ class Serializable: self.unmarshal(base64.b64decode(data)) # For remarshalling purposes - # These allows us to faster migration of old data formats to new ones + # These facilitates a faster migration of old data formats to new ones # alowing us to remove old format support as soon as possible def flag_for_upgrade(self, value: bool = True) -> None: """ diff --git a/server/src/uds/core/services/publication.py b/server/src/uds/core/services/publication.py index e82c18cf9..9a889d37e 100644 --- a/server/src/uds/core/services/publication.py +++ b/server/src/uds/core/services/publication.py @@ -100,10 +100,10 @@ class Publication(Environmentable, Serializable): """ Environmentable.__init__(self, environment) Serializable.__init__(self) - self._osManager = kwargs.get('osManager', None) + self._osManager = kwargs.get('osmanager', None) self._service = kwargs['service'] # Raises an exception if service is not included self._revision = kwargs.get('revision', -1) - self._servicepool_name = kwargs.get('dsName', 'Unknown') + self._servicepool_name = kwargs.get('servicepool_name', 'Unknown') self._uuid = kwargs.get('uuid', '') self.initialize() @@ -114,7 +114,7 @@ class Publication(Environmentable, Serializable): This is provided so you don't have to provide your own __init__ method, and invoke base class __init__. This will get invoked when all initialization stuff is done, so - you can here access service, osManager, ... + you can here access service, osmanager, ... """ def db_obj(self) -> 'models.ServicePoolPublication': @@ -137,7 +137,7 @@ class Publication(Environmentable, Serializable): """ return self._service - def os_manager(self) -> typing.Optional['osmanagers.OSManager']: + def osmanager(self) -> typing.Optional['osmanagers.OSManager']: """ Utility method to access os manager for this publication. diff --git a/server/src/uds/core/services/service.py b/server/src/uds/core/services/service.py index 6b550a01c..acd5423b4 100644 --- a/server/src/uds/core/services/service.py +++ b/server/src/uds/core/services/service.py @@ -124,11 +124,11 @@ class Service(Module): # : for providing user services. This attribute can be set here or # : modified at instance level, core will access always to it using an instance object. # : Note: you can override this value on service instantiation by providing a "maxService": - # : - If maxServices is an integer, it will be used as max_user_services - # : - If maxServices is a gui.NumericField, it will be used as max_user_services (.num() will be called) - # : - If maxServices is a callable, it will be called and the result will be used as max_user_services - # : - If maxServices is None, max_user_services will be set to consts.UNLIMITED (as default) - max_user_services: int = consts.UNLIMITED + # : - If userservices_limit is an integer, it will be used as max_user_services + # : - If userservices_limit is a gui.NumericField, it will be used as max_user_services (.num() will be called) + # : - If userservices_limit is a callable, it will be called and the result will be used as max_user_services + # : - If userservices_limit is None, max_user_services will be set to consts.UNLIMITED (as default) + userservices_limit: int = consts.UNLIMITED # : If this item "has overrided fields", on deployed service edition, defined keys will overwrite defined ones # : That is, this Dicionary will OVERWRITE fields ON ServicePool (normally cache related ones) dictionary from a REST api save invocation!! @@ -271,34 +271,34 @@ class Service(Module): return True def unmarshal(self, data: bytes) -> None: - # In fact, we will not unmarshall anything here, but setup maxDeployed - # if maxServices exists and it is a gui.NumericField - # Invoke base unmarshall, so "gui fields" gets loaded from data + # In fact, we will not unmarshal anything here, but setup maxDeployed + # if services_limit exists and it is a gui.NumericField + # Invoke base unmarshal, so "gui fields" gets loaded from data super().unmarshal(data) - if hasattr(self, 'maxServices'): + if hasattr(self, 'services_limit'): # Fix self "max_user_services" value after loading fields try: - maxServices = getattr(self, 'maxServices', None) - if isinstance(maxServices, int): - self.max_user_services = maxServices - elif isinstance(maxServices, gui.NumericField): - self.max_user_services = maxServices.as_int() + services_limit = getattr(self, 'services_limit', None) + if isinstance(services_limit, int): + self.userservices_limit = services_limit + elif isinstance(services_limit, gui.NumericField): + self.userservices_limit = services_limit.as_int() # For 0 values on max_user_services field, we will set it to UNLIMITED - if self.max_user_services == 0: - self.max_user_services = consts.UNLIMITED - elif callable(maxServices): - self.max_user_services = maxServices() + if self.userservices_limit == 0: + self.userservices_limit = consts.UNLIMITED + elif callable(services_limit): + self.userservices_limit = services_limit() else: - self.max_user_services = consts.UNLIMITED + self.userservices_limit = consts.UNLIMITED except Exception: - self.max_user_services = consts.UNLIMITED + self.userservices_limit = consts.UNLIMITED # Ensure that max_user_services is not negative - if self.max_user_services < 0: - self.max_user_services = consts.UNLIMITED + if self.userservices_limit < 0: + self.userservices_limit = consts.UNLIMITED - # Keep untouched if maxServices is not present + # Keep untouched if services_limit is not present def user_services_for_assignation(self, **kwargs) -> collections.abc.Iterable['UserService']: """ diff --git a/server/src/uds/core/services/user_service.py b/server/src/uds/core/services/user_service.py index 8f0fb0ab1..90f1b7549 100644 --- a/server/src/uds/core/services/user_service.py +++ b/server/src/uds/core/services/user_service.py @@ -164,7 +164,7 @@ class UserService(Environmentable, Serializable): This is provided so you don't have to provide your own __init__ method, and invoke base class __init__. This will get invoked when all initialization stuff is done, so - you can here access publication, service, osManager, ... + you can here access publication, service, osmanager, ... """ def db_obj(self) -> 'models.UserService': diff --git a/server/src/uds/core/ui/user_interface.py b/server/src/uds/core/ui/user_interface.py index 81d6bc337..e5b5fcb61 100644 --- a/server/src/uds/core/ui/user_interface.py +++ b/server/src/uds/core/ui/user_interface.py @@ -1488,6 +1488,7 @@ class UserInterface(metaclass=UserInterfaceAbstract): # Any unexpected type will raise an exception # Note that currently, we will store old field name on db # to allow "backwards" migration if needed, but will be removed on a future version + # this allows even several name changes, as long as we keep "old_field_name" as original one... fields = [ (field.old_field_name() or field_name, field.type.name, FIELDS_ENCODERS[field.type](field)) for field_name, field in self._gui.items() diff --git a/server/src/uds/core/util/auto_attributes.py b/server/src/uds/core/util/auto_attributes.py index ff176b282..1c21a2b0d 100644 --- a/server/src/uds/core/util/auto_attributes.py +++ b/server/src/uds/core/util/auto_attributes.py @@ -77,8 +77,8 @@ class AutoAttributes(Serializable): attrs: collections.abc.MutableMapping[str, Attribute] def __init__(self, **kwargs): + self.attrs = {} # Ensure attrs is created BEFORE calling super, that can contain _ variables Serializable.__init__(self) - self.attrs = {} self.declare(**kwargs) def __getattribute__(self, name) -> typing.Any: @@ -97,7 +97,7 @@ class AutoAttributes(Serializable): for key, typ in kwargs.items(): d[key] = Attribute(typ) self.attrs = d - + def marshal(self) -> bytes: return b'v1' + pickle.dumps(self.attrs) diff --git a/server/src/uds/core/util/auto_serializable.py b/server/src/uds/core/util/auto_serializable.py index 4652447ac..910ce8b38 100644 --- a/server/src/uds/core/util/auto_serializable.py +++ b/server/src/uds/core/util/auto_serializable.py @@ -60,6 +60,8 @@ from cryptography import fernet from django.conf import settings from requests import get +from uds.core.serializable import Serializable + # pylint: disable=too-few-public-methods class _Unassigned: @@ -77,9 +79,9 @@ logger = logging.getLogger(__name__) # Constants # Headers for the serialized data -HEADER_BASE: typing.Final[bytes] = b'MGB1' -HEADER_COMPRESSED: typing.Final[bytes] = b'MGZ1' -HEADER_ENCRYPTED: typing.Final[bytes] = b'MGE1' +HEADER_BASE: typing.Final[bytes] = b'MGBAS1' +HEADER_COMPRESSED: typing.Final[bytes] = b'MGZAS1' +HEADER_ENCRYPTED: typing.Final[bytes] = b'MGEAS1' # Size of crc32 checksum CRC_SIZE: typing.Final[int] = 4 @@ -100,6 +102,19 @@ def fernet_key(crypt_key: bytes) -> str: # Generate an URL-Safe base64 encoded 32 bytes key for Fernet return base64.b64encode(hashlib.sha256(crypt_key).digest()).decode() +# checker for autoserializable data +def is_autoserializable_data(data: bytes) -> bool: + """Check if data is is from an autoserializable class + + Args: + data: Data to check + + Returns: + True if data is autoserializable, False otherwise + """ + return data[: len(HEADER_BASE)] == HEADER_BASE + + # pylint: disable=unnecessary-dunder-call class _SerializableField(typing.Generic[T]): @@ -212,16 +227,16 @@ class BoolField(_SerializableField[bool]): self.__set__(instance, data == b'1') -class ListField(_SerializableField[list]): +class ListField(_SerializableField[list[T]]): """List field Note: - All elements in the list must be serializable. + All elements in the list must be serializable in JSON, but can be of different types. """ def __init__( self, - default: typing.Union[list, collections.abc.Callable[[], list]] = lambda: [], + default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [], ): super().__init__(list, default) @@ -323,7 +338,7 @@ class _FieldNameSetter(type): return super().__new__(mcs, name, bases, attrs) -class AutoSerializable(metaclass=_FieldNameSetter): +class AutoSerializable(Serializable, metaclass=_FieldNameSetter): """This class allows the automatic serialization of fields in a class. Example: @@ -366,7 +381,6 @@ class AutoSerializable(metaclass=_FieldNameSetter): """ return bytes(a ^ b for a, b in zip(data, itertools.cycle(header))) - @typing.final def marshal(self) -> bytes: # Iterate over own members and extract fields fields = {} @@ -396,8 +410,8 @@ class AutoSerializable(metaclass=_FieldNameSetter): # Return data processed with header return header + self.process_data(header, data) - # final method, do not override - @typing.final + # Only override this for checking if data is valid + # and, alternatively, retrieve it from a different source def unmarshal(self, data: bytes) -> None: # Check header if data[: len(HEADER_BASE)] != HEADER_BASE: diff --git a/server/src/uds/core/util/fields.py b/server/src/uds/core/util/fields.py index e89cca5a3..64968cd70 100644 --- a/server/src/uds/core/util/fields.py +++ b/server/src/uds/core/util/fields.py @@ -191,7 +191,9 @@ def get_certificates_from_field( # Timeout def timeout_field( default: int = 3, - order: int = 90, tab: 'types.ui.Tab|str|None' = None, old_field_name: typing.Optional[str] = None + order: int = 90, + tab: 'types.ui.Tab|str|None|bool' = None, + old_field_name: typing.Optional[str] = None, ) -> ui.gui.NumericField: return ui.gui.NumericField( length=3, @@ -201,50 +203,49 @@ def timeout_field( tooltip=_('Timeout in seconds for network connections'), required=True, min_value=1, - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name=old_field_name, ) -# Ssl verification + +# Ssl verification def verify_ssl_field( default: bool = True, - order: int = 92, tab: 'types.ui.Tab|str|None' = None, old_field_name: typing.Optional[str] = None -) -> ui.gui.CheckBoxField: - return ui.gui.CheckBoxField( + order: int = 92, + tab: 'types.ui.Tab|str|None|bool' = None, + old_field_name: typing.Optional[str] = None, +) -> ui.gui.CheckBoxField: + return ui.gui.CheckBoxField( label=_('Verify SSL'), default=default, order=order, - tooltip=_( - 'If checked, SSL verification will be enforced. If not, SSL verification will be disabled' - ), - tab=tab or types.ui.Tab.ADVANCED, + tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'), + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name=old_field_name, ) - + # Basename field -def basename_field(order: int = 32, tab: 'types.ui.Tab|str|None' = None) -> ui.gui.TextField: +def basename_field(order: int = 32, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.TextField: return ui.gui.TextField( label=_('Base Name'), order=order, tooltip=_('Base name for clones from this service'), - tab=tab, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, required=True, old_field_name='baseName', ) # Length of name field -def lenname_field( - order: int = 33, tab: 'types.ui.Tab|str|None' = None -) -> ui.gui.NumericField: +def lenname_field(order: int = 33, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.NumericField: return ui.gui.NumericField( length=1, label=_('Name Length'), default=3, order=order, tooltip=_('Size of numeric part for the names derived from this service'), - tab=tab, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, required=True, old_field_name='lenName', ) @@ -252,7 +253,7 @@ def lenname_field( # Max preparing services field def concurrent_creation_limit_field( - order: int = 50, tab: typing.Optional[types.ui.Tab] = None + order: int = 50, tab: 'types.ui.Tab|str|None|bool' = None ) -> ui.gui.NumericField: # Advanced tab return ui.gui.NumericField( @@ -264,13 +265,13 @@ def concurrent_creation_limit_field( order=order, tooltip=_('Maximum number of concurrently creating VMs'), required=True, - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name='maxPreparingServices', ) def concurrent_removal_limit_field( - order: int = 51, tab: 'types.ui.Tab|str|None' = None + order: int = 51, tab: 'types.ui.Tab|str|None|bool' = None ) -> ui.gui.NumericField: return ui.gui.NumericField( length=3, @@ -281,27 +282,25 @@ def concurrent_removal_limit_field( order=order, tooltip=_('Maximum number of concurrently removing VMs'), required=True, - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name='maxRemovingServices', ) -def remove_duplicates_field( - order: int = 102, tab: 'types.ui.Tab|str|None' = None -) -> ui.gui.CheckBoxField: +def remove_duplicates_field(order: int = 102, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.CheckBoxField: return ui.gui.CheckBoxField( label=_('Remove found duplicates'), default=True, order=order, tooltip=_('If active, found duplicates vApps for this service will be removed'), - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name='removeDuplicates', ) def soft_shutdown_field( order: int = 103, - tab: 'types.ui.Tab|str|None' = None, + tab: 'types.ui.Tab|str|None|bool' = None, old_field_name: typing.Optional[str] = None, ) -> ui.gui.CheckBoxField: return ui.gui.CheckBoxField( @@ -311,14 +310,14 @@ def soft_shutdown_field( tooltip=_( 'If active, UDS will try to shutdown (soft) the machine using Nutanix ACPI. Will delay 30 seconds the power off of hanged machines.' ), - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name=old_field_name, ) def keep_on_access_error_field( order: int = 104, - tab: 'types.ui.Tab|str|None' = None, + tab: 'types.ui.Tab|str|None|bool' = None, old_field_name: typing.Optional[str] = None, ) -> ui.gui.CheckBoxField: return ui.gui.CheckBoxField( @@ -326,7 +325,7 @@ def keep_on_access_error_field( value=False, order=order, tooltip=_('If active, access errors found on machine will not be considered errors.'), - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name=old_field_name, ) @@ -334,7 +333,7 @@ def keep_on_access_error_field( def macs_range_field( default: str, order: int = 91, - tab: 'types.ui.Tab|str|None' = None, + tab: 'types.ui.Tab|str|None|bool' = None, readonly: bool = False, ) -> ui.gui.TextField: return ui.gui.TextField( @@ -347,11 +346,12 @@ def macs_range_field( default=default ), required=True, - tab=tab or types.ui.Tab.ADVANCED, + tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED, old_field_name='macsRange', ) -def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None' = None) -> ui.gui.TextField: + +def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.TextField: return ui.gui.TextField( length=2048, lines=2, @@ -359,6 +359,6 @@ def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None' = None) -> ui.g order=order, tooltip=_('Attribute from where to extract the MFA code'), required=False, - tab=tab or types.ui.Tab.MFA, + tab=None if tab is False else None if tab is None else types.ui.Tab.MFA, old_field_name='mfaAttr', ) diff --git a/server/src/uds/core/workers/servicepools_cache_updater.py b/server/src/uds/core/workers/servicepools_cache_updater.py index 65afaa328..29f2de98b 100644 --- a/server/src/uds/core/workers/servicepools_cache_updater.py +++ b/server/src/uds/core/workers/servicepools_cache_updater.py @@ -29,7 +29,9 @@ """ @author: Adolfo Gómez, dkmaster at dkmon dot com """ +import dataclasses import logging +from multiprocessing import pool import typing import collections.abc @@ -46,6 +48,60 @@ from uds.core.jobs import Job logger = logging.getLogger(__name__) +# The functionallyty of counters are: +# * while we have less items than initial, we need to create cached l1 items +# * Once initial is reached, we have to keep cached l1 items at cache_l1_srvs +# * We stop creating cached l1 items when max is reached +# * If we have more than initial, we can remove cached l1 items until we reach cache_l1_srvs +# * If we have more than max, we can remove cached l1 items until we have no more than max +# * l2 is independent of any other counter, and will be created until cache_l2_srvs is reached +# * l2 will be removed until cache_l2_srvs is reached + + +@dataclasses.dataclass(slots=True) +class ServicePoolStats: + servicepool: ServicePool + l1_cache_count: int + l2_cache_count: int + assigned_count: int + + def l1_cache_overflow(self) -> bool: + """Checks if L1 cache is overflown + + Overflows if: + * l1_assigned_count > max_srvs + (this is, if we have more than max, we can remove cached l1 items until we reach max) + * l1_assigned_count > initial_srvs and l1_cache_count > cache_l1_srvs + (this is, if we have more than initial, we can remove cached l1 items until we reach cache_l1_srvs) + """ + l1_assigned_count = self.l1_cache_count + self.assigned_count + return l1_assigned_count > self.servicepool.max_srvs or ( + l1_assigned_count > self.servicepool.initial_srvs + and self.l1_cache_count > self.servicepool.cache_l1_srvs + ) + + def l1_cache_needed(self) -> bool: + """Checks if L1 cache is needed + + Grow L1 cache if: + * l1_assigned_count < max_srvs and (l1_assigned_count < initial_srvs or l1_cache_count < cache_l1_srvs) + (this is, if we have not reached max, and we have not reached initial or cache_l1_srvs, we need to grow L1 cache) + + """ + l1_assigned_count = self.l1_cache_count + self.assigned_count + return l1_assigned_count < self.servicepool.max_srvs and ( + l1_assigned_count < self.servicepool.initial_srvs + or self.l1_cache_count < self.servicepool.cache_l1_srvs + ) + + def l2_cache_overflow(self) -> bool: + """Checks if L2 cache is overflown""" + return self.l2_cache_count > self.servicepool.cache_l2_srvs + + def l2_cache_needed(self) -> bool: + """Checks if L2 cache is needed""" + return self.l2_cache_count < self.servicepool.cache_l2_srvs + class ServiceCacheUpdater(Job): """ @@ -73,11 +129,11 @@ class ServiceCacheUpdater(Job): def service_pools_needing_cache_update( self, - ) -> list[tuple[ServicePool, int, int, int]]: + ) -> list[ServicePoolStats]: # State filter for cached and inAssigned objects # First we get all deployed services that could need cache generation # We start filtering out the deployed services that do not need caching at all. - servicePoolsNeedingCaching: collections.abc.Iterable[ServicePool] = ( + candidate_servicepools: collections.abc.Iterable[ServicePool] = ( ServicePool.objects.filter(Q(initial_srvs__gte=0) | Q(cache_l1_srvs__gte=0)) .filter( max_srvs__gt=0, @@ -88,126 +144,121 @@ class ServiceCacheUpdater(Job): ) # We will get the one that proportionally needs more cache - servicesPools: list[tuple[ServicePool, int, int, int]] = [] - for servicePool in servicePoolsNeedingCaching: - servicePool.userServices.update() # Cleans cached queries + servicepools_numbers: list[ServicePoolStats] = [] + for servicepool in candidate_servicepools: + servicepool.user_services.update() # Cleans cached queries # If this deployedService don't have a publication active and needs it, ignore it - spServiceInstance = servicePool.service.get_instance() # type: ignore + service_instance = servicepool.service.get_instance() # type: ignore - if spServiceInstance.uses_cache is False: + if service_instance.uses_cache is False: logger.debug( 'Skipping cache generation for service pool that does not uses cache: %s', - servicePool.name, + servicepool.name, ) continue - if servicePool.active_publication() is None and spServiceInstance.publication_type is not None: + if servicepool.active_publication() is None and service_instance.publication_type is not None: logger.debug( 'Skipping. %s Needs publication but do not have one', - servicePool.name, + servicepool.name, ) continue # If it has any running publication, do not generate cache anymore - if servicePool.publications.filter(state=State.PREPARING).count() > 0: + if servicepool.publications.filter(state=State.PREPARING).count() > 0: logger.debug( 'Skipping cache generation for service pool with publication running: %s', - servicePool.name, + servicepool.name, ) continue - if servicePool.is_restrained(): + if servicepool.is_restrained(): logger.debug( 'StopSkippingped cache generation for restrained service pool: %s', - servicePool.name, + servicepool.name, ) - ServiceCacheUpdater._notify_restrain(servicePool) + ServiceCacheUpdater._notify_restrain(servicepool) continue # Get data related to actual state of cache # Before we were removing the elements marked to be destroyed after creation, but this makes us # to create new items over the limit stablisshed, so we will not remove them anymore - inCacheL1: int = ( - servicePool.cached_users_services() - .filter(UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L1_CACHE)) + l1_cache_count: int = ( + servicepool.cached_users_services() + .filter(UserServiceManager().get_cache_state_filter(servicepool, services.UserService.L1_CACHE)) .count() ) - inCacheL2: int = ( + l2_cache_count: int = ( ( - servicePool.cached_users_services() + servicepool.cached_users_services() .filter( - UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L2_CACHE) + UserServiceManager().get_cache_state_filter(servicepool, services.UserService.L2_CACHE) ) .count() ) - if spServiceInstance.uses_cache_l2 + if service_instance.uses_cache_l2 else 0 ) - inAssigned: int = ( - servicePool.assigned_user_services() - .filter(UserServiceManager().get_state_filter(servicePool.service)) # type: ignore + assigned_count: int = ( + servicepool.assigned_user_services() + .filter(UserServiceManager().get_state_filter(servicepool.service)) # type: ignore .count() ) + pool_stat = ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count) # if we bypasses max cache, we will reduce it in first place. This is so because this will free resources on service provider logger.debug( "Examining %s with %s in cache L1 and %s in cache L2, %s inAssigned", - servicePool.name, - inCacheL1, - inCacheL2, - inAssigned, + servicepool.name, + l1_cache_count, + l2_cache_count, + assigned_count, ) - totalL1Assigned = inCacheL1 + inAssigned + l1_assigned_count = l1_cache_count + assigned_count # We have more than we want - if totalL1Assigned > servicePool.max_srvs: - logger.debug('We have more services than max configured. skipping.') - servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned)) - continue - # We have more in L1 cache than needed - if totalL1Assigned > servicePool.initial_srvs and inCacheL1 > servicePool.cache_l1_srvs: - logger.debug('We have more services in cache L1 than configured, appending') - servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned)) + if pool_stat.l1_cache_overflow(): + logger.debug('We have more services than max configured. Reducing..') + servicepools_numbers.append( + ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count) + ) continue # If we have more in L2 cache than needed, decrease L2 cache, but int this case, we continue checking cause L2 cache removal # has less priority than l1 creations or removals, but higher. In this case, we will simply take last l2 oversized found and reduce it - if spServiceInstance.uses_cache_l2 and inCacheL2 > servicePool.cache_l2_srvs: - logger.debug('We have more services in L2 cache than configured, appending') - servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned)) + if pool_stat.l2_cache_overflow(): + logger.debug('We have more services in L2 cache than configured, reducing') + servicepools_numbers.append( + ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count) + ) continue - + # If this service don't allows more starting user services, continue - if not UserServiceManager().can_grow_service_pool(servicePool): + if not UserServiceManager().can_grow_service_pool(servicepool): logger.debug( 'This pool cannot grow rithg now: %s', - servicePool, + servicepool, + ) + continue + + if pool_stat.l1_cache_needed(): + logger.debug('Needs to grow L1 cache for %s', servicepool) + servicepools_numbers.append( + ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count) + ) + continue + + if pool_stat.l2_cache_needed(): + logger.debug('Needs to grow L2 cache for %s', servicepool) + servicepools_numbers.append( + ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count) ) continue - # If wee need to grow l2 cache, annotate it - # Whe check this before checking the total, because the l2 cache is independent of max services or l1 cache. - # It reflects a value that must be keeped in cache for futre fast use. - if inCacheL2 < servicePool.cache_l2_srvs: - logger.debug('Needs to grow L2 cache for %s', servicePool) - servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned)) - continue - - # We skip it if already at max - if totalL1Assigned == servicePool.max_srvs: - continue - - if totalL1Assigned < servicePool.initial_srvs or inCacheL1 < servicePool.cache_l1_srvs: - logger.debug('Needs to grow L1 cache for %s', servicePool) - servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned)) - # We also return calculated values so we can reuse then - return servicesPools + return servicepools_numbers def grow_l1_cache( self, - servicePool: ServicePool, - cacheL1: int, # pylint: disable=unused-argument - cacheL2: int, - assigned: int, # pylint: disable=unused-argument + servicepool_stats: ServicePoolStats, ) -> None: """ This method tries to enlarge L1 cache. @@ -216,16 +267,18 @@ class ServiceCacheUpdater(Job): and PREPARING, assigned, L1 and L2) is over max allowed service deployments, this method will not grow the L1 cache """ - logger.debug('Growing L1 cache creating a new service for %s', servicePool.name) + logger.debug('Growing L1 cache creating a new service for %s', servicepool_stats.servicepool.name) # First, we try to assign from L2 cache - if cacheL2 > 0: + if servicepool_stats.l2_cache_count > 0: valid = None with transaction.atomic(): for n in ( - servicePool.cached_users_services() + servicepool_stats.servicepool.cached_users_services() .select_for_update() .filter( - UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L2_CACHE) + UserServiceManager().get_cache_state_filter( + servicepool_stats.servicepool, services.UserService.L2_CACHE + ) ) .order_by('creation_date') ): @@ -246,30 +299,27 @@ class ServiceCacheUpdater(Job): try: # This has a velid publication, or it will not be here UserServiceManager().create_cache_for( - typing.cast(ServicePoolPublication, servicePool.active_publication()), + typing.cast(ServicePoolPublication, servicepool_stats.servicepool.active_publication()), services.UserService.L1_CACHE, ) except MaxServicesReachedError: log.log( - servicePool, + servicepool_stats.servicepool, log.LogLevel.ERROR, 'Max number of services reached for this service', log.LogSource.INTERNAL, ) logger.warning( 'Max user services reached for %s: %s. Cache not created', - servicePool.name, - servicePool.max_srvs, + servicepool_stats.servicepool.name, + servicepool_stats.servicepool.max_srvs, ) except Exception: logger.exception('Exception') def grow_l2_cache( self, - servicePool: ServicePool, - cacheL1: int, # pylint: disable=unused-argument - cacheL2: int, # pylint: disable=unused-argument - assigned: int, # pylint: disable=unused-argument + servicepool_stats: ServicePoolStats, ) -> None: """ Tries to grow L2 cache of service. @@ -278,35 +328,36 @@ class ServiceCacheUpdater(Job): and PREPARING, assigned, L1 and L2) is over max allowed service deployments, this method will not grow the L1 cache """ - logger.debug("Growing L2 cache creating a new service for %s", servicePool.name) + logger.debug("Growing L2 cache creating a new service for %s", servicepool_stats.servicepool.name) try: # This has a velid publication, or it will not be here UserServiceManager().create_cache_for( - typing.cast(ServicePoolPublication, servicePool.active_publication()), + typing.cast(ServicePoolPublication, servicepool_stats.servicepool.active_publication()), services.UserService.L2_CACHE, ) except MaxServicesReachedError: logger.warning( 'Max user services reached for %s: %s. Cache not created', - servicePool.name, - servicePool.max_srvs, + servicepool_stats.servicepool.name, + servicepool_stats.servicepool.max_srvs, ) # TODO: When alerts are ready, notify this def reduce_l1_cache( self, - servicePool: ServicePool, - cacheL1: int, # pylint: disable=unused-argument - cacheL2: int, - assigned: int, # pylint: disable=unused-argument - ): - logger.debug("Reducing L1 cache erasing a service in cache for %s", servicePool) - # We will try to destroy the newest cacheL1 element that is USABLE if the deployer can't cancel a new service creation + servicepool_stats: ServicePoolStats, + ) -> None: + logger.debug("Reducing L1 cache erasing a service in cache for %s", servicepool_stats.servicepool) + # We will try to destroy the newest l1_cache_count element that is USABLE if the deployer can't cancel a new service creation # Here, we will take into account the "remove_after" marked user services, so we don't try to remove them cacheItems: list[UserService] = [ i - for i in servicePool.cached_users_services() - .filter(UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L1_CACHE)) + for i in servicepool_stats.servicepool.cached_users_services() + .filter( + UserServiceManager().get_cache_state_filter( + servicepool_stats.servicepool, services.UserService.L1_CACHE + ) + ) .order_by('-creation_date') .iterator() if not i.destroy_after @@ -318,7 +369,7 @@ class ServiceCacheUpdater(Job): ) return - if cacheL2 < servicePool.cache_l2_srvs: + if servicepool_stats.l2_cache_count < servicepool_stats.servicepool.cache_l2_srvs: valid = None for n in cacheItems: if n.needsOsManager(): @@ -338,17 +389,16 @@ class ServiceCacheUpdater(Job): def reduce_l2_cache( self, - service_pool: ServicePool, - cacheL1: int, # pylint: disable=unused-argument - cacheL2: int, - assigned: int, # pylint: disable=unused-argument + servicepool_stats: ServicePoolStats, ): - logger.debug("Reducing L2 cache erasing a service in cache for %s", service_pool.name) - if cacheL2 > 0: + logger.debug("Reducing L2 cache erasing a service in cache for %s", servicepool_stats.servicepool.name) + if servicepool_stats.l2_cache_count > 0: cacheItems = ( - service_pool.cached_users_services() + servicepool_stats.servicepool.cached_users_services() .filter( - UserServiceManager().get_cache_state_filter(service_pool, services.UserService.L2_CACHE) + UserServiceManager().get_cache_state_filter( + servicepool_stats.servicepool, services.UserService.L2_CACHE + ) ) .order_by('creation_date') ) @@ -359,29 +409,18 @@ class ServiceCacheUpdater(Job): def run(self) -> None: logger.debug('Starting cache checking') # We need to get - servicesThatNeedsUpdate = self.service_pools_needing_cache_update() - for servicePool, cacheL1, cacheL2, assigned in servicesThatNeedsUpdate: + for servicepool_stat in self.service_pools_needing_cache_update(): # We have cache to update?? - logger.debug("Updating cache for %s", servicePool) - totalL1Assigned = cacheL1 + assigned + logger.debug("Updating cache for %s", servicepool_stat) - # We try first to reduce cache before tring to increase it. - # This means that if there is excesive number of user deployments - # for L1 or L2 cache, this will be reduced untill they have good numbers. - # This is so because service can have limited the number of services and, - # if we try to increase cache before having reduced whatever needed - # first, the service will get lock until someone removes something. - if totalL1Assigned > servicePool.max_srvs: - self.reduce_l1_cache(servicePool, cacheL1, cacheL2, assigned) - elif totalL1Assigned > servicePool.initial_srvs and cacheL1 > servicePool.cache_l1_srvs: - self.reduce_l1_cache(servicePool, cacheL1, cacheL2, assigned) - elif cacheL2 > servicePool.cache_l2_srvs: # We have excesives L2 items - self.reduce_l2_cache(servicePool, cacheL1, cacheL2, assigned) - elif totalL1Assigned < servicePool.max_srvs and ( - totalL1Assigned < servicePool.initial_srvs or cacheL1 < servicePool.cache_l1_srvs - ): # We need more services - self.grow_l1_cache(servicePool, cacheL1, cacheL2, assigned) - elif cacheL2 < servicePool.cache_l2_srvs: # We need more L2 items - self.grow_l2_cache(servicePool, cacheL1, cacheL2, assigned) - else: - logger.warning("We have more services than max requested for %s", servicePool.name) + # Treat l1 and l2 cache independently + # first, try to reduce cache and then grow it + if servicepool_stat.l1_cache_overflow(): + self.reduce_l1_cache(servicepool_stat) + elif servicepool_stat.l1_cache_needed(): # We need more L1 items + self.grow_l1_cache(servicepool_stat) + # Treat l1 and l2 cache independently + if servicepool_stat.l2_cache_overflow(): + self.reduce_l2_cache(servicepool_stat) + elif servicepool_stat.l2_cache_needed(): # We need more L2 items + self.grow_l2_cache(servicepool_stat) diff --git a/server/src/uds/management/commands/tree.py b/server/src/uds/management/commands/tree.py index 4db1acbbb..1a314c39c 100644 --- a/server/src/uds/management/commands/tree.py +++ b/server/src/uds/management/commands/tree.py @@ -278,8 +278,8 @@ class Command(BaseCommand): # os managers osManagers: dict[str, typing.Any] = {} - for osManager in models.OSManager.objects.all(): - osManagers[osManager.name] = get_serialized_from_managed_object(osManager) + for osmanager in models.OSManager.objects.all(): + osManagers[osmanager.name] = get_serialized_from_managed_object(osmanager) tree[counter('OSMANAGERS')] = osManagers diff --git a/server/src/uds/models/__init__.py b/server/src/uds/models/__init__.py index c73ba3176..afc7cc94a 100644 --- a/server/src/uds/models/__init__.py +++ b/server/src/uds/models/__init__.py @@ -42,7 +42,7 @@ from .provider import Provider from .service import Service, ServiceTokenAlias # Os managers -from .os_manager import OSManager +from .osmanager import OSManager # Transports from .transport import Transport diff --git a/server/src/uds/models/os_manager.py b/server/src/uds/models/osmanager.py similarity index 100% rename from server/src/uds/models/os_manager.py rename to server/src/uds/models/osmanager.py diff --git a/server/src/uds/models/service_pool.py b/server/src/uds/models/service_pool.py index 5a0db2279..8454f0fdc 100644 --- a/server/src/uds/models/service_pool.py +++ b/server/src/uds/models/service_pool.py @@ -47,7 +47,7 @@ from uds.core.util.model import sql_datetime from .account import Account from .group import Group from .image import Image -from .os_manager import OSManager +from .osmanager import OSManager from .service import Service from .service_pool_group import ServicePoolGroup from .tag import TaggingMixin @@ -642,7 +642,7 @@ class ServicePool(UUIDModel, TaggingMixin): # type: ignore """ maxs = self.max_srvs if maxs == 0 and self.service: - maxs = self.service.get_instance().max_user_services + maxs = self.service.get_instance().userservices_limit if cachedValue == -1: cachedValue = ( diff --git a/server/src/uds/models/service_pool_publication.py b/server/src/uds/models/service_pool_publication.py index dbb726812..a5555bc85 100644 --- a/server/src/uds/models/service_pool_publication.py +++ b/server/src/uds/models/service_pool_publication.py @@ -36,6 +36,7 @@ import typing import collections.abc from django.db import models +import public from uds.core.managers import publication_manager from uds.core.types.states import State @@ -132,32 +133,36 @@ class ServicePoolPublication(UUIDModel): """ if not self.deployed_service.service: raise Exception('No service assigned to publication') - serviceInstance = self.deployed_service.service.get_instance() - osManager = self.deployed_service.osmanager - osManagerInstance = osManager.get_instance() if osManager else None + service_instance = self.deployed_service.service.get_instance() + osmanager = self.deployed_service.osmanager + osmanager_instance = osmanager.get_instance() if osmanager else None # Sanity check, so it's easier to find when we have created # a service that needs publication but do not have - if serviceInstance.publication_type is None: + if service_instance.publication_type is None: raise Exception( - f'Class {serviceInstance.__class__.__name__} do not have defined publication_type but needs to be published!!!' + f'Class {service_instance.__class__.__name__} do not have defined publication_type but needs to be published!!!' ) - publication = serviceInstance.publication_type( + publication = service_instance.publication_type( self.get_environment(), - service=serviceInstance, - osManager=osManagerInstance, + service=service_instance, + osmanager=osmanager_instance, revision=self.revision, - dsName=self.deployed_service.name, + servicepool_name=self.deployed_service.name, uuid=self.uuid, ) # Only invokes deserialization if data has something. '' is nothing if self.data: publication.deserialize(self.data) + if publication.needs_upgrade(): + self.update_data(publication) + publication.flag_for_upgrade(False) + return publication - def update_data(self, publication): + def update_data(self, publication_instance: 'services.Publication') -> None: """ Updates the data field with the serialized uds.core.services.Publication @@ -166,7 +171,7 @@ class ServicePoolPublication(UUIDModel): :note: This method do not saves the updated record, just updates the field """ - self.data = publication.serialize() + self.data = publication_instance.serialize() self.save(update_fields=['data']) def set_state(self, state: str) -> None: diff --git a/server/src/uds/models/user_service.py b/server/src/uds/models/user_service.py index 3526314c9..255e3a1a7 100644 --- a/server/src/uds/models/user_service.py +++ b/server/src/uds/models/user_service.py @@ -146,7 +146,10 @@ class UserService(UUIDModel, properties.PropertiesMixin): """ Returns True if this service is to be removed """ - return self.properties.get('destroy_after', False) in ('y', True) # Compare to str to keep compatibility with old values + return self.properties.get('destroy_after', False) in ( + 'y', + True, + ) # Compare to str to keep compatibility with old values @destroy_after.setter def destroy_after(self, value: bool) -> None: @@ -201,19 +204,19 @@ class UserService(UUIDModel, properties.PropertiesMixin): Raises: """ # We get the service instance, publication instance and osmanager instance - servicePool = self.deployed_service - if not servicePool.service: + servicepool = self.deployed_service + if not servicepool.service: raise Exception('Service not found') - serviceInstance = servicePool.service.get_instance() - if serviceInstance.needs_manager is False or not servicePool.osmanager: - osmanagerInstance = None + service_instance = servicepool.service.get_instance() + if service_instance.needs_manager is False or not servicepool.osmanager: + osmanager_instance = None else: - osmanagerInstance = servicePool.osmanager.get_instance() + osmanager_instance = servicepool.osmanager.get_instance() # We get active publication - publicationInstance = None + publication_instance = None try: # We may have deleted publication... if self.publication is not None: - publicationInstance = self.publication.get_instance() + publication_instance = self.publication.get_instance() except Exception: # The publication to witch this item points to, does not exists self.publication = None # type: ignore @@ -221,20 +224,29 @@ class UserService(UUIDModel, properties.PropertiesMixin): 'Got exception at get_instance of an userService %s (seems that publication does not exists!)', self, ) - if serviceInstance.user_service_type is None: + if service_instance.user_service_type is None: raise Exception( - f'Class {serviceInstance.__class__.__name__} needs user_service_type but it is not defined!!!' + f'Class {service_instance.__class__.__name__} needs user_service_type but it is not defined!!!' ) - us = serviceInstance.user_service_type( + us = service_instance.user_service_type( self.get_environment(), - service=serviceInstance, - publication=publicationInstance, - osmanager=osmanagerInstance, + service=service_instance, + publication=publication_instance, + osmanager=osmanager_instance, uuid=self.uuid, ) - if self.data != '' and self.data is not None: + if self.data: try: us.deserialize(self.data) + + # if needs upgrade, we will serialize it again to ensure it is upgraded ASAP + # Eventually, it will be upgraded anyway, but could take too much time... + # This way, if we instantiate it, it will be upgraded + if us.needs_upgrade(): + self.data = us.serialize() + self.save(update_fields=['data']) + us.flag_for_upgrade(False) + except Exception: logger.exception( 'Error unserializing %s//%s : %s', @@ -244,7 +256,7 @@ class UserService(UUIDModel, properties.PropertiesMixin): ) return us - def update_data(self, userServiceInstance: 'services.UserService'): + def update_data(self, userservice_instance: 'services.UserService'): """ Updates the data field with the serialized :py:class:uds.core.services.UserDeployment @@ -253,7 +265,7 @@ class UserService(UUIDModel, properties.PropertiesMixin): :note: This method SAVES the updated record, just updates the field """ - self.data = userServiceInstance.serialize() + self.data = userservice_instance.serialize() self.save(update_fields=['data']) def get_name(self) -> str: @@ -347,9 +359,9 @@ class UserService(UUIDModel, properties.PropertiesMixin): return self.deployed_service.osmanager def get_osmanager_instance(self) -> typing.Optional['osmanagers.OSManager']: - osManager = self.getOsManager() - if osManager: - return osManager.get_instance() + osmanager = self.getOsManager() + if osmanager: + return osmanager.get_instance() return None def needsOsManager(self) -> bool: @@ -544,6 +556,7 @@ class UserService(UUIDModel, properties.PropertiesMixin): # pylint: disable=import-outside-toplevel from uds.core.managers.user_service import UserServiceManager + # Cancel is a "forced" operation, so they are not checked against limits UserServiceManager().cancel(self) def remove_or_cancel(self) -> None: @@ -597,7 +610,7 @@ class UserService(UUIDModel, properties.PropertiesMixin): @property def actor_version(self) -> str: return self.properties.get('actor_version') or '0.0.0' - + @actor_version.setter def actor_version(self, version: str) -> None: self.properties['actor_version'] = version diff --git a/server/src/uds/services/OpenGnsys/service.py b/server/src/uds/services/OpenGnsys/service.py index a75c84f34..909ddf662 100644 --- a/server/src/uds/services/OpenGnsys/service.py +++ b/server/src/uds/services/OpenGnsys/service.py @@ -144,7 +144,7 @@ class OGService(services.Service): ), ) - maxServices = gui.NumericField( + services_limit = gui.NumericField( order=4, label=_("Max. Allowed services"), min_value=0, @@ -153,7 +153,8 @@ class OGService(services.Service): readonly=False, tooltip=_('Maximum number of allowed services (0 or less means no limit)'), required=True, - tab=types.ui.Tab.ADVANCED + tab=types.ui.Tab.ADVANCED, + old_field_name='maxServices', ) ov = gui.HiddenField(value=None) diff --git a/server/src/uds/services/PhysicalMachines/service_multi.py b/server/src/uds/services/PhysicalMachines/service_multi.py index 5a977e6e2..8b1b9c664 100644 --- a/server/src/uds/services/PhysicalMachines/service_multi.py +++ b/server/src/uds/services/PhysicalMachines/service_multi.py @@ -239,7 +239,7 @@ class IPMachinesService(IPServiceBase): self._useRandomIp = gui.as_bool(values[6].decode()) # Sets maximum services for this - self.max_user_services = len(self._ips) + self.userservices_limit = len(self._ips) def canBeUsed(self, locked: typing.Optional[typing.Union[str, int]], now: int) -> int: # If _maxSessionForMachine is 0, it can be used only if not locked diff --git a/server/src/uds/services/Test/provider.py b/server/src/uds/services/Test/provider.py index 5ade09b76..f95a8b5b1 100644 --- a/server/src/uds/services/Test/provider.py +++ b/server/src/uds/services/Test/provider.py @@ -95,7 +95,7 @@ class TestProvider(services.ServiceProvider): self.data.name = ''.join(random.SystemRandom().choices(string.ascii_letters, k=10)) self.data.integer = random.randint(0, 100) return super().initialize(values) - + @staticmethod def test( env: 'Environment', data: dict[str, str] diff --git a/server/tests/auths/regex_ldap/test_serialization.py b/server/tests/auths/regex_ldap/test_serialization.py index 822dbc2c7..1d7ba9daf 100644 --- a/server/tests/auths/regex_ldap/test_serialization.py +++ b/server/tests/auths/regex_ldap/test_serialization.py @@ -1,13 +1,34 @@ -# -*- coding: utf-8 -*- - +# pylint: disable=no-member # ldap module gives errors to pylint # -# Copyright (c) 2022 Virtual Cable S.L.U. +# Copyright (c) 2024 Virtual Cable S.L.U. # All rights reserved. # +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" +''' Author: Adolfo Gómez, dkmaster at dkmon dot com -""" +''' import typing # We use commit/rollback @@ -55,7 +76,7 @@ SERIALIZED_AUTH_DATA: typing.Final[typing.Mapping[str, bytes]] = { class RegexSerializationTest(UDSTestCase): - def check_provider(self, version: str, instance: 'authenticator.RegexLdap'): + def check_provider(self, version: str, instance: 'authenticator.RegexLdap') -> None: self.assertEqual(instance.host.as_str(), 'host') self.assertEqual(instance.port.as_int(), 166) self.assertEqual(instance.use_ssl.as_bool(), True) @@ -69,32 +90,32 @@ class RegexSerializationTest(UDSTestCase): if version >= 'v2': self.assertEqual(instance.username_attr.as_str(), 'usernattr') - def test_unmarshall_all_versions(self): + def test_unmarshall_all_versions(self) -> None: for v in range(1, len(SERIALIZED_AUTH_DATA) + 1): instance = authenticator.RegexLdap(environment=Environment.get_temporary_environment()) instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)]) self.check_provider(f'v{v}', instance) - def test_marshaling(self): + def test_marshaling(self) -> None: # Unmarshall last version, remarshall and check that is marshalled using new marshalling format LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA)) instance = authenticator.RegexLdap( environment=Environment.get_temporary_environment() ) instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION]) - marshalled_data = instance.marshal() + marshaled_data = instance.marshal() # Ensure remarshalled flag is set self.assertTrue(instance.needs_upgrade()) instance.flag_for_upgrade(False) # reset flag # Ensure fields has been marshalled using new format - self.assertFalse(marshalled_data.startswith(b'v')) + self.assertFalse(marshaled_data.startswith(b'v')) # Reunmarshall again and check that remarshalled flag is not set instance = authenticator.RegexLdap( environment=Environment.get_temporary_environment() ) - instance.unmarshal(marshalled_data) + instance.unmarshal(marshaled_data) self.assertFalse(instance.needs_upgrade()) # Check that data is correct diff --git a/server/tests/auths/simple_ldap/__init__.py b/server/tests/auths/simple_ldap/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/tests/auths/simple_ldap/test_serialization.py b/server/tests/auths/simple_ldap/test_serialization.py new file mode 100644 index 000000000..491bbc282 --- /dev/null +++ b/server/tests/auths/simple_ldap/test_serialization.py @@ -0,0 +1,124 @@ +# pylint: disable=no-member # ldap module gives errors to pylint +# +# Copyright (c) 2024 Virtual Cable S.L.U. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +''' +Author: Adolfo Gómez, dkmaster at dkmon dot com +''' +import typing + +# We use commit/rollback + +from tests.utils.test import UDSTestCase +from uds.core import ui +from uds.core.ui.user_interface import gui, UDSB, UDSK +from uds.core.environment import Environment +from uds.core.managers import crypto + +from django.conf import settings + + +from uds.auths.SimpleLDAP import authenticator + +PASSWD: typing.Final[str] = 'PASSWD' + +# v1: +# self.host.value, +# self.port.value, +# self.use_ssl.value, +# self.username.value, +# self.password.value, +# self.timeout.value, +# self.ldap_base.value, +# self.user_class.value, +# self.userid_attr.value, +# self.groupname_attr.value, +# v2: +# self.username_attr.value = vals[11] +# v3: +# self.alternate_class.value = vals[12] +# v4: +# self.mfa_attribute.value = vals[13] +# v5: +# self.verify_ssl.value = vals[14] +# self.certificate.value = vals[15] +SERIALIZED_AUTH_DATA: typing.Final[typing.Mapping[str, bytes]] = { + 'v1': b'v1\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\t\tusernattr', + 'v2': b'v2\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr', + 'v3': b'v3\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass', + 'v4': b'v4\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass\tmfa', + 'v5': b'v5\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass\tmfa\tTRUE\tcert', +} + + +class RegexSerializationTest(UDSTestCase): + def check_provider(self, version: str, instance: 'authenticator.SimpleLDAPAuthenticator'): + self.assertEqual(instance.host.as_str(), 'host') + self.assertEqual(instance.port.as_int(), 166) + self.assertEqual(instance.use_ssl.as_bool(), True) + self.assertEqual(instance.username.as_str(), 'uame') + self.assertEqual(instance.password.as_str(), PASSWD) + self.assertEqual(instance.timeout.as_int(), 99) + self.assertEqual(instance.ldap_base.as_str(), 'dc=dom,dc=m') + self.assertEqual(instance.user_class.as_str(), 'uclass') + self.assertEqual(instance.userid_attr.as_str(), 'useridAttr') + self.assertEqual(instance.groupname_attr.as_str(), 'group_attr') + if version >= 'v2': + self.assertEqual(instance.username_attr.as_str(), 'usernattr') + + def test_unmarshall_all_versions(self): + return + for v in range(1, len(SERIALIZED_AUTH_DATA) + 1): + instance = authenticator.SimpleLDAPAuthenticator(environment=Environment.get_temporary_environment()) + instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)]) + self.check_provider(f'v{v}', instance) + + def test_marshaling(self): + return + # Unmarshall last version, remarshall and check that is marshalled using new marshalling format + LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA)) + instance = authenticator.SimpleLDAPAuthenticator( + environment=Environment.get_temporary_environment() + ) + instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION]) + marshaled_data = instance.marshal() + + # Ensure remarshalled flag is set + self.assertTrue(instance.needs_upgrade()) + instance.flag_for_upgrade(False) # reset flag + + # Ensure fields has been marshalled using new format + self.assertFalse(marshaled_data.startswith(b'v')) + # Reunmarshall again and check that remarshalled flag is not set + instance = authenticator.SimpleLDAPAuthenticator( + environment=Environment.get_temporary_environment() + ) + instance.unmarshal(marshaled_data) + self.assertFalse(instance.needs_upgrade()) + + # Check that data is correct + self.check_provider(LAST_VERSION, instance) diff --git a/server/tests/core/workers/test_servicepools_cache_updater.py b/server/tests/core/workers/test_servicepools_cache_updater.py index fe8e3af26..9eb610156 100644 --- a/server/tests/core/workers/test_servicepools_cache_updater.py +++ b/server/tests/core/workers/test_servicepools_cache_updater.py @@ -56,15 +56,15 @@ class ServiceCacheUpdaterTest(UDSTestCase): # Default values for max TestProvider.concurrent_creation_limit = 1000 TestProvider.concurrent_removal_limit = 1000 - TestServiceCache.max_user_services = 1000 - TestServiceNoCache.max_user_services = 1000 + TestServiceCache.userservices_limit = 1000 + TestServiceNoCache.userservices_limit = 1000 ServiceCacheUpdater.setup() userService = services_fixtures.create_cache_testing_userservices()[0] self.servicePool = userService.deployed_service userService.delete() # empty all - def numberOfRemovingOrCanced(self) -> int: + def removing_or_canceled_count(self) -> int: return self.servicePool.userServices.filter( state__in=[State.REMOVABLE, State.CANCELED] ).count() @@ -74,7 +74,7 @@ class ServiceCacheUpdaterTest(UDSTestCase): updater = ServiceCacheUpdater(Environment.get_temporary_environment()) updater.run() # Test user service will cancel automatically so it will not get in "removable" state (on remove start, it will tell it has been removed) - return self.servicePool.userServices.count() - self.numberOfRemovingOrCanced() + return self.servicePool.userServices.count() - self.removing_or_canceled_count() def setCache( self, @@ -113,10 +113,10 @@ class ServiceCacheUpdaterTest(UDSTestCase): self.setCache(cache=10) self.assertEqual( - self.runCacheUpdater(mustDelete), self.servicePool.initial_srvs + self.runCacheUpdater(mustDelete*2), self.servicePool.initial_srvs ) - self.assertEqual(self.numberOfRemovingOrCanced(), mustDelete) + self.assertEqual(self.removing_or_canceled_count(), mustDelete) def test_max(self) -> None: self.setCache(initial=100, cache=10, max=50) @@ -154,8 +154,9 @@ class ServiceCacheUpdaterTest(UDSTestCase): TestProvider.concurrent_creation_limit = 0 self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), 1) - def test_provider_removing_limits(self) -> None: - TestProvider.concurrent_removal_limit = 10 + def test_provider_no_removing_limits(self) -> None: + # Removing limits are appliend in fact when EXECUTING removal, not when marking as removable + # Note that "cancel" also overpass this limit self.setCache(initial=0, cache=50, max=50) # Try to "overcreate" cache elements but provider limits it to 10 @@ -164,21 +165,22 @@ class ServiceCacheUpdaterTest(UDSTestCase): # Now set cache to a lower value self.setCache(cache=10) - # Execute updater, must remove 10 elements (concurrent_removal_limit) - self.assertEqual(self.runCacheUpdater(10), 40) + # Execute updater, must remove as long as runs elements (we use cancle here, so it will be removed) + # removes until 10, that is the limit due to cache + self.assertEqual(self.runCacheUpdater(50), 10) def test_service_max_deployed(self) -> None: - TestServiceCache.max_user_services = 22 + TestServiceCache.userservices_limit = 22 self.setCache(initial=100, cache=100, max=50) # Try to "overcreate" cache elements but provider limits it to 10 - self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), TestServiceCache.max_user_services) + self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), TestServiceCache.userservices_limit) # Delete all userServices self.servicePool.userServices.all().delete() # We again allow masUserServices to be zero (meaning that no service will be created) # This allows us to "honor" some external providers that, in some cases, will not have services available... - TestServiceCache.max_user_services = 0 + TestServiceCache.userservices_limit = 0 self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), 0)