1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-02-02 09:47:13 +03:00

Refactoring code and improved possible use of a "is_dirty".

This method is to improve in the future the storage of data (Skipping it).
Curently, no one uses is_dirty (base class always returns True), but there it is in case longs serializations requires some optimization.
This commit is contained in:
Adolfo Gómez García 2024-07-11 05:55:32 +02:00
parent e4d5bef48a
commit b0cf8c5ddf
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
12 changed files with 193 additions and 167 deletions

View File

@ -31,17 +31,19 @@
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from django.conf import settings
from uds.core.util import ensure
from uds.core.ui.user_interface import gui, UDSK
from uds.core import consts
# We use commit/rollback
from ...utils.test import UDSTestCase
from uds.core.ui.user_interface import gui, UDSB, UDSK
from django.conf import settings
from uds.core.util import ensure
class GuiTest(UDSTestCase):
def test_globals(self) -> None:
self.assertEqual(UDSK, settings.SECRET_KEY[8:24].encode())
self.assertEqual(UDSB, b'udsprotect')
self.assertEqual(consts.ui.UDSB, b'udsprotect')
def test_convert_to_choices(self) -> None:
# Several cases

View File

@ -79,7 +79,9 @@ class AutoSerializableClass(autoserializable.AutoSerializable):
)
dict_field = autoserializable.DictField[str, int](default=lambda: {'a': 1, 'b': 2, 'c': 3})
# Note that due to the dict being serialized as json, the keys are always strings
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)))
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](
cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))
)
obj_dc_field = autoserializable.ObjectField[SerializableDataclass](
SerializableDataclass, default=lambda: SerializableDataclass(1, '2', 3.0)
)
@ -99,7 +101,9 @@ class AutoSerializableCompressedClass(autoserializable.AutoSerializableCompresse
list_field = autoserializable.ListField[int]()
list_field_with_cast = autoserializable.ListField[EnumTest](cast=EnumTest.from_int)
dict_field = autoserializable.DictField[str, int]()
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)))
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](
cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))
)
obj_dc_field = autoserializable.ObjectField[SerializableDataclass](SerializableDataclass)
obj_nt_field = autoserializable.ObjectField[SerializableNamedTuple](SerializableNamedTuple)
@ -115,7 +119,9 @@ class AutoSerializableEncryptedClass(autoserializable.AutoSerializableEncrypted)
list_field = autoserializable.ListField[int]()
list_field_with_cast = autoserializable.ListField[EnumTest](cast=EnumTest.from_int)
dict_field = autoserializable.DictField[str, int]()
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](cast=lambda k, v: (EnumTest(int(k)), EnumTest(v)))
dict_field_with_cast = autoserializable.DictField[EnumTest, EnumTest](
cast=lambda k, v: (EnumTest(int(k)), EnumTest(v))
)
obj_dc_field = autoserializable.ObjectField[SerializableDataclass](SerializableDataclass)
obj_nt_field = autoserializable.ObjectField[SerializableNamedTuple](SerializableNamedTuple)
@ -188,7 +194,9 @@ class AutoSerializable(UDSTestCase):
for vv in i.list_field_with_cast:
self.assertIsInstance(vv, EnumTest)
self.assertEqual(i.dict_field, {'a': 1, 'b': 2, 'c': 3})
self.assertEqual(i.dict_field_with_cast, {EnumTest.VALUE1: EnumTest.VALUE2, EnumTest.VALUE2: EnumTest.VALUE3})
self.assertEqual(
i.dict_field_with_cast, {EnumTest.VALUE1: EnumTest.VALUE2, EnumTest.VALUE2: EnumTest.VALUE3}
)
for kk, vv in i.dict_field_with_cast.items():
self.assertIsInstance(kk, EnumTest)
self.assertIsInstance(vv, EnumTest)
@ -332,3 +340,25 @@ class AutoSerializable(UDSTestCase):
self.assertEqual(instance2.dict_field, {'a': 1, 'b': 2, 'c': 3}) # default value
self.assertEqual(instance2.obj_dc_field, SerializableDataclass(1, '2', 3.0)) # default value
self.assertEqual(instance2.obj_nt_field, SerializableNamedTuple(2, '3', 4.0)) # deserialized value
def test_autoserializable_dirty(self) -> None:
instance = AutoSerializableClass()
self.assertFalse(instance.is_dirty())
instance.int_field = 1
self.assertTrue(instance.is_dirty())
instance.marshal() # should reset dirty flag
self.assertFalse(instance.is_dirty())
instance.int_field = 1
self.assertTrue(instance.is_dirty())
instance2 = AutoSerializableClass()
self.assertFalse(instance2.is_dirty())
instance2.int_field = 22
self.assertTrue(instance2.is_dirty())
instance2.unmarshal(instance.marshal())
self.assertFalse(instance2.is_dirty())

