diff --git a/server/src/uds/core/managers/publication.py b/server/src/uds/core/managers/publication.py index da3ed2900..bee11174d 100644 --- a/server/src/uds/core/managers/publication.py +++ b/server/src/uds/core/managers/publication.py @@ -319,10 +319,10 @@ class PublicationManager(metaclass=singleton.Singleton): return publication try: - pubInstance = publication.get_instance() - state = pubInstance.cancel() + pub_instance = publication.get_instance() + state = pub_instance.cancel() publication.set_state(State.CANCELING) - PublicationFinishChecker.state_updater(publication, pubInstance, state) + PublicationFinishChecker.state_updater(publication, pub_instance, state) return publication except Exception as e: raise PublishException(str(e)) from e diff --git a/server/src/uds/core/managers/userservice.py b/server/src/uds/core/managers/userservice.py index dfa6a46db..8ec25ee4e 100644 --- a/server/src/uds/core/managers/userservice.py +++ b/server/src/uds/core/managers/userservice.py @@ -106,12 +106,12 @@ class UserServiceManager(metaclass=singleton.Singleton): """ Checks if the maximum number of user services for this service has been reached """ - serviceInstance = service.get_instance() + service_instance = service.get_instance() # Early return, so no database count is needed - if serviceInstance.userservices_limit == consts.UNLIMITED: + if service_instance.userservices_limit == consts.UNLIMITED: return False - if self.get_existing_user_services(service) >= serviceInstance.userservices_limit: + if self.get_existing_user_services(service) >= service_instance.userservices_limit: return True return False diff --git a/server/src/uds/migrations/fixers/providers_v4/_migrator.py b/server/src/uds/migrations/fixers/providers_v4/_migrator.py index 375d2a50c..dca418900 100644 --- a/server/src/uds/migrations/fixers/providers_v4/_migrator.py +++ b/server/src/uds/migrations/fixers/providers_v4/_migrator.py @@ -50,7 +50,7 @@ if typing.TYPE_CHECKING: class TypeTestingClass(Serializable): server_group: typing.Any - def post_migrate(self) -> None: + def post_migrate(self, apps: typing.Any, record: typing.Any) -> None: pass def _get_environment(record: typing.Any) -> Environment: @@ -137,7 +137,7 @@ def migrate( logger.info('Setting server group %s on provider %s', registeredServerGroup.name, record.name) obj.server_group.value = registeredServerGroup.uuid # Now, execute post_migrate of obj - obj.post_migrate() + obj.post_migrate(apps, record) # Save record record.data = obj.serialize() record.save(update_fields=['data']) diff --git a/server/src/uds/migrations/fixers/providers_v4/physical_machine_multiple.py b/server/src/uds/migrations/fixers/providers_v4/physical_machine_multiple.py index 5058c026a..6c64ce786 100644 --- a/server/src/uds/migrations/fixers/providers_v4/physical_machine_multiple.py +++ b/server/src/uds/migrations/fixers/providers_v4/physical_machine_multiple.py @@ -29,21 +29,47 @@ """ Author: Adolfo Gómez, dkmaster at dkmon dot com """ +import base64 import datetime import logging import typing import pickle # nosec: pickle is used for legacy data transition -from uds.core import services +from uds.core import services, types from uds.core.ui import gui +from uds.core.util import auto_attributes, autoserializable from . import _migrator logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + import uds.models + IP_SUBTYPE: typing.Final[str] = 'ip' +class OldIPSerialData(auto_attributes.AutoAttributes): + _ip: str + _reason: str + _state: str + + def __init__(self) -> None: + auto_attributes.AutoAttributes.__init__(self, ip=str, reason=str, state=str) + self._ip = '' + self._reason = '' + self._state = types.states.TaskState.FINISHED + + +class NewIpSerialData(autoserializable.AutoSerializable): + suggested_delay = 10 + + _ip = autoserializable.StringField(default='') + _mac = autoserializable.StringField(default='') + _vmid = autoserializable.StringField(default='') + _reason = autoserializable.StringField(default='') # If != '', this is the error message and state is ERROR + + class IPMachinesService(services.Service): type_type = 'IPMachinesService' @@ -96,14 +122,15 @@ class IPMachinesService(services.Service): self.useRandomIp = gui.as_bool(values[6].decode()) # Note that will be marshalled as new format, so we don't need to care about old format in code anymore :) - def post_migrate(self) -> None: + def post_migrate(self, apps: typing.Any, record: typing.Any) -> None: from uds.core.util import fields FOREVER: typing.Final[datetime.timedelta] = datetime.timedelta(days=365 * 20) now = datetime.datetime.now() server_group = fields.get_server_group_from_field(self.server_group) + for server in server_group.servers.all(): - + locked = self.storage.read_pickled(server.ip) # print(f'Locked: {locked} for {server.ip}') if not locked: @@ -120,13 +147,11 @@ class IPMachinesService(services.Service): if bool(locked): # print(f'Locking {server.ip} forever due to maxSessionForMachine=0') server.lock(FOREVER) # Almost forever - server.save(update_fields=['locked_until']) continue # Not locked, continue if not isinstance(locked, int): # print(f'Locking {server.ip} due to not being an int (very old data)') server.lock(FOREVER) - server.save(update_fields=['locked_until']) continue if not bool(locked) or locked < now.timestamp() - self.maxSessionForMachine.value * 3600: @@ -137,6 +162,54 @@ class IPMachinesService(services.Service): # print(f'Locking {server.ip} until {datetime.datetime.fromtimestamp(locked)}') server.lock(datetime.timedelta(seconds=locked - now.timestamp())) + Service: 'type[uds.models.Service]' = apps.get_model('uds', 'Service') + ServicePool: 'type[uds.models.ServicePool]' = apps.get_model('uds', 'ServicePool') + + assigned_servers: set[str] = set() + for servicepool in ServicePool.objects.filter(service=Service.objects.get(uuid=record.uuid)): + for userservice in servicepool.userServices.all(): + new_data = NewIpSerialData() + try: + auto_data = OldIPSerialData() + auto_data.unmarshal(base64.b64decode(userservice.data)) + # Fill own data from restored data + ip_mac = auto_data._ip.split('~')[0] + if ';' in ip_mac: + new_data._ip, new_data._mac = ip_mac.split(';', 2)[:2] + else: + new_data._ip = ip_mac + new_data._mac = '' + new_data._reason = auto_data._reason + state = auto_data._state + # Ensure error is set if _reason is set + if state == types.states.TaskState.ERROR and new_data._reason == '': + new_data._reason = 'Unknown error' + + # Reget vmid if needed + if not new_data._reason and userservice.state == types.states.State.USABLE: + new_data._vmid = '' + for server in server_group.servers.all(): + if server.ip == new_data._ip and server.uuid not in assigned_servers: + new_data._vmid = server.uuid + assigned_servers.add(server.uuid) + # Ensure locked, relock if needed + if not server.locked_until or server.locked_until < now: + if self.maxSessionForMachine.value <= 0: + server.lock(FOREVER) + else: + server.lock(datetime.timedelta(hours=self.maxSessionForMachine.value)) + break + if not new_data._vmid: + new_data._reason = f'Migrated machine not found for {new_data._ip}' + except Exception as e: # Invalid serialized record, record new format with error + new_data._ip = '' + new_data._mac = '' + new_data._vmid = '' + new_data._reason = f'Error migrating: {e}'[:320] + + userservice.data = new_data.serialize() + userservice.save(update_fields=['data']) + def migrate(apps: typing.Any, schema_editor: typing.Any) -> None: _migrator.migrate( diff --git a/server/src/uds/migrations/fixers/providers_v4/rds.py b/server/src/uds/migrations/fixers/providers_v4/rds.py index 0c2d4a707..33d01e871 100644 --- a/server/src/uds/migrations/fixers/providers_v4/rds.py +++ b/server/src/uds/migrations/fixers/providers_v4/rds.py @@ -70,7 +70,7 @@ class RDSProvider(services.ServiceProvider): # This value is the new server group that contains the "ipList" server_group = gui.ChoiceField(label='') - def post_migrate(self) -> None: + def post_migrate(self, apps: typing.Any, record: typing.Any) -> None: pass def migrate(apps: typing.Any, schema_editor: typing.Any) -> None: diff --git a/server/src/uds/services/PhysicalMachines/deployment_multi.py b/server/src/uds/services/PhysicalMachines/deployment_multi.py index 334cb8fb9..030b5c04e 100644 --- a/server/src/uds/services/PhysicalMachines/deployment_multi.py +++ b/server/src/uds/services/PhysicalMachines/deployment_multi.py @@ -150,25 +150,4 @@ class IPMachinesUserService(services.UserService, autoserializable.AutoSerializa def cancel(self) -> types.states.TaskState: return self.destroy() - def unmarshal(self, data: bytes) -> None: - if autoserializable.is_autoserializable_data(data): - return super().unmarshal(data) - - _auto_data = OldIPSerialData() - _auto_data.unmarshal(data) - - # Fill own data from restored data - ip_mac = _auto_data._ip.split('~')[0] - if ';' in ip_mac: - self._ip, self._mac = ip_mac.split(';', 2)[:2] - else: - self._ip = ip_mac - self._mac = '' - self._reason = _auto_data._reason - state = _auto_data._state - # Ensure error is set if _reason is set - if state == types.states.TaskState.ERROR and self._reason == '': - self._reason = 'Unknown error' - - # Flag for upgrade - self.mark_for_upgrade(True) + # Data is migrated on migration 0046, so no unmarshall is needed \ No newline at end of file diff --git a/server/tests/services/physical_machines/test_serialization_deployment_multi.py b/server/tests/services/physical_machines/test_serialization_deployment_multi.py new file mode 100644 index 000000000..352add123 --- /dev/null +++ b/server/tests/services/physical_machines/test_serialization_deployment_multi.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +# +# Copyright (c) 2022 Virtual Cable S.L.U. +# All rights reserved. +# + +""" +Author: Adolfo Gómez, dkmaster at dkmon dot com +""" +# We use commit/rollback + +from tests.utils.test import UDSTestCase +from uds.core.util import autoserializable +from uds.core.environment import Environment + + +from uds.services.PhysicalMachines import deployment_multi + + +class IPMultipleMachinesUserServiceSerializationTest(UDSTestCase): + def test_marshalling(self) -> None: + obj = deployment_multi.OldIPSerialData() + obj._ip = '1.1.1.1' + obj._state = 'state' + obj._reason = '' + + def _check_fields(instance: deployment_multi.IPMachinesUserService) -> None: + self.assertEqual(instance._ip, '1.1.1.1') + # _state has been removed. If error, _reason is set + self.assertEqual(instance._reason, '') + + data = obj.marshal() + + instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used + instance.unmarshal(data) + + marshaled_data = instance.marshal() + + # Ensure remarshalled flag is set + self.assertTrue(instance.needs_upgrade()) + instance.mark_for_upgrade(False) # reset flag + + # Ensure fields has been marshalled using new format + self.assertTrue(autoserializable.is_autoserializable_data(marshaled_data)) + + # Check fields + _check_fields(instance) + + # Reunmarshall again and check that remarshalled flag is not set + instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used + instance.unmarshal(marshaled_data) + self.assertFalse(instance.needs_upgrade()) + + # Check fields again + _check_fields(instance)