1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Updating deployment multi migration. Needed to keep locked values correctly.

This commit is contained in:
Adolfo Gómez García 2024-04-21 23:50:45 +02:00
parent 1dcc1f71df
commit 156608b6ae
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
7 changed files with 144 additions and 36 deletions

View File

@ -319,10 +319,10 @@ class PublicationManager(metaclass=singleton.Singleton):
return publication
try:
pubInstance = publication.get_instance()
state = pubInstance.cancel()
pub_instance = publication.get_instance()
state = pub_instance.cancel()
publication.set_state(State.CANCELING)
PublicationFinishChecker.state_updater(publication, pubInstance, state)
PublicationFinishChecker.state_updater(publication, pub_instance, state)
return publication
except Exception as e:
raise PublishException(str(e)) from e

View File

@ -106,12 +106,12 @@ class UserServiceManager(metaclass=singleton.Singleton):
"""
Checks if the maximum number of user services for this service has been reached
"""
serviceInstance = service.get_instance()
service_instance = service.get_instance()
# Early return, so no database count is needed
if serviceInstance.userservices_limit == consts.UNLIMITED:
if service_instance.userservices_limit == consts.UNLIMITED:
return False
if self.get_existing_user_services(service) >= serviceInstance.userservices_limit:
if self.get_existing_user_services(service) >= service_instance.userservices_limit:
return True
return False

View File