View File

@ -34,6 +34,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import contextlib
import copy
import functools
import random
import typing
import datetime
@ -615,13 +616,16 @@ def create_publication(
Create a publication
"""
uuid_ = str(uuid.uuid4())
return publication.ProxmoxPublication(
pub = publication.ProxmoxPublication(
environment=environment.Environment.private_environment(uuid_),
service=service or create_service_linked(**kwargs),
revision=1,
servicepool_name='servicepool_name',
uuid=uuid_,
)
pub._vmid = str(random.choice(VMINFO_LIST).id)
return pub
def create_userservice_fixed(

View File

@ -35,7 +35,7 @@ import time
import typing
from datetime import datetime
from . import actor, auth, cache, calendar, images, net, os, system, ticket, rest, services, transports
from . import actor, auth, cache, calendar, images, net, os, system, ticket, rest, services, transports, ui
# Date related constants
NEVER: typing.Final[datetime] = datetime(1972, 7, 1)

View File

@ -31,10 +31,10 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import base64
import pickle # nosec: Safe pickle usage
import abc
class Serializable:
class Serializable(abc.ABC):
"""
This class represents the interface that all serializable objects must provide.
@ -54,6 +54,7 @@ class Serializable:
def __init__(self) -> None:
self._needs_upgrade = False
@abc.abstractmethod
def marshal(self) -> bytes:
"""
This is the method that must be overriden in order to serialize an object.
@ -62,17 +63,12 @@ class Serializable:
only suitable methods to "codify" serialized values
:note: This method can be overriden.
:note: if you provide a "data" member variable, and it has __dict__, then it will be used
to marshal that data variable
"""
# Default implementation will look for a member variable called "data"
# This is an struct, and will be pickled by default
...
if hasattr(self, 'data') and hasattr(getattr(self, 'data'), '__dict__'):
return pickle.dumps(getattr(self, 'data'), protocol=pickle.HIGHEST_PROTOCOL)
raise NotImplementedError('You must override the marshal method or provide a data member')
@abc.abstractmethod
def unmarshal(self, data: bytes) -> None:
"""
This is the method that must be overriden in order to deserialize an object.
@ -87,14 +83,8 @@ class Serializable:
data : String readed from persistent storage to deseralilize
:note: This method can be overriden.
:note: if you provide a "data" member variable, and it has __dict__, then it will be used
to unmarshal that data variable
"""
if hasattr(self, 'data') and hasattr(getattr(self, 'data'), '__dict__'):
setattr(self, 'data', pickle.loads(data)) # nosec: Safe pickle load
return
raise NotImplementedError('You must override the unmarshal method or provide a data member')
...
def serialize(self) -> str:
"""
@ -132,3 +122,9 @@ class Serializable:
Returns true if this object needs remarshalling
"""
return self._needs_upgrade
def is_dirty(self) -> bool:
"""
Returns true if this object needs remarshalling
"""
return True

View File

@ -630,4 +630,4 @@ class UserService(Environmentable, Serializable, abc.ABC):
"""
Mainly used for debugging purposses
"""
return f'{self.__class__.__name__}({self.get_unique_id()})'
return f'{self.__class__.__name__}'

View File

@ -53,11 +53,13 @@ from uds.core.util import serializer, validators, ensure
logger = logging.getLogger(__name__)
# Old encryption key
UDSB: typing.Final[bytes] = b'udsprotect'
SERIALIZATION_HEADER: typing.Final[bytes] = b'GUIZ'
SERIALIZATION_VERSION: typing.Final[bytes] = b'\001'
# To simplify choice parameters declaration of fields
_ChoicesParamType: typing.TypeAlias = typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[str | types.ui.ChoiceItem],
dict[str, str],
None,
]
class gui:
@ -138,12 +140,7 @@ class gui:
# Helpers
@staticmethod
def as_choices(
vals: typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]],
dict[str, str],
None,
]
vals: _ChoicesParamType,
) -> typing.Union[collections.abc.Callable[[], list['types.ui.ChoiceItem']], list['types.ui.ChoiceItem']]:
"""
Helper to convert from array of strings (or dictionaries) to the same dict used in choice,
@ -177,8 +174,11 @@ class gui:
@staticmethod
def sorted_choices(
choices: collections.abc.Iterable[types.ui.ChoiceItem], *, by_id: bool = False, reverse: bool = False,
key: typing.Optional[collections.abc.Callable[[types.ui.ChoiceItem], typing.Any]] = None
choices: collections.abc.Iterable[types.ui.ChoiceItem],
*,
by_id: bool = False,
reverse: bool = False,
key: typing.Optional[collections.abc.Callable[[types.ui.ChoiceItem], typing.Any]] = None,
) -> list[types.ui.ChoiceItem]:
if by_id:
key = lambda item: item['id']
@ -613,12 +613,7 @@ class gui:
tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None,
default: typing.Union[collections.abc.Callable[[], str], str] = '',
value: typing.Optional[str] = None,
choices: typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]],
dict[str, str],
None,
] = None,
choices: _ChoicesParamType = None,
old_field_name: types.ui.OldFieldNameType = None,
) -> None:
super().__init__(
@ -634,7 +629,9 @@ class gui:
old_field_name=old_field_name,
)
# Update parent type
self.field_type = types.ui.FieldType.TEXT_AUTOCOMPLETE # pyright: ignore[reportIncompatibleMethodOverride]
self.field_type = (
types.ui.FieldType.TEXT_AUTOCOMPLETE
) # pyright: ignore[reportIncompatibleMethodOverride]
self._fields_info.choices = gui.as_choices(choices or [])
def set_choices(self, values: collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]]) -> None:
@ -1093,12 +1090,7 @@ class gui:
order: int = 0,
tooltip: str = '',
required: typing.Optional[bool] = None,
choices: typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]],
dict[str, str],
None,
] = None,
choices: _ChoicesParamType = None,
fills: typing.Optional[types.ui.Filler] = None,
tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None,
default: typing.Union[collections.abc.Callable[[], str], str, None] = None,
@ -1118,7 +1110,7 @@ class gui:
type=types.ui.FieldType.CHOICE,
)
self._fields_info.choices = gui.as_choices(choices or [])
self._fields_info.choices = gui.as_choices(choices)
# if has fillers, set them
if fills:
if 'function' not in fills or 'callback_name' not in fills:
@ -1161,12 +1153,7 @@ class gui:
order: int = 0,
tooltip: str = '',
required: typing.Optional[bool] = None,
choices: typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]],
dict[str, str],
None,
] = None,
choices: _ChoicesParamType = None,
tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None,
default: typing.Union[collections.abc.Callable[[], str], str, None] = None,
value: typing.Optional[str] = None,
@ -1252,12 +1239,7 @@ class gui:
order: int = 0,
tooltip: str = '',
required: typing.Optional[bool] = None,
choices: typing.Union[
collections.abc.Callable[[], list['types.ui.ChoiceItem']],
collections.abc.Iterable[typing.Union[str, types.ui.ChoiceItem]],
dict[str, str],
None,
] = None,
choices: _ChoicesParamType = None,
tab: typing.Optional[typing.Union[str, types.ui.Tab]] = None,
default: typing.Union[
collections.abc.Callable[[], str], collections.abc.Callable[[], list[str]], list[str], str, None
@ -1584,7 +1566,7 @@ class UserInterface(metaclass=UserInterfaceType):
if FIELDS_ENCODERS[field.field_type](field) is not None
]
return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(fields)
return consts.ui.SERIALIZATION_HEADER + consts.ui.SERIALIZATION_VERSION + serializer.serialize(fields)
def deserialize_fields(
self,
@ -1619,16 +1601,19 @@ class UserInterface(metaclass=UserInterfaceType):
if not values:
return False
if not values.startswith(SERIALIZATION_HEADER):
if not values.startswith(consts.ui.SERIALIZATION_HEADER):
# Unserialize with old method, and notify that we need to upgrade
self.deserialize_from_old_format(values)
return True
# For future use, right now we only have one version
# Prepared for a possible future versioning of data serialization
_version = values[len(SERIALIZATION_HEADER) : len(SERIALIZATION_HEADER) + len(SERIALIZATION_VERSION)]
_version = values[
len(consts.ui.SERIALIZATION_HEADER) : len(consts.ui.SERIALIZATION_HEADER)
+ len(consts.ui.SERIALIZATION_VERSION)
]
values = values[len(SERIALIZATION_HEADER) + len(SERIALIZATION_VERSION) :]
values = values[len(consts.ui.SERIALIZATION_HEADER) + len(consts.ui.SERIALIZATION_VERSION) :]
if not values: # Apart of the header, there is nothing...
logger.info('Empty values on unserialize_fields')
@ -1656,7 +1641,12 @@ class UserInterface(metaclass=UserInterfaceType):
logger.warning('Field %s has no decoder', field_name)
continue
if field_type != internal_field_type.name:
logger.warning('Field %s has different type than expected: %s != %s', field_name, field_type, internal_field_type.name)
logger.warning(
'Field %s has different type than expected: %s != %s',
field_name,
field_type,
internal_field_type.name,
)
continue
self._gui[field_name].value = FIELD_DECODERS[internal_field_type](field_value)
@ -1701,11 +1691,9 @@ class UserInterface(metaclass=UserInterfaceType):
if k in self._gui:
try:
if v.startswith(MULTIVALUE_FIELD):
val = pickle.loads( # nosec: safe pickle, controlled
v[1:]
) # nosec: secure pickled by us for sure
val = pickle.loads(v[1:])
elif v.startswith(OLD_PASSWORD_FIELD):
val = CryptoManager().aes_decrypt(v[1:], UDSB, True).decode()
val = CryptoManager().aes_decrypt(v[1:], consts.ui.UDSB, True).decode()
elif v.startswith(PASSWORD_FIELD):
val = CryptoManager().aes_decrypt(v[1:], UDSK, True).decode()
else:
@ -1811,7 +1799,9 @@ FIELD_DECODERS: typing.Final[
types.ui.FieldType.TEXT: lambda x: x,
types.ui.FieldType.TEXT_AUTOCOMPLETE: lambda x: x,
types.ui.FieldType.NUMERIC: int,
types.ui.FieldType.PASSWORD: lambda x: (CryptoManager.manager().aes_decrypt(x.encode(), UDSK, True).decode()),
types.ui.FieldType.PASSWORD: lambda x: (
CryptoManager.manager().aes_decrypt(x.encode(), UDSK, True).decode()
),
types.ui.FieldType.HIDDEN: lambda x: x,
types.ui.FieldType.CHOICE: lambda x: x,
types.ui.FieldType.MULTICHOICE: lambda x: serializer.deserialize(base64.b64decode(x.encode())),

View File

@ -376,7 +376,11 @@ class DictField(_SerializableField[dict[T, V]], dict[T, V]):
raise ValueError('Invalid dict data')
self.__set__(
instance,
dict(self._cast(k, v) for k, v in json.loads(data[1:]).items()) if self._cast else json.loads(data[1:]),
(
dict(self._cast(k, v) for k, v in json.loads(data[1:]).items())
if self._cast
else json.loads(data[1:])
),
)
@ -500,10 +504,14 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
... d = ListField(defalut=lambda: [1, 2, 3])
"""
_fields: dict[str, typing.Any]
_fields: dict[str, typing.Any] # Values for the fields (serializable fields only ofc)
serialization_version: int = 0 # So autoserializable classes can keep their version if needed
def __init__(self):
super().__init__()
self._fields = {}
def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]:
"""Returns an iterator over all fields in the class, including inherited ones
(that is, all fields that are instances of _SerializableField in the class and its bases)
@ -513,9 +521,11 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
"""
cls = self.__class__
while True:
# Get own fields first
for k, v in cls.__dict__.items():
if isinstance(v, _SerializableField):
yield k, v
# and then look for the first base that is also an AutoSerializable
for c in cls.__bases__:
if issubclass(c, AutoSerializable) and c != AutoSerializable:
cls = c
@ -624,6 +634,9 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
logger.debug('Field %s not found in unmarshalled data', v.name)
v.__set__(self, v._default()) # Set default value
def as_dict(self) -> dict[str, typing.Any]:
return {k: v.__get__(self) for k, v in self._autoserializable_fields()}
def __eq__(self, other: typing.Any) -> bool:
"""
Basic equality check, checks if all _SerializableFields are equal
@ -649,9 +662,6 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
[f"{k}={v.obj_type.__name__}({v.__get__(self)})" for k, v in self._autoserializable_fields()]
)
def as_dict(self) -> dict[str, typing.Any]:
return {k: v.__get__(self) for k, v in self._autoserializable_fields()}
class AutoSerializableCompressed(AutoSerializable):
"""This class allows the automatic serialization of fields in a class compressed with zlib."""

View File

@ -169,6 +169,10 @@ class ServicePoolPublication(UUIDModel):
:note: This method do not saves the updated record, just updates the field
"""
if not publication_instance.is_dirty():
logger.debug('Skipping update of publication %s, no changes', self)
return # Nothing to do
self.data = publication_instance.serialize()
self.save(update_fields=['data'])

