1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-08 21:18:00 +03:00

Remodeled cache updater, minor fixes and more refactoring

This commit is contained in:
Adolfo Gómez García 2024-01-26 01:30:40 +01:00
parent 42e042d4d4
commit 485520f402
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
29 changed files with 572 additions and 386 deletions

View File

@ -446,9 +446,9 @@ class Initialize(ActorV3Action):
# Set last seen actor version
userService.actor_version = self._params['version']
osData: collections.abc.MutableMapping[str, typing.Any] = {}
osManager = userService.get_osmanager_instance()
if osManager:
osData = osManager.actor_data(userService)
osmanager = userService.get_osmanager_instance()
if osmanager:
osData = osmanager.actor_data(userService)
if service and not alias_token: # Is a service managed by UDS
# Create a new alias for it, and save
@ -501,11 +501,11 @@ class BaseReadyChange(ActorV3Action):
if userService.os_state != State.USABLE:
userService.setOsState(State.USABLE)
# Notify osManager or readyness if has os manager
osManager = userService.get_osmanager_instance()
# Notify osmanager or readyness if has os manager
osmanager = userService.get_osmanager_instance()
if osManager:
osManager.to_ready(userService)
if osmanager:
osmanager.to_ready(userService)
UserServiceManager().notify_ready_from_os_manager(userService, '')
# Generates a certificate and send it to client.
@ -590,10 +590,10 @@ class Login(ActorV3Action):
@staticmethod
def process_login(userservice: UserService, username: str) -> typing.Optional[osmanagers.OSManager]:
osManager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance()
osmanager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance()
if not userservice.in_use: # If already logged in, do not add a second login (windows does this i.e.)
osmanagers.OSManager.logged_in(userservice, username)
return osManager
return osmanager
def action(self) -> dict[str, typing.Any]:
isManaged = self._params.get('type') != consts.actor.UNMANAGED
@ -605,17 +605,17 @@ class Login(ActorV3Action):
try:
userservice: UserService = self.get_userservice()
os_manager = Login.process_login(userservice, self._params.get('username') or '')
osmanager = Login.process_login(userservice, self._params.get('username') or '')
max_idle = os_manager.max_idle() if os_manager else None
max_idle = osmanager.max_idle() if osmanager else None
logger.debug('Max idle: %s', max_idle)
src = userservice.getConnectionSource()
session_id = userservice.start_session() # creates a session for every login requested
if os_manager: # For os managed services, let's check if we honor deadline
if os_manager.ignore_deadline():
if osmanager: # For os managed services, let's check if we honor deadline
if osmanager.ignore_deadline():
deadline = userservice.deployed_service.get_deadline()
else:
deadline = None
@ -653,7 +653,7 @@ class Logout(ActorV3Action):
"""
This method is static so can be invoked from elsewhere
"""
osManager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance()
osmanager: typing.Optional[osmanagers.OSManager] = userservice.get_osmanager_instance()
# Close session
# For compat, we have taken '' as "all sessions"
@ -661,9 +661,9 @@ class Logout(ActorV3Action):
if userservice.in_use: # If already logged out, do not add a second logout (windows does this i.e.)
osmanagers.OSManager.logged_out(userservice, username)
if osManager:
if osManager.is_removable_on_logout(userservice):
logger.debug('Removable on logout: %s', osManager)
if osmanager:
if osmanager.is_removable_on_logout(userservice):
logger.debug('Removable on logout: %s', osmanager)
userservice.remove()
else:
userservice.remove()

View File

@ -72,7 +72,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
return {
'icon': info.icon64().replace('\n', ''),
'needs_publication': info.publication_type is not None,
'max_deployed': info.max_user_services,
'max_deployed': info.userservices_limit,
'uses_cache': info.uses_cache and info.overrided_fields is None,
'uses_cache_l2': info.uses_cache_l2,
'cache_tooltip': _(info.cache_tooltip),

View File

@ -98,23 +98,10 @@ class RegexLdap(auths.Authenticator):
required=True,
tab=types.ui.Tab.CREDENTIALS,
)
timeout = gui.NumericField(
length=3,
label=_('Timeout'),
default=10,
order=6,
tooltip=_('Timeout in seconds of connection to LDAP'),
required=True,
min_value=1,
)
verify_ssl = gui.CheckBoxField(
label=_('Verify SSL'),
default=True,
order=11,
tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'),
tab=types.ui.Tab.ADVANCED,
old_field_name='verifySsl',
)
timeout = fields.timeout_field(tab=False, default=10) # Use "main tab"
verify_ssl = fields.verify_ssl_field(order=11)
certificate = gui.TextField(
length=8192,
lines=4,
@ -123,7 +110,6 @@ class RegexLdap(auths.Authenticator):
tooltip=_('Certificate to use for SSL verification'),
required=False,
tab=types.ui.Tab.ADVANCED,
old_field_name='certificate',
)
ldap_base = gui.TextField(
length=64,
@ -132,7 +118,6 @@ class RegexLdap(auths.Authenticator):
tooltip=_('Common search base (used for "users" and "groups")'),
required=True,
tab=_('Ldap info'),
old_field_name='ldapBase',
)
user_class = gui.TextField(
length=64,
@ -142,7 +127,6 @@ class RegexLdap(auths.Authenticator):
tooltip=_('Class for LDAP users (normally posixAccount)'),
required=True,
tab=_('Ldap info'),
old_field_name='userClass',
)
userid_attr = gui.TextField(
length=64,
@ -152,7 +136,6 @@ class RegexLdap(auths.Authenticator):
tooltip=_('Attribute that contains the user id.'),
required=True,
tab=_('Ldap info'),
old_field_name='userIdAttr',
)
username_attr = gui.TextField(
length=640,
@ -165,7 +148,6 @@ class RegexLdap(auths.Authenticator):
),
required=True,
tab=_('Ldap info'),
old_field_name='userNameAttr',
)
groupname_attr = gui.TextField(
length=640,
@ -178,7 +160,6 @@ class RegexLdap(auths.Authenticator):
),
required=True,
tab=_('Ldap info'),
old_field_name='groupNameAttr',
)
# regex = gui.TextField(length=64, label = _('Regular Exp. for groups'), defvalue = '^(.*)', order = 12, tooltip = _('Regular Expression to extract the group name'), required = True)
@ -190,7 +171,6 @@ class RegexLdap(auths.Authenticator):
tooltip=_('Class for LDAP objects that will be also checked for groups retrieval (normally empty)'),
required=False,
tab=_('Advanced'),
old_field_name='altClass',
)
mfa_attribute = fields.mfa_attr_field()

View File

