1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-03 01:17:56 +03:00

Updating deployment multi migration. Needed to keep locked values correctly.

This commit is contained in:
Adolfo Gómez García 2024-04-21 23:50:45 +02:00
parent 1dcc1f71df
commit 156608b6ae
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
7 changed files with 144 additions and 36 deletions

View File

@ -319,10 +319,10 @@ class PublicationManager(metaclass=singleton.Singleton):
return publication return publication
try: try:
pubInstance = publication.get_instance() pub_instance = publication.get_instance()
state = pubInstance.cancel() state = pub_instance.cancel()
publication.set_state(State.CANCELING) publication.set_state(State.CANCELING)
PublicationFinishChecker.state_updater(publication, pubInstance, state) PublicationFinishChecker.state_updater(publication, pub_instance, state)
return publication return publication
except Exception as e: except Exception as e:
raise PublishException(str(e)) from e raise PublishException(str(e)) from e

View File

@ -106,12 +106,12 @@ class UserServiceManager(metaclass=singleton.Singleton):
""" """
Checks if the maximum number of user services for this service has been reached Checks if the maximum number of user services for this service has been reached
""" """
serviceInstance = service.get_instance() service_instance = service.get_instance()
# Early return, so no database count is needed # Early return, so no database count is needed
if serviceInstance.userservices_limit == consts.UNLIMITED: if service_instance.userservices_limit == consts.UNLIMITED:
return False return False
if self.get_existing_user_services(service) >= serviceInstance.userservices_limit: if self.get_existing_user_services(service) >= service_instance.userservices_limit:
return True return True
return False return False

View File