@ -50,7 +50,7 @@ if typing.TYPE_CHECKING:
class TypeTestingClass(Serializable):
server_group: typing.Any
def post_migrate(self) -> None:
def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
pass
def _get_environment(record: typing.Any) -> Environment:
@ -137,7 +137,7 @@ def migrate(
logger.info('Setting server group %s on provider %s', registeredServerGroup.name, record.name)
obj.server_group.value = registeredServerGroup.uuid
# Now, execute post_migrate of obj
obj.post_migrate()
obj.post_migrate(apps, record)
# Save record
record.data = obj.serialize()
record.save(update_fields=['data'])

View File

@ -29,21 +29,47 @@
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import base64
import datetime
import logging
import typing
import pickle # nosec: pickle is used for legacy data transition
from uds.core import services
from uds.core import services, types
from uds.core.ui import gui
from uds.core.util import auto_attributes, autoserializable
from . import _migrator
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
import uds.models
IP_SUBTYPE: typing.Final[str] = 'ip'
class OldIPSerialData(auto_attributes.AutoAttributes):
_ip: str
_reason: str
_state: str
def __init__(self) -> None:
auto_attributes.AutoAttributes.__init__(self, ip=str, reason=str, state=str)
self._ip = ''
self._reason = ''
self._state = types.states.TaskState.FINISHED
class NewIpSerialData(autoserializable.AutoSerializable):
suggested_delay = 10
_ip = autoserializable.StringField(default='')
_mac = autoserializable.StringField(default='')
_vmid = autoserializable.StringField(default='')
_reason = autoserializable.StringField(default='') # If != '', this is the error message and state is ERROR
class IPMachinesService(services.Service):
type_type = 'IPMachinesService'
@ -96,14 +122,15 @@ class IPMachinesService(services.Service):
self.useRandomIp = gui.as_bool(values[6].decode())
# Note that will be marshalled as new format, so we don't need to care about old format in code anymore :)
def post_migrate(self) -> None:
def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
from uds.core.util import fields
FOREVER: typing.Final[datetime.timedelta] = datetime.timedelta(days=365 * 20)
now = datetime.datetime.now()
server_group = fields.get_server_group_from_field(self.server_group)
for server in server_group.servers.all():
locked = self.storage.read_pickled(server.ip)
# print(f'Locked: {locked} for {server.ip}')
if not locked:
@ -120,13 +147,11 @@ class IPMachinesService(services.Service):
if bool(locked):
# print(f'Locking {server.ip} forever due to maxSessionForMachine=0')
server.lock(FOREVER) # Almost forever
server.save(update_fields=['locked_until'])
continue # Not locked, continue
if not isinstance(locked, int):
# print(f'Locking {server.ip} due to not being an int (very old data)')
server.lock(FOREVER)
server.save(update_fields=['locked_until'])
continue
if not bool(locked) or locked < now.timestamp() - self.maxSessionForMachine.value * 3600:
@ -137,6 +162,54 @@ class IPMachinesService(services.Service):
# print(f'Locking {server.ip} until {datetime.datetime.fromtimestamp(locked)}')
server.lock(datetime.timedelta(seconds=locked - now.timestamp()))
Service: 'type[uds.models.Service]' = apps.get_model('uds', 'Service')
ServicePool: 'type[uds.models.ServicePool]' = apps.get_model('uds', 'ServicePool')
assigned_servers: set[str] = set()
for servicepool in ServicePool.objects.filter(service=Service.objects.get(uuid=record.uuid)):
for userservice in servicepool.userServices.all():
new_data = NewIpSerialData()
try:
auto_data = OldIPSerialData()
auto_data.unmarshal(base64.b64decode(userservice.data))
# Fill own data from restored data
ip_mac = auto_data._ip.split('~')[0]
if ';' in ip_mac:
new_data._ip, new_data._mac = ip_mac.split(';', 2)[:2]
else:
new_data._ip = ip_mac
new_data._mac = ''
new_data._reason = auto_data._reason
state = auto_data._state
# Ensure error is set if _reason is set
if state == types.states.TaskState.ERROR and new_data._reason == '':
new_data._reason = 'Unknown error'
# Reget vmid if needed
if not new_data._reason and userservice.state == types.states.State.USABLE:
new_data._vmid = ''
for server in server_group.servers.all():
if server.ip == new_data._ip and server.uuid not in assigned_servers:
new_data._vmid = server.uuid
assigned_servers.add(server.uuid)
# Ensure locked, relock if needed
if not server.locked_until or server.locked_until < now:
if self.maxSessionForMachine.value <= 0:
server.lock(FOREVER)
else:
server.lock(datetime.timedelta(hours=self.maxSessionForMachine.value))
break
if not new_data._vmid:
new_data._reason = f'Migrated machine not found for {new_data._ip}'
except Exception as e: # Invalid serialized record, record new format with error
new_data._ip = ''
new_data._mac = ''
new_data._vmid = ''
new_data._reason = f'Error migrating: {e}'[:320]
userservice.data = new_data.serialize()
userservice.save(update_fields=['data'])
def migrate(apps: typing.Any, schema_editor: typing.Any) -> None:
_migrator.migrate(

View File

@ -70,7 +70,7 @@ class RDSProvider(services.ServiceProvider):
# This value is the new server group that contains the "ipList"
server_group = gui.ChoiceField(label='')
def post_migrate(self) -> None:
def post_migrate(self, apps: typing.Any, record: typing.Any) -> None:
pass
def migrate(apps: typing.Any, schema_editor: typing.Any) -> None:

View File

@ -150,25 +150,4 @@ class IPMachinesUserService(services.UserService, autoserializable.AutoSerializa
def cancel(self) -> types.states.TaskState:
return self.destroy()
def unmarshal(self, data: bytes) -> None:
if autoserializable.is_autoserializable_data(data):
return super().unmarshal(data)
_auto_data = OldIPSerialData()
_auto_data.unmarshal(data)
# Fill own data from restored data
ip_mac = _auto_data._ip.split('~')[0]
if ';' in ip_mac:
self._ip, self._mac = ip_mac.split(';', 2)[:2]
else:
self._ip = ip_mac
self._mac = ''
self._reason = _auto_data._reason
state = _auto_data._state
# Ensure error is set if _reason is set
if state == types.states.TaskState.ERROR and self._reason == '':
self._reason = 'Unknown error'
# Flag for upgrade
self.mark_for_upgrade(True)
# Data is migrated on migration 0046, so no unmarshall is needed

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Virtual Cable S.L.U.
# All rights reserved.
#
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
# We use commit/rollback
from tests.utils.test import UDSTestCase
from uds.core.util import autoserializable
from uds.core.environment import Environment
from uds.services.PhysicalMachines import deployment_multi
class IPMultipleMachinesUserServiceSerializationTest(UDSTestCase):
def test_marshalling(self) -> None:
obj = deployment_multi.OldIPSerialData()
obj._ip = '1.1.1.1'
obj._state = 'state'
obj._reason = ''
def _check_fields(instance: deployment_multi.IPMachinesUserService) -> None:
self.assertEqual(instance._ip, '1.1.1.1')
# _state has been removed. If error, _reason is set
self.assertEqual(instance._reason, '')
data = obj.marshal()
instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(data)
marshaled_data = instance.marshal()
# Ensure remarshalled flag is set
self.assertTrue(instance.needs_upgrade())
instance.mark_for_upgrade(False) # reset flag
# Ensure fields has been marshalled using new format
self.assertTrue(autoserializable.is_autoserializable_data(marshaled_data))
# Check fields
_check_fields(instance)
# Reunmarshall again and check that remarshalled flag is not set
instance = deployment_multi.IPMachinesUserService(environment=Environment.testing_environment(), service=None) # type: ignore # service is not used
instance.unmarshal(marshaled_data)
self.assertFalse(instance.needs_upgrade())
# Check fields again
_check_fields(instance)