1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-12 04:58:34 +03:00

Improving environment class

This commit is contained in:
Adolfo Gómez García 2024-01-29 02:52:27 +01:00
parent b42f93839f
commit 33cdf27375
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
35 changed files with 355 additions and 304 deletions

View File

@ -106,7 +106,7 @@ class Authenticators(ModelHandler):
authType = auths.factory().lookup(type_)
if authType:
# Create a new instance of the authenticator to access to its GUI
with Environment.get_unique_environment() as env:
with Environment.temporary_environment() as env:
authInstance = authType(env, None)
field = self.add_default_fields(
authInstance.gui_description(),
@ -234,10 +234,11 @@ class Authenticators(ModelHandler):
dct = self._params.copy()
dct['_request'] = self._request
res = authType.test(Environment.get_temporary_environment(), dct)
if res[0]:
return self.success()
return res[1]
with Environment.temporary_environment() as env:
res = authType.test(env, dct)
if res[0]:
return self.success()
return res[1]
def pre_save(
self, fields: dict[str, typing.Any]

View File

@ -73,35 +73,36 @@ class MFA(ModelHandler):
raise self.invalid_item_response()
# Create a temporal instance to get the gui
mfa = mfaType(Environment.get_temporary_environment(), None)
with Environment.temporary_environment() as env:
mfa = mfaType(env, None)
localGui = self.add_default_fields(mfa.gui_description(), ['name', 'comments', 'tags'])
self.add_field(
localGui,
{
'name': 'remember_device',
'value': '0',
'min_value': '0',
'label': gettext('Device Caching'),
'tooltip': gettext('Time in hours to cache device so MFA is not required again. User based.'),
'type': types.ui.FieldType.NUMERIC,
'order': 111,
},
)
self.add_field(
localGui,
{
'name': 'validity',
'value': '5',
'min_value': '0',
'label': gettext('MFA code validity'),
'tooltip': gettext('Time in minutes to allow MFA code to be used.'),
'type': types.ui.FieldType.NUMERIC,
'order': 112,
},
)
localGui = self.add_default_fields(mfa.gui_description(), ['name', 'comments', 'tags'])
self.add_field(
localGui,
{
'name': 'remember_device',
'value': '0',
'min_value': '0',
'label': gettext('Device Caching'),
'tooltip': gettext('Time in hours to cache device so MFA is not required again. User based.'),
'type': types.ui.FieldType.NUMERIC,
'order': 111,
},
)
self.add_field(
localGui,
{
'name': 'validity',
'value': '5',
'min_value': '0',
'label': gettext('MFA code validity'),
'tooltip': gettext('Time in minutes to allow MFA code to be used.'),
'type': types.ui.FieldType.NUMERIC,
'order': 112,
},
)
return localGui
return localGui
def item_as_dict(self, item: 'Model') -> dict[str, typing.Any]:
item = ensure.is_instance(item, models.MFA)

View File

@ -84,33 +84,34 @@ class Notifiers(ModelHandler):
if not notifierType:
raise self.invalid_item_response()
notifier = notifierType(Environment.get_temporary_environment(), None)
with Environment.temporary_environment() as env:
notifier = notifierType(env, None)
localGui = self.add_default_fields(
notifier.gui_description(), ['name', 'comments', 'tags']
)
local_gui = self.add_default_fields(
notifier.gui_description(), ['name', 'comments', 'tags']
)
for field in [
{
'name': 'level',
'choices': [gui.choice_item(i[0], i[1]) for i in LogLevel.interesting()],
'label': gettext('Level'),
'tooltip': gettext('Level of notifications'),
'type': types.ui.FieldType.CHOICE,
'order': 102,
},
{
'name': 'enabled',
'label': gettext('Enabled'),
'tooltip': gettext('If checked, this notifier will be used'),
'type': types.ui.FieldType.CHECKBOX,
'order': 103,
'default': True,
}
]:
self.add_field(localGui, field)
for field in [
{
'name': 'level',
'choices': [gui.choice_item(i[0], i[1]) for i in LogLevel.interesting()],
'label': gettext('Level'),
'tooltip': gettext('Level of notifications'),
'type': types.ui.FieldType.CHOICE,
'order': 102,
},
{
'name': 'enabled',
'label': gettext('Enabled'),
'tooltip': gettext('If checked, this notifier will be used'),
'type': types.ui.FieldType.CHECKBOX,
'order': 103,
'default': True,
}
]:
self.add_field(local_gui, field)
return localGui
return local_gui
def item_as_dict(self, item: 'Model') -> dict[str, typing.Any]:
item = ensure.is_instance(item, Notifier)

View File

@ -101,12 +101,12 @@ class OsManagers(ModelHandler):
if not osmanagerType:
raise exceptions.rest.NotFound('OS Manager type not found')
with Environment.temporary_environment() as env:
osmanager = osmanagerType(env, None)
osmanager = osmanagerType(Environment.get_temporary_environment(), None)
return self.add_default_fields(
osmanager.gui_description(), # type: ignore # may raise an exception if lookup fails
['name', 'comments', 'tags'],
)
return self.add_default_fields(
osmanager.gui_description(), # type: ignore # may raise an exception if lookup fails
['name', 'comments', 'tags'],
)
except:
raise exceptions.rest.NotFound('type not found')

View File

@ -124,8 +124,9 @@ class Providers(ModelHandler):
def get_gui(self, type_: str) -> list[typing.Any]:
providerType = services.factory().lookup(type_)
if providerType:
provider = providerType(Environment.get_temporary_environment(), None)
return self.add_default_fields(provider.gui_description(), ['name', 'comments', 'tags'])
with Environment.temporary_environment() as env:
provider = providerType(env, None)
return self.add_default_fields(provider.gui_description(), ['name', 'comments', 'tags'])
raise exceptions.rest.NotFound('Type not found!')
def allservices(self) -> typing.Generator[dict, None, None]:
@ -168,18 +169,18 @@ class Providers(ModelHandler):
from uds.core.environment import Environment
logger.debug('Type: %s', type_)
spType = services.factory().lookup(type_)
provider_type = services.factory().lookup(type_)
if not spType:
if not provider_type:
raise exceptions.rest.NotFound('Type not found!')
tmpEnvironment = Environment.get_temporary_environment()
logger.debug('spType: %s', spType)
with Environment.temporary_environment() as temp_environment:
logger.debug('spType: %s', provider_type)
dct = self._params.copy()
dct['_request'] = self._request
res = spType.test(tmpEnvironment, dct)
if res[0]:
return 'ok'
dct = self._params.copy()
dct['_request'] = self._request
res = provider_type.test(temp_environment, dct)
if res[0]:
return 'ok'
return res[1]
return res[1]

View File

@ -290,36 +290,36 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
parent = ensure.is_instance(parent, models.Provider)
try:
logger.debug('getGui parameters: %s, %s', parent, forType)
parentInstance = parent.get_instance()
serviceType = parentInstance.get_service_by_type(forType)
if not serviceType:
parent_instance = parent.get_instance()
service_type = parent_instance.get_service_by_type(forType)
if not service_type:
raise self.invalid_item_response(f'Gui for {forType} not found')
with Environment.temporary_environment() as env:
service = service_type(
env, parent_instance
) # Instantiate it so it has the opportunity to alter gui description based on parent
local_gui = self.add_default_fields(
service.gui_description(), ['name', 'comments', 'tags']
)
self.add_field(
local_gui,
{
'name': 'max_services_count_type',
'choices': [
gui.choice_item('0', _('Standard')),
gui.choice_item('1', _('Conservative')),
],
'label': _('Service counting method'),
'tooltip': _(
'Kind of service counting for calculating if MAX is reached'
),
'type': types.ui.FieldType.CHOICE,
'readonly': False,
'order': 101,
},
)
service = serviceType(
Environment.get_temporary_environment(), parentInstance
) # Instantiate it so it has the opportunity to alter gui description based on parent
localGui = self.add_default_fields(
service.gui_description(), ['name', 'comments', 'tags']
)
self.add_field(
localGui,
{
'name': 'max_services_count_type',
'choices': [
gui.choice_item('0', _('Standard')),
gui.choice_item('1', _('Conservative')),
],
'label': _('Service counting method'),
'tooltip': _(
'Kind of service counting for calculating if MAX is reached'
),
'type': types.ui.FieldType.CHOICE,
'readonly': False,
'order': 101,
},
)
return localGui
return local_gui
except Exception as e:
logger.exception('getGui')

View File

@ -90,62 +90,63 @@ class Transports(ModelHandler):
if not transportType:
raise self.invalid_item_response()
transport = transportType(Environment.get_temporary_environment(), None)
with Environment.temporary_environment() as env:
transport = transportType(env, None)
field = self.add_default_fields(
transport.gui_description(), ['name', 'comments', 'tags', 'priority', 'networks']
)
field = self.add_field(
field,
{
'name': 'allowed_oss',
'value': [],
'choices': sorted(
[ui.gui.choice_item(x.name, x.name) for x in consts.os.KNOWN_OS_LIST],
key=lambda x: x['text'].lower(),
),
'label': gettext('Allowed Devices'),
'tooltip': gettext(
'If empty, any kind of device compatible with this transport will be allowed. Else, only devices compatible with selected values will be allowed'
),
'type': types.ui.FieldType.MULTICHOICE,
'tab': types.ui.Tab.ADVANCED,
'order': 102,
},
)
field = self.add_field(
field,
{
'name': 'pools',
'value': [],
'choices': [
ui.gui.choice_item(x.uuid, x.name)
for x in ServicePool.objects.filter(service__isnull=False)
.order_by('name')
.prefetch_related('service')
if transportType.protocol in x.service.get_type().allowed_protocols
],
'label': gettext('Service Pools'),
'tooltip': gettext('Currently assigned services pools'),
'type': types.ui.FieldType.MULTICHOICE,
'order': 103,
},
)
field = self.add_field(
field,
{
'name': 'label',
'length': 32,
'value': '',
'label': gettext('Label'),
'tooltip': gettext('Metapool transport label (only used on metapool transports grouping)'),
'type': types.ui.FieldType.TEXT,
'order': 201,
'tab': types.ui.Tab.ADVANCED,
},
)
field = self.add_default_fields(
transport.gui_description(), ['name', 'comments', 'tags', 'priority', 'networks']
)
field = self.add_field(
field,
{
'name': 'allowed_oss',
'value': [],
'choices': sorted(
[ui.gui.choice_item(x.name, x.name) for x in consts.os.KNOWN_OS_LIST],
key=lambda x: x['text'].lower(),
),
'label': gettext('Allowed Devices'),
'tooltip': gettext(
'If empty, any kind of device compatible with this transport will be allowed. Else, only devices compatible with selected values will be allowed'
),
'type': types.ui.FieldType.MULTICHOICE,
'tab': types.ui.Tab.ADVANCED,
'order': 102,
},
)
field = self.add_field(
field,
{
'name': 'pools',
'value': [],
'choices': [
ui.gui.choice_item(x.uuid, x.name)
for x in ServicePool.objects.filter(service__isnull=False)
.order_by('name')
.prefetch_related('service')
if transportType.protocol in x.service.get_type().allowed_protocols
],
'label': gettext('Service Pools'),
'tooltip': gettext('Currently assigned services pools'),
'type': types.ui.FieldType.MULTICHOICE,
'order': 103,
},
)
field = self.add_field(
field,
{
'name': 'label',
'length': 32,
'value': '',
'label': gettext('Label'),
'tooltip': gettext('Metapool transport label (only used on metapool transports grouping)'),
'type': types.ui.FieldType.TEXT,
'order': 201,
'tab': types.ui.Tab.ADVANCED,
},
)
return field
return field
def item_as_dict(self, item: 'Model') -> dict[str, typing.Any]:
item = ensure.is_instance(item, Transport)
@ -176,7 +177,9 @@ class Transports(ModelHandler):
fields['label'] = fields['label'].strip().replace(' ', '-')
# And ensure small_name chars are valid [ a-zA-Z0-9:-]+
if fields['label'] and not re.match(r'^[a-zA-Z0-9:-]+$', fields['label']):
raise self.invalid_request_response(gettext('Label must contain only letters, numbers, ":" and "-"'))
raise self.invalid_request_response(
gettext('Label must contain only letters, numbers, ":" and "-"')
)
def post_save(self, item: 'Model') -> None:
item = ensure.is_instance(item, Transport)

View File

@ -41,8 +41,8 @@ if typing.TYPE_CHECKING:
from uds.core.util.unique_id_generator import UniqueIDGenerator
TEMP_ENV = 'temporary'
GLOBAL_ENV = 'global'
TEST_ENV = 'testing_env'
COMMON_ENV = 'global'
class Environment:
@ -62,7 +62,7 @@ class Environment:
def __init__(
self,
uniqueKey: str,
unique_key: str,
id_generators: typing.Optional[dict[str, 'UniqueIDGenerator']] = None,
):
"""
@ -76,9 +76,9 @@ class Environment:
from uds.core.util.cache import Cache # pylint: disable=import-outside-toplevel
from uds.core.util.storage import Storage # pylint: disable=import-outside-toplevel
self._key = uniqueKey
self._cache = Cache(uniqueKey)
self._storage = Storage(uniqueKey)
self._key = unique_key
self._cache = Cache(unique_key)
self._storage = Storage(unique_key)
self._id_generators = id_generators or {}
@property
@ -126,7 +126,7 @@ class Environment:
v.release()
@staticmethod
def get_environment_for_table_record(
def environment_for_table_record(
table_name: str,
record_id: 'str|int|None' = None,
id_generator_types: typing.Optional[dict[str, typing.Any]] = None,
@ -153,7 +153,7 @@ class Environment:
return Environment(name, id_generators)
@staticmethod
def get_environment_for_type(type_) -> 'Environment':
def environment_for_type(type_) -> 'Environment':
"""
Obtains an environment associated with a type instead of a record
@param type_: Type
@ -162,7 +162,7 @@ class Environment:
return Environment('type-' + str(type_))
@staticmethod
def get_unique_environment() -> 'Environment':
def temporary_environment() -> 'Environment':
"""
Obtains an enviromnet with an unique identifier
@ -172,34 +172,33 @@ class Environment:
Note:
Use this with "with" statement to ensure that environment is cleared after use
"""
return Environment(secrets.token_hex(16))
return Environment(
'#_#' + secrets.token_hex(16) + '#^#'
) # Weird enough name to be unique, unless you try hard :)
@staticmethod
def get_temporary_environment() -> 'Environment':
def testing_environment() -> 'Environment':
"""
Provides a temporary environment needed in some calls (test provider, for example)
It will not make environment persistent
"""
env = Environment(TEMP_ENV)
env.storage.clear()
env.cache.clear()
env = Environment(TEST_ENV)
env.clean_related_data()
return env
@staticmethod
def get_common_environment() -> 'Environment':
def ommon_environment() -> 'Environment':
"""
Provides global environment
"""
return Environment(GLOBAL_ENV) # This environment is a global environment for general utility.
return Environment(COMMON_ENV) # This environment is a global environment for general utility.
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self._cache.clear()
self._storage.clear()
for _, v in self._id_generators.items():
v.release()
if self._key == TEST_ENV or (self._key.startswith('#_#') and self._key.endswith('#^#')):
self.clean_related_data()
class Environmentable:

View File

@ -50,7 +50,7 @@ class DelayedTask(Environmentable):
"""
Remember to invoke parent init in derived clases using super(myClass,self).__init__() to let this initialize its own variables
"""
super().__init__(environment or Environment.get_environment_for_type(self.__class__))
super().__init__(environment or Environment.environment_for_type(self.__class__))
def execute(self) -> None:
"""

View File

@ -137,7 +137,7 @@ class DelayedTaskRunner(metaclass=singleton.Singleton):
if taskInstance:
logger.debug('Executing delayedTask:>%s<', task)
# Re-create environment data
taskInstance.env = Environment.get_environment_for_type(taskInstance.__class__)
taskInstance.env = Environment.environment_for_type(taskInstance.__class__)
DelayedTaskThread(taskInstance).start()
def _insert(self, instance: DelayedTask, delay: int, tag: str) -> None:

View File

@ -104,7 +104,7 @@ class Authenticator(ManagedObjectModel, TaggingMixin):
"""
if self.id is None:
# Return a fake authenticator
return auths.Authenticator(environment.Environment.get_temporary_environment(), values, uuid=self.uuid)
return auths.Authenticator(environment.Environment.environment_for_table_record('fake_auth'), values, uuid=self.uuid)
auType = self.get_type()
env = self.get_environment()

View File

@ -69,7 +69,7 @@ class ManagedObjectModel(UUIDModel):
"""
Returns an environment valid for the record this object represents
"""
return Environment.get_environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore # pylint: disable=no-member
return Environment.environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore # pylint: disable=no-member
def deserialize(self, obj: Module, values: typing.Optional[collections.abc.Mapping[str, str]]):
"""

View File

@ -87,7 +87,7 @@ class Scheduler(models.Model):
"""
Returns an environment valid for the record this object represents
"""
return Environment.get_environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore # pylint: disable=no-member
return Environment.environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore # pylint: disable=no-member
def get_instance(self) -> typing.Optional[jobs.Job]:
"""

View File

@ -103,7 +103,7 @@ class Service(ManagedObjectModel, TaggingMixin): # type: ignore
"""
Returns an environment valid for the record this object represents
"""
return Environment.get_environment_for_table_record(
return Environment.environment_for_table_record(
self._meta.verbose_name, # type: ignore
self.id,
{

View File

@ -165,7 +165,7 @@ class ServicePool(UUIDModel, TaggingMixin): # type: ignore
"""
Returns an environment valid for the record this object represents
"""
return Environment.get_environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore
return Environment.environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore
def active_publication(self) -> typing.Optional['ServicePoolPublication']:
"""

View File

@ -113,7 +113,7 @@ class ServicePoolPublication(UUIDModel):
"""
Returns an environment valid for the record this object represents
"""
return Environment.get_environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore
return Environment.environment_for_table_record(self._meta.verbose_name, self.id) # type: ignore
def get_instance(self) -> 'services.Publication':
"""

View File

@ -177,7 +177,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
(see related classes uds.core.util.unique_name_generator and uds.core.util.unique_mac_generator)
"""
return Environment.get_environment_for_table_record(
return Environment.environment_for_table_record(
self._meta.verbose_name, # type: ignore # pylint: disable=no-member
self.id,
{

View File

@ -87,31 +87,34 @@ class RegexSerializationTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_AUTH_DATA) + 1):
instance = authenticator.RegexLdap(environment=Environment.get_temporary_environment())
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
with Environment.temporary_environment() as env:
instance = authenticator.RegexLdap(environment=env)
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
def test_marshaling(self) -> None:
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA))
instance = authenticator.RegexLdap(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
with Environment.temporary_environment() as env:
instance = authenticator.RegexLdap(
environment=env
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = authenticator.RegexLdap(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
with Environment.temporary_environment() as env:
instance = authenticator.RegexLdap(
environment=env
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check that data is correct
self.check_provider(LAST_VERSION, instance)
# Check that data is correct
self.check_provider(LAST_VERSION, instance)

View File

@ -93,31 +93,35 @@ class SimpleLdapSerializationTest(UDSTestCase):
def test_unmarshall_all_versions(self):
for v in range(1, len(SERIALIZED_AUTH_DATA) + 1):
instance = authenticator.SimpleLDAPAuthenticator(environment=Environment.get_temporary_environment())
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
with Environment.temporary_environment() as env:
instance = authenticator.SimpleLDAPAuthenticator(environment=env)
instance.unmarshal(SERIALIZED_AUTH_DATA['v{}'.format(v)])
self.check_provider(f'v{v}', instance)
def test_marshaling(self):
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_AUTH_DATA))
instance = authenticator.SimpleLDAPAuthenticator(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
with Environment.temporary_environment() as env:
instance = authenticator.SimpleLDAPAuthenticator(
environment=env
)
instance.unmarshal(SERIALIZED_AUTH_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.flag_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = authenticator.SimpleLDAPAuthenticator(
environment=Environment.get_temporary_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Check that data is correct
self.check_provider(LAST_VERSION, instance)
with Environment.temporary_environment() as env:
# Reunmarshall again and check that remarshalled flag is not set
instance = authenticator.SimpleLDAPAuthenticator(
environment=env
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check that data is correct
self.check_provider(LAST_VERSION, instance)

View File

@ -33,6 +33,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import collections.abc
import dataclasses
from enum import unique
import typing
from uds.core import environment
@ -40,53 +41,89 @@ from uds.core.util.cache import Cache
from uds.core.util.storage import Storage
from uds.core.util.unique_id_generator import UniqueIDGenerator
from ..utils.test import UDSTestCase
from ..utils.test import UDSTransactionTestCase
class TestEnvironment(UDSTestCase):
def test_global_environment(self) -> None:
env = environment.Environment.get_common_environment()
class TestEnvironment(UDSTransactionTestCase):
def _check_environment(
self,
env: environment.Environment,
expected_key: 'str|None',
is_persistent: bool,
recreate_fnc: typing.Optional[typing.Callable[[], environment.Environment]] = None,
) -> None:
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self.assertEqual(env.key, environment.GLOBAL_ENV)
self.assertIsInstance(env._id_generators, dict)
if expected_key is not None:
self.assertEqual(env.key, expected_key)
env.storage.put('test', 'test')
self.assertEqual(env.storage.get('test'), 'test')
env.cache.put('test', 'test')
self.assertEqual(env.cache.get('test'), 'test')
# Recreate environment
env = environment.Environment(env.key) if not recreate_fnc else recreate_fnc()
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, dict)
if expected_key is not None:
self.assertEqual(env.key, expected_key)
if is_persistent:
self.assertEqual(env.storage.get('test'), 'test')
self.assertEqual(env.cache.get('test'), 'test')
else:
self.assertEqual(env.storage.get('test'), None)
self.assertEqual(env.cache.get('test'), None)
def test_global_environment(self) -> None:
env = environment.Environment.ommon_environment()
self._check_environment(env, environment.COMMON_ENV, True)
def test_temporary_environment(self) -> None:
env = environment.Environment.get_temporary_environment()
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self.assertEqual(env.key, environment.TEMP_ENV)
env = environment.Environment.testing_environment()
self._check_environment(env, environment.TEST_ENV, False, recreate_fnc=environment.Environment.testing_environment)
def test_table_record_environment(self) -> None:
env = environment.Environment.get_environment_for_table_record('test_table')
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self.assertEqual(env.key, 't-test_table')
env = environment.Environment.environment_for_table_record('test_table')
self._check_environment(env, 't-test_table-', True)
env = environment.Environment.get_environment_for_table_record('test_table', 123)
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self.assertEqual(env.key, 't-test_table-123')
def test_table_record_environment_with_id(self) -> None:
env = environment.Environment.environment_for_table_record('test_table', 123)
self._check_environment(env, 't-test_table-123', True)
def test_environment_for_type(self) -> None:
env = environment.Environment.get_environment_for_type('test_type')
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self.assertEqual(env.key, 'type-test_type')
env = environment.Environment.environment_for_type('test_type')
self._check_environment(env, 'type-test_type', True)
def test_unique_environment(self) -> None:
env = environment.Environment.get_unique_environment()
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, UniqueIDGenerator)
self
def test_exclusive_temporary_environment(self) -> None:
unique_key: str = ''
with environment.Environment.temporary_environment() as env:
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, dict)
unique_key = env.key # store for later test
env.storage.put('test', 'test')
self.assertEqual(env.storage.get('test'), 'test')
env.cache.put('test', 'test')
self.assertEqual(env.cache.get('test'), 'test')
# Environment is cleared after exit, ensure it
env = environment.Environment(unique_key)
with env as env:
self.assertIsInstance(env, environment.Environment)
self.assertIsInstance(env.cache, Cache)
self.assertIsInstance(env.storage, Storage)
self.assertIsInstance(env._id_generators, dict)
self.assertEqual(env.key, unique_key)
self.assertEqual(env.storage.get('test'), None)
self.assertEqual(env.cache.get('test'), None)

View File

@ -57,8 +57,9 @@ class AssignedAndUnusedTest(UDSTestCase):
for us in self.userServices: # Update state date to now
us.set_state(State.USABLE)
# Set now, should not be removed
count = models.UserService.objects.filter(state=State.REMOVABLE).count()
cleaner = AssignedAndUnused(Environment.get_temporary_environment())
cleaner = AssignedAndUnused(Environment.testing_environment())
# since_state = util.sql_datetime() - datetime.timedelta(seconds=cleaner.frecuency)
cleaner.run()
self.assertEqual(models.UserService.objects.filter(state=State.REMOVABLE).count(), count)

View File

@ -84,7 +84,7 @@ class HangedCleanerTest(UDSTestCase):
def test_hanged_cleaner(self):
# At start, there is no "removable" user services
cleaner = HangedCleaner(Environment.get_temporary_environment())
cleaner = HangedCleaner(Environment.testing_environment())
cleaner.run()
one_fith = TEST_SERVICES // 5
self.assertEqual(

View File

@ -71,7 +71,7 @@ class ServiceCacheUpdaterTest(UDSTestCase):
def runCacheUpdater(self, times: int) -> int:
for _ in range(times):
updater = ServiceCacheUpdater(Environment.get_temporary_environment())
updater = ServiceCacheUpdater(Environment.testing_environment())
updater.run()
# Test user service will cancel automatically so it will not get in "removable" state (on remove start, it will tell it has been removed)
return self.servicePool.userServices.count() - self.removing_or_canceled_count()

View File

@ -93,7 +93,7 @@ class StatsAcummulatorTest(UDSTestCase):
total_base_stats = DAYS * 24 * NUMBER_PER_HOUR * NUMBER_OF_POOLS * len(COUNTERS_TYPES) # All stats
self.assertEqual(base_stats.count(), total_base_stats)
optimizer = stats_collector.StatsAccumulator(Environment.get_temporary_environment())
optimizer = stats_collector.StatsAccumulator(Environment.testing_environment())
optimizer.run()
# Shoul have DAYS // 2 + 1 stats
hour_stats = models.StatsCountersAccum.objects.filter(

View File

@ -80,7 +80,7 @@ class StuckCleanerTest(UDSTestCase):
def test_worker_outdated(self):
count = UserService.objects.count()
cleaner = StuckCleaner(Environment.get_temporary_environment())
cleaner = StuckCleaner(Environment.testing_environment())
cleaner.run()
self.assertEqual(
UserService.objects.count(), count // 4
@ -94,7 +94,7 @@ class StuckCleanerTest(UDSTestCase):
)
us.save(update_fields=['state_date'])
count = UserService.objects.count()
cleaner = StuckCleaner(Environment.get_temporary_environment())
cleaner = StuckCleaner(Environment.testing_environment())
cleaner.run()
self.assertEqual(
UserService.objects.count(), count

View File

@ -163,7 +163,7 @@ def create_test_transport() -> models.Transport:
from uds.transports.Test import TestTransport
values = TestTransport(
environment.Environment.get_temporary_environment(), None
environment.Environment.testing_environment(), None
).get_fields_as_dict()
transport: 'models.Transport' = models.Transport.objects.create(
name='Transport %d' % (glob['transport_id']),

View File

@ -54,7 +54,7 @@ AUTOMATIC_ID_MAPPING: typing.Final[bool] = True
class LinuxAdOsManagerSerialTest(UDSTestCase):
def test_marshaling(self) -> None:
instance = osmanager.LinuxOsADManager(environment=Environment.get_temporary_environment())
instance = osmanager.LinuxOsADManager(environment=Environment.testing_environment())
instance.domain.value = DOMAIN
instance.account.value = ACCOUNT
instance.password.value = PASSWORD
@ -74,7 +74,7 @@ class LinuxAdOsManagerSerialTest(UDSTestCase):
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.LinuxOsADManager(environment=Environment.get_temporary_environment())
instance = osmanager.LinuxOsADManager(environment=Environment.testing_environment())
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -76,7 +76,7 @@ class LinuxOsManagerTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_OSMANAGER_DATA) + 1):
instance = osmanager.LinuxOsManager(environment=Environment.get_temporary_environment())
instance = osmanager.LinuxOsManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA['v{}'.format(v)])
self.check(f'v{v}', instance)
@ -84,7 +84,7 @@ class LinuxOsManagerTest(UDSTestCase):
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_OSMANAGER_DATA))
instance = osmanager.LinuxOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(SERIALIZED_OSMANAGER_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
@ -97,7 +97,7 @@ class LinuxOsManagerTest(UDSTestCase):
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.LinuxOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -69,7 +69,7 @@ class LinuxOsManagerSerialTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_OSMANAGER_DATA) + 1):
instance = osmanager.LinuxRandomPassManager(environment=Environment.get_temporary_environment())
instance = osmanager.LinuxRandomPassManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA['v{}'.format(v)])
self.check(f'v{v}', instance)
@ -77,7 +77,7 @@ class LinuxOsManagerSerialTest(UDSTestCase):
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_OSMANAGER_DATA))
instance = osmanager.LinuxRandomPassManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(SERIALIZED_OSMANAGER_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
@ -90,7 +90,7 @@ class LinuxOsManagerSerialTest(UDSTestCase):
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.LinuxRandomPassManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -102,7 +102,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_OSMANAGER_DATA) + 1):
instance = osmanager.WinDomainOsManager(environment=Environment.get_temporary_environment())
instance = osmanager.WinDomainOsManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA['v{}'.format(v)])
self.check(f'v{v}', instance)
@ -110,7 +110,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_OSMANAGER_DATA))
instance = osmanager.WinDomainOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(SERIALIZED_OSMANAGER_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
@ -123,7 +123,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.WinDomainOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -76,7 +76,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_OSMANAGER_DATA) + 1):
instance = osmanager.WindowsOsManager(environment=Environment.get_temporary_environment())
instance = osmanager.WindowsOsManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA['v{}'.format(v)])
self.check(f'v{v}', instance)
@ -84,7 +84,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_OSMANAGER_DATA))
instance = osmanager.WindowsOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(SERIALIZED_OSMANAGER_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
@ -97,7 +97,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.WindowsOsManager(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -71,14 +71,14 @@ class WindowsOsManagerSerialTest(UDSTestCase):
def test_unmarshall_all_versions(self) -> None:
for v in range(1, len(SERIALIZED_OSMANAGER_DATA) + 1):
instance = osmanager.WinRandomPassManager(environment=Environment.get_temporary_environment())
instance = osmanager.WinRandomPassManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA['v{}'.format(v)])
self.check(f'v{v}', instance)
def test_marshaling(self) -> None:
# Unmarshall last version, remarshall and check that is marshalled using new marshalling format
LAST_VERSION = 'v{}'.format(len(SERIALIZED_OSMANAGER_DATA))
instance = osmanager.WinRandomPassManager(environment=Environment.get_temporary_environment())
instance = osmanager.WinRandomPassManager(environment=Environment.testing_environment())
instance.unmarshal(SERIALIZED_OSMANAGER_DATA[LAST_VERSION])
marshaled_data = instance.marshal()
@ -89,7 +89,7 @@ class WindowsOsManagerSerialTest(UDSTestCase):
# Ensure fields has been marshalled using new format
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = osmanager.WinRandomPassManager(environment=Environment.get_temporary_environment())
instance = osmanager.WinRandomPassManager(environment=Environment.testing_environment())
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())

View File

@ -88,7 +88,7 @@ class PhysicalMachinesMultiSerializationTest(UDSTestCase):
environment: Environment
def setUp(self) -> None:
self.environment = Environment.get_environment_for_type('test')
self.environment = Environment.environment_for_type('test')
self.environment.storage.save_to_db('ips', pickle.dumps(STORED_IPS))
def check(self, version: str, instance: 'service_multi.IPMachinesService') -> None:
@ -113,7 +113,7 @@ class PhysicalMachinesMultiSerializationTest(UDSTestCase):
for v in range(1, len(SERIALIZED_DATA) + 1):
version = f'v{v}'
uninitialized_provider = provider.PhysicalMachinesProvider(
environment=Environment.get_temporary_environment()
environment=Environment.testing_environment()
)
instance = service_multi.IPMachinesService(

View File

@ -75,7 +75,7 @@ class TestProxmoxProviderSerialization(UDSTestCase):
return super().tearDown()
def test_provider_serialization(self) -> None:
provider = ProxmoxProvider(environment=Environment.get_temporary_environment())
provider = ProxmoxProvider(environment=Environment.testing_environment())
provider.deserialize(PROVIDER_SERIALIZE_DATA)
# Ensure values are ok

View File

@ -61,7 +61,7 @@ class XendDeploymentSerializationTest(UDSTestCase):
for v in range(1, len(SERIALIZED_LINKED_DEPLOYMENT_DATA) + 1):
version = f'v{v}'
instance = deployment.XenLinkedDeployment(
environment=Environment.get_temporary_environment(), service=None
environment=Environment.testing_environment(), service=None
)
instance.unmarshal(SERIALIZED_LINKED_DEPLOYMENT_DATA[version])
@ -73,7 +73,7 @@ class XendDeploymentSerializationTest(UDSTestCase):
VERSION = f'v{len(SERIALIZED_LINKED_DEPLOYMENT_DATA)}'
instance = deployment.XenLinkedDeployment(
environment=Environment.get_temporary_environment(), service=None
environment=Environment.testing_environment(), service=None
)
instance.unmarshal(SERIALIZED_LINKED_DEPLOYMENT_DATA[VERSION])
marshaled_data = instance.marshal()
@ -86,7 +86,7 @@ class XendDeploymentSerializationTest(UDSTestCase):
self.assertFalse(marshaled_data.startswith(b'v'))
# Reunmarshall again and check that remarshalled flag is not set
instance = deployment.XenLinkedDeployment(
environment=Environment.get_temporary_environment(), service=None
environment=Environment.testing_environment(), service=None
)
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
@ -97,7 +97,7 @@ class XendDeploymentSerializationTest(UDSTestCase):
def test_marshaling_queue(self) -> None:
def _create_instance(unmarshal_data: 'bytes|None' = None) -> deployment.XenLinkedDeployment:
instance = deployment.XenLinkedDeployment(
environment=Environment.get_temporary_environment(), service=None
environment=Environment.testing_environment(), service=None
)
if unmarshal_data:
instance.unmarshal(unmarshal_data)