@ -1,6 +1,6 @@
# pylint: disable=no-member # ldap module gives errors to pylint
#
# Copyright (c) 2012-2021 Virtual Cable S.L.U.
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
@author: Adolfo Gómez, dkmaster at dkmon dot com
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import logging
import typing
@ -40,7 +40,7 @@ from django.utils.translation import gettext_noop as _
from uds.core import auths, types, consts, exceptions
from uds.core.auths.auth import log_login
from uds.core.ui import gui
from uds.core.util import ldaputil
from uds.core.util import fields, ldaputil
# Not imported at runtime, just for type checking
if typing.TYPE_CHECKING:
@ -68,7 +68,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
tooltip=_('Ldap port (usually 389 for non ssl and 636 for ssl)'),
required=True,
)
ssl = gui.CheckBoxField(
use_ssl = gui.CheckBoxField(
label=_('Use SSL'),
order=3,
tooltip=_('If checked, the connection will be ssl, using port 636 instead of 389'),
@ -89,23 +89,10 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=types.ui.Tab.CREDENTIALS,
)
timeout = gui.NumericField(
length=3,
label=_('Timeout'),
default=10,
order=10,
tooltip=_('Timeout in seconds of connection to LDAP'),
required=True,
min_value=1,
tab=types.ui.Tab.ADVANCED,
)
verifySsl = gui.CheckBoxField(
label=_('Verify SSL'),
default=True,
order=11,
tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'),
tab=types.ui.Tab.ADVANCED,
)
timeout = fields.timeout_field(tab=False, default=10) # Use "main tab"
verify_ssl = fields.verify_ssl_field(order=11)
certificate = gui.TextField(
length=8192,
lines=4,
@ -116,7 +103,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
tab=types.ui.Tab.ADVANCED,
)
ldapBase = gui.TextField(
ldap_base = gui.TextField(
length=64,
label=_('Base'),
order=30,
@ -124,7 +111,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
userClass = gui.TextField(
user_class = gui.TextField(
length=64,
label=_('User class'),
default='posixAccount',
@ -133,7 +120,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
userIdAttr = gui.TextField(
user_id_attr = gui.TextField(
length=64,
label=_('User Id Attr'),
default='uid',
@ -142,7 +129,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
userNameAttr = gui.TextField(
username_attr = gui.TextField(
length=64,
label=_('User Name Attr'),
default='uid',
@ -151,7 +138,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
groupClass = gui.TextField(
group_class = gui.TextField(
length=64,
label=_('Group class'),
default='posixGroup',
@ -160,7 +147,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
groupIdAttr = gui.TextField(
group_id_attr = gui.TextField(
length=64,
label=_('Group Id Attr'),
default='cn',
@ -169,7 +156,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
memberAttr = gui.TextField(
member_attr = gui.TextField(
length=64,
label=_('Group membership attr'),
default='memberUid',
@ -178,7 +165,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
mfaAttr = gui.TextField(
mfa_attribute = gui.TextField(
length=2048,
lines=2,
label=_('MFA attribute'),

View File

@ -99,13 +99,13 @@ def process_login(server: 'models.Server', data: dict[str, typing.Any]) -> typin
src = userService.getConnectionSource()
session_id = userService.start_session() # creates a session for every login requested
osManager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance()
maxIdle = osManager.max_idle() if osManager else None
osmanager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance()
maxIdle = osmanager.max_idle() if osmanager else None
logger.debug('Max idle: %s', maxIdle)
deadLine = (
userService.deployed_service.get_deadline() if not osManager or osManager.ignore_deadline() else None
userService.deployed_service.get_deadline() if not osmanager or osmanager.ignore_deadline() else None
)
result = {
'ip': src.ip,
@ -138,9 +138,9 @@ def process_logout(server: 'models.Server', data: dict[str, typing.Any]) -> typi
if userService.in_use: # If already logged out, do not add a second logout (windows does this i.e.)
osmanagers.OSManager.logged_out(userService, data['username'])
osManager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance()
if not osManager or osManager.is_removable_on_logout(userService):
logger.debug('Removable on logout: %s', osManager)
osmanager: typing.Optional[osmanagers.OSManager] = userService.get_osmanager_instance()
if not osmanager or osmanager.is_removable_on_logout(userService):
logger.debug('Removable on logout: %s', osmanager)
userService.remove()
return rest_result(consts.OK)

View File

@ -86,7 +86,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
states = [State.PREPARING, State.USABLE, State.REMOVING, State.REMOVABLE]
return Q(state__in=states)
def _check_if_max_user_services_reached(self, service_pool: ServicePool) -> None:
def _check_user_services_limit_reached(self, service_pool: ServicePool) -> None:
"""
Checks if max_user_services for the service has been reached, and, if so,
raises an exception that no more services of this kind can be reached
@ -110,10 +110,10 @@ class UserServiceManager(metaclass=singleton.Singleton):
"""
serviceInstance = service.get_instance()
# Early return, so no database count is needed
if serviceInstance.max_user_services == consts.UNLIMITED:
if serviceInstance.userservices_limit == consts.UNLIMITED:
return False
if self.get_existing_user_services(service) >= serviceInstance.max_user_services:
if self.get_existing_user_services(service) >= serviceInstance.userservices_limit:
return True
return False
@ -123,7 +123,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
Private method to instatiate a cache element at database with default states
"""
# Checks if max_user_services has been reached and if so, raises an exception
self._check_if_max_user_services_reached(publication.deployed_service)
self._check_user_services_limit_reached(publication.deployed_service)
now = sql_datetime()
return publication.userServices.create(
cache_level=cacheLevel,
@ -141,7 +141,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
"""
Private method to instatiate an assigned element at database with default state
"""
self._check_if_max_user_services_reached(publication.deployed_service)
self._check_user_services_limit_reached(publication.deployed_service)
now = sql_datetime()
return publication.userServices.create(
cache_level=0,
@ -161,7 +161,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
There is cases where deployed services do not have publications (do not need them), so we need this method to create
an UserService with no publications, and create them from an ServicePool
"""
self._check_if_max_user_services_reached(service_pool)
self._check_user_services_limit_reached(service_pool)
now = sql_datetime()
return service_pool.userServices.create(
cache_level=0,
@ -296,7 +296,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
# Data will be serialized on makeUnique process
UserServiceOpChecker.make_unique(cache, cacheInstance, state)
def cancel(self, user_service: UserService) -> UserService:
def cancel(self, user_service: UserService) -> None:
"""
Cancels an user service creation
@return: the Uservice canceling
@ -325,9 +325,8 @@ class UserServiceManager(metaclass=singleton.Singleton):
# opchecker will set state to "removable"
UserServiceOpChecker.make_unique(user_service, user_service_instance, state)
return user_service
def remove(self, userservice: UserService) -> UserService:
def remove(self, userservice: UserService) -> None:
"""
Removes a uService element
"""
@ -347,9 +346,7 @@ class UserServiceManager(metaclass=singleton.Singleton):
# Data will be serialized on makeUnique process
UserServiceOpChecker.make_unique(userservice, userServiceInstance, state)
return userservice
def remove_or_cancel(self, user_service: UserService):
def remove_or_cancel(self, user_service: UserService) -> None:
if user_service.is_usable() or State.from_str(user_service.state).is_removable():
return self.remove(user_service)
@ -629,9 +626,9 @@ class UserServiceManager(metaclass=singleton.Singleton):
This method is used by UserService when a request for setInUse(False) is made
This checks that the service can continue existing or not
"""
osManager = user_service.deployed_service.osmanager
osmanager = user_service.deployed_service.osmanager
# If os manager says "machine is persistent", do not try to delete "previous version" assigned machines
doPublicationCleanup = True if not osManager else not osManager.get_instance().is_persistent()
doPublicationCleanup = True if not osmanager else not osmanager.get_instance().is_persistent()
if doPublicationCleanup:
remove = False

View File

@ -117,17 +117,17 @@ class StateUpdater:
class UpdateFromPreparing(StateUpdater):
def check_os_manager_related(self) -> str:
osManager = self.user_service_instance.osmanager()
osmanager = self.user_service_instance.osmanager()
state = State.USABLE
# and make this usable if os manager says that it is usable, else it pass to configuring state
# This is an "early check" for os manager, so if we do not have os manager, or os manager
# already notifies "ready" for this, we
if osManager is not None and State.from_str(self.user_service.os_state).is_preparing():
if osmanager is not None and State.from_str(self.user_service.os_state).is_preparing():
logger.debug('Has valid osmanager for %s', self.user_service.friendly_name)
stateOs = osManager.check_state(self.user_service)
stateOs = osmanager.check_state(self.user_service)
else:
stateOs = State.FINISHED
@ -175,18 +175,18 @@ class UpdateFromPreparing(StateUpdater):
class UpdateFromRemoving(StateUpdater):
def finish(self):
osManager = self.user_service_instance.osmanager()
if osManager is not None:
osManager.release(self.user_service)
osmanager = self.user_service_instance.osmanager()
if osmanager is not None:
osmanager.release(self.user_service)
self.save(State.REMOVED)
class UpdateFromCanceling(StateUpdater):
def finish(self):
osManager = self.user_service_instance.osmanager()
if osManager is not None:
osManager.release(self.user_service)
osmanager = self.user_service_instance.osmanager()
if osmanager is not None:
osmanager.release(self.user_service)
self.save(State.CANCELED)
@ -203,33 +203,35 @@ class UserServiceOpChecker(DelayedTask):
"""
This is the delayed task responsible of executing the service tasks and the service state transitions
"""
_svrId: int
_state: str
def __init__(self, service):
def __init__(self, service: 'UserService'):
super().__init__()
self._svrId = service.id
self._state = service.state
@staticmethod
def make_unique(userService: UserService, userServiceInstance: services.UserService, state: str):
def make_unique(userservice: UserService, userservice_instance: services.UserService, state: str):
"""
This method ensures that there will be only one delayedtask related to the userService indicated
"""
DelayedTaskRunner.runner().remove(USERSERVICE_TAG + userService.uuid)
UserServiceOpChecker.state_updater(userService, userServiceInstance, state)
DelayedTaskRunner.runner().remove(USERSERVICE_TAG + userservice.uuid)
UserServiceOpChecker.state_updater(userservice, userservice_instance, state)
@staticmethod
def state_updater(userService: UserService, userServiceInstance: services.UserService, state: str):
def state_updater(userservice: UserService, userservice_instance: services.UserService, state: str):
"""
Checks the value returned from invocation to publish or checkPublishingState, updating the servicePoolPub database object
Return True if it has to continue checking, False if finished
"""
try:
# Fills up basic data
userService.unique_id = userServiceInstance.get_unique_id() # Updates uniqueId
userService.friendly_name = (
userServiceInstance.get_name()
userservice.unique_id = userservice_instance.get_unique_id() # Updates uniqueId
userservice.friendly_name = (
userservice_instance.get_name()
) # And name, both methods can modify serviceInstance, so we save it later
userService.save(update_fields=['unique_id', 'friendly_name'])
userservice.save(update_fields=['unique_id', 'friendly_name'])
updater = typing.cast(
type[StateUpdater],
@ -237,39 +239,39 @@ class UserServiceOpChecker(DelayedTask):
State.PREPARING: UpdateFromPreparing,
State.REMOVING: UpdateFromRemoving,
State.CANCELING: UpdateFromCanceling,
}.get(State.from_str(userService.state), UpdateFromOther),
}.get(State.from_str(userservice.state), UpdateFromOther),
)
logger.debug(
'Updating %s from %s with updater %s and state %s',
userService.friendly_name,
State.from_str(userService.state).literal,
userservice.friendly_name,
State.from_str(userservice.state).literal,
updater,
state,
)
updater(userService, userServiceInstance).run(state)
updater(userservice, userservice_instance).run(state)
except Exception as e:
logger.exception('Checking service state')
log.log(userService, log.LogLevel.ERROR, f'Exception: {e}', log.LogSource.INTERNAL)
userService.set_state(State.ERROR)
userService.save(update_fields=['data'])
log.log(userservice, log.LogLevel.ERROR, f'Exception: {e}', log.LogSource.INTERNAL)
userservice.set_state(State.ERROR)
userservice.save(update_fields=['data'])
@staticmethod
def check_later(userService, ci):
def check_later(userservice: 'UserService', instance: 'services.UserService'):
"""
Inserts a task in the delayedTaskRunner so we can check the state of this publication
Inserts a task in the delayedTaskRunner so we can check the state of this service later
@param dps: Database object for ServicePoolPublication
@param pi: Instance of Publication manager for the object
"""
# Do not add task if already exists one that updates this service
if DelayedTaskRunner.runner().tag_exists(USERSERVICE_TAG + userService.uuid):
if DelayedTaskRunner.runner().tag_exists(USERSERVICE_TAG + userservice.uuid):
return
DelayedTaskRunner.runner().insert(
UserServiceOpChecker(userService),
ci.suggested_delay,
USERSERVICE_TAG + userService.uuid,
UserServiceOpChecker(userservice),
instance.suggested_delay,
USERSERVICE_TAG + userservice.uuid,
)
def run(self) -> None:

View File

@ -38,7 +38,7 @@ class Serializable:
"""
This class represents the interface that all serializable objects must provide.
Every single serializable class must implement marshall & unmarshall methods. Also, the class must allow
Every single serializable class must implement marshal & unmarshal methods. Also, the class must allow
to be initialized without parameters, so we can:
- Initialize the object with default values
- Read values from seralized data
@ -109,7 +109,7 @@ class Serializable:
self.unmarshal(base64.b64decode(data))
# For remarshalling purposes
# These allows us to faster migration of old data formats to new ones
# These facilitates a faster migration of old data formats to new ones
# alowing us to remove old format support as soon as possible
def flag_for_upgrade(self, value: bool = True) -> None:
"""

View File

@ -100,10 +100,10 @@ class Publication(Environmentable, Serializable):
"""
Environmentable.__init__(self, environment)
Serializable.__init__(self)
self._osManager = kwargs.get('osManager', None)
self._osManager = kwargs.get('osmanager', None)
self._service = kwargs['service'] # Raises an exception if service is not included
self._revision = kwargs.get('revision', -1)
self._servicepool_name = kwargs.get('dsName', 'Unknown')
self._servicepool_name = kwargs.get('servicepool_name', 'Unknown')
self._uuid = kwargs.get('uuid', '')
self.initialize()
@ -114,7 +114,7 @@ class Publication(Environmentable, Serializable):
This is provided so you don't have to provide your own __init__ method,
and invoke base class __init__.
This will get invoked when all initialization stuff is done, so
you can here access service, osManager, ...
you can here access service, osmanager, ...
"""
def db_obj(self) -> 'models.ServicePoolPublication':
@ -137,7 +137,7 @@ class Publication(Environmentable, Serializable):
"""
return self._service
def os_manager(self) -> typing.Optional['osmanagers.OSManager']:
def osmanager(self) -> typing.Optional['osmanagers.OSManager']:
"""
Utility method to access os manager for this publication.

View File

@ -124,11 +124,11 @@ class Service(Module):
# : for providing user services. This attribute can be set here or
# : modified at instance level, core will access always to it using an instance object.
# : Note: you can override this value on service instantiation by providing a "maxService":
# : - If maxServices is an integer, it will be used as max_user_services
# : - If maxServices is a gui.NumericField, it will be used as max_user_services (.num() will be called)
# : - If maxServices is a callable, it will be called and the result will be used as max_user_services
# : - If maxServices is None, max_user_services will be set to consts.UNLIMITED (as default)
max_user_services: int = consts.UNLIMITED
# : - If userservices_limit is an integer, it will be used as max_user_services
# : - If userservices_limit is a gui.NumericField, it will be used as max_user_services (.num() will be called)
# : - If userservices_limit is a callable, it will be called and the result will be used as max_user_services
# : - If userservices_limit is None, max_user_services will be set to consts.UNLIMITED (as default)
userservices_limit: int = consts.UNLIMITED
# : If this item "has overrided fields", on deployed service edition, defined keys will overwrite defined ones
# : That is, this Dicionary will OVERWRITE fields ON ServicePool (normally cache related ones) dictionary from a REST api save invocation!!
@ -271,34 +271,34 @@ class Service(Module):
return True
def unmarshal(self, data: bytes) -> None:
# In fact, we will not unmarshall anything here, but setup maxDeployed
# if maxServices exists and it is a gui.NumericField
# Invoke base unmarshall, so "gui fields" gets loaded from data
# In fact, we will not unmarshal anything here, but setup maxDeployed
# if services_limit exists and it is a gui.NumericField
# Invoke base unmarshal, so "gui fields" gets loaded from data
super().unmarshal(data)
if hasattr(self, 'maxServices'):
if hasattr(self, 'services_limit'):
# Fix self "max_user_services" value after loading fields
try:
maxServices = getattr(self, 'maxServices', None)
if isinstance(maxServices, int):
self.max_user_services = maxServices
elif isinstance(maxServices, gui.NumericField):
self.max_user_services = maxServices.as_int()
services_limit = getattr(self, 'services_limit', None)
if isinstance(services_limit, int):
self.userservices_limit = services_limit
elif isinstance(services_limit, gui.NumericField):
self.userservices_limit = services_limit.as_int()
# For 0 values on max_user_services field, we will set it to UNLIMITED
if self.max_user_services == 0:
self.max_user_services = consts.UNLIMITED
elif callable(maxServices):
self.max_user_services = maxServices()
if self.userservices_limit == 0:
self.userservices_limit = consts.UNLIMITED
elif callable(services_limit):
self.userservices_limit = services_limit()
else:
self.max_user_services = consts.UNLIMITED
self.userservices_limit = consts.UNLIMITED
except Exception:
self.max_user_services = consts.UNLIMITED
self.userservices_limit = consts.UNLIMITED
# Ensure that max_user_services is not negative
if self.max_user_services < 0:
self.max_user_services = consts.UNLIMITED
if self.userservices_limit < 0:
self.userservices_limit = consts.UNLIMITED
# Keep untouched if maxServices is not present
# Keep untouched if services_limit is not present
def user_services_for_assignation(self, **kwargs) -> collections.abc.Iterable['UserService']:
"""

View File

@ -164,7 +164,7 @@ class UserService(Environmentable, Serializable):
This is provided so you don't have to provide your own __init__ method,
and invoke base class __init__.
This will get invoked when all initialization stuff is done, so
you can here access publication, service, osManager, ...
you can here access publication, service, osmanager, ...
"""
def db_obj(self) -> 'models.UserService':

View File

@ -1488,6 +1488,7 @@ class UserInterface(metaclass=UserInterfaceAbstract):
# Any unexpected type will raise an exception
# Note that currently, we will store old field name on db
# to allow "backwards" migration if needed, but will be removed on a future version
# this allows even several name changes, as long as we keep "old_field_name" as original one...
fields = [
(field.old_field_name() or field_name, field.type.name, FIELDS_ENCODERS[field.type](field))
for field_name, field in self._gui.items()

View File

@ -77,8 +77,8 @@ class AutoAttributes(Serializable):
attrs: collections.abc.MutableMapping[str, Attribute]
def __init__(self, **kwargs):
self.attrs = {} # Ensure attrs is created BEFORE calling super, that can contain _ variables
Serializable.__init__(self)
self.attrs = {}
self.declare(**kwargs)
def __getattribute__(self, name) -> typing.Any:
@ -97,7 +97,7 @@ class AutoAttributes(Serializable):
for key, typ in kwargs.items():
d[key] = Attribute(typ)
self.attrs = d
def marshal(self) -> bytes:
return b'v1' + pickle.dumps(self.attrs)

View File

@ -60,6 +60,8 @@ from cryptography import fernet
from django.conf import settings
from requests import get
from uds.core.serializable import Serializable
# pylint: disable=too-few-public-methods
class _Unassigned:
@ -77,9 +79,9 @@ logger = logging.getLogger(__name__)
# Constants
# Headers for the serialized data
HEADER_BASE: typing.Final[bytes] = b'MGB1'
HEADER_COMPRESSED: typing.Final[bytes] = b'MGZ1'
HEADER_ENCRYPTED: typing.Final[bytes] = b'MGE1'
HEADER_BASE: typing.Final[bytes] = b'MGBAS1'
HEADER_COMPRESSED: typing.Final[bytes] = b'MGZAS1'
HEADER_ENCRYPTED: typing.Final[bytes] = b'MGEAS1'
# Size of crc32 checksum
CRC_SIZE: typing.Final[int] = 4
@ -100,6 +102,19 @@ def fernet_key(crypt_key: bytes) -> str:
# Generate an URL-Safe base64 encoded 32 bytes key for Fernet
return base64.b64encode(hashlib.sha256(crypt_key).digest()).decode()
# checker for autoserializable data
def is_autoserializable_data(data: bytes) -> bool:
"""Check if data is is from an autoserializable class
Args:
data: Data to check
Returns:
True if data is autoserializable, False otherwise
"""
return data[: len(HEADER_BASE)] == HEADER_BASE
# pylint: disable=unnecessary-dunder-call
class _SerializableField(typing.Generic[T]):
@ -212,16 +227,16 @@ class BoolField(_SerializableField[bool]):
self.__set__(instance, data == b'1')
class ListField(_SerializableField[list]):
class ListField(_SerializableField[list[T]]):
"""List field
Note:
All elements in the list must be serializable.
All elements in the list must be serializable in JSON, but can be of different types.
"""
def __init__(
self,
default: typing.Union[list, collections.abc.Callable[[], list]] = lambda: [],
default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [],
):
super().__init__(list, default)
@ -323,7 +338,7 @@ class _FieldNameSetter(type):
return super().__new__(mcs, name, bases, attrs)
class AutoSerializable(metaclass=_FieldNameSetter):
class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
"""This class allows the automatic serialization of fields in a class.
Example:
@ -366,7 +381,6 @@ class AutoSerializable(metaclass=_FieldNameSetter):
"""
return bytes(a ^ b for a, b in zip(data, itertools.cycle(header)))
@typing.final
def marshal(self) -> bytes:
# Iterate over own members and extract fields
fields = {}
@ -396,8 +410,8 @@ class AutoSerializable(metaclass=_FieldNameSetter):
# Return data processed with header
return header + self.process_data(header, data)
# final method, do not override
@typing.final
# Only override this for checking if data is valid
# and, alternatively, retrieve it from a different source
def unmarshal(self, data: bytes) -> None:
# Check header
if data[: len(HEADER_BASE)] != HEADER_BASE:

View File

@ -191,7 +191,9 @@ def get_certificates_from_field(
# Timeout
def timeout_field(
default: int = 3,
order: int = 90, tab: 'types.ui.Tab|str|None' = None, old_field_name: typing.Optional[str] = None
order: int = 90,
tab: 'types.ui.Tab|str|None|bool' = None,
old_field_name: typing.Optional[str] = None,
) -> ui.gui.NumericField:
return ui.gui.NumericField(
length=3,
@ -201,50 +203,49 @@ def timeout_field(
tooltip=_('Timeout in seconds for network connections'),
required=True,
min_value=1,
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name=old_field_name,
)
# Ssl verification
# Ssl verification
def verify_ssl_field(
default: bool = True,
order: int = 92, tab: 'types.ui.Tab|str|None' = None, old_field_name: typing.Optional[str] = None
) -> ui.gui.CheckBoxField:
return ui.gui.CheckBoxField(
order: int = 92,
tab: 'types.ui.Tab|str|None|bool' = None,
old_field_name: typing.Optional[str] = None,
) -> ui.gui.CheckBoxField:
return ui.gui.CheckBoxField(
label=_('Verify SSL'),
default=default,
order=order,
tooltip=_(
'If checked, SSL verification will be enforced. If not, SSL verification will be disabled'
),
tab=tab or types.ui.Tab.ADVANCED,
tooltip=_('If checked, SSL verification will be enforced. If not, SSL verification will be disabled'),
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name=old_field_name,
)
# Basename field
def basename_field(order: int = 32, tab: 'types.ui.Tab|str|None' = None) -> ui.gui.TextField:
def basename_field(order: int = 32, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.TextField:
return ui.gui.TextField(
label=_('Base Name'),
order=order,
tooltip=_('Base name for clones from this service'),
tab=tab,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
required=True,
old_field_name='baseName',
)
# Length of name field
def lenname_field(
order: int = 33, tab: 'types.ui.Tab|str|None' = None
) -> ui.gui.NumericField:
def lenname_field(order: int = 33, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.NumericField:
return ui.gui.NumericField(
length=1,
label=_('Name Length'),
default=3,
order=order,
tooltip=_('Size of numeric part for the names derived from this service'),
tab=tab,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
required=True,
old_field_name='lenName',
)
@ -252,7 +253,7 @@ def lenname_field(
# Max preparing services field
def concurrent_creation_limit_field(
order: int = 50, tab: typing.Optional[types.ui.Tab] = None
order: int = 50, tab: 'types.ui.Tab|str|None|bool' = None
) -> ui.gui.NumericField:
# Advanced tab
return ui.gui.NumericField(
@ -264,13 +265,13 @@ def concurrent_creation_limit_field(
order=order,
tooltip=_('Maximum number of concurrently creating VMs'),
required=True,
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name='maxPreparingServices',
)
def concurrent_removal_limit_field(
order: int = 51, tab: 'types.ui.Tab|str|None' = None
order: int = 51, tab: 'types.ui.Tab|str|None|bool' = None
) -> ui.gui.NumericField:
return ui.gui.NumericField(
length=3,
@ -281,27 +282,25 @@ def concurrent_removal_limit_field(
order=order,
tooltip=_('Maximum number of concurrently removing VMs'),
required=True,
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name='maxRemovingServices',
)
def remove_duplicates_field(
order: int = 102, tab: 'types.ui.Tab|str|None' = None
) -> ui.gui.CheckBoxField:
def remove_duplicates_field(order: int = 102, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.CheckBoxField:
return ui.gui.CheckBoxField(
label=_('Remove found duplicates'),
default=True,
order=order,
tooltip=_('If active, found duplicates vApps for this service will be removed'),
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name='removeDuplicates',
)
def soft_shutdown_field(
order: int = 103,
tab: 'types.ui.Tab|str|None' = None,
tab: 'types.ui.Tab|str|None|bool' = None,
old_field_name: typing.Optional[str] = None,
) -> ui.gui.CheckBoxField:
return ui.gui.CheckBoxField(
@ -311,14 +310,14 @@ def soft_shutdown_field(
tooltip=_(
'If active, UDS will try to shutdown (soft) the machine using Nutanix ACPI. Will delay 30 seconds the power off of hanged machines.'
),
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name=old_field_name,
)
def keep_on_access_error_field(
order: int = 104,
tab: 'types.ui.Tab|str|None' = None,
tab: 'types.ui.Tab|str|None|bool' = None,
old_field_name: typing.Optional[str] = None,
) -> ui.gui.CheckBoxField:
return ui.gui.CheckBoxField(
@ -326,7 +325,7 @@ def keep_on_access_error_field(
value=False,
order=order,
tooltip=_('If active, access errors found on machine will not be considered errors.'),
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name=old_field_name,
)
@ -334,7 +333,7 @@ def keep_on_access_error_field(
def macs_range_field(
default: str,
order: int = 91,
tab: 'types.ui.Tab|str|None' = None,
tab: 'types.ui.Tab|str|None|bool' = None,
readonly: bool = False,
) -> ui.gui.TextField:
return ui.gui.TextField(
@ -347,11 +346,12 @@ def macs_range_field(
default=default
),
required=True,
tab=tab or types.ui.Tab.ADVANCED,
tab=None if tab is False else None if tab is None else types.ui.Tab.ADVANCED,
old_field_name='macsRange',
)
def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None' = None) -> ui.gui.TextField:
def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None|bool' = None) -> ui.gui.TextField:
return ui.gui.TextField(
length=2048,
lines=2,
@ -359,6 +359,6 @@ def mfa_attr_field(order: int = 20, tab: 'types.ui.Tab|str|None' = None) -> ui.g
order=order,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=tab or types.ui.Tab.MFA,
tab=None if tab is False else None if tab is None else types.ui.Tab.MFA,
old_field_name='mfaAttr',
)

View File

@ -29,7 +29,9 @@
"""
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import dataclasses
import logging
from multiprocessing import pool
import typing
import collections.abc
@ -46,6 +48,60 @@ from uds.core.jobs import Job
logger = logging.getLogger(__name__)
# The functionallyty of counters are:
# * while we have less items than initial, we need to create cached l1 items
# * Once initial is reached, we have to keep cached l1 items at cache_l1_srvs
# * We stop creating cached l1 items when max is reached
# * If we have more than initial, we can remove cached l1 items until we reach cache_l1_srvs
# * If we have more than max, we can remove cached l1 items until we have no more than max
# * l2 is independent of any other counter, and will be created until cache_l2_srvs is reached
# * l2 will be removed until cache_l2_srvs is reached
@dataclasses.dataclass(slots=True)
class ServicePoolStats:
servicepool: ServicePool
l1_cache_count: int
l2_cache_count: int
assigned_count: int
def l1_cache_overflow(self) -> bool:
"""Checks if L1 cache is overflown
Overflows if:
* l1_assigned_count > max_srvs
(this is, if we have more than max, we can remove cached l1 items until we reach max)
* l1_assigned_count > initial_srvs and l1_cache_count > cache_l1_srvs
(this is, if we have more than initial, we can remove cached l1 items until we reach cache_l1_srvs)
"""
l1_assigned_count = self.l1_cache_count + self.assigned_count
return l1_assigned_count > self.servicepool.max_srvs or (
l1_assigned_count > self.servicepool.initial_srvs
and self.l1_cache_count > self.servicepool.cache_l1_srvs
)
def l1_cache_needed(self) -> bool:
"""Checks if L1 cache is needed
Grow L1 cache if:
* l1_assigned_count < max_srvs and (l1_assigned_count < initial_srvs or l1_cache_count < cache_l1_srvs)
(this is, if we have not reached max, and we have not reached initial or cache_l1_srvs, we need to grow L1 cache)
"""
l1_assigned_count = self.l1_cache_count + self.assigned_count
return l1_assigned_count < self.servicepool.max_srvs and (
l1_assigned_count < self.servicepool.initial_srvs
or self.l1_cache_count < self.servicepool.cache_l1_srvs
)
def l2_cache_overflow(self) -> bool:
"""Checks if L2 cache is overflown"""
return self.l2_cache_count > self.servicepool.cache_l2_srvs
def l2_cache_needed(self) -> bool:
"""Checks if L2 cache is needed"""
return self.l2_cache_count < self.servicepool.cache_l2_srvs
class ServiceCacheUpdater(Job):
"""
@ -73,11 +129,11 @@ class ServiceCacheUpdater(Job):
def service_pools_needing_cache_update(
self,
) -> list[tuple[ServicePool, int, int, int]]:
) -> list[ServicePoolStats]:
# State filter for cached and inAssigned objects
# First we get all deployed services that could need cache generation
# We start filtering out the deployed services that do not need caching at all.
servicePoolsNeedingCaching: collections.abc.Iterable[ServicePool] = (
candidate_servicepools: collections.abc.Iterable[ServicePool] = (
ServicePool.objects.filter(Q(initial_srvs__gte=0) | Q(cache_l1_srvs__gte=0))
.filter(
max_srvs__gt=0,
@ -88,126 +144,121 @@ class ServiceCacheUpdater(Job):
)
# We will get the one that proportionally needs more cache
servicesPools: list[tuple[ServicePool, int, int, int]] = []
for servicePool in servicePoolsNeedingCaching:
servicePool.userServices.update() # Cleans cached queries
servicepools_numbers: list[ServicePoolStats] = []
for servicepool in candidate_servicepools:
servicepool.user_services.update() # Cleans cached queries
# If this deployedService don't have a publication active and needs it, ignore it
spServiceInstance = servicePool.service.get_instance() # type: ignore
service_instance = servicepool.service.get_instance() # type: ignore
if spServiceInstance.uses_cache is False:
if service_instance.uses_cache is False:
logger.debug(
'Skipping cache generation for service pool that does not uses cache: %s',
servicePool.name,
servicepool.name,
)
continue
if servicePool.active_publication() is None and spServiceInstance.publication_type is not None:
if servicepool.active_publication() is None and service_instance.publication_type is not None:
logger.debug(
'Skipping. %s Needs publication but do not have one',
servicePool.name,
servicepool.name,
)
continue
# If it has any running publication, do not generate cache anymore
if servicePool.publications.filter(state=State.PREPARING).count() > 0:
if servicepool.publications.filter(state=State.PREPARING).count() > 0:
logger.debug(
'Skipping cache generation for service pool with publication running: %s',
servicePool.name,
servicepool.name,
)
continue
if servicePool.is_restrained():
if servicepool.is_restrained():
logger.debug(
'StopSkippingped cache generation for restrained service pool: %s',
servicePool.name,
servicepool.name,
)
ServiceCacheUpdater._notify_restrain(servicePool)
ServiceCacheUpdater._notify_restrain(servicepool)
continue
# Get data related to actual state of cache
# Before we were removing the elements marked to be destroyed after creation, but this makes us
# to create new items over the limit stablisshed, so we will not remove them anymore
inCacheL1: int = (
servicePool.cached_users_services()
.filter(UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L1_CACHE))
l1_cache_count: int = (
servicepool.cached_users_services()
.filter(UserServiceManager().get_cache_state_filter(servicepool, services.UserService.L1_CACHE))
.count()
)
inCacheL2: int = (
l2_cache_count: int = (
(
servicePool.cached_users_services()
servicepool.cached_users_services()
.filter(
UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L2_CACHE)
UserServiceManager().get_cache_state_filter(servicepool, services.UserService.L2_CACHE)
)
.count()
)
if spServiceInstance.uses_cache_l2
if service_instance.uses_cache_l2
else 0
)
inAssigned: int = (
servicePool.assigned_user_services()
.filter(UserServiceManager().get_state_filter(servicePool.service)) # type: ignore
assigned_count: int = (
servicepool.assigned_user_services()
.filter(UserServiceManager().get_state_filter(servicepool.service)) # type: ignore
.count()
)
pool_stat = ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count)
# if we bypasses max cache, we will reduce it in first place. This is so because this will free resources on service provider
logger.debug(
"Examining %s with %s in cache L1 and %s in cache L2, %s inAssigned",
servicePool.name,
inCacheL1,
inCacheL2,
inAssigned,
servicepool.name,
l1_cache_count,
l2_cache_count,
assigned_count,
)
totalL1Assigned = inCacheL1 + inAssigned
l1_assigned_count = l1_cache_count + assigned_count
# We have more than we want
if totalL1Assigned > servicePool.max_srvs:
logger.debug('We have more services than max configured. skipping.')
servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned))
continue
# We have more in L1 cache than needed
if totalL1Assigned > servicePool.initial_srvs and inCacheL1 > servicePool.cache_l1_srvs:
logger.debug('We have more services in cache L1 than configured, appending')
servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned))
if pool_stat.l1_cache_overflow():
logger.debug('We have more services than max configured. Reducing..')
servicepools_numbers.append(
ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count)
)
continue
# If we have more in L2 cache than needed, decrease L2 cache, but int this case, we continue checking cause L2 cache removal
# has less priority than l1 creations or removals, but higher. In this case, we will simply take last l2 oversized found and reduce it
if spServiceInstance.uses_cache_l2 and inCacheL2 > servicePool.cache_l2_srvs:
logger.debug('We have more services in L2 cache than configured, appending')
servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned))
if pool_stat.l2_cache_overflow():
logger.debug('We have more services in L2 cache than configured, reducing')
servicepools_numbers.append(
ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count)
)
continue
# If this service don't allows more starting user services, continue
if not UserServiceManager().can_grow_service_pool(servicePool):
if not UserServiceManager().can_grow_service_pool(servicepool):
logger.debug(
'This pool cannot grow rithg now: %s',
servicePool,
servicepool,
)
continue
if pool_stat.l1_cache_needed():
logger.debug('Needs to grow L1 cache for %s', servicepool)
servicepools_numbers.append(
ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count)
)
continue
if pool_stat.l2_cache_needed():
logger.debug('Needs to grow L2 cache for %s', servicepool)
servicepools_numbers.append(
ServicePoolStats(servicepool, l1_cache_count, l2_cache_count, assigned_count)
)
continue
# If wee need to grow l2 cache, annotate it
# Whe check this before checking the total, because the l2 cache is independent of max services or l1 cache.
# It reflects a value that must be keeped in cache for futre fast use.
if inCacheL2 < servicePool.cache_l2_srvs:
logger.debug('Needs to grow L2 cache for %s', servicePool)
servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned))
continue
# We skip it if already at max
if totalL1Assigned == servicePool.max_srvs:
continue
if totalL1Assigned < servicePool.initial_srvs or inCacheL1 < servicePool.cache_l1_srvs:
logger.debug('Needs to grow L1 cache for %s', servicePool)
servicesPools.append((servicePool, inCacheL1, inCacheL2, inAssigned))
# We also return calculated values so we can reuse then
return servicesPools
return servicepools_numbers
def grow_l1_cache(
self,
servicePool: ServicePool,
cacheL1: int, # pylint: disable=unused-argument
cacheL2: int,
assigned: int, # pylint: disable=unused-argument
servicepool_stats: ServicePoolStats,
) -> None:
"""
This method tries to enlarge L1 cache.
@ -216,16 +267,18 @@ class ServiceCacheUpdater(Job):
and PREPARING, assigned, L1 and L2) is over max allowed service deployments,
this method will not grow the L1 cache
"""
logger.debug('Growing L1 cache creating a new service for %s', servicePool.name)
logger.debug('Growing L1 cache creating a new service for %s', servicepool_stats.servicepool.name)
# First, we try to assign from L2 cache
if cacheL2 > 0:
if servicepool_stats.l2_cache_count > 0:
valid = None
with transaction.atomic():
for n in (
servicePool.cached_users_services()
servicepool_stats.servicepool.cached_users_services()
.select_for_update()
.filter(
UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L2_CACHE)
UserServiceManager().get_cache_state_filter(
servicepool_stats.servicepool, services.UserService.L2_CACHE
)
)
.order_by('creation_date')
):
@ -246,30 +299,27 @@ class ServiceCacheUpdater(Job):
try:
# This has a velid publication, or it will not be here
UserServiceManager().create_cache_for(
typing.cast(ServicePoolPublication, servicePool.active_publication()),
typing.cast(ServicePoolPublication, servicepool_stats.servicepool.active_publication()),
services.UserService.L1_CACHE,
)
except MaxServicesReachedError:
log.log(
servicePool,
servicepool_stats.servicepool,
log.LogLevel.ERROR,
'Max number of services reached for this service',
log.LogSource.INTERNAL,
)
logger.warning(
'Max user services reached for %s: %s. Cache not created',
servicePool.name,
servicePool.max_srvs,
servicepool_stats.servicepool.name,
servicepool_stats.servicepool.max_srvs,
)
except Exception:
logger.exception('Exception')
def grow_l2_cache(
self,
servicePool: ServicePool,
cacheL1: int, # pylint: disable=unused-argument
cacheL2: int, # pylint: disable=unused-argument
assigned: int, # pylint: disable=unused-argument
servicepool_stats: ServicePoolStats,
) -> None:
"""
Tries to grow L2 cache of service.
@ -278,35 +328,36 @@ class ServiceCacheUpdater(Job):
and PREPARING, assigned, L1 and L2) is over max allowed service deployments,
this method will not grow the L1 cache
"""
logger.debug("Growing L2 cache creating a new service for %s", servicePool.name)
logger.debug("Growing L2 cache creating a new service for %s", servicepool_stats.servicepool.name)
try:
# This has a velid publication, or it will not be here
UserServiceManager().create_cache_for(
typing.cast(ServicePoolPublication, servicePool.active_publication()),
typing.cast(ServicePoolPublication, servicepool_stats.servicepool.active_publication()),
services.UserService.L2_CACHE,
)
except MaxServicesReachedError:
logger.warning(
'Max user services reached for %s: %s. Cache not created',
servicePool.name,
servicePool.max_srvs,
servicepool_stats.servicepool.name,
servicepool_stats.servicepool.max_srvs,
)
# TODO: When alerts are ready, notify this
def reduce_l1_cache(
self,
servicePool: ServicePool,
cacheL1: int, # pylint: disable=unused-argument
cacheL2: int,
assigned: int, # pylint: disable=unused-argument
):
logger.debug("Reducing L1 cache erasing a service in cache for %s", servicePool)
# We will try to destroy the newest cacheL1 element that is USABLE if the deployer can't cancel a new service creation
servicepool_stats: ServicePoolStats,
) -> None:
logger.debug("Reducing L1 cache erasing a service in cache for %s", servicepool_stats.servicepool)
# We will try to destroy the newest l1_cache_count element that is USABLE if the deployer can't cancel a new service creation
# Here, we will take into account the "remove_after" marked user services, so we don't try to remove them
cacheItems: list[UserService] = [
i
for i in servicePool.cached_users_services()
.filter(UserServiceManager().get_cache_state_filter(servicePool, services.UserService.L1_CACHE))
for i in servicepool_stats.servicepool.cached_users_services()
.filter(
UserServiceManager().get_cache_state_filter(
servicepool_stats.servicepool, services.UserService.L1_CACHE
)
)
.order_by('-creation_date')
.iterator()
if not i.destroy_after
@ -318,7 +369,7 @@ class ServiceCacheUpdater(Job):
)
return
if cacheL2 < servicePool.cache_l2_srvs:
if servicepool_stats.l2_cache_count < servicepool_stats.servicepool.cache_l2_srvs:
valid = None
for n in cacheItems:
if n.needsOsManager():
@ -338,17 +389,16 @@ class ServiceCacheUpdater(Job):
def reduce_l2_cache(
self,
service_pool: ServicePool,
cacheL1: int, # pylint: disable=unused-argument
cacheL2: int,
assigned: int, # pylint: disable=unused-argument
servicepool_stats: ServicePoolStats,
):
logger.debug("Reducing L2 cache erasing a service in cache for %s", service_pool.name)
if cacheL2 > 0:
logger.debug("Reducing L2 cache erasing a service in cache for %s", servicepool_stats.servicepool.name)
if servicepool_stats.l2_cache_count > 0:
cacheItems = (
service_pool.cached_users_services()
servicepool_stats.servicepool.cached_users_services()
.filter(
UserServiceManager().get_cache_state_filter(service_pool, services.UserService.L2_CACHE)
UserServiceManager().get_cache_state_filter(
servicepool_stats.servicepool, services.UserService.L2_CACHE
)
)
.order_by('creation_date')
)
@ -359,29 +409,18 @@ class ServiceCacheUpdater(Job):
def run(self) -> None:
logger.debug('Starting cache checking')
# We need to get
servicesThatNeedsUpdate = self.service_pools_needing_cache_update()
for servicePool, cacheL1, cacheL2, assigned in servicesThatNeedsUpdate:
for servicepool_stat in self.service_pools_needing_cache_update():
# We have cache to update??
logger.debug("Updating cache for %s", servicePool)
totalL1Assigned = cacheL1 + assigned
logger.debug("Updating cache for %s", servicepool_stat)
# We try first to reduce cache before tring to increase it.
# This means that if there is excesive number of user deployments
# for L1 or L2 cache, this will be reduced untill they have good numbers.
# This is so because service can have limited the number of services and,
# if we try to increase cache before having reduced whatever needed
# first, the service will get lock until someone removes something.
if totalL1Assigned > servicePool.max_srvs:
self.reduce_l1_cache(servicePool, cacheL1, cacheL2, assigned)
elif totalL1Assigned > servicePool.initial_srvs and cacheL1 > servicePool.cache_l1_srvs:
self.reduce_l1_cache(servicePool, cacheL1, cacheL2, assigned)
elif cacheL2 > servicePool.cache_l2_srvs: # We have excesives L2 items
self.reduce_l2_cache(servicePool, cacheL1, cacheL2, assigned)
elif totalL1Assigned < servicePool.max_srvs and (
totalL1Assigned < servicePool.initial_srvs or cacheL1 < servicePool.cache_l1_srvs
): # We need more services
self.grow_l1_cache(servicePool, cacheL1, cacheL2, assigned)
elif cacheL2 < servicePool.cache_l2_srvs: # We need more L2 items
self.grow_l2_cache(servicePool, cacheL1, cacheL2, assigned)
else:
logger.warning("We have more services than max requested for %s", servicePool.name)
# Treat l1 and l2 cache independently
# first, try to reduce cache and then grow it
if servicepool_stat.l1_cache_overflow():
self.reduce_l1_cache(servicepool_stat)
elif servicepool_stat.l1_cache_needed(): # We need more L1 items
self.grow_l1_cache(servicepool_stat)
# Treat l1 and l2 cache independently
if servicepool_stat.l2_cache_overflow():
self.reduce_l2_cache(servicepool_stat)
elif servicepool_stat.l2_cache_needed(): # We need more L2 items
self.grow_l2_cache(servicepool_stat)

View File

@ -278,8 +278,8 @@ class Command(BaseCommand):
# os managers
osManagers: dict[str, typing.Any] = {}
for osManager in models.OSManager.objects.all():
osManagers[osManager.name] = get_serialized_from_managed_object(osManager)
for osmanager in models.OSManager.objects.all():
osManagers[osmanager.name] = get_serialized_from_managed_object(osmanager)
tree[counter('OSMANAGERS')] = osManagers

View File

@ -42,7 +42,7 @@ from .provider import Provider
from .service import Service, ServiceTokenAlias
# Os managers
from .os_manager import OSManager
from .osmanager import OSManager
# Transports
from .transport import Transport

View File

@ -47,7 +47,7 @@ from uds.core.util.model import sql_datetime
from .account import Account
from .group import Group
from .image import Image
from .os_manager import OSManager
from .osmanager import OSManager
from .service import Service
from .service_pool_group import ServicePoolGroup
from .tag import TaggingMixin
@ -642,7 +642,7 @@ class ServicePool(UUIDModel, TaggingMixin): # type: ignore
"""
maxs = self.max_srvs
if maxs == 0 and self.service:
maxs = self.service.get_instance().max_user_services
maxs = self.service.get_instance().userservices_limit
if cachedValue == -1:
cachedValue = (

View File

@ -36,6 +36,7 @@ import typing
import collections.abc
from django.db import models
import public
from uds.core.managers import publication_manager
from uds.core.types.states import State
@ -132,32 +133,36 @@ class ServicePoolPublication(UUIDModel):
"""
if not self.deployed_service.service:
raise Exception('No service assigned to publication')
serviceInstance = self.deployed_service.service.get_instance()
osManager = self.deployed_service.osmanager
osManagerInstance = osManager.get_instance() if osManager else None
service_instance = self.deployed_service.service.get_instance()
osmanager = self.deployed_service.osmanager
osmanager_instance = osmanager.get_instance() if osmanager else None
# Sanity check, so it's easier to find when we have created
# a service that needs publication but do not have
if serviceInstance.publication_type is None:
if service_instance.publication_type is None:
raise Exception(
f'Class {serviceInstance.__class__.__name__} do not have defined publication_type but needs to be published!!!'
f'Class {service_instance.__class__.__name__} do not have defined publication_type but needs to be published!!!'
)
publication = serviceInstance.publication_type(
publication = service_instance.publication_type(
self.get_environment(),
service=serviceInstance,
osManager=osManagerInstance,
service=service_instance,
osmanager=osmanager_instance,
revision=self.revision,
dsName=self.deployed_service.name,
servicepool_name=self.deployed_service.name,
uuid=self.uuid,
)
# Only invokes deserialization if data has something. '' is nothing
if self.data:
publication.deserialize(self.data)
if publication.needs_upgrade():
self.update_data(publication)
publication.flag_for_upgrade(False)
return publication
def update_data(self, publication):
def update_data(self, publication_instance: 'services.Publication') -> None:
"""
Updates the data field with the serialized uds.core.services.Publication
@ -166,7 +171,7 @@ class ServicePoolPublication(UUIDModel):
:note: This method do not saves the updated record, just updates the field
"""
self.data = publication.serialize()
self.data = publication_instance.serialize()
self.save(update_fields=['data'])
def set_state(self, state: str) -> None:

View File

@ -146,7 +146,10 @@ class UserService(UUIDModel, properties.PropertiesMixin):
"""
Returns True if this service is to be removed
"""
return self.properties.get('destroy_after', False) in ('y', True) # Compare to str to keep compatibility with old values
return self.properties.get('destroy_after', False) in (
'y',
True,
) # Compare to str to keep compatibility with old values
@destroy_after.setter
def destroy_after(self, value: bool) -> None:
@ -201,19 +204,19 @@ class UserService(UUIDModel, properties.PropertiesMixin):
Raises:
"""
# We get the service instance, publication instance and osmanager instance
servicePool = self.deployed_service
if not servicePool.service:
servicepool = self.deployed_service
if not servicepool.service:
raise Exception('Service not found')
serviceInstance = servicePool.service.get_instance()
if serviceInstance.needs_manager is False or not servicePool.osmanager:
osmanagerInstance = None
service_instance = servicepool.service.get_instance()
if service_instance.needs_manager is False or not servicepool.osmanager:
osmanager_instance = None
else:
osmanagerInstance = servicePool.osmanager.get_instance()
osmanager_instance = servicepool.osmanager.get_instance()
# We get active publication
publicationInstance = None
publication_instance = None
try: # We may have deleted publication...
if self.publication is not None:
publicationInstance = self.publication.get_instance()
publication_instance = self.publication.get_instance()
except Exception:
# The publication to witch this item points to, does not exists
self.publication = None # type: ignore
@ -221,20 +224,29 @@ class UserService(UUIDModel, properties.PropertiesMixin):
'Got exception at get_instance of an userService %s (seems that publication does not exists!)',
self,
)
if serviceInstance.user_service_type is None:
if service_instance.user_service_type is None:
raise Exception(
f'Class {serviceInstance.__class__.__name__} needs user_service_type but it is not defined!!!'
f'Class {service_instance.__class__.__name__} needs user_service_type but it is not defined!!!'
)
us = serviceInstance.user_service_type(
us = service_instance.user_service_type(
self.get_environment(),
service=serviceInstance,
publication=publicationInstance,
osmanager=osmanagerInstance,
service=service_instance,
publication=publication_instance,
osmanager=osmanager_instance,
uuid=self.uuid,
)
if self.data != '' and self.data is not None:
if self.data:
try:
us.deserialize(self.data)
# if needs upgrade, we will serialize it again to ensure it is upgraded ASAP
# Eventually, it will be upgraded anyway, but could take too much time...
# This way, if we instantiate it, it will be upgraded
if us.needs_upgrade():
self.data = us.serialize()
self.save(update_fields=['data'])
us.flag_for_upgrade(False)
except Exception:
logger.exception(
'Error unserializing %s//%s : %s',
@ -244,7 +256,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
)
return us
def update_data(self, userServiceInstance: 'services.UserService'):
def update_data(self, userservice_instance: 'services.UserService'):
"""
Updates the data field with the serialized :py:class:uds.core.services.UserDeployment
@ -253,7 +265,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
:note: This method SAVES the updated record, just updates the field
"""
self.data = userServiceInstance.serialize()
self.data = userservice_instance.serialize()
self.save(update_fields=['data'])
def get_name(self) -> str:
@ -347,9 +359,9 @@ class UserService(UUIDModel, properties.PropertiesMixin):
return self.deployed_service.osmanager
def get_osmanager_instance(self) -> typing.Optional['osmanagers.OSManager']:
osManager = self.getOsManager()
if osManager:
return osManager.get_instance()
osmanager = self.getOsManager()
if osmanager:
return osmanager.get_instance()
return None
def needsOsManager(self) -> bool:
@ -544,6 +556,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
# pylint: disable=import-outside-toplevel
from uds.core.managers.user_service import UserServiceManager
# Cancel is a "forced" operation, so they are not checked against limits
UserServiceManager().cancel(self)
def remove_or_cancel(self) -> None:
@ -597,7 +610,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
@property
def actor_version(self) -> str:
return self.properties.get('actor_version') or '0.0.0'
@actor_version.setter
def actor_version(self, version: str) -> None:
self.properties['actor_version'] = version

View File

@ -144,7 +144,7 @@ class OGService(services.Service):
),
)
maxServices = gui.NumericField(
services_limit = gui.NumericField(
order=4,
label=_("Max. Allowed services"),
min_value=0,
@ -153,7 +153,8 @@ class OGService(services.Service):
readonly=False,
tooltip=_('Maximum number of allowed services (0 or less means no limit)'),
required=True,
tab=types.ui.Tab.ADVANCED
tab=types.ui.Tab.ADVANCED,
old_field_name='maxServices',
)
ov = gui.HiddenField(value=None)

View File

@ -239,7 +239,7 @@ class IPMachinesService(IPServiceBase):
self._useRandomIp = gui.as_bool(values[6].decode())
# Sets maximum services for this
self.max_user_services = len(self._ips)
self.userservices_limit = len(self._ips)
def canBeUsed(self, locked: typing.Optional[typing.Union[str, int]], now: int) -> int:
# If _maxSessionForMachine is 0, it can be used only if not locked

View File

@ -95,7 +95,7 @@ class TestProvider(services.ServiceProvider):
self.data.name = ''.join(random.SystemRandom().choices(string.ascii_letters, k=10))
self.data.integer = random.randint(0, 100)
return super().initialize(values)
@staticmethod
def test(
env: 'Environment', data: dict[str, str]

View File

@ -1,13 +1,34 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-member # ldap module gives errors to pylint
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
'''
import typing
# We use commit/rollback
@ -55,7 +76,7 @@ SERIALIZED_AUTH_DATA: typing.Final[typing.Mapping[str, bytes]] = {
class RegexSerializationTest(UDSTestCase):
def check_provider(self, version: str, instance: 'authenticator.RegexLdap'):
def check_provider(self, version: str, instance: 'authenticator.RegexLdap') -> None:
self.assertEqual(instance.host.as_str(), 'host')
self.assertEqual(instance.port.as_int(), 166)
self.assertEqual(instance.use_ssl.as_bool(), True)
@ -69,32 +90,32 @@ class RegexSerializationTest(UDSTestCase):
if version >= 'v2':
self.assertEqual(instance.username_attr.as_str(), 'usernattr')
def test_unmarshall_all_versions(self):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_AUTH_DATA) + 1):
instance = authenticator.RegexLdap(environment=Environment.get_temporary_environment())
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
def test_marshaling(self):
def test_marshaling(self) -> None:
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA))
instance = authenticator.RegexLdap(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshalled_data = instance.marshal()
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertFalse(marshalled_data.startswith(b'v'))
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = authenticator.RegexLdap(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(marshalled_data)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check that data is correct

View File

@ -0,0 +1,124 @@
# pylint: disable=no-member # ldap module gives errors to pylint
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
import typing
# We use commit/rollback
from tests.utils.test import UDSTestCase
from uds.core import ui
from uds.core.ui.user_interface import gui, UDSB, UDSK
from uds.core.environment import Environment
from uds.core.managers import crypto
from django.conf import settings
from uds.auths.SimpleLDAP import authenticator
PASSWD: typing.Final[str] = 'PASSWD'
# v1:
# self.host.value,
# self.port.value,
# self.use_ssl.value,
# self.username.value,
# self.password.value,
# self.timeout.value,
# self.ldap_base.value,
# self.user_class.value,
# self.userid_attr.value,
# self.groupname_attr.value,
# v2:
# self.username_attr.value = vals[11]
# v3:
# self.alternate_class.value = vals[12]
# v4:
# self.mfa_attribute.value = vals[13]
# v5:
# self.verify_ssl.value = vals[14]
# self.certificate.value = vals[15]
SERIALIZED_AUTH_DATA: typing.Final[typing.Mapping[str, bytes]] = {
'v1': b'v1\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\t\tusernattr',
'v2': b'v2\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr',
'v3': b'v3\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass',
'v4': b'v4\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass\tmfa',
'v5': b'v5\thost\t166\t1\tuame\t' + PASSWD.encode('utf8') + b'\t99\tdc=dom,dc=m\tuclass\tuseridAttr\tgroup_attr\tusernattr\taltClass\tmfa\tTRUE\tcert',
}
class RegexSerializationTest(UDSTestCase):
def check_provider(self, version: str, instance: 'authenticator.SimpleLDAPAuthenticator'):
self.assertEqual(instance.host.as_str(), 'host')
self.assertEqual(instance.port.as_int(), 166)
self.assertEqual(instance.use_ssl.as_bool(), True)
self.assertEqual(instance.username.as_str(), 'uame')
self.assertEqual(instance.password.as_str(), PASSWD)
self.assertEqual(instance.timeout.as_int(), 99)
self.assertEqual(instance.ldap_base.as_str(), 'dc=dom,dc=m')
self.assertEqual(instance.user_class.as_str(), 'uclass')
self.assertEqual(instance.userid_attr.as_str(), 'useridAttr')
self.assertEqual(instance.groupname_attr.as_str(), 'group_attr')
if version >= 'v2':
self.assertEqual(instance.username_attr.as_str(), 'usernattr')
def test_unmarshall_all_versions(self):
return
for v in range(1, len(SERIALIZED_AUTH_DATA) + 1):
instance = authenticator.SimpleLDAPAuthenticator(environment=Environment.get_temporary_environment())
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
def test_marshaling(self):
return
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA))
instance = authenticator.SimpleLDAPAuthenticator(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = authenticator.SimpleLDAPAuthenticator(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check that data is correct
self.check_provider(LAST_VERSION, instance)

View File

@ -56,15 +56,15 @@ class ServiceCacheUpdaterTest(UDSTestCase):
# Default values for max
TestProvider.concurrent_creation_limit = 1000
TestProvider.concurrent_removal_limit = 1000
TestServiceCache.max_user_services = 1000
TestServiceNoCache.max_user_services = 1000
TestServiceCache.userservices_limit = 1000
TestServiceNoCache.userservices_limit = 1000
ServiceCacheUpdater.setup()
userService = services_fixtures.create_cache_testing_userservices()[0]
self.servicePool = userService.deployed_service
userService.delete() # empty all
def numberOfRemovingOrCanced(self) -> int:
def removing_or_canceled_count(self) -> int:
return self.servicePool.userServices.filter(
state__in=[State.REMOVABLE, State.CANCELED]
).count()
@ -74,7 +74,7 @@ class ServiceCacheUpdaterTest(UDSTestCase):
updater = ServiceCacheUpdater(Environment.get_temporary_environment())
updater.run()
# Test user service will cancel automatically so it will not get in "removable" state (on remove start, it will tell it has been removed)
return self.servicePool.userServices.count() - self.numberOfRemovingOrCanced()
return self.servicePool.userServices.count() - self.removing_or_canceled_count()
def setCache(
self,
@ -113,10 +113,10 @@ class ServiceCacheUpdaterTest(UDSTestCase):
self.setCache(cache=10)
self.assertEqual(
self.runCacheUpdater(mustDelete), self.servicePool.initial_srvs
self.runCacheUpdater(mustDelete*2), self.servicePool.initial_srvs
)
self.assertEqual(self.numberOfRemovingOrCanced(), mustDelete)
self.assertEqual(self.removing_or_canceled_count(), mustDelete)
def test_max(self) -> None:
self.setCache(initial=100, cache=10, max=50)
@ -154,8 +154,9 @@ class ServiceCacheUpdaterTest(UDSTestCase):
TestProvider.concurrent_creation_limit = 0
self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), 1)
def test_provider_removing_limits(self) -> None:
TestProvider.concurrent_removal_limit = 10
def test_provider_no_removing_limits(self) -> None:
# Removing limits are appliend in fact when EXECUTING removal, not when marking as removable
# Note that "cancel" also overpass this limit
self.setCache(initial=0, cache=50, max=50)
# Try to "overcreate" cache elements but provider limits it to 10
@ -164,21 +165,22 @@ class ServiceCacheUpdaterTest(UDSTestCase):
# Now set cache to a lower value
self.setCache(cache=10)
# Execute updater, must remove 10 elements (concurrent_removal_limit)
self.assertEqual(self.runCacheUpdater(10), 40)
# Execute updater, must remove as long as runs elements (we use cancle here, so it will be removed)
# removes until 10, that is the limit due to cache
self.assertEqual(self.runCacheUpdater(50), 10)
def test_service_max_deployed(self) -> None:
TestServiceCache.max_user_services = 22
TestServiceCache.userservices_limit = 22
self.setCache(initial=100, cache=100, max=50)
# Try to "overcreate" cache elements but provider limits it to 10
self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), TestServiceCache.max_user_services)
self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), TestServiceCache.userservices_limit)
# Delete all userServices
self.servicePool.userServices.all().delete()
# We again allow masUserServices to be zero (meaning that no service will be created)
# This allows us to "honor" some external providers that, in some cases, will not have services available...
TestServiceCache.max_user_services = 0
TestServiceCache.userservices_limit = 0
self.assertEqual(self.runCacheUpdater(self.servicePool.cache_l1_srvs + 10), 0)