mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Adding autoserializable and fixing up some tests
This commit is contained in:
parent
57013ed1e1
commit
9ca30b5c30
@ -1,3 +1,3 @@
|
||||
udsactor-unmanaged_3.6.0_all.deb admin optional
|
||||
udsactor_3.6.0_all.deb admin optional
|
||||
udsactor_3.6.0_amd64.buildinfo admin optional
|
||||
udsactor-unmanaged_4.0.0_all.deb admin optional
|
||||
udsactor_4.0.0_all.deb admin optional
|
||||
udsactor_4.0.0_amd64.buildinfo admin optional
|
||||
|
@ -41,12 +41,12 @@ from ...fixtures import osmanagers as osmanagers_fixtures
|
||||
from ...fixtures import notifiers as notifiers_fixtures
|
||||
from ...fixtures import stats_counters as stats_counters_fixtures
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelXXTest(UDSTransactionTestCase):
|
||||
class ModelXXTest(UDSTestCase):
|
||||
pass
|
||||
|
@ -37,7 +37,7 @@ from uds import models
|
||||
from uds.models.account_usage import AccountUsage
|
||||
from ...fixtures import services as services_fixtures
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
|
||||
NUM_USERSERVICES = 8
|
||||
|
||||
|
||||
class ModelAccountTest(UDSTransactionTestCase):
|
||||
class ModelAccountTest(UDSTestCase):
|
||||
user_services: typing.List['models.UserService']
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
@ -35,7 +35,7 @@ import logging
|
||||
|
||||
from uds import models
|
||||
from uds.models.calendar_rule import freqs, dunits, weekdays
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
@ -43,7 +43,7 @@ if typing.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelCalendarTest(UDSTransactionTestCase):
|
||||
class ModelCalendarTest(UDSTestCase):
|
||||
def test_calendar(self) -> None:
|
||||
# Ensure we can create some calendars
|
||||
for i in range(32):
|
||||
|
@ -35,12 +35,12 @@ import logging
|
||||
|
||||
from uds.core import messaging
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures.images import createImage
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds import models
|
||||
|
||||
|
||||
class ModelImageTest(UDSTransactionTestCase):
|
||||
class ModelImageTest(UDSTestCase):
|
||||
pass
|
||||
|
@ -32,14 +32,14 @@
|
||||
import typing
|
||||
import logging
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import authenticators as authenticators_fixtures
|
||||
from uds import models
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
class ModelUUIDTest(UDSTransactionTestCase):
|
||||
class ModelUUIDTest(UDSTestCase):
|
||||
auth: 'models.Authenticator'
|
||||
user: 'models.User'
|
||||
group: 'models.Group'
|
||||
|
@ -32,7 +32,7 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
# We use commit/rollback
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from uds.core.ui.user_interface import (
|
||||
gui,
|
||||
UserInterface
|
||||
@ -42,7 +42,7 @@ import time
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class UserinterfaceTest(UDSTransactionTestCase):
|
||||
class UserinterfaceTest(UDSTestCase):
|
||||
|
||||
def test_userinterface(self):
|
||||
pass
|
||||
|
114
server/src/tests/core/util/test_auto_serializable.py
Normal file
114
server/src/tests/core/util/test_auto_serializable.py
Normal file
@ -0,0 +1,114 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
"""
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
# We use commit/rollback
|
||||
import typing
|
||||
|
||||
from ...utils.test import UDSTestCase
|
||||
from uds.core.util import auto_serializable
|
||||
|
||||
UNICODE_CHARS = 'ñöçóá^(pípè)'
|
||||
UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
|
||||
|
||||
|
||||
class AutoSerializableClass(auto_serializable.AutoSerializable):
|
||||
int_field = auto_serializable.IntField()
|
||||
str_field = auto_serializable.StringField()
|
||||
float_field = auto_serializable.FloatField()
|
||||
bool_field = auto_serializable.BoolField()
|
||||
password_field = auto_serializable.PasswordField()
|
||||
list_field = auto_serializable.ListField()
|
||||
dict_field = auto_serializable.DictField()
|
||||
|
||||
|
||||
class AutoSerializableCompressedClass(auto_serializable.AutoSerializableCompressed):
|
||||
int_field = auto_serializable.IntField()
|
||||
str_field = auto_serializable.StringField()
|
||||
float_field = auto_serializable.FloatField()
|
||||
bool_field = auto_serializable.BoolField()
|
||||
password_field = auto_serializable.PasswordField()
|
||||
list_field = auto_serializable.ListField()
|
||||
dict_field = auto_serializable.DictField()
|
||||
|
||||
|
||||
class AutoSerializableEncryptedClass(auto_serializable.AutoSerializableEncrypted):
|
||||
int_field = auto_serializable.IntField()
|
||||
str_field = auto_serializable.StringField()
|
||||
float_field = auto_serializable.FloatField()
|
||||
bool_field = auto_serializable.BoolField()
|
||||
password_field = auto_serializable.PasswordField()
|
||||
list_field = auto_serializable.ListField()
|
||||
dict_field = auto_serializable.DictField()
|
||||
|
||||
|
||||
class AutoSerializable(UDSTestCase):
|
||||
def basic_check(
|
||||
self,
|
||||
cls1: typing.Type[
|
||||
'AutoSerializableClass|AutoSerializableCompressedClass|AutoSerializableEncryptedClass'
|
||||
],
|
||||
cls2: typing.Type[
|
||||
'AutoSerializableClass|AutoSerializableCompressedClass|AutoSerializableEncryptedClass'
|
||||
],
|
||||
) -> None:
|
||||
# Test basic serialization
|
||||
a = cls1()
|
||||
a.int_field = 1
|
||||
a.str_field = UNICODE_CHARS
|
||||
a.float_field = 3.0
|
||||
a.bool_field = True
|
||||
a.password_field = UNICODE_CHARS_2 # nosec: test password
|
||||
a.list_field = [1, 2, 3]
|
||||
a.dict_field = {'a': 1, 'b': 2, 'c': 3}
|
||||
|
||||
data = a.marshal()
|
||||
|
||||
b = cls2()
|
||||
b.unmarshal(data)
|
||||
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def test_auto_serializable_base(self):
|
||||
self.basic_check(AutoSerializableClass, AutoSerializableClass)
|
||||
|
||||
def test_auto_serializable_compressed(self):
|
||||
self.basic_check(AutoSerializableCompressedClass, AutoSerializableCompressedClass)
|
||||
|
||||
def test_auto_serializable_encrypted(self):
|
||||
self.basic_check(AutoSerializableEncryptedClass, AutoSerializableEncryptedClass)
|
||||
|
||||
def test_auto_serializable_base_compressed(self):
|
||||
self.basic_check(AutoSerializableClass, AutoSerializableCompressedClass)
|
||||
|
||||
def test_auto_serializable_base_encrypted(self):
|
||||
self.basic_check(AutoSerializableClass, AutoSerializableEncryptedClass)
|
@ -32,40 +32,59 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
# We use commit/rollback
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from uds.core.util.cache import Cache
|
||||
import time
|
||||
import string
|
||||
|
||||
# Some random chars, that include unicode non-ascci chars
|
||||
UNICODE_CHARS = 'ñöçóá^(pípè)'
|
||||
UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
|
||||
VALUE_1 = [u'únîcödè€', b'string', {'a': 1, 'b': 2.0}]
|
||||
|
||||
|
||||
class CacheTest(UDSTransactionTestCase):
|
||||
|
||||
class CacheTest(UDSTestCase):
|
||||
def test_cache(self):
|
||||
cache = Cache(UNICODE_CHARS)
|
||||
|
||||
# Get default value, with unicode
|
||||
self.assertEqual(cache.get(UNICODE_CHARS, UNICODE_CHARS_2), UNICODE_CHARS_2, 'Unicode unexisting key returns default unicode')
|
||||
self.assertEqual(
|
||||
cache.get(UNICODE_CHARS, UNICODE_CHARS_2),
|
||||
UNICODE_CHARS_2,
|
||||
'Unicode unexisting key returns default unicode',
|
||||
)
|
||||
|
||||
# Remove unexisting key, not a problem
|
||||
self.assertEqual(cache.remove('non-existing-1'), False, 'Removing unexisting key')
|
||||
self.assertEqual(
|
||||
cache.remove('non-existing-1'), False, 'Removing unexisting key'
|
||||
)
|
||||
|
||||
# Add new key (non existing) with default duration (60 seconds probable)
|
||||
cache.put(UNICODE_CHARS_2, VALUE_1)
|
||||
|
||||
# checks it
|
||||
self.assertEqual(cache.get(UNICODE_CHARS_2), VALUE_1, 'Put a key and recover it')
|
||||
self.assertEqual(
|
||||
cache.get(UNICODE_CHARS_2), VALUE_1, 'Put a key and recover it'
|
||||
)
|
||||
|
||||
# Add new "str" key, with 1 second life, wait 2 seconds and recover
|
||||
cache.put(b'key', VALUE_1, 1)
|
||||
time.sleep(1.1)
|
||||
self.assertEqual(cache.get(b'key'), None, 'Put an "str" key and recover it once it has expired')
|
||||
self.assertEqual(
|
||||
cache.get(b'key'),
|
||||
None,
|
||||
'Put an "str" key and recover it once it has expired',
|
||||
)
|
||||
|
||||
# Refresh previous key and will be again available
|
||||
cache.refresh(b'key')
|
||||
self.assertEqual(cache.get(b'key'), VALUE_1, 'Refreshed cache key is {} and should be {}'.format(cache.get(b'key'), VALUE_1))
|
||||
self.assertEqual(
|
||||
cache.get(b'key'),
|
||||
VALUE_1,
|
||||
'Refreshed cache key is {} and should be {}'.format(
|
||||
cache.get(b'key'), VALUE_1
|
||||
),
|
||||
)
|
||||
|
||||
# Checks cache clean
|
||||
cache.put('key', VALUE_1)
|
||||
@ -82,5 +101,8 @@ class CacheTest(UDSTransactionTestCase):
|
||||
time.sleep(0.1)
|
||||
Cache.cleanUp()
|
||||
cache.refresh('key')
|
||||
self.assertEqual(cache.get('key'), None, 'Put a key and recover it once it has expired and has been cleaned')
|
||||
|
||||
self.assertEqual(
|
||||
cache.get('key'),
|
||||
None,
|
||||
'Put a key and recover it once it has expired and has been cleaned',
|
||||
)
|
||||
|
@ -30,14 +30,14 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures.calendars import createCalendars
|
||||
from uds.core.util import calendar
|
||||
from uds.models import Calendar
|
||||
import datetime
|
||||
|
||||
|
||||
class CalendarTest(UDSTransactionTestCase):
|
||||
class CalendarTest(UDSTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
createCalendars()
|
||||
|
@ -33,14 +33,14 @@ import typing
|
||||
import logging
|
||||
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
|
||||
from uds.core.util import net
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NetTest(UDSTransactionTestCase):
|
||||
class NetTest(UDSTestCase):
|
||||
|
||||
def testNetworkFromString(self):
|
||||
for n in (
|
||||
|
@ -37,7 +37,7 @@ from uds.core.util import permissions
|
||||
from uds.core.util import ot
|
||||
from uds import models
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import (
|
||||
authenticators as authenticators_fixtures,
|
||||
services as services_fixtures,
|
||||
@ -45,14 +45,18 @@ from ...fixtures import (
|
||||
)
|
||||
|
||||
|
||||
class PermissionsTest(UDSTransactionTestCase):
|
||||
class PermissionsTest(UDSTestCase):
|
||||
authenticator: models.Authenticator
|
||||
groups: typing.List[models.Group]
|
||||
users: typing.List[models.User]
|
||||
admins: typing.List[models.User]
|
||||
staffs: typing.List[models.User]
|
||||
userService: models.UserService
|
||||
servicePool: models.ServicePool
|
||||
service: models.Service
|
||||
provider: models.Provider
|
||||
network: models.Network
|
||||
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.authenticator = authenticators_fixtures.createAuthenticator()
|
||||
@ -72,6 +76,10 @@ class PermissionsTest(UDSTransactionTestCase):
|
||||
list(self.users[0].groups.all()),
|
||||
'managed',
|
||||
)
|
||||
self.servicePool = self.userService.deployed_service
|
||||
self.service = self.servicePool.service
|
||||
self.provider = self.service.provider
|
||||
|
||||
self.network = network_fixtures.createNetwork()
|
||||
|
||||
def doTestUserPermissions(self, obj, user: models.User):
|
||||
|
@ -37,7 +37,7 @@ import typing
|
||||
from ...fixtures.stats_counters import create_stats_counters
|
||||
|
||||
# We use commit/rollback
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
|
||||
from uds.core.util.stats import counters
|
||||
from uds import models
|
||||
@ -47,7 +47,7 @@ END_DATE_DAY = datetime.datetime(2020, 1, 2, 0, 0, 0)
|
||||
END_DATE_MONTH = datetime.datetime(2020, 2, 1, 0, 0, 0)
|
||||
END_DATE_YEAR = datetime.datetime(2021, 1, 1, 0, 0, 0)
|
||||
|
||||
class StatsCountersTest(UDSTransactionTestCase):
|
||||
class StatsCountersTest(UDSTestCase):
|
||||
def setUp(self) -> None:
|
||||
return super().setUp()
|
||||
|
||||
|
@ -30,7 +30,7 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from uds.core.util.storage import Storage
|
||||
|
||||
UNICODE_CHARS = 'ñöçóá^(pípè)'
|
||||
@ -38,7 +38,7 @@ UNICODE_CHARS_2 = 'ñöçóá^(€íöè)'
|
||||
VALUE_1 = ['unicode', b'string', {'a': 1, 'b': 2.0}]
|
||||
|
||||
|
||||
class StorageTest(UDSTransactionTestCase):
|
||||
class StorageTest(UDSTestCase):
|
||||
def test_storage(self):
|
||||
storage = Storage(UNICODE_CHARS)
|
||||
|
||||
|
@ -31,8 +31,6 @@
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import time
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from ...utils.test import UDSTestCase
|
||||
from django.conf import settings
|
||||
|
@ -38,11 +38,11 @@ from uds.core.util import config
|
||||
from uds.core.util.state import State
|
||||
from uds.core.workers.assigned_unused import AssignedAndUnused
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import services as fixtures_services
|
||||
|
||||
|
||||
class AssignedAndUnusedTest(UDSTransactionTestCase):
|
||||
class AssignedAndUnusedTest(UDSTestCase):
|
||||
userServices: typing.List[models.UserService]
|
||||
|
||||
def setUp(self):
|
||||
|
@ -38,14 +38,14 @@ from uds.core.util import config
|
||||
from uds.core.util.state import State
|
||||
from uds.core.workers.hanged_userservice_cleaner import HangedCleaner
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import services as fixtures_services
|
||||
|
||||
MAX_INIT = 300
|
||||
TEST_SERVICES = 5 * 5 # Ensure multiple of 5 for testing
|
||||
|
||||
|
||||
class HangedCleanerTest(UDSTransactionTestCase):
|
||||
class HangedCleanerTest(UDSTestCase):
|
||||
userServices: typing.List[models.UserService]
|
||||
|
||||
def setUp(self):
|
||||
|
@ -40,7 +40,7 @@ from uds.core.environment import Environment
|
||||
from uds.services.Test.provider import TestProvider
|
||||
from uds.services.Test.service import TestServiceCache, TestServiceNoCache
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import services as services_fixtures
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
@ -48,7 +48,7 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ServiceCacheUpdaterTest(UDSTransactionTestCase):
|
||||
class ServiceCacheUpdaterTest(UDSTestCase):
|
||||
servicePool: 'models.ServicePool'
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
@ -36,7 +36,7 @@ import datetime
|
||||
from uds import models
|
||||
from uds.core.util.stats import counters
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import stats_counters as fixtures_stats_counters
|
||||
|
||||
from uds.core.workers import stats_collector
|
||||
@ -65,7 +65,7 @@ class StatsFunction:
|
||||
return self.counter * 100
|
||||
|
||||
|
||||
class StatsAcummulatorTest(UDSTransactionTestCase):
|
||||
class StatsAcummulatorTest(UDSTestCase):
|
||||
def setUp(self):
|
||||
# In fact, real data will not be assigned to Userservices, but it's ok for testing
|
||||
for pool_id in range(NUMBER_OF_POOLS):
|
||||
|
@ -33,7 +33,7 @@ import typing
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
from ...utils.test import UDSTransactionTestCase
|
||||
from ...utils.test import UDSTestCase
|
||||
from ...fixtures import services as services_fixtures
|
||||
|
||||
from uds.models import UserService
|
||||
@ -47,7 +47,7 @@ if typing.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StuckCleanerTest(UDSTransactionTestCase):
|
||||
class StuckCleanerTest(UDSTestCase):
|
||||
userServices: typing.List['models.UserService']
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
@ -36,7 +36,7 @@ import logging
|
||||
from uds.core import messaging
|
||||
|
||||
from ..fixtures import notifiers as notifiers_fixtures
|
||||
from ..utils.test import UDSTransactionTestCase
|
||||
from ..utils.test import UDSTestCase
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds import models
|
||||
@ -45,7 +45,7 @@ if typing.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailNotifierTest(UDSTransactionTestCase):
|
||||
class EmailNotifierTest(UDSTestCase):
|
||||
"""
|
||||
Test Email Notifier
|
||||
"""
|
||||
|
@ -45,7 +45,7 @@ from uds.REST.handlers import AUTH_TOKEN_HEADER
|
||||
NUMBER_OF_ITEMS_TO_CREATE = 4
|
||||
|
||||
|
||||
class RESTTestCase(test.UDSTransactionTestCase):
|
||||
class RESTTestCase(test.UDSTestCase):
|
||||
# Authenticators related
|
||||
auth: models.Authenticator
|
||||
groups: typing.List[models.Group]
|
||||
|
@ -113,7 +113,7 @@ class UDSTestCase(TestCase):
|
||||
super().setUpClass()
|
||||
setupClass(cls)
|
||||
|
||||
class UDSTransactionTestCase(TransactionTestCase):
|
||||
class UDSTestCase(TransactionTestCase):
|
||||
client_class: typing.Type = UDSClient
|
||||
|
||||
client: UDSClient
|
||||
@ -123,6 +123,6 @@ class UDSTransactionTestCase(TransactionTestCase):
|
||||
super().setUpClass()
|
||||
setupClass(cls)
|
||||
|
||||
def setupClass(cls: typing.Union[typing.Type[UDSTestCase], typing.Type[UDSTransactionTestCase]]) -> None:
|
||||
def setupClass(cls: typing.Union[typing.Type[UDSTestCase], typing.Type[UDSTestCase]]) -> None:
|
||||
# Nothing right now
|
||||
pass
|
||||
|
@ -42,7 +42,7 @@ from uds.REST.handlers import AUTH_TOKEN_HEADER
|
||||
NUMBER_OF_ITEMS_TO_CREATE = 4
|
||||
|
||||
|
||||
class WEBTestCase(test.UDSTransactionTestCase):
|
||||
class WEBTestCase(test.UDSTestCase):
|
||||
# Authenticators related
|
||||
auth: models.Authenticator
|
||||
groups: typing.List[models.Group]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2012-2019 Virtual Cable S.L.
|
||||
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
@ -12,7 +12,7 @@
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
@ -28,5 +28,5 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
536
server/src/uds/core/util/auto_serializable.py
Normal file
536
server/src/uds/core/util/auto_serializable.py
Normal file
@ -0,0 +1,536 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2022 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
|
||||
Implements de AutoSerializable class, that allows to serialize/deserialize in a simple way
|
||||
This class in incompatible with UserInterface derived classes, as it metaclass is not compatible
|
||||
|
||||
To use it, simple place it as first parent class, and follow by the rest of the classes to inherit from
|
||||
|
||||
Example:
|
||||
from uds.core.util import AutoSerializable
|
||||
from uds.core import services
|
||||
|
||||
class UserDeploymentService(AutoSerializable, services.UserDeployment):
|
||||
...
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import typing
|
||||
import logging
|
||||
import json
|
||||
import zlib
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
# Import the cryptography library
|
||||
from cryptography import fernet
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class _Unassigned:
|
||||
pass
|
||||
|
||||
|
||||
# means field has no default value
|
||||
UNASSIGNED = _Unassigned()
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
DefaultValueType = typing.Union[T, typing.Callable[[], T], _Unassigned]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
|
||||
# Headers for the serialized data
|
||||
HEADER_BASE: typing.Final[bytes] = b'MGB1'
|
||||
HEADER_COMPRESSED: typing.Final[bytes] = b'MGZ1'
|
||||
HEADER_ENCRYPTED: typing.Final[bytes] = b'MGE1'
|
||||
|
||||
# Size of crc32 checksum
|
||||
CRC_SIZE: typing.Final[int] = 4
|
||||
|
||||
# Packing data struct
|
||||
pack_struct = struct.Struct('<HHI')
|
||||
|
||||
# Helper functions
|
||||
def fernet_key(crypt_key: bytes) -> str:
|
||||
"""Generate key from password and seed
|
||||
|
||||
Args:
|
||||
seed: Seed to use (normally header)
|
||||
|
||||
Note: if password is not set, this will raise an exception
|
||||
"""
|
||||
# Generate an URL-Safe base64 encoded 32 bytes key for Fernet
|
||||
# First, with seed + password, generate a 32 bytes key based on seed + password
|
||||
return base64.b64encode(hashlib.sha256(crypt_key).digest()).decode()
|
||||
|
||||
|
||||
class _SerializableField(typing.Generic[T]):
|
||||
name: str
|
||||
type: typing.Type[T]
|
||||
default: DefaultValueType
|
||||
|
||||
def __init__(self, type: typing.Type[T], default: DefaultValueType = UNASSIGNED):
|
||||
self.type = type
|
||||
self.default = default
|
||||
|
||||
def _default(self) -> T:
|
||||
if isinstance(self.default, _Unassigned):
|
||||
return self.type()
|
||||
elif callable(self.default):
|
||||
return self.default()
|
||||
else:
|
||||
return self.default
|
||||
|
||||
def __get__(
|
||||
self,
|
||||
instance: 'AutoSerializable',
|
||||
objtype: typing.Optional[typing.Type['AutoSerializable']] = None,
|
||||
) -> T:
|
||||
"""Get field value
|
||||
|
||||
Arguments:
|
||||
instance {SerializableFields} -- Instance of class with field
|
||||
|
||||
"""
|
||||
if hasattr(instance, '_fields'):
|
||||
return getattr(instance, '_fields').get(self.name, self._default())
|
||||
if self.default is None:
|
||||
raise AttributeError(f"Field {self.name} is not set")
|
||||
return self._default()
|
||||
|
||||
def __set__(self, instance: 'AutoSerializable', value: T) -> None:
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(
|
||||
f"Field {self.name} cannot be set to {value} (type {self.type.__name__})"
|
||||
)
|
||||
if not hasattr(instance, '_fields'):
|
||||
setattr(instance, '_fields', {})
|
||||
getattr(instance, '_fields')[self.name] = value
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
"""Basic marshalling of field
|
||||
|
||||
Args:
|
||||
instance: Instance of class with field
|
||||
|
||||
Returns:
|
||||
Marshalled field
|
||||
|
||||
Note:
|
||||
Only str, int, and float are supported in this base class.
|
||||
"""
|
||||
if self.type in (str, int, float):
|
||||
return str(self.__get__(instance)).encode()
|
||||
raise TypeError(f"Field {self.name} cannot be marshalled (type {self.type})")
|
||||
|
||||
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
|
||||
"""Basic unmarshalling of field
|
||||
|
||||
Args:
|
||||
instance: Instance of class with field
|
||||
data: Marshalled field string
|
||||
|
||||
Returns:
|
||||
None. The data is loaded into the field.
|
||||
|
||||
Note:
|
||||
Only str, int, and float are supported in this base class.
|
||||
"""
|
||||
if self.type in (str, int, float):
|
||||
tp: typing.Type = self.type
|
||||
self.__set__(instance, tp(data.decode()))
|
||||
return
|
||||
raise TypeError(f"Field {self.name} cannot be unmarshalled (type {self.type})")
|
||||
|
||||
|
||||
# Integer field
|
||||
class IntField(_SerializableField[int]):
|
||||
def __init__(self, default: int = 0):
|
||||
super().__init__(int, default)
|
||||
|
||||
|
||||
class StringField(_SerializableField[str]):
|
||||
def __init__(self, default: str = ''):
|
||||
super().__init__(str, default)
|
||||
|
||||
|
||||
class FloatField(_SerializableField[float]):
|
||||
def __init__(self, default: float = 0.0):
|
||||
super().__init__(float, default)
|
||||
|
||||
class BoolField(_SerializableField[bool]):
|
||||
def __init__(self, default: bool = False):
|
||||
super().__init__(bool, default)
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
return b'1' if self.__get__(instance) else b'0'
|
||||
|
||||
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
|
||||
self.__set__(instance, data == b'1')
|
||||
|
||||
class ListField(_SerializableField[typing.List]):
|
||||
"""List field
|
||||
|
||||
Note:
|
||||
All elements in the list must be serializable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: typing.Union[
|
||||
typing.List, typing.Callable[[], typing.List]
|
||||
] = lambda: [],
|
||||
):
|
||||
super().__init__(list, default)
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
return json.dumps(self.__get__(instance)).encode()
|
||||
|
||||
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
|
||||
self.__set__(instance, json.loads(data))
|
||||
|
||||
|
||||
class DictField(_SerializableField[typing.Dict]):
|
||||
"""Dict field
|
||||
|
||||
Note:
|
||||
All elements in the dict must be serializable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: typing.Union[
|
||||
typing.Dict, typing.Callable[[], typing.Dict]
|
||||
] = lambda: {},
|
||||
):
|
||||
super().__init__(dict, default)
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
return json.dumps(self.__get__(instance)).encode()
|
||||
|
||||
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
|
||||
self.__set__(instance, json.loads(data))
|
||||
|
||||
|
||||
class PasswordField(StringField):
|
||||
"""Password field
|
||||
|
||||
Note:
|
||||
The password is stored as a compressed string.
|
||||
"""
|
||||
|
||||
_crypt_key: str = ''
|
||||
|
||||
def __init__(self, default: str = '', crypt_key: str = ''):
|
||||
super().__init__(default)
|
||||
self._crypt_key = crypt_key
|
||||
|
||||
def _encrypt(self, value: str) -> bytes:
|
||||
"""Encrypt a password
|
||||
|
||||
Args:
|
||||
value: Password to encrypt
|
||||
|
||||
Returns:
|
||||
Encrypted password
|
||||
"""
|
||||
if self._crypt_key:
|
||||
# Generate a Fernet key from the password
|
||||
f = fernet.Fernet(fernet_key(self._crypt_key.encode()))
|
||||
return HEADER_ENCRYPTED + f.encrypt(value.encode())
|
||||
|
||||
logger.warning("Password encryption is not enabled")
|
||||
return zlib.compress(value.encode())
|
||||
|
||||
def _decrypt(self, value: bytes) -> bytes:
|
||||
"""Decrypt a password
|
||||
|
||||
Args:
|
||||
value: Password to decrypt
|
||||
|
||||
Returns:
|
||||
Decrypted password
|
||||
"""
|
||||
if self._crypt_key and value[: len(HEADER_ENCRYPTED)] == HEADER_ENCRYPTED:
|
||||
try:
|
||||
f = fernet.Fernet(fernet_key(self._crypt_key.encode()))
|
||||
return f.decrypt(value[len(HEADER_ENCRYPTED) :])
|
||||
except Exception: # nosec: Defaults to zlib compression out of the exception
|
||||
pass # returns the unencrypted password
|
||||
|
||||
return zlib.decompress(value)
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
return base64.b64encode(self._encrypt(self.__get__(instance)))
|
||||
|
||||
def unmarshal(self, instance: 'AutoSerializable', data: bytes) -> None:
|
||||
self.__set__(instance, self._decrypt(base64.b64decode(data)).decode())
|
||||
|
||||
|
||||
# ************************
|
||||
# * Serializable classes *
|
||||
# ************************
|
||||
|
||||
|
||||
class _FieldNameSetter(type):
|
||||
"""Simply adds the name of the field in the class to the field object"""
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
fields = dict()
|
||||
for k, v in attrs.items():
|
||||
if isinstance(v, _SerializableField):
|
||||
v.name = k
|
||||
|
||||
return super().__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
class AutoSerializable(metaclass=_FieldNameSetter):
|
||||
"""This class allows the automatic serialization of fields in a class.
|
||||
|
||||
Example:
|
||||
>>> class Test(SerializableFields):
|
||||
... a = IntField()
|
||||
... b = StrField()
|
||||
... c = FloatField()
|
||||
... d = ListField(defalut=lambda: [1, 2, 3])
|
||||
"""
|
||||
|
||||
_fields: typing.Dict[str, typing.Any]
|
||||
|
||||
def process_data(self, header: bytes, data: bytes) -> bytes:
|
||||
"""Process data before marshalling
|
||||
|
||||
Args:
|
||||
data: Data to process
|
||||
|
||||
Returns:
|
||||
Processed data
|
||||
|
||||
Note:
|
||||
We provide header as a way of some constant value to be used in the
|
||||
processing. This is useful for encryption. (as salt, for example)
|
||||
"""
|
||||
return bytes(a ^ b for a, b in zip(data, itertools.cycle(header)))
|
||||
|
||||
return data
|
||||
|
||||
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
|
||||
"""Process data after unmarshalling
|
||||
|
||||
Args:
|
||||
data: Data to process
|
||||
|
||||
Returns:
|
||||
Processed data
|
||||
|
||||
Note:
|
||||
We provide header as a way of some constant value to be used in the
|
||||
processing. This is useful for encryption. (as salt, for example)
|
||||
"""
|
||||
return bytes(a ^ b for a, b in zip(data, itertools.cycle(header)))
|
||||
|
||||
@typing.final
|
||||
def marshal(self) -> bytes:
|
||||
# Iterate over own members and extract fields
|
||||
fields = {}
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
if isinstance(v, _SerializableField):
|
||||
fields[v.name] = (str(v.__class__.__name__), v.marshal(self))
|
||||
|
||||
# Serialized data is:
|
||||
# 2 bytes -> name length
|
||||
# 2 bytes -> type name length
|
||||
# 4 bytes -> data length
|
||||
# n bytes -> name
|
||||
# n bytes -> type name
|
||||
# n bytes -> data
|
||||
data = b''.join(
|
||||
pack_struct.pack(len(name.encode()), len(type_name.encode()), len(value))
|
||||
+ name.encode()
|
||||
+ type_name.encode()
|
||||
+ value
|
||||
for name, (type_name, value) in fields.items()
|
||||
)
|
||||
|
||||
# Calculate checksum
|
||||
checksum = zlib.crc32(data)
|
||||
# Compose header, that is V1_HEADER + checksum (4 bytes, big endian)
|
||||
header = HEADER_BASE + checksum.to_bytes(CRC_SIZE, 'big')
|
||||
# Return data processed with header
|
||||
return header + self.process_data(header, data)
|
||||
|
||||
# final method, do not override
|
||||
@typing.final
|
||||
def unmarshal(self, data: bytes) -> None:
|
||||
# Check header
|
||||
if data[: len(HEADER_BASE)] != HEADER_BASE:
|
||||
raise ValueError('Invalid header')
|
||||
|
||||
header = data[: len(HEADER_BASE) + CRC_SIZE]
|
||||
# Extract checksum
|
||||
checksum = int.from_bytes(
|
||||
header[len(HEADER_BASE) : len(HEADER_BASE) + 4], 'big'
|
||||
)
|
||||
# Unprocess data
|
||||
data = self.unprocess_data(header, data[len(header) :])
|
||||
|
||||
# Check checksum
|
||||
if zlib.crc32(data) != checksum:
|
||||
raise ValueError('Invalid checksum')
|
||||
|
||||
# Iterate over fields
|
||||
fields = {}
|
||||
while data:
|
||||
# Extract name length, type name length and data length
|
||||
name_len, type_name_len, data_len = pack_struct.unpack(data[:8])
|
||||
# Extract name, type name and data
|
||||
name, type_name, value = (
|
||||
data[8 : 8 + name_len].decode(),
|
||||
data[8 + name_len : 8 + name_len + type_name_len].decode(),
|
||||
data[
|
||||
8
|
||||
+ name_len
|
||||
+ type_name_len : 8
|
||||
+ name_len
|
||||
+ type_name_len
|
||||
+ data_len
|
||||
],
|
||||
)
|
||||
# Add to fields
|
||||
fields[name] = (type_name, value)
|
||||
# Remove from data
|
||||
data = data[8 + name_len + type_name_len + data_len :]
|
||||
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
if isinstance(v, _SerializableField):
|
||||
if v.name in fields and fields[v.name][0] == str(v.__class__.__name__):
|
||||
v.unmarshal(self, fields[v.name][1])
|
||||
else:
|
||||
if not v.name in fields:
|
||||
logger.warning(f"Field {v.name} not found in unmarshalled data")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Field {v.name} has wrong type in unmarshalled data (should be {fields[v.name][0]} and is {v.__class__.name}"
|
||||
)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
"""
|
||||
Basic equality check, checks if all _SerializableFields are equal
|
||||
"""
|
||||
if not isinstance(other, AutoSerializable):
|
||||
return False
|
||||
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
if isinstance(v, _SerializableField):
|
||||
if getattr(self, k) != getattr(other, k):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ', '.join(
|
||||
[
|
||||
f"{k}={v.type.__name__}({v.__get__(self)})"
|
||||
for k, v in self.__class__.__dict__.items()
|
||||
if isinstance(v, _SerializableField)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AutoSerializableCompressed(AutoSerializable):
|
||||
"""This class allows the automatic serialization of fields in a class compressed with zlib."""
|
||||
|
||||
def process_data(self, header: bytes, data: bytes) -> bytes:
|
||||
return HEADER_COMPRESSED + zlib.compress(data)
|
||||
|
||||
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
|
||||
# if decompress fails, return data as is
|
||||
try:
|
||||
# Check header
|
||||
if data[: len(HEADER_COMPRESSED)] != HEADER_COMPRESSED:
|
||||
raise Exception() # Returns data as is
|
||||
return zlib.decompress(data[len(HEADER_COMPRESSED) :])
|
||||
except Exception:
|
||||
return super().unprocess_data(header, data)
|
||||
|
||||
|
||||
class AutoSerializableEncrypted(AutoSerializable):
|
||||
"""This class allows the automatic serialization of fields in a class encrypted with AES."""
|
||||
|
||||
# Common key for all instances
|
||||
_crypt_key: typing.ClassVar[str] = settings.SECRET_KEY[:16]
|
||||
|
||||
def key(self, seed: bytes) -> str:
|
||||
"""Generate key from password and seed
|
||||
|
||||
Args:
|
||||
seed: Seed to use (normally header)
|
||||
|
||||
Note: if password is not set, this will raise an exception
|
||||
"""
|
||||
if not self._crypt_key:
|
||||
raise ValueError('Password not set')
|
||||
|
||||
return fernet_key(seed + (self._crypt_key.encode()))
|
||||
|
||||
def process_data(self, header: bytes, data: bytes) -> bytes:
|
||||
f = fernet.Fernet(self.key(header))
|
||||
return HEADER_ENCRYPTED + f.encrypt(data)
|
||||
|
||||
def unprocess_data(self, header: bytes, data: bytes) -> bytes:
|
||||
# if decrypt fails, return data as is
|
||||
try:
|
||||
# Check if data is encrypted
|
||||
if data[: len(HEADER_ENCRYPTED)] != HEADER_ENCRYPTED:
|
||||
return super().unprocess_data(header, data)
|
||||
f = fernet.Fernet(self.key(header))
|
||||
return f.decrypt(data[len(HEADER_ENCRYPTED) :])
|
||||
except fernet.InvalidToken:
|
||||
return super().unprocess_data(header, data)
|
||||
|
||||
@staticmethod
|
||||
def set_crypt_key(crypt_key: str) -> None:
|
||||
"""Set the password for all instances of this class.
|
||||
|
||||
Args:
|
||||
password: Password to set
|
||||
|
||||
Note:
|
||||
On Django, this should be set preferably in settings.py,
|
||||
so all instances of this class will use the same password from the start.
|
||||
"""
|
||||
AutoSerializableEncrypted._crypt_key = crypt_key[:16]
|
@ -1,4 +1,5 @@
|
||||
# Generated by Django 4.1 on 2022-10-01 06:23
|
||||
# Generated by Django 4.1.3 on 2022-11-12 21:03
|
||||
import logging
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
@ -7,6 +8,20 @@ import uds.models.notifications
|
||||
import uds.models.user_service_session
|
||||
import uds.models.util
|
||||
|
||||
logger = logging.getLogger('uds')
|
||||
|
||||
def remove_servicepool_with_null_service(apps, schema_editor):
|
||||
""" In fact, there should be no Service Pool with no service, but we have found some in the wild"""
|
||||
ServicePool = apps.get_model('uds', 'ServicePool')
|
||||
# Log in the django.db.backends logger removed services
|
||||
logger.info('Removing ServicePools with null service')
|
||||
for i in ServicePool.objects.filter(service=None):
|
||||
logger.info(' * Removing ServicePool %s - %s', i.uuid, i.name)
|
||||
i.delete()
|
||||
|
||||
def null_backwards(apps, schema_editor):
|
||||
# Remove null services backwards is not possible, we have deleted them
|
||||
pass
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@ -15,6 +30,8 @@ class Migration(migrations.Migration):
|
||||
]
|
||||
|
||||
operations = [
|
||||
# First, we remove all DeployedServices with null service because we have fixed foreign key
|
||||
migrations.RunPython(remove_servicepool_with_null_service, null_backwards),
|
||||
migrations.CreateModel(
|
||||
name="Notification",
|
||||
fields=[
|
||||
@ -313,6 +330,15 @@ class Migration(migrations.Migration):
|
||||
default=uds.core.util.model.generateUuid, max_length=50, unique=True
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="servicepool",
|
||||
name="service",
|
||||
field=models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="deployedServices",
|
||||
to="uds.service",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="servicepool",
|
||||
name="uuid",
|
||||
|
@ -84,10 +84,8 @@ class ServicePool(UUIDModel, TaggingMixin): # type: ignore
|
||||
name = models.CharField(max_length=128, default='')
|
||||
short_name = models.CharField(max_length=32, default='')
|
||||
comments = models.CharField(max_length=256, default='')
|
||||
service: 'models.ForeignKey[Service | None]' = models.ForeignKey(
|
||||
service: 'models.ForeignKey[Service]' = models.ForeignKey(
|
||||
Service,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='deployedServices',
|
||||
on_delete=models.CASCADE,
|
||||
)
|
||||
|
@ -74,6 +74,9 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
|
||||
deployed_service: 'models.ForeignKey["ServicePool"]' = models.ForeignKey(
|
||||
ServicePool, on_delete=models.CASCADE, related_name='userServices'
|
||||
)
|
||||
# Althoug deployed_services has its publication, the user service is bound to a specific publication
|
||||
# so we need to store the publication id here (or the revision, but we need to store something)
|
||||
# storing the id simplifies the queries
|
||||
publication: 'models.ForeignKey[ServicePoolPublication | None]' = models.ForeignKey(
|
||||
ServicePoolPublication,
|
||||
on_delete=models.CASCADE,
|
||||
|
Loading…
Reference in New Issue
Block a user