1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Adding autoserializable and fixing up some tests

This commit is contained in:
Adolfo Gómez García 2022-11-12 21:47:48 +01:00
parent 57013ed1e1
commit 9ca30b5c30
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
29 changed files with 765 additions and 60 deletions

View File

@ -1,3 +1,3 @@
udsactor-unmanaged_3.6.0_all.deb admin optional
udsactor_3.6.0_all.deb admin optional
udsactor_3.6.0_amd64.buildinfo admin optional
udsactor-unmanaged_4.0.0_all.deb admin optional
udsactor_4.0.0_all.deb admin optional
udsactor_4.0.0_amd64.buildinfo admin optional

View File

@ -41,12 +41,12 @@ from ...fixtures import osmanagers as osmanagers_fixtures
from ...fixtures import notifiers as notifiers_fixtures
from ...fixtures import stats_counters as stats_counters_fixtures
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
if typing.TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class ModelXXTest(UDSTransactionTestCase):
class ModelXXTest(UDSTestCase):
pass

View File

@ -37,7 +37,7 @@ from uds import models
from uds.models.account_usage import AccountUsage
from ...fixtures import services as services_fixtures
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
if typing.TYPE_CHECKING:
pass
@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
NUM_USERSERVICES = 8
class ModelAccountTest(UDSTransactionTestCase):
class ModelAccountTest(UDSTestCase):
user_services: typing.List['models.UserService']
def setUp(self) -> None:

View File

@ -35,7 +35,7 @@ import logging
from uds import models
from uds.models.calendar_rule import freqs, dunits, weekdays
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
if typing.TYPE_CHECKING:
pass
@ -43,7 +43,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ModelCalendarTest(UDSTransactionTestCase):
class ModelCalendarTest(UDSTestCase):
def test_calendar(self) -> None:
# Ensure we can create some calendars
for i in range(32):

View File

@ -35,12 +35,12 @@ import logging
from uds.core import messaging
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures.images import createImage
if typing.TYPE_CHECKING:
from uds import models
class ModelImageTest(UDSTransactionTestCase):
class ModelImageTest(UDSTestCase):
pass

View File

@ -32,14 +32,14 @@
import typing
import logging
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import authenticators as authenticators_fixtures
from uds import models
if typing.TYPE_CHECKING:
pass
class ModelUUIDTest(UDSTransactionTestCase):
class ModelUUIDTest(UDSTestCase):
auth: 'models.Authenticator'
user: 'models.User'
group: 'models.Group'

View File