@ -50,7 +50,7 @@ if typing.TYPE_CHECKING:
class TypeTestingClass(Serializable): class TypeTestingClass(Serializable):
server_group: typing.Any server_group: typing.Any
def post_migrate(self) -> None: def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
pass pass
def _get_environment(record: typing.Any) -> Environment: def _get_environment(record: typing.Any) -> Environment:
@ -137,7 +137,7 @@ def migrate(
logger.info('Setting server group %s on provider %s', registeredServerGroup.name, record.name) logger.info('Setting server group %s on provider %s', registeredServerGroup.name, record.name)
obj.server_group.value = registeredServerGroup.uuid obj.server_group.value = registeredServerGroup.uuid
# Now, execute post_migrate of obj # Now, execute post_migrate of obj
obj.post_migrate() obj.post_migrate(apps, record)
# Save record # Save record
record.data = obj.serialize() record.data = obj.serialize()
record.save(update_fields=['data']) record.save(update_fields=['data'])

View File

@ -29,21 +29,47 @@
""" """
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import base64
import datetime import datetime
import logging import logging
import typing import typing
import pickle # nosec: pickle is used for legacy data transition import pickle # nosec: pickle is used for legacy data transition
from uds.core import services from uds.core import services, types
from uds.core.ui import gui from uds.core.ui import gui
from uds.core.util import auto_attributes, autoserializable
from . import _migrator from . import _migrator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
import uds.models
IP_SUBTYPE: typing.Final[str] = 'ip' IP_SUBTYPE: typing.Final[str] = 'ip'
class OldIPSerialData(auto_attributes.AutoAttributes):
_ip: str
_reason: str
_state: str
def __init__(self) -> None:
auto_attributes.AutoAttributes.__init__(self, ip=str, reason=str, state=str)
self._ip = ''
self._reason = ''
self._state = types.states.TaskState.FINISHED
class NewIpSerialData(autoserializable.AutoSerializable):
suggested_delay = 10
_ip = autoserializable.StringField(default='')
_mac = autoserializable.StringField(default='')
_vmid = autoserializable.StringField(default='')
_reason = autoserializable.StringField(default='') # If != '', this is the error message and state is ERROR
class IPMachinesService(services.Service): class IPMachinesService(services.Service):
type_type = 'IPMachinesService' type_type = 'IPMachinesService'
@ -96,12 +122,13 @@ class IPMachinesService(services.Service):
self.useRandomIp = gui.as_bool(values[6].decode()) self.useRandomIp = gui.as_bool(values[6].decode())
# Note that will be marshalled as new format, so we don't need to care about old format in code anymore :) # Note that will be marshalled as new format, so we don't need to care about old format in code anymore :)
def post_migrate(self) -> None: def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
from uds.core.util import fields from uds.core.util import fields
FOREVER: typing.Final[datetime.timedelta] = datetime.timedelta(days=365 * 20) FOREVER: typing.Final[datetime.timedelta] = datetime.timedelta(days=365 * 20)
now = datetime.datetime.now() now = datetime.datetime.now()
server_group = fields.get_server_group_from_field(self.server_group) server_group = fields.get_server_group_from_field(self.server_group)
for server in server_group.servers.all(): for server in server_group.servers.all():
locked = self.storage.read_pickled(server.ip) locked = self.storage.read_pickled(server.ip)
@ -120,13 +147,11 @@ class IPMachinesService(services.Service):
if bool(locked): if bool(locked):
# print(f'Locking {server.ip} forever due to maxSessionForMachine=0') # print(f'Locking {server.ip} forever due to maxSessionForMachine=0')
server.lock(FOREVER) # Almost forever server.lock(FOREVER) # Almost forever
server.save(update_fields=['locked_until'])
continue # Not locked, continue continue # Not locked, continue
if not isinstance(locked, int): if not isinstance(locked, int):
# print(f'Locking {server.ip} due to not being an int (very old data)') # print(f'Locking {server.ip} due to not being an int (very old data)')
server.lock(FOREVER) server.lock(FOREVER)
server.save(update_fields=['locked_until'])
continue continue
if not bool(locked) or locked < now.timestamp() - self.maxSessionForMachine.value * 3600: if not bool(locked) or locked < now.timestamp() - self.maxSessionForMachine.value * 3600:
@ -137,6 +162,54 @@ class IPMachinesService(services.Service):
# print(f'Locking {server.ip} until {datetime.datetime.fromtimestamp(locked)}') # print(f'Locking {server.ip} until {datetime.datetime.fromtimestamp(locked)}')
server.lock(datetime.timedelta(seconds=locked - now.timestamp())) server.lock(datetime.timedelta(seconds=locked - now.timestamp()))
Service: 'type[uds.models.Service]' = apps.get_model('uds', 'Service')
ServicePool: 'type[uds.models.ServicePool]' = apps.get_model('uds', 'ServicePool')
assigned_servers: set[str] = set()
for servicepool in ServicePool.objects.filter(service=Service.objects.get(uuid=record.uuid)):
for userservice in servicepool.userServices.all():
new_data = NewIpSerialData()
try:
auto_data = OldIPSerialData()
auto_data.unmarshal(base64.b64decode(userservice.data))
# Fill own data from restored data
ip_mac = auto_data._ip.split('~')[0]
if ';' in ip_mac:
new_data._ip, new_data._mac = ip_mac.split(';', 2)[:2]
else:
new_data._ip = ip_mac
new_data._mac = ''
new_data._reason = auto_data._reason
state = auto_data._state
# Ensure error is set if _reason is set
if state == types.states.TaskState.ERROR and new_data._reason == '':
new_data._reason = 'Unknown error'
# Reget vmid if needed
if not new_data._reason and userservice.state == types.states.State.USABLE:
new_data._vmid = ''
for server in server_group.servers.all():
if server.ip == new_data._ip and server.uuid not in assigned_servers:
new_data._vmid = server.uuid
assigned_servers.add(server.uuid)
# Ensure locked, relock if needed
if not server.locked_until or server.locked_until < now:
if self.maxSessionForMachine.value <= 0:
server.lock(FOREVER)
else:
server.lock(datetime.timedelta(hours=self.maxSessionForMachine.value))
break
if not new_data._vmid:
new_data._reason = f'Migrated machine not found for {new_data._ip}'
except Exception as e: # Invalid serialized record, record new format with error
new_data._ip = ''
new_data._mac = ''
new_data._vmid = ''
new_data._reason = f'Error migrating: {e}'[:320]
userservice.data = new_data.serialize()
userservice.save(update_fields=['data'])
def migrate(apps: typing.Any, schema_editor: typing.Any) -> None: def migrate(apps: typing.Any, schema_editor: typing.Any) -> None:
_migrator.migrate( _migrator.migrate(

View File

@ -70,7 +70,7 @@ class RDSProvider(services.ServiceProvider):
# This value is the new server group that contains the "ipList" # This value is the new server group that contains the "ipList"
server_group = gui.ChoiceField(label='') server_group = gui.ChoiceField(label='')
def post_migrate(self) -> None: def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
pass pass
def migrate(apps: typing.Any, schema_editor: typing.Any) -> None: def migrate(apps: typing.Any, schema_editor: typing.Any) -> None:

View File

@ -150,25 +150,4 @@ class IPMachinesUserService(services.UserService, autoserializable.AutoSerializa
def cancel(self) -> types.states.TaskState: def cancel(self) -> types.states.TaskState:
return self.destroy() return self.destroy()
def unmarshal(self, data: bytes) -> None: # Data is migrated on migration 0046, so no unmarshall is needed
if autoserializable.is_autoserializable_data(data):
return super().unmarshal(data)
_auto_data = OldIPSerialData()
_auto_data.unmarshal(data)
# Fill own data from restored data
ip_mac = _auto_data._ip.split('~')[0]
if ';' in ip_mac:
self._ip, self._mac = ip_mac.split(';', 2)[:2]
else:
self._ip = ip_mac
self._mac = ''
self._reason = _auto_data._reason
state = _auto_data._state
# Ensure error is set if _reason is set
if state == types.states.TaskState.ERROR and self._reason == '':
self._reason = 'Unknown error'
# Flag for upgrade
self.mark_for_upgrade(True)

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
# We use commit/rollback
from tests.utils.test import UDSTestCase
from uds.core.util import autoserializable
from uds.core.environment import Environment
from uds.services.PhysicalMachines import deployment_multi
class IPMultipleMachinesUserServiceSerializationTest(UDSTestCase):
def test_marshalling(self) -> None:
obj = deployment_multi.OldIPSerialData()
obj._ip = '1.1.1.1'
obj._state = 'state'
obj._reason = ''
def _check_fields(instance: deployment_multi.IPMachinesUserService) -> None:
self.assertEqual(instance._ip, '1.1.1.1')
# _state has been removed. If error, _reason is set
self.assertEqual(instance._reason, '')
data = obj.marshal()
instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(data)
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.mark_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertTrue(autoserializable.is_autoserializable_data(marshaled_data))
# Check fields
_check_fields(instance)
# Reunmarshall again and check that remarshalled flag is not set
instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check fields again
_check_fields(instance)