View File

@ -258,6 +258,9 @@ class UserService(UUIDModel, properties.PropertiesMixin):
:note: This method SAVES the updated record, just updates the field
"""
if not userservice_instance.is_dirty():
logger.debug('Skipping update of user service %s, no changes', self)
return # Nothing to do
self.data = userservice_instance.serialize()
self.save(update_fields=['data'])
@ -550,7 +553,7 @@ class UserService(UUIDModel, properties.PropertiesMixin):
from uds.core.managers.userservice import UserServiceManager
# Cancel is a "forced" operation, so they are not checked against limits
UserServiceManager().cancel(self)
UserServiceManager.manager().cancel(self)
def remove_or_cancel(self) -> None:
"""

View File

@ -31,10 +31,10 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import logging
import dataclasses
import typing
from uds.core import services, types
from uds.core.util import autoserializable
# Not imported at runtime, just for type checking
if typing.TYPE_CHECKING:
@ -44,28 +44,18 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TestUserService(services.UserService):
class TestUserService(services.UserService, autoserializable.AutoSerializable):
"""
Simple testing deployment, no cache
"""
@dataclasses.dataclass
class Data:
"""
This is the data we will store in the storage
"""
count: int = -1
ready: bool = False
name: str = ''
ip: str = ''
mac: str = ''
data: Data
count = autoserializable.IntegerField(default=-1)
ready = autoserializable.BoolField(default=False)
name = autoserializable.StringField(default='')
ip = autoserializable.StringField(default='')
mac = autoserializable.StringField(default='')
def initialize(self) -> None:
super().initialize()
self.data = TestUserService.Data()
# : Recheck every five seconds by default (for task methods)
suggested_delay = 5
@ -74,56 +64,56 @@ class TestUserService(services.UserService):
return typing.cast('TestServiceNoCache', super().service())
def get_name(self) -> str:
if not self.data.name:
self.data.name = self.name_generator().get(self.service().get_basename(), 3)
if not self.name:
self.name = self.name_generator().get(self.service().get_basename(), 3)
logger.info('Getting name of deployment %s', self.data)
logger.info('Getting name of deployment %s', self)
return self.data.name
return self.name
def set_ip(self, ip: str) -> None:
logger.info('Setting ip of deployment %s to %s', self.data, ip)
self.data.ip = ip
logger.info('Setting ip of deployment %s to %s', self, ip)
self.ip = ip
def get_unique_id(self) -> str:
logger.info('Getting unique id of deployment %s', self.data)
if not self.data.mac:
self.data.mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF')
return self.data.mac
logger.info('Getting unique id of deployment %s', self)
if not self.mac:
self.mac = self.mac_generator().get('00:00:00:00:00:00-00:FF:FF:FF:FF:FF')
return self.mac
def get_ip(self) -> str:
logger.info('Getting ip of deployment %s', self.data)
logger.info('Getting ip of deployment %s', self)
ip = typing.cast(str, self.storage.read_from_db('ip'))
if not ip:
ip = '8.6.4.2' # Sample IP for testing purposses only
return ip
def set_ready(self) -> types.states.TaskState:
logger.info('Setting ready %s', self.data)
self.data.ready = True
logger.info('Setting ready %s', self)
self.ready = True
return types.states.TaskState.FINISHED
def deploy_for_user(self, user: 'models.User') -> types.states.TaskState:
logger.info('Deploying for user %s %s', user, self.data)
self.data.count = 3
logger.info('Deploying for user %s %s', user, self)
self.count = 3
return types.states.TaskState.RUNNING
def deploy_for_cache(self, level: types.services.CacheLevel) -> types.states.TaskState:
logger.info('Deploying for cache %s %s', level, self.data)
self.data.count = 3
logger.info('Deploying for cache %s %s', level, self)
self.count = 3
return types.states.TaskState.RUNNING
def check_state(self) -> types.states.TaskState:
logger.info('Checking state of deployment %s', self.data)
if self.data.count <= 0:
logger.info('Checking state of deployment %s', self)
if self.count <= 0:
return types.states.TaskState.FINISHED
self.data.count -= 1
self.count -= 1
return types.states.TaskState.RUNNING
def finish(self) -> None:
logger.info('Finishing deployment %s', self.data)
self.data.count = -1
logger.info('Finishing deployment %s', self)
self.count = -1
def user_logged_in(self, username: str) -> None:
logger.info('User %s has logged in', username)
@ -135,8 +125,8 @@ class TestUserService(services.UserService):
return 'No error'
def destroy(self) -> types.states.TaskState:
logger.info('Destroying deployment %s', self.data)
self.data.count = -1
logger.info('Destroying deployment %s', self)
self.count = -1
return types.states.TaskState.FINISHED
def cancel(self) -> types.states.TaskState:

