From b0cf8c5ddfd609fd617eea0209ef41692fcae87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Thu, 11 Jul 2024 05:55:32 +0200 Subject: [PATCH] Refactoring code and improved possible use of a "is_dirty". This method is to improve in the future the storage of data (Skipping it). Curently, no one uses is_dirty (base class always returns True), but there it is in case longs serializations requires some optimization. --- server/src/tests/core/ui/test_gui.py | 12 ++- .../tests/core/util/test_auto_serializable.py | 38 ++++++- server/src/tests/services/proxmox/fixtures.py | 6 +- server/src/uds/core/consts/__init__.py | 2 +- server/src/uds/core/serializable.py | 30 +++--- server/src/uds/core/services/user_service.py | 2 +- server/src/uds/core/ui/user_interface.py | 100 ++++++++---------- server/src/uds/core/util/autoserializable.py | 24 +++-- .../uds/models/service_pool_publication.py | 4 + server/src/uds/models/user_service.py | 5 +- server/src/uds/services/Test/deployment.py | 72 ++++++------- server/src/uds/services/Test/publication.py | 65 ++++++------ 12 files changed, 193 insertions(+), 167 deletions(-) diff --git a/server/src/tests/core/ui/test_gui.py b/server/src/tests/core/ui/test_gui.py index 80da87f58..4fe4e00a0 100644 --- a/server/src/tests/core/ui/test_gui.py +++ b/server/src/tests/core/ui/test_gui.py @@ -31,17 +31,19 @@ """ Author: Adolfo Gómez, dkmaster at dkmon dot com """ +from django.conf import settings + +from uds.core.util import ensure +from uds.core.ui.user_interface import gui, UDSK +from uds.core import consts + # We use commit/rollback from ...utils.test import UDSTestCase -from uds.core.ui.user_interface import gui, UDSB, UDSK - -from django.conf import settings -from uds.core.util import ensure class GuiTest(UDSTestCase): def test_globals(self) -> None: self.assertEqual(UDSK, settings.SECRET_KEY[8:24].encode()) - self.assertEqual(UDSB, b'udsprotect') + self.assertEqual(consts.ui.UDSB, b'udsprotect') def test_convert_to_choices(self) -> None: # Several cases diff --git a/server/src/tests/core/util/test_auto_serializable.py b/server/src/tests/core/util/test_auto_serializable.py index 54efdc3f0..afda92e56 100644 --- a/server/src/tests/core/util/test_auto_serializable.py +++ b/server/src/tests/core/util/test_auto_serializable.py @@ -79,7 +79,9 @@ class AutoSerializableClass(autoserializable.AutoSerializable): ) dict_field = autoserializable.DictField[str, int](default=lambda: {'a': 1, 'b': 2, 'c': 3}) # Note that due to the dict being serialized as json, the keys are always strings - dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))) + dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest]( + cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)) + ) obj_dc_field = autoserializable.ObjectField[SerializableDataclass]( SerializableDataclass, default=lambda: SerializableDataclass(1, '2', 3.0) ) @@ -99,7 +101,9 @@ class AutoSerializableCompressedClass(autoserializable.AutoSerializableCompresse list_field = autoserializable.ListField[int]() list_field_with_cast = autoserializable.ListField[EnumTest](cast=EnumTest.from_int) dict_field = autoserializable.DictField[str, int]() - dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))) + dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest]( + cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)) + ) obj_dc_field = autoserializable.ObjectField[SerializableDataclass](SerializableDataclass) obj_nt_field = autoserializable.ObjectField[SerializableNamedTuple](SerializableNamedTuple) @@ -115,7 +119,9 @@ class AutoSerializableEncryptedClass(autoserializable.AutoSerializableEncrypted) list_field = autoserializable.ListField[int]() list_field_with_cast = autoserializable.ListField[EnumTest](cast=EnumTest.from_int) dict_field = autoserializable.DictField[str, int]() - dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))) + dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest]( + cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)) + ) obj_dc_field = autoserializable.ObjectField[SerializableDataclass](SerializableDataclass) obj_nt_field = autoserializable.ObjectField[SerializableNamedTuple](SerializableNamedTuple) @@ -188,7 +194,9 @@ class AutoSerializable(UDSTestCase): for vv in i.list_field_with_cast: self.assertIsInstance(vv, EnumTest) self.assertEqual(i.dict_field, {'a': 1, 'b': 2, 'c': 3}) - self.assertEqual(i.dict_field_with_cast, {EnumTest.VALUE1: EnumTest.VALUE2, EnumTest.VALUE2: EnumTest.VALUE3}) + self.assertEqual( + i.dict_field_with_cast, {EnumTest.VALUE1: EnumTest.VALUE2, EnumTest.VALUE2: EnumTest.VALUE3} + ) for kk, vv in i.dict_field_with_cast.items(): self.assertIsInstance(kk, EnumTest) self.assertIsInstance(vv, EnumTest) @@ -332,3 +340,25 @@ class AutoSerializable(UDSTestCase): self.assertEqual(instance2.dict_field, {'a': 1, 'b': 2, 'c': 3}) # default value self.assertEqual(instance2.obj_dc_field, SerializableDataclass(1, '2', 3.0)) # default value self.assertEqual(instance2.obj_nt_field, SerializableNamedTuple(2, '3', 4.0)) # deserialized value + + def test_autoserializable_dirty(self) -> None: + instance = AutoSerializableClass() + self.assertFalse(instance.is_dirty()) + + instance.int_field = 1 + self.assertTrue(instance.is_dirty()) + + instance.marshal() # should reset dirty flag + self.assertFalse(instance.is_dirty()) + + instance.int_field = 1 + self.assertTrue(instance.is_dirty()) + + instance2 = AutoSerializableClass() + self.assertFalse(instance2.is_dirty()) + + instance2.int_field = 22 + self.assertTrue(instance2.is_dirty()) + + instance2.unmarshal(instance.marshal()) + self.assertFalse(instance2.is_dirty()) diff --git a/server/src/tests/services/proxmox/fixtures.py b/server/src/tests/services/proxmox/fixtures.py index 9c69b31c0..dc60376fa 100644 --- a/server/src/tests/services/proxmox/fixtures.py +++ b/server/src/tests/services/proxmox/fixtures.py @@ -34,6 +34,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com import contextlib import copy import functools +import random import typing import datetime @@ -615,13 +616,16 @@ def create_publication( Create a publication """ uuid_ = str(uuid.uuid4()) - return publication.ProxmoxPublication( + pub = publication.ProxmoxPublication( environment=environment.Environment.private_environment(uuid_), service=service or create_service_linked(**kwargs), revision=1, servicepool_name='servicepool_name', uuid=uuid_, ) + pub._vmid = str(random.choice(VMINFO_LIST).id) + return pub + def create_userservice_fixed( diff --git a/server/src/uds/core/consts/__init__.py b/server/src/uds/core/consts/__init__.py index af41f6cb6..367e73392 100644 --- a/server/src/uds/core/consts/__init__.py +++ b/server/src/uds/core/consts/__init__.py @@ -35,7 +35,7 @@ import time import typing from datetime import datetime -from . import actor, auth, cache, calendar, images, net, os, system, ticket, rest, services, transports +from . import actor, auth, cache, calendar, images, net, os, system, ticket, rest, services, transports, ui # Date related constants NEVER: typing.Final[datetime] = datetime(1972, 7, 1) diff --git a/server/src/uds/core/serializable.py b/server/src/uds/core/serializable.py index b3f9ccf32..b883fc1be 100644 --- a/server/src/uds/core/serializable.py +++ b/server/src/uds/core/serializable.py @@ -31,10 +31,10 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com """ import base64 -import pickle # nosec: Safe pickle usage +import abc -class Serializable: +class Serializable(abc.ABC): """ This class represents the interface that all serializable objects must provide. @@ -54,6 +54,7 @@ class Serializable: def __init__(self) -> None: self._needs_upgrade = False + @abc.abstractmethod def marshal(self) -> bytes: """ This is the method that must be overriden in order to serialize an object. @@ -62,17 +63,12 @@ class Serializable: only suitable methods to "codify" serialized values :note: This method can be overriden. - :note: if you provide a "data" member variable, and it has __dict__, then it will be used - to marshal that data variable """ # Default implementation will look for a member variable called "data" # This is an struct, and will be pickled by default + ... - if hasattr(self, 'data') and hasattr(getattr(self, 'data'), '__dict__'): - return pickle.dumps(getattr(self, 'data'), protocol=pickle.HIGHEST_PROTOCOL) - - raise NotImplementedError('You must override the marshal method or provide a data member') - + @abc.abstractmethod def unmarshal(self, data: bytes) -> None: """ This is the method that must be overriden in order to deserialize an object. @@ -87,14 +83,8 @@ class Serializable: data : String readed from persistent storage to deseralilize :note: This method can be overriden. - :note: if you provide a "data" member variable, and it has __dict__, then it will be used - to unmarshal that data variable """ - if hasattr(self, 'data') and hasattr(getattr(self, 'data'), '__dict__'): - setattr(self, 'data', pickle.loads(data)) # nosec: Safe pickle load - return - - raise NotImplementedError('You must override the unmarshal method or provide a data member') + ... def serialize(self) -> str: """ @@ -116,7 +106,7 @@ class Serializable: def mark_for_upgrade(self, value: bool = True) -> None: """ Flags this object for remarshalling - + Args: value: True if this object needs remarshalling, False if not @@ -132,3 +122,9 @@ class Serializable: Returns true if this object needs remarshalling """ return self._needs_upgrade + + def is_dirty(self) -> bool: + """ + Returns true if this object needs remarshalling + """ + return True diff --git a/server/src/uds/core/services/user_service.py b/server/src/uds/core/services/user_service.py index cefdbda18..4cc10bdb9 100644 --- a/server/src/uds/core/services/user_service.py +++ b/server/src/uds/core/services/user_service.py @@ -630,4 +630,4 @@ class UserService(Environmentable, Serializable, abc.ABC): """ Mainly used for debugging purposses """ - return f'{self.__class__.__name__}({self.get_unique_id()})' + return f'{self.__class__.__name__}' diff --git a/server/src/uds/core/ui/user_interface.py b/server/src/uds/core/ui/user_interface.py index c907a09b0..68f5b6d7e 100644 --- a/server/src/uds/core/ui/user_interface.py +++ b/server/src/uds/core/ui/user_interface.py @@ -53,11 +53,13 @@ from uds.core.util import serializer, validators, ensure logger = logging.getLogger(__name__) -# Old encryption key -UDSB: typing.Final[bytes] = b'udsprotect' - -SERIALIZATION_HEADER: typing.Final[bytes] = b'GUIZ' -SERIALIZATION_VERSION: typing.Final[bytes] = b'\001' +# To simplify choice parameters declaration of fields +_ChoicesParamType: typing.TypeAlias = typing.Union[ + collections.abc.Callable[[], list['types.ui.ChoiceItem']], + collections.abc.Iterable[str | types.ui.ChoiceItem], + dict[str, str], + None, +] class gui: @@ -138,12 +140,7 @@ class gui: # Helpers @staticmethod def as_choices( - vals: typing.Union[ - collections.abc.Callable[[], list['types.ui.ChoiceItem']], - collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]], - dict[str, str], - None, - ] + vals: _ChoicesParamType, ) -> typing.Union[collections.abc.Callable[[], list['types.ui.ChoiceItem']], list['types.ui.ChoiceItem']]: """ Helper to convert from array of strings (or dictionaries) to the same dict used in choice, @@ -177,8 +174,11 @@ class gui: @staticmethod def sorted_choices( - choices: collections.abc.Iterable[types.ui.ChoiceItem], *, by_id: bool = False, reverse: bool = False, - key: typing.Optional[collections.abc.Callable[[types.ui.ChoiceItem], typing.Any]] = None + choices: collections.abc.Iterable[types.ui.ChoiceItem], + *, + by_id: bool = False, + reverse: bool = False, + key: typing.Optional[collections.abc.Callable[[types.ui.ChoiceItem], typing.Any]] = None, ) -> list[types.ui.ChoiceItem]: if by_id: key = lambda item: item['id'] @@ -279,7 +279,7 @@ class gui: so if you use both, the used one will be "value". This is valid for all form fields. (Anyway, default is part of the "value" property, so if you use "value", you will get the default value if not set) - + Note: Currently, old field name is only intended for 4.0 migration, so it has only one value. This means that only one rename can be donoe currently. If needed, we can add a list of old names @@ -613,12 +613,7 @@ class gui: tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None, default: typing.Union[collections.abc.Callable[[], str], str] = '', value: typing.Optional[str] = None, - choices: typing.Union[ - collections.abc.Callable[[], list['types.ui.ChoiceItem']], - collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]], - dict[str, str], - None, - ] = None, + choices: _ChoicesParamType = None, old_field_name: types.ui.OldFieldNameType = None, ) -> None: super().__init__( @@ -634,7 +629,9 @@ class gui: old_field_name=old_field_name, ) # Update parent type - self.field_type = types.ui.FieldType.TEXT_AUTOCOMPLETE # pyright: ignore[reportIncompatibleMethodOverride] + self.field_type = ( + types.ui.FieldType.TEXT_AUTOCOMPLETE + ) # pyright: ignore[reportIncompatibleMethodOverride] self._fields_info.choices = gui.as_choices(choices or []) def set_choices(self, values: collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]]) -> None: @@ -784,7 +781,7 @@ class gui: return self.as_date() @value.setter - def value(self, value: datetime.date|str) -> None: + def value(self, value: datetime.date | str) -> None: self._set_value(value) def gui_description(self) -> dict[str, typing.Any]: @@ -1093,12 +1090,7 @@ class gui: order: int = 0, tooltip: str = '', required: typing.Optional[bool] = None, - choices: typing.Union[ - collections.abc.Callable[[], list['types.ui.ChoiceItem']], - collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]], - dict[str, str], - None, - ] = None, + choices: _ChoicesParamType = None, fills: typing.Optional[types.ui.Filler] = None, tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None, default: typing.Union[collections.abc.Callable[[], str], str, None] = None, @@ -1118,7 +1110,7 @@ class gui: type=types.ui.FieldType.CHOICE, ) - self._fields_info.choices = gui.as_choices(choices or []) + self._fields_info.choices = gui.as_choices(choices) # if has fillers, set them if fills: if 'function' not in fills or 'callback_name' not in fills: @@ -1161,12 +1153,7 @@ class gui: order: int = 0, tooltip: str = '', required: typing.Optional[bool] = None, - choices: typing.Union[ - collections.abc.Callable[[], list['types.ui.ChoiceItem']], - collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]], - dict[str, str], - None, - ] = None, + choices: _ChoicesParamType = None, tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None, default: typing.Union[collections.abc.Callable[[], str], str, None] = None, value: typing.Optional[str] = None, @@ -1252,12 +1239,7 @@ class gui: order: int = 0, tooltip: str = '', required: typing.Optional[bool] = None, - choices: typing.Union[ - collections.abc.Callable[[], list['types.ui.ChoiceItem']], - collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]], - dict[str, str], - None, - ] = None, + choices: _ChoicesParamType = None, tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None, default: typing.Union[ collections.abc.Callable[[], str], collections.abc.Callable[[], list[str]], list[str], str, None @@ -1506,7 +1488,7 @@ class UserInterface(metaclass=UserInterfaceType): else: logger.warning('Field %s.%s not found in values data, ', self.__class__.__name__, fld_name) if getattr(settings, 'DEBUG', False): - for caller in itertools.islice(inspect.stack(), 1, 8): + for caller in itertools.islice(inspect.stack(), 1, 8): logger.warning(' %s:%s:%s', caller.filename, caller.lineno, caller.function) def init_gui(self) -> None: @@ -1584,7 +1566,7 @@ class UserInterface(metaclass=UserInterfaceType): if FIELDS_ENCODERS[field.field_type](field) is not None ] - return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(fields) + return consts.ui.SERIALIZATION_HEADER + consts.ui.SERIALIZATION_VERSION + serializer.serialize(fields) def deserialize_fields( self, @@ -1619,16 +1601,19 @@ class UserInterface(metaclass=UserInterfaceType): if not values: return False - if not values.startswith(SERIALIZATION_HEADER): + if not values.startswith(consts.ui.SERIALIZATION_HEADER): # Unserialize with old method, and notify that we need to upgrade self.deserialize_from_old_format(values) return True # For future use, right now we only have one version # Prepared for a possible future versioning of data serialization - _version = values[len(SERIALIZATION_HEADER) : len(SERIALIZATION_HEADER) + len(SERIALIZATION_VERSION)] + _version = values[ + len(consts.ui.SERIALIZATION_HEADER) : len(consts.ui.SERIALIZATION_HEADER) + + len(consts.ui.SERIALIZATION_VERSION) + ] - values = values[len(SERIALIZATION_HEADER) + len(SERIALIZATION_VERSION) :] + values = values[len(consts.ui.SERIALIZATION_HEADER) + len(consts.ui.SERIALIZATION_VERSION) :] if not values: # Apart of the header, there is nothing... logger.info('Empty values on unserialize_fields') @@ -1656,7 +1641,12 @@ class UserInterface(metaclass=UserInterfaceType): logger.warning('Field %s has no decoder', field_name) continue if field_type != internal_field_type.name: - logger.warning('Field %s has different type than expected: %s != %s', field_name, field_type, internal_field_type.name) + logger.warning( + 'Field %s has different type than expected: %s != %s', + field_name, + field_type, + internal_field_type.name, + ) continue self._gui[field_name].value = FIELD_DECODERS[internal_field_type](field_value) @@ -1692,7 +1682,7 @@ class UserInterface(metaclass=UserInterfaceType): return field_names_translations: dict[str, str] = self._get_fieldname_translations() - + for txt in values.split(FIELD_SEPARATOR): kb, v = txt.split(NAME_VALUE_SEPARATOR) k = kb.decode('utf8') # Convert name to string @@ -1701,11 +1691,9 @@ class UserInterface(metaclass=UserInterfaceType): if k in self._gui: try: if v.startswith(MULTIVALUE_FIELD): - val = pickle.loads( # nosec: safe pickle, controlled - v[1:] - ) # nosec: secure pickled by us for sure + val = pickle.loads(v[1:]) elif v.startswith(OLD_PASSWORD_FIELD): - val = CryptoManager().aes_decrypt(v[1:], UDSB, True).decode() + val = CryptoManager().aes_decrypt(v[1:], consts.ui.UDSB, True).decode() elif v.startswith(PASSWORD_FIELD): val = CryptoManager().aes_decrypt(v[1:], UDSK, True).decode() else: @@ -1728,7 +1716,7 @@ class UserInterface(metaclass=UserInterfaceType): Args: skip_init_gui: If True, init_gui will not be called - + Note: skip_init_gui is used to avoid calling init_gui when we are not going to use the result This is used, for example, when exporting data, generating the tree, etc... @@ -1774,7 +1762,7 @@ class UserInterface(metaclass=UserInterfaceType): field_names_translations[fld_old_field_name] = fld_name return field_names_translations - + def has_field(self, field_name: str) -> bool: """ So we can check against field existence on "own" instance @@ -1811,7 +1799,9 @@ FIELD_DECODERS: typing.Final[ types.ui.FieldType.TEXT: lambda x: x, types.ui.FieldType.TEXT_AUTOCOMPLETE: lambda x: x, types.ui.FieldType.NUMERIC: int, - types.ui.FieldType.PASSWORD: lambda x: (CryptoManager.manager().aes_decrypt(x.encode(), UDSK, True).decode()), + types.ui.FieldType.PASSWORD: lambda x: ( + CryptoManager.manager().aes_decrypt(x.encode(), UDSK, True).decode() + ), types.ui.FieldType.HIDDEN: lambda x: x, types.ui.FieldType.CHOICE: lambda x: x, types.ui.FieldType.MULTICHOICE: lambda x: serializer.deserialize(base64.b64decode(x.encode())), diff --git a/server/src/uds/core/util/autoserializable.py b/server/src/uds/core/util/autoserializable.py index 54097eb8a..024575d56 100644 --- a/server/src/uds/core/util/autoserializable.py +++ b/server/src/uds/core/util/autoserializable.py @@ -310,7 +310,7 @@ class BoolField(_SerializableField[bool]): class ListField(_SerializableField[list[T]], list[T]): """List field - + Args: default: Default value for the field. Can be a list or a callable that returns a list. cast: Optional function to cast the values of the list to the desired type. If not provided, the values will be "deserialized" as they are. (see notes) @@ -346,7 +346,7 @@ class ListField(_SerializableField[list[T]], list[T]): class DictField(_SerializableField[dict[T, V]], dict[T, V]): """Dict field - + Args: default: Default value for the field. Can be a dict or a callable that returns a dict. cast: Optional function to cast the values of the dict to the desired type. If not provided, the values will be "deserialized" as they are. (see notes) @@ -376,7 +376,11 @@ class DictField(_SerializableField[dict[T, V]], dict[T, V]): raise ValueError('Invalid dict data') self.__set__( instance, - dict(self._cast(k, v) for k, v in json.loads(data[1:]).items()) if self._cast else json.loads(data[1:]), + ( + dict(self._cast(k, v) for k, v in json.loads(data[1:]).items()) + if self._cast + else json.loads(data[1:]) + ), ) @@ -500,10 +504,14 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): ... d = ListField(defalut=lambda: [1, 2, 3]) """ - _fields: dict[str, typing.Any] + _fields: dict[str, typing.Any] # Values for the fields (serializable fields only ofc) serialization_version: int = 0 # So autoserializable classes can keep their version if needed + def __init__(self): + super().__init__() + self._fields = {} + def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]: """Returns an iterator over all fields in the class, including inherited ones (that is, all fields that are instances of _SerializableField in the class and its bases) @@ -513,9 +521,11 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): """ cls = self.__class__ while True: + # Get own fields first for k, v in cls.__dict__.items(): if isinstance(v, _SerializableField): yield k, v + # and then look for the first base that is also an AutoSerializable for c in cls.__bases__: if issubclass(c, AutoSerializable) and c != AutoSerializable: cls = c @@ -624,6 +634,9 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): logger.debug('Field %s not found in unmarshalled data', v.name) v.__set__(self, v._default()) # Set default value + def as_dict(self) -> dict[str, typing.Any]: + return {k: v.__get__(self) for k, v in self._autoserializable_fields()} + def __eq__(self, other: typing.Any) -> bool: """ Basic equality check, checks if all _SerializableFields are equal @@ -649,9 +662,6 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): [f"{k}={v.obj_type.__name__}({v.__get__(self)})" for k, v in self._autoserializable_fields()] ) - def as_dict(self) -> dict[str, typing.Any]: - return {k: v.__get__(self) for k, v in self._autoserializable_fields()} - class AutoSerializableCompressed(AutoSerializable): """This class allows the automatic serialization of fields in a class compressed with zlib.""" diff --git a/server/src/uds/models/service_pool_publication.py b/server/src/uds/models/service_pool_publication.py index 878b2f51f..87e2d4d82 100644 --- a/server/src/uds/models/service_pool_publication.py +++ b/server/src/uds/models/service_pool_publication.py @@ -169,6 +169,10 @@ class ServicePoolPublication(UUIDModel): :note: This method do not saves the updated record, just updates the field """ + if not publication_instance.is_dirty(): + logger.debug('Skipping update of publication %s, no changes', self) + return # Nothing to do + self.data = publication_instance.serialize() self.save(update_fields=['data']) diff --git a/server/src/uds/models/user_service.py b/server/src/uds/models/user_service.py index 46ec4abde..e253c2878 100644 --- a/server/src/uds/models/user_service.py +++ b/server/src/uds/models/user_service.py @@ -258,6 +258,9 @@ class UserService(UUIDModel, properties.PropertiesMixin): :note: This method SAVES the updated record, just updates the field """ + if not userservice_instance.is_dirty(): + logger.debug('Skipping update of user service %s, no changes', self) + return # Nothing to do self.data = userservice_instance.serialize() self.save(update_fields=['data']) @@ -550,7 +553,7 @@ class UserService(UUIDModel, properties.PropertiesMixin): from uds.core.managers.userservice import UserServiceManager # Cancel is a "forced" operation, so they are not checked against limits - UserServiceManager().cancel(self) + UserServiceManager.manager().cancel(self) def remove_or_cancel(self) -> None: """ diff --git a/server/src/uds/services/Test/deployment.py b/server/src/uds/services/Test/deployment.py index fc140531d..c47665e2b 100644 --- a/server/src/uds/services/Test/deployment.py +++ b/server/src/uds/services/Test/deployment.py @@ -31,10 +31,10 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com """ import logging -import dataclasses import typing from uds.core import services, types +from uds.core.util import autoserializable # Not imported at runtime, just for type checking if typing.TYPE_CHECKING: @@ -44,28 +44,18 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -class TestUserService(services.UserService): +class TestUserService(services.UserService, autoserializable.AutoSerializable): """ Simple testing deployment, no cache """ - - @dataclasses.dataclass - class Data: - """ - This is the data we will store in the storage - """ - - count: int = -1 - ready: bool = False - name: str = '' - ip: str = '' - mac: str = '' - - data: Data + count = autoserializable.IntegerField(default=-1) + ready = autoserializable.BoolField(default=False) + name = autoserializable.StringField(default='') + ip = autoserializable.StringField(default='') + mac = autoserializable.StringField(default='') def initialize(self) -> None: super().initialize() - self.data = TestUserService.Data() # : Recheck every five seconds by default (for task methods) suggested_delay = 5 @@ -74,56 +64,56 @@ class TestUserService(services.UserService): return typing.cast('TestServiceNoCache', super().service()) def get_name(self) -> str: - if not self.data.name: - self.data.name = self.name_generator().get(self.service().get_basename(), 3) + if not self.name: + self.name = self.name_generator().get(self.service().get_basename(), 3) - logger.info('Getting name of deployment %s', self.data) + logger.info('Getting name of deployment %s', self) - return self.data.name + return self.name def set_ip(self, ip: str) -> None: - logger.info('Setting ip of deployment %s to %s', self.data, ip) - self.data.ip = ip + logger.info('Setting ip of deployment %s to %s', self, ip) + self.ip = ip def get_unique_id(self) -> str: - logger.info('Getting unique id of deployment %s', self.data) - if not self.data.mac: - self.data.mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF') - return self.data.mac + logger.info('Getting unique id of deployment %s', self) + if not self.mac: + self.mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF') + return self.mac def get_ip(self) -> str: - logger.info('Getting ip of deployment %s', self.data) + logger.info('Getting ip of deployment %s', self) ip = typing.cast(str, self.storage.read_from_db('ip')) if not ip: ip = '8.6.4.2' # Sample IP for testing purposses only return ip def set_ready(self) -> types.states.TaskState: - logger.info('Setting ready %s', self.data) - self.data.ready = True + logger.info('Setting ready %s', self) + self.ready = True return types.states.TaskState.FINISHED def deploy_for_user(self, user: 'models.User') -> types.states.TaskState: - logger.info('Deploying for user %s %s', user, self.data) - self.data.count = 3 + logger.info('Deploying for user %s %s', user, self) + self.count = 3 return types.states.TaskState.RUNNING def deploy_for_cache(self, level: types.services.CacheLevel) -> types.states.TaskState: - logger.info('Deploying for cache %s %s', level, self.data) - self.data.count = 3 + logger.info('Deploying for cache %s %s', level, self) + self.count = 3 return types.states.TaskState.RUNNING def check_state(self) -> types.states.TaskState: - logger.info('Checking state of deployment %s', self.data) - if self.data.count <= 0: + logger.info('Checking state of deployment %s', self) + if self.count <= 0: return types.states.TaskState.FINISHED - self.data.count -= 1 + self.count -= 1 return types.states.TaskState.RUNNING def finish(self) -> None: - logger.info('Finishing deployment %s', self.data) - self.data.count = -1 + logger.info('Finishing deployment %s', self) + self.count = -1 def user_logged_in(self, username: str) -> None: logger.info('User %s has logged in', username) @@ -135,8 +125,8 @@ class TestUserService(services.UserService): return 'No error' def destroy(self) -> types.states.TaskState: - logger.info('Destroying deployment %s', self.data) - self.data.count = -1 + logger.info('Destroying deployment %s', self) + self.count = -1 return types.states.TaskState.FINISHED def cancel(self) -> types.states.TaskState: diff --git a/server/src/uds/services/Test/publication.py b/server/src/uds/services/Test/publication.py index dc9041023..6e26259a6 100644 --- a/server/src/uds/services/Test/publication.py +++ b/server/src/uds/services/Test/publication.py @@ -33,11 +33,11 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com import random import string import logging -import dataclasses import typing from django.utils.translation import gettext as _ from uds.core import services, types +from uds.core.util import autoserializable logger = logging.getLogger(__name__) @@ -46,26 +46,20 @@ if typing.TYPE_CHECKING: pass -class TestPublication(services.Publication): +class TestPublication(services.Publication, autoserializable.AutoSerializable): """ - Simple test publication + Simple test publication """ - suggested_delay = ( - 5 # : Suggested recheck time if publication is unfinished in seconds - ) - + + suggested_delay = 5 # : Suggested recheck time if publication is unfinished in seconds + # Data to store - @dataclasses.dataclass - class Data: - name: str = '' - state: str = '' - reason: str = '' - number: int = -1 - other: str = '' - other2: str = 'other2' - - - data: Data = Data() + name = autoserializable.StringField() + state = autoserializable.StringField() + reason = autoserializable.StringField() + number = autoserializable.IntegerField(default=-1) + other = autoserializable.StringField() + other2 = autoserializable.StringField(default='other2') def initialize(self) -> None: """ @@ -77,34 +71,37 @@ class TestPublication(services.Publication): # We do not check anything at marshal method, so we ensure that # default values are correctly handled by marshal. - self.data.name = ''.join(random.choices(string.ascii_letters, k=8)) - self.data.state = types.states.TaskState.RUNNING - self.data.reason = 'none' - self.data.number = 10 + self.name = ''.join(random.choices(string.ascii_letters, k=8)) + self.state = types.states.TaskState.RUNNING + self.reason = 'none' + self.number = 10 def publish(self) -> types.states.TaskState: - logger.info('Publishing publication %s: %s remaining',self.data.name, self.data.number) - self.data.number -= 1 + logger.info('Publishing publication %s: %s remaining', self.name, self.number) + self.number -= 1 - if self.data.number <= 0: - self.data.state = types.states.TaskState.FINISHED - return types.states.TaskState.from_str(self.data.state) + if self.number <= 0: + self.state = types.states.TaskState.FINISHED + return types.states.TaskState.from_str(self.state) + + def check_state(self) -> types.states.TaskState: + return types.states.TaskState.from_str(self.state) def finish(self) -> None: # Make simply a random string - logger.info('Finishing publication %s', self.data.name) - self.data.number = 0 - self.data.state = types.states.TaskState.FINISHED + logger.info('Finishing publication %s', self.name) + self.number = 0 + self.state = types.states.TaskState.FINISHED def error_reason(self) -> str: - return self.data.reason + return self.reason def destroy(self) -> types.states.TaskState: - logger.info('Destroying publication %s', self.data.name) + logger.info('Destroying publication %s', self.name) return types.states.TaskState.FINISHED def cancel(self) -> types.states.TaskState: - logger.info('Canceling publication %s', self.data.name) + logger.info('Canceling publication %s', self.name) return self.destroy() # Here ends the publication needed methods. @@ -117,4 +114,4 @@ class TestPublication(services.Publication): the name generater for this publication. This is just a sample, and this will do the work """ - return self.data.name + return self.name