@ -32,7 +32,7 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
# We use commit/rollback
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from uds.core.ui.user_interface import (
gui,
UserInterface
@ -42,7 +42,7 @@ import time
from django.conf import settings
class UserinterfaceTest(UDSTransactionTestCase):
class UserinterfaceTest(UDSTestCase):
def test_userinterface(self):
pass

View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
# We use commit/rollback
import typing
from ...utils.test import UDSTestCase
from uds.core.util import auto_serializable
UNICODE_CHARS = 'ñöçóá^(pípè)'
UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
class AutoSerializableClass(auto_serializable.AutoSerializable):
int_field = auto_serializable.IntField()
str_field = auto_serializable.StringField()
float_field = auto_serializable.FloatField()
bool_field = auto_serializable.BoolField()
password_field = auto_serializable.PasswordField()
list_field = auto_serializable.ListField()
dict_field = auto_serializable.DictField()
class AutoSerializableCompressedClass(auto_serializable.AutoSerializableCompressed):
int_field = auto_serializable.IntField()
str_field = auto_serializable.StringField()
float_field = auto_serializable.FloatField()
bool_field = auto_serializable.BoolField()
password_field = auto_serializable.PasswordField()
list_field = auto_serializable.ListField()
dict_field = auto_serializable.DictField()
class AutoSerializableEncryptedClass(auto_serializable.AutoSerializableEncrypted):
int_field = auto_serializable.IntField()
str_field = auto_serializable.StringField()
float_field = auto_serializable.FloatField()
bool_field = auto_serializable.BoolField()
password_field = auto_serializable.PasswordField()
list_field = auto_serializable.ListField()
dict_field = auto_serializable.DictField()
class AutoSerializable(UDSTestCase):
def basic_check(
self,
cls1: typing.Type[
'AutoSerializableClass|AutoSerializableCompressedClass|AutoSerializableEncryptedClass'
],
cls2: typing.Type[
'AutoSerializableClass|AutoSerializableCompressedClass|AutoSerializableEncryptedClass'
],
) -> None:
# Test basic serialization
a = cls1()
a.int_field = 1
a.str_field = UNICODE_CHARS
a.float_field = 3.0
a.bool_field = True
a.password_field = UNICODE_CHARS_2 # nosec: test password
a.list_field = [1, 2, 3]
a.dict_field = {'a': 1, 'b': 2, 'c': 3}
data = a.marshal()
b = cls2()
b.unmarshal(data)
self.assertEqual(a, b)
def test_auto_serializable_base(self):
self.basic_check(AutoSerializableClass, AutoSerializableClass)
def test_auto_serializable_compressed(self):
self.basic_check(AutoSerializableCompressedClass, AutoSerializableCompressedClass)
def test_auto_serializable_encrypted(self):
self.basic_check(AutoSerializableEncryptedClass, AutoSerializableEncryptedClass)
def test_auto_serializable_base_compressed(self):
self.basic_check(AutoSerializableClass, AutoSerializableCompressedClass)
def test_auto_serializable_base_encrypted(self):
self.basic_check(AutoSerializableClass, AutoSerializableEncryptedClass)

View File

@ -32,40 +32,59 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
# We use commit/rollback
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from uds.core.util.cache import Cache
import time
import string
# Some random chars, that include unicode non-ascci chars
UNICODE_CHARS = 'ñöçóá^(pípè)'
UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
VALUE_1 = [u'únîcödè€', b'string', {'a': 1, 'b': 2.0}]
class CacheTest(UDSTransactionTestCase):
class CacheTest(UDSTestCase):
def test_cache(self):
cache = Cache(UNICODE_CHARS)
# Get default value, with unicode
self.assertEqual(cache.get(UNICODE_CHARS, UNICODE_CHARS_2), UNICODE_CHARS_2, 'Unicode unexisting key returns default unicode')
self.assertEqual(
cache.get(UNICODE_CHARS, UNICODE_CHARS_2),
UNICODE_CHARS_2,
'Unicode unexisting key returns default unicode',
)
# Remove unexisting key, not a problem
self.assertEqual(cache.remove('non-existing-1'), False, 'Removing unexisting key')
self.assertEqual(
cache.remove('non-existing-1'), False, 'Removing unexisting key'
)
# Add new key (non existing) with default duration (60 seconds probable)
cache.put(UNICODE_CHARS_2, VALUE_1)
# checks it
self.assertEqual(cache.get(UNICODE_CHARS_2), VALUE_1, 'Put a key and recover it')
self.assertEqual(
cache.get(UNICODE_CHARS_2), VALUE_1, 'Put a key and recover it'
)
# Add new "str" key, with 1 second life, wait 2 seconds and recover
cache.put(b'key', VALUE_1, 1)
time.sleep(1.1)
self.assertEqual(cache.get(b'key'), None, 'Put an "str" key and recover it once it has expired')
self.assertEqual(
cache.get(b'key'),
None,
'Put an "str" key and recover it once it has expired',
)
# Refresh previous key and will be again available
cache.refresh(b'key')
self.assertEqual(cache.get(b'key'), VALUE_1, 'Refreshed cache key is {} and should be {}'.format(cache.get(b'key'), VALUE_1))
self.assertEqual(
cache.get(b'key'),
VALUE_1,
'Refreshed cache key is {} and should be {}'.format(
cache.get(b'key'), VALUE_1
),
)
# Checks cache clean
cache.put('key', VALUE_1)
@ -82,5 +101,8 @@ class CacheTest(UDSTransactionTestCase):
time.sleep(0.1)
Cache.cleanUp()
cache.refresh('key')
self.assertEqual(cache.get('key'), None, 'Put a key and recover it once it has expired and has been cleaned')
self.assertEqual(
cache.get('key'),
None,
'Put a key and recover it once it has expired and has been cleaned',
)

View File

@ -30,14 +30,14 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures.calendars import createCalendars
from uds.core.util import calendar
from uds.models import Calendar
import datetime
class CalendarTest(UDSTransactionTestCase):
class CalendarTest(UDSTestCase):
def setUp(self) -> None:
createCalendars()

View File

@ -33,14 +33,14 @@ import typing
import logging
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from uds.core.util import net
logger = logging.getLogger(__name__)
class NetTest(UDSTransactionTestCase):
class NetTest(UDSTestCase):
def testNetworkFromString(self):
for n in (

View File

@ -37,7 +37,7 @@ from uds.core.util import permissions
from uds.core.util import ot
from uds import models
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import (
authenticators as authenticators_fixtures,
services as services_fixtures,
@ -45,14 +45,18 @@ from ...fixtures import (
)
class PermissionsTest(UDSTransactionTestCase):
class PermissionsTest(UDSTestCase):
authenticator: models.Authenticator
groups: typing.List[models.Group]
users: typing.List[models.User]
admins: typing.List[models.User]
staffs: typing.List[models.User]
userService: models.UserService
servicePool: models.ServicePool
service: models.Service
provider: models.Provider
network: models.Network
def setUp(self) -> None:
self.authenticator = authenticators_fixtures.createAuthenticator()
@ -72,6 +76,10 @@ class PermissionsTest(UDSTransactionTestCase):
list(self.users[0].groups.all()),
'managed',
)
self.servicePool = self.userService.deployed_service
self.service = self.servicePool.service
self.provider = self.service.provider
self.network = network_fixtures.createNetwork()
def doTestUserPermissions(self, obj, user: models.User):

View File

@ -37,7 +37,7 @@ import typing
from ...fixtures.stats_counters import create_stats_counters
# We use commit/rollback
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from uds.core.util.stats import counters
from uds import models
@ -47,7 +47,7 @@ END_DATE_DAY = datetime.datetime(2020, 1, 2, 0, 0, 0)
END_DATE_MONTH = datetime.datetime(2020, 2, 1, 0, 0, 0)
END_DATE_YEAR = datetime.datetime(2021, 1, 1, 0, 0, 0)
class StatsCountersTest(UDSTransactionTestCase):
class StatsCountersTest(UDSTestCase):
def setUp(self) -> None:
return super().setUp()

View File

@ -30,7 +30,7 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from uds.core.util.storage import Storage
UNICODE_CHARS = 'ñöçóá^(pípè)'
@ -38,7 +38,7 @@ UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
VALUE_1 = ['unicode', b'string', {'a': 1, 'b': 2.0}]
class StorageTest(UDSTransactionTestCase):
class StorageTest(UDSTestCase):
def test_storage(self):
storage = Storage(UNICODE_CHARS)

View File

@ -31,8 +31,6 @@
@author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import time
import sys
import threading
from ...utils.test import UDSTestCase
from django.conf import settings

View File

@ -38,11 +38,11 @@ from uds.core.util import config
from uds.core.util.state import State
from uds.core.workers.assigned_unused import AssignedAndUnused
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import services as fixtures_services
class AssignedAndUnusedTest(UDSTransactionTestCase):
class AssignedAndUnusedTest(UDSTestCase):
userServices: typing.List[models.UserService]
def setUp(self):

View File

@ -38,14 +38,14 @@ from uds.core.util import config
from uds.core.util.state import State
from uds.core.workers.hanged_userservice_cleaner import HangedCleaner
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import services as fixtures_services
MAX_INIT = 300
TEST_SERVICES = 5 * 5 # Ensure multiple of 5 for testing
class HangedCleanerTest(UDSTransactionTestCase):
class HangedCleanerTest(UDSTestCase):
userServices: typing.List[models.UserService]
def setUp(self):

View File

@ -40,7 +40,7 @@ from uds.core.environment import Environment
from uds.services.Test.provider import TestProvider
from uds.services.Test.service import TestServiceCache, TestServiceNoCache
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import services as services_fixtures
if typing.TYPE_CHECKING:
@ -48,7 +48,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ServiceCacheUpdaterTest(UDSTransactionTestCase):
class ServiceCacheUpdaterTest(UDSTestCase):
servicePool: 'models.ServicePool'
def setUp(self) -> None:

View File

@ -36,7 +36,7 @@ import datetime
from uds import models
from uds.core.util.stats import counters
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import stats_counters as fixtures_stats_counters
from uds.core.workers import stats_collector
@ -65,7 +65,7 @@ class StatsFunction:
return self.counter * 100
class StatsAcummulatorTest(UDSTransactionTestCase):
class StatsAcummulatorTest(UDSTestCase):
def setUp(self):
# In fact, real data will not be assigned to Userservices, but it's ok for testing
for pool_id in range(NUMBER_OF_POOLS):

View File

@ -33,7 +33,7 @@ import typing
import logging
import datetime
from ...utils.test import UDSTransactionTestCase
from ...utils.test import UDSTestCase
from ...fixtures import services as services_fixtures
from uds.models import UserService
@ -47,7 +47,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class StuckCleanerTest(UDSTransactionTestCase):
class StuckCleanerTest(UDSTestCase):
userServices: typing.List['models.UserService']
def setUp(self) -> None:

View File

@ -36,7 +36,7 @@ import logging
from uds.core import messaging
from ..fixtures import notifiers as notifiers_fixtures
from ..utils.test import UDSTransactionTestCase
from ..utils.test import UDSTestCase
if typing.TYPE_CHECKING:
from uds import models
@ -45,7 +45,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class EmailNotifierTest(UDSTransactionTestCase):
class EmailNotifierTest(UDSTestCase):
"""
Test Email Notifier
"""

View File

@ -45,7 +45,7 @@ from uds.REST.handlers import AUTH_TOKEN_HEADER
NUMBER_OF_ITEMS_TO_CREATE = 4
class RESTTestCase(test.UDSTransactionTestCase):
class RESTTestCase(test.UDSTestCase):
# Authenticators related
auth: models.Authenticator
groups: typing.List[models.Group]

View File

@ -113,7 +113,7 @@ class UDSTestCase(TestCase):
super().setUpClass()
setupClass(cls)
class UDSTransactionTestCase(TransactionTestCase):
class UDSTestCase(TransactionTestCase):
client_class: typing.Type = UDSClient
client: UDSClient
@ -123,6 +123,6 @@ class UDSTransactionTestCase(TransactionTestCase):
super().setUpClass()
setupClass(cls)
def setupClass(cls: typing.Union[typing.Type[UDSTestCase], typing.Type[UDSTransactionTestCase]]) -> None:
def setupClass(cls: typing.Union[typing.Type[UDSTestCase], typing.Type[UDSTestCase]]) -> None:
# Nothing right now
pass

View File

@ -42,7 +42,7 @@ from uds.REST.handlers import AUTH_TOKEN_HEADER
NUMBER_OF_ITEMS_TO_CREATE = 4
class WEBTestCase(test.UDSTransactionTestCase):
class WEBTestCase(test.UDSTestCase):
# Authenticators related
auth: models.Authenticator
groups: typing.List[models.Group]

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2019 Virtual Cable S.L.
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -12,7 +12,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -28,5 +28,5 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@author: Adolfo Gómez, dkmaster at dkmon dot com
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""

View File

@ -0,0 +1,536 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
Implements de AutoSerializable class, that allows to serialize/deserialize in a simple way
This class in incompatible with UserInterface derived classes, as it metaclass is not compatible
To use it, simple place it as first parent class, and follow by the rest of the classes to inherit from
Example:
from uds.core.util import AutoSerializable
from uds.core import services
class UserDeploymentService(AutoSerializable, services.UserDeployment):
...
"""
import itertools
import typing
import logging
import json
import zlib
import base64
import hashlib
import struct
# Import the cryptography library
from cryptography import fernet
from django.conf import settings
class _Unassigned:
pass
# means field has no default value
UNASSIGNED = _Unassigned()
T = typing.TypeVar('T')
DefaultValueType = typing.Union[T, typing.Callable[[], T], _Unassigned]
logger = logging.getLogger(__name__)
# Constants
# Headers for the serialized data
HEADER_BASE: typing.Final[bytes] = b'MGB1'
HEADER_COMPRESSED: typing.Final[bytes] = b'MGZ1'
HEADER_ENCRYPTED: typing.Final[bytes] = b'MGE1'
# Size of crc32 checksum
CRC_SIZE: typing.Final[int] = 4
# Packing data struct
pack_struct = struct.Struct('<HHI')
# Helper functions
def fernet_key(crypt_key: bytes) -> str:
"""Generate key from password and seed
Args:
seed: Seed to use (normally header)
Note: if password is not set, this will raise an exception
"""
# Generate an URL-Safe base64 encoded 32 bytes key for Fernet
# First, with seed + password, generate a 32 bytes key based on seed + password
return base64.b64encode(hashlib.sha256(crypt_key).digest()).decode()
class _SerializableField(typing.Generic[T]):
name: str
type: typing.Type[T]
default: DefaultValueType
def __init__(self, type: typing.Type[T], default: DefaultValueType = UNASSIGNED):
self.type = type
self.default = default
def _default(self) -> T:
if isinstance(self.default, _Unassigned):
return self.type()
elif callable(self.default):
return self.default()
else:
return self.default
def __get__(
self,
instance: 'AutoSerializable',
objtype: typing.Optional[typing.Type['AutoSerializable']] = None,
) -> T:
"""Get field value
Arguments:
instance {SerializableFields} -- Instance of class with field
"""
if hasattr(instance, '_fields'):
return getattr(instance, '_fields').get(self.name, self._default())
if self.default is None:
raise AttributeError(f"Field {self.name} is not set")
return self._default()
def __set__(self, instance: 'AutoSerializable', value: T) -> None:
if not isinstance(value, self.type):
raise TypeError(
f"Field {self.name} cannot be set to {value} (type {self.type.__name__})"
)
if not hasattr(instance, '_fields'):
setattr(instance, '_fields', {})
getattr(instance, '_fields')[self.name] = value
def marshal(self, instance: 'AutoSerializable') -> bytes:
"""Basic marshalling of field
Args:
instance: Instance of class with field
Returns:
Marshalled field
Note:
Only str, int, and float are supported in this base class.
"""
if self.type in (str, int, float):
return str(self.__get__(instance)).encode()
raise TypeError(f"Field {self.name} cannot be marshalled (type {self.type})")
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
"""Basic unmarshalling of field
Args:
instance: Instance of class with field
data: Marshalled field string
Returns:
None. The data is loaded into the field.
Note:
Only str, int, and float are supported in this base class.
"""
if self.type in (str, int, float):
tp: typing.Type = self.type
self.__set__(instance, tp(data.decode()))
return
raise TypeError(f"Field {self.name} cannot be unmarshalled (type {self.type})")
# Integer field
class IntField(_SerializableField[int]):
def __init__(self, default: int = 0):
super().__init__(int, default)
class StringField(_SerializableField[str]):
def __init__(self, default: str = ''):
super().__init__(str, default)
class FloatField(_SerializableField[float]):
def __init__(self, default: float = 0.0):
super().__init__(float, default)
class BoolField(_SerializableField[bool]):
def __init__(self, default: bool = False):
super().__init__(bool, default)
def marshal(self, instance: 'AutoSerializable') -> bytes:
return b'1' if self.__get__(instance) else b'0'
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
self.__set__(instance, data == b'1')
class ListField(_SerializableField[typing.List]):
"""List field
Note:
All elements in the list must be serializable.
"""
def __init__(
self,
default: typing.Union[
typing.List, typing.Callable[[], typing.List]
] = lambda: [],
):
super().__init__(list, default)
def marshal(self, instance: 'AutoSerializable') -> bytes:
return json.dumps(self.__get__(instance)).encode()
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
self.__set__(instance, json.loads(data))
class DictField(_SerializableField[typing.Dict]):
"""Dict field
Note:
All elements in the dict must be serializable.
"""
def __init__(
self,
default: typing.Union[
typing.Dict, typing.Callable[[], typing.Dict]
] = lambda: {},
):
super().__init__(dict, default)
def marshal(self, instance: 'AutoSerializable') -> bytes:
return json.dumps(self.__get__(instance)).encode()
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
self.__set__(instance, json.loads(data))
class PasswordField(StringField):
"""Password field
Note:
The password is stored as a compressed string.
"""
_crypt_key: str = ''
def __init__(self, default: str = '', crypt_key: str = ''):
super().__init__(default)
self._crypt_key = crypt_key
def _encrypt(self, value: str) -> bytes:
"""Encrypt a password
Args:
value: Password to encrypt
Returns:
Encrypted password
"""
if self._crypt_key:
# Generate a Fernet key from the password
f = fernet.Fernet(fernet_key(self._crypt_key.encode()))
return HEADER_ENCRYPTED + f.encrypt(value.encode())
logger.warning("Password encryption is not enabled")
return zlib.compress(value.encode())
def _decrypt(self, value: bytes) -> bytes:
"""Decrypt a password
Args:
value: Password to decrypt
Returns:
Decrypted password
"""
if self._crypt_key and value[: len(HEADER_ENCRYPTED)] == HEADER_ENCRYPTED:
try:
f = fernet.Fernet(fernet_key(self._crypt_key.encode()))
return f.decrypt(value[len(HEADER_ENCRYPTED) :])
except Exception: # nosec: Defaults to zlib compression out of the exception
pass # returns the unencrypted password
return zlib.decompress(value)
def marshal(self, instance: 'AutoSerializable') -> bytes:
return base64.b64encode(self._encrypt(self.__get__(instance)))
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
self.__set__(instance, self._decrypt(base64.b64decode(data)).decode())
# ************************
# * Serializable classes *
# ************************
class _FieldNameSetter(type):
"""Simply adds the name of the field in the class to the field object"""
def __new__(cls, name, bases, attrs):
fields = dict()
for k, v in attrs.items():
if isinstance(v, _SerializableField):
v.name = k
return super().__new__(cls, name, bases, attrs)
class AutoSerializable(metaclass=_FieldNameSetter):
"""This class allows the automatic serialization of fields in a class.
Example:
>>> class Test(SerializableFields):
... a = IntField()
... b = StrField()
... c = FloatField()
... d = ListField(defalut=lambda: [1, 2, 3])
"""
_fields: typing.Dict[str, typing.Any]
def process_data(self, header: bytes, data: bytes) -> bytes:
"""Process data before marshalling
Args:
data: Data to process
Returns:
Processed data
Note:
We provide header as a way of some constant value to be used in the
processing. This is useful for encryption. (as salt, for example)
"""
return bytes(a ^ b for a, b in zip(data, itertools.cycle(header)))
return data
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
"""Process data after unmarshalling
Args:
data: Data to process
Returns:
Processed data
Note:
We provide header as a way of some constant value to be used in the
processing. This is useful for encryption. (as salt, for example)
"""
return bytes(a ^ b for a, b in zip(data, itertools.cycle(header)))
@typing.final
def marshal(self) -> bytes:
# Iterate over own members and extract fields
fields = {}
for k, v in self.__class__.__dict__.items():
if isinstance(v, _SerializableField):
fields[v.name] = (str(v.__class__.__name__), v.marshal(self))
# Serialized data is:
# 2 bytes -> name length
# 2 bytes -> type name length
# 4 bytes -> data length
# n bytes -> name
# n bytes -> type name
# n bytes -> data
data = b''.join(
pack_struct.pack(len(name.encode()), len(type_name.encode()), len(value))
+ name.encode()
+ type_name.encode()
+ value
for name, (type_name, value) in fields.items()
)
# Calculate checksum
checksum = zlib.crc32(data)
# Compose header, that is V1_HEADER + checksum (4 bytes, big endian)
header = HEADER_BASE + checksum.to_bytes(CRC_SIZE, 'big')
# Return data processed with header
return header + self.process_data(header, data)
# final method, do not override
@typing.final
def unmarshal(self, data: bytes) -> None:
# Check header
if data[: len(HEADER_BASE)] != HEADER_BASE:
raise ValueError('Invalid header')
header = data[: len(HEADER_BASE) + CRC_SIZE]
# Extract checksum
checksum = int.from_bytes(
header[len(HEADER_BASE) : len(HEADER_BASE) + 4], 'big'
)
# Unprocess data
data = self.unprocess_data(header, data[len(header) :])
# Check checksum
if zlib.crc32(data) != checksum:
raise ValueError('Invalid checksum')
# Iterate over fields
fields = {}
while data:
# Extract name length, type name length and data length
name_len, type_name_len, data_len = pack_struct.unpack(data[:8])
# Extract name, type name and data
name, type_name, value = (
data[8 : 8 + name_len].decode(),
data[8 + name_len : 8 + name_len + type_name_len].decode(),
data[
8
+ name_len
+ type_name_len : 8
+ name_len
+ type_name_len
+ data_len
],
)
# Add to fields
fields[name] = (type_name, value)
# Remove from data
data = data[8 + name_len + type_name_len + data_len :]
for k, v in self.__class__.__dict__.items():
if isinstance(v, _SerializableField):
if v.name in fields and fields[v.name][0] == str(v.__class__.__name__):
v.unmarshal(self, fields[v.name][1])
else:
if not v.name in fields:
logger.warning(f"Field {v.name} not found in unmarshalled data")
else:
logger.warning(
f"Field {v.name} has wrong type in unmarshalled data (should be {fields[v.name][0]} and is {v.__class__.name}"
)
def __eq__(self, other: typing.Any) -> bool:
"""
Basic equality check, checks if all _SerializableFields are equal
"""
if not isinstance(other, AutoSerializable):
return False
for k, v in self.__class__.__dict__.items():
if isinstance(v, _SerializableField):
if getattr(self, k) != getattr(other, k):
return False
return True
def __str__(self) -> str:
return ', '.join(
[
f"{k}={v.type.__name__}({v.__get__(self)})"
for k, v in self.__class__.__dict__.items()
if isinstance(v, _SerializableField)
]
)
class AutoSerializableCompressed(AutoSerializable):
"""This class allows the automatic serialization of fields in a class compressed with zlib."""
def process_data(self, header: bytes, data: bytes) -> bytes:
return HEADER_COMPRESSED + zlib.compress(data)
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
# if decompress fails, return data as is
try:
# Check header
if data[: len(HEADER_COMPRESSED)] != HEADER_COMPRESSED:
raise Exception() # Returns data as is
return zlib.decompress(data[len(HEADER_COMPRESSED) :])
except Exception:
return super().unprocess_data(header, data)
class AutoSerializableEncrypted(AutoSerializable):
"""This class allows the automatic serialization of fields in a class encrypted with AES."""
# Common key for all instances
_crypt_key: typing.ClassVar[str] = settings.SECRET_KEY[:16]
def key(self, seed: bytes) -> str:
"""Generate key from password and seed
Args:
seed: Seed to use (normally header)
Note: if password is not set, this will raise an exception
"""
if not self._crypt_key:
raise ValueError('Password not set')
return fernet_key(seed + (self._crypt_key.encode()))
def process_data(self, header: bytes, data: bytes) -> bytes:
f = fernet.Fernet(self.key(header))
return HEADER_ENCRYPTED + f.encrypt(data)
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
# if decrypt fails, return data as is
try:
# Check if data is encrypted
if data[: len(HEADER_ENCRYPTED)] != HEADER_ENCRYPTED:
return super().unprocess_data(header, data)
f = fernet.Fernet(self.key(header))
return f.decrypt(data[len(HEADER_ENCRYPTED) :])
except fernet.InvalidToken:
return super().unprocess_data(header, data)
@staticmethod
def set_crypt_key(crypt_key: str) -> None:
"""Set the password for all instances of this class.
Args:
password: Password to set
Note:
On Django, this should be set preferably in settings.py,
so all instances of this class will use the same password from the start.
"""
AutoSerializableEncrypted._crypt_key = crypt_key[:16]

View File

@ -1,4 +1,5 @@
# Generated by Django 4.1 on 2022-10-01 06:23
# Generated by Django 4.1.3 on 2022-11-12 21:03
import logging
from django.db import migrations, models
import django.db.models.deletion
@ -7,6 +8,20 @@ import uds.models.notifications
import uds.models.user_service_session
import uds.models.util
logger = logging.getLogger('uds')
def remove_servicepool_with_null_service(apps, schema_editor):
""" In fact, there should be no Service Pool with no service, but we have found some in the wild"""
ServicePool = apps.get_model('uds', 'ServicePool')
# Log in the django.db.backends logger removed services
logger.info('Removing ServicePools with null service')
for i in ServicePool.objects.filter(service=None):
logger.info(' * Removing ServicePool %s - %s', i.uuid, i.name)
i.delete()
def null_backwards(apps, schema_editor):
# Remove null services backwards is not possible, we have deleted them
pass
class Migration(migrations.Migration):
@ -15,6 +30,8 @@ class Migration(migrations.Migration):
]
operations = [
# First, we remove all DeployedServices with null service because we have fixed foreign key
migrations.RunPython(remove_servicepool_with_null_service, null_backwards),
migrations.CreateModel(
name="Notification",
fields=[
@ -313,6 +330,15 @@ class Migration(migrations.Migration):
default=uds.core.util.model.generateUuid, max_length=50, unique=True
),
),
migrations.AlterField(
model_name="servicepool",
name="service",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="deployedServices",
to="uds.service",
),
),
migrations.AlterField(
model_name="servicepool",
name="uuid",

View File

@ -84,10 +84,8 @@ class ServicePool(UUIDModel, TaggingMixin): # type: ignore
name = models.CharField(max_length=128, default='')
short_name = models.CharField(max_length=32, default='')
comments = models.CharField(max_length=256, default='')
service: 'models.ForeignKey[Service | None]' = models.ForeignKey(
service: 'models.ForeignKey[Service]' = models.ForeignKey(
Service,
null=True,
blank=True,
related_name='deployedServices',
on_delete=models.CASCADE,
)

View File

@ -74,6 +74,9 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
deployed_service: 'models.ForeignKey["ServicePool"]' = models.ForeignKey(
ServicePool, on_delete=models.CASCADE, related_name='userServices'
)
# Althoug deployed_services has its publication, the user service is bound to a specific publication
# so we need to store the publication id here (or the revision, but we need to store something)
# storing the id simplifies the queries
publication: 'models.ForeignKey[ServicePoolPublication | None]' = models.ForeignKey(
ServicePoolPublication,
on_delete=models.CASCADE,