View File

@ -33,11 +33,11 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import random
import string
import logging
import dataclasses
import typing
from django.utils.translation import gettext as _
from uds.core import services, types
from uds.core.util import autoserializable
logger = logging.getLogger(__name__)
@ -46,26 +46,20 @@ if typing.TYPE_CHECKING:
pass
class TestPublication(services.Publication):
class TestPublication(services.Publication, autoserializable.AutoSerializable):
"""
Simple test publication
"""
suggested_delay = (
5 # : Suggested recheck time if publication is unfinished in seconds
)
suggested_delay = 5 # : Suggested recheck time if publication is unfinished in seconds
# Data to store
@dataclasses.dataclass
class Data:
name: str = ''
state: str = ''
reason: str = ''
number: int = -1
other: str = ''
other2: str = 'other2'
data: Data = Data()
name = autoserializable.StringField()
state = autoserializable.StringField()
reason = autoserializable.StringField()
number = autoserializable.IntegerField(default=-1)
other = autoserializable.StringField()
other2 = autoserializable.StringField(default='other2')
def initialize(self) -> None:
"""
@ -77,34 +71,37 @@ class TestPublication(services.Publication):
# We do not check anything at marshal method, so we ensure that
# default values are correctly handled by marshal.
self.data.name = ''.join(random.choices(string.ascii_letters, k=8))
self.data.state = types.states.TaskState.RUNNING
self.data.reason = 'none'
self.data.number = 10
self.name = ''.join(random.choices(string.ascii_letters, k=8))
self.state = types.states.TaskState.RUNNING
self.reason = 'none'
self.number = 10
def publish(self) -> types.states.TaskState:
logger.info('Publishing publication %s: %s remaining',self.data.name, self.data.number)
self.data.number -= 1
logger.info('Publishing publication %s: %s remaining', self.name, self.number)
self.number -= 1
if self.data.number <= 0:
self.data.state = types.states.TaskState.FINISHED
return types.states.TaskState.from_str(self.data.state)
if self.number <= 0:
self.state = types.states.TaskState.FINISHED
return types.states.TaskState.from_str(self.state)
def check_state(self) -> types.states.TaskState:
return types.states.TaskState.from_str(self.state)
def finish(self) -> None:
# Make simply a random string
logger.info('Finishing publication %s', self.data.name)
self.data.number = 0
self.data.state = types.states.TaskState.FINISHED
logger.info('Finishing publication %s', self.name)
self.number = 0
self.state = types.states.TaskState.FINISHED
def error_reason(self) -> str:
return self.data.reason
return self.reason
def destroy(self) -> types.states.TaskState:
logger.info('Destroying publication %s', self.data.name)
logger.info('Destroying publication %s', self.name)
return types.states.TaskState.FINISHED
def cancel(self) -> types.states.TaskState:
logger.info('Canceling publication %s', self.data.name)
logger.info('Canceling publication %s', self.name)
return self.destroy()
# Here ends the publication needed methods.
@ -117,4 +114,4 @@ class TestPublication(services.Publication):
the name generater for this publication. This is just a sample, and
this will do the work
"""
return self.data.name
return self.name