1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-02-02 09:47:13 +03:00

Done cancel/destroy test and user assignation for proxmox

This commit is contained in:
Adolfo Gómez García 2024-02-28 20:31:08 +01:00
parent 09858a165c
commit 05d26c732e
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
9 changed files with 150 additions and 35 deletions

View File

@ -56,7 +56,7 @@ class OSManager(ManagedObjectModel, TaggingMixin):
# objects: 'models.manager.Manager[OSManager]' # objects: 'models.manager.Manager[OSManager]'
deployedServices: 'models.manager.RelatedManager[ServicePool]' deployedServices: 'models.manager.RelatedManager[ServicePool]'
class Meta(ManagedObjectModel.Meta): # pylint: disable=too-few-public-methods class Meta(ManagedObjectModel.Meta): # pyright: ignore
""" """
Meta class to declare default order Meta class to declare default order
""" """

View File

@ -31,6 +31,7 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import logging import logging
import typing
from django.db import models from django.db import models
@ -46,7 +47,7 @@ class Properties(models.Model):
owner_id = models.CharField(max_length=128, db_index=True) owner_id = models.CharField(max_length=128, db_index=True)
owner_type = models.CharField(max_length=64, db_index=True) owner_type = models.CharField(max_length=64, db_index=True)
key = models.CharField(max_length=64, db_index=True) key = models.CharField(max_length=64, db_index=True)
value = models.JSONField(default=dict) value: typing.Any = models.JSONField(default=dict)
class Meta: # pylint: disable=too-few-public-methods class Meta: # pylint: disable=too-few-public-methods
""" """

View File

@ -52,7 +52,7 @@ class StatsCounters(models.Model):
value = models.IntegerField(db_index=True, default=0) value = models.IntegerField(db_index=True, default=0)
# "fake" declarations for type checking # "fake" declarations for type checking
objects: 'models.manager.Manager[StatsCounters]' # objects: 'models.manager.Manager[StatsCounters]'
class Meta: # pylint: disable=too-few-public-methods class Meta: # pylint: disable=too-few-public-methods
""" """

View File

@ -81,7 +81,7 @@ class StatsCountersAccum(models.Model):
v_min = models.IntegerField(default=0) v_min = models.IntegerField(default=0)
# "fake" declarations for type checking # "fake" declarations for type checking
objects: 'models.manager.Manager[StatsCountersAccum]' # objects: 'models.manager.Manager[StatsCountersAccum]'
class Meta: # pylint: disable=too-few-public-methods class Meta: # pylint: disable=too-few-public-methods
""" """

View File

@ -61,13 +61,13 @@ class PoolsUsageSummary(UsageByPool):
) -> tuple[ ) -> tuple[
typing.ValuesView[collections.abc.MutableMapping[str, typing.Any]], int, int, int typing.ValuesView[collections.abc.MutableMapping[str, typing.Any]], int, int, int
]: ]:
orig, poolNames = super().get_data() # pylint: disable=unused-variable # Keep name for reference orig, _pool_names = super().get_data() # pylint: disable=unused-variable # Keep name for reference
pools: dict[str, dict[str, typing.Any]] = {} pools: dict[str, dict[str, typing.Any]] = {}
totalTime: int = 0 totalTime: int = 0
totalCount: int = 0 totalCount: int = 0
uniqueUsers = set() unique_users: set[str] = set()
for v in orig: for v in orig:
uuid = v['pool'] uuid = v['pool']
@ -82,7 +82,7 @@ class PoolsUsageSummary(UsageByPool):
pools[uuid]['count'] += 1 pools[uuid]['count'] += 1
# Now add user id to pool # Now add user id to pool
pools[uuid]['users'].add(v['name']) pools[uuid]['users'].add(v['name'])
uniqueUsers.add(v['name']) unique_users.add(v['name'])
totalTime += v['time'] totalTime += v['time']
totalCount += 1 totalCount += 1
@ -92,7 +92,7 @@ class PoolsUsageSummary(UsageByPool):
for _, pn in pools.items(): for _, pn in pools.items():
pn['users'] = len(pn['users']) pn['users'] = len(pn['users'])
return pools.values(), totalTime, totalCount or 1, len(uniqueUsers) return pools.values(), totalTime, totalCount or 1, len(unique_users)
def generate(self) -> bytes: def generate(self) -> bytes:
pools, totalTime, totalCount, uniqueUsers = self.processedData() pools, totalTime, totalCount, uniqueUsers = self.processedData()

View File

@ -691,7 +691,7 @@ if sys.platform == 'win32':
Operation.CREATE: 'create', Operation.CREATE: 'create',
Operation.START: 'start', Operation.START: 'start',
Operation.STOP: 'stop', Operation.STOP: 'stop',
Operation.SHUTDOWN: 'suspend', Operation.SHUTDOWN: 'shutdown',
Operation.GRACEFUL_STOP: 'gracely stop', Operation.GRACEFUL_STOP: 'gracely stop',
Operation.REMOVE: 'remove', Operation.REMOVE: 'remove',
Operation.WAIT: 'wait', Operation.WAIT: 'wait',

View File

@ -90,7 +90,7 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
# Tag will also include non visible authenticators # Tag will also include non visible authenticators
# tag, later will remove "auth_host" # tag, later will remove "auth_host"
authenticators = list(auths.filter(small_name__in=[auth_host, tag])) authenticators = list(auths.filter(small_name__in=[auth_host, tag]))
except Exception as e: except Exception:
authenticators = [] authenticators = []
else: else:
if not tag: # If no tag, remove hidden auths if not tag: # If no tag, remove hidden auths
@ -145,7 +145,7 @@ def uds_js(request: 'ExtendedHttpRequest') -> str:
'is_custom': theType.is_custom(), 'is_custom': theType.is_custom(),
} }
config = { config: dict[str, typing.Any] = {
'version': consts.system.VERSION, 'version': consts.system.VERSION,
'version_stamp': consts.system.VERSION_STAMP, 'version_stamp': consts.system.VERSION_STAMP,
'language': get_language(), 'language': get_language(),

View File

@ -67,8 +67,6 @@ if typing.TYPE_CHECKING:
def index(request: HttpRequest) -> HttpResponse: def index(request: HttpRequest) -> HttpResponse:
# Gets csrf token # Gets csrf token
csrf_token = csrf.get_token(request) csrf_token = csrf.get_token(request)
if csrf_token is not None:
csrf_token = str(csrf_token)
response = render( response = render(
request=request, request=request,
@ -175,7 +173,7 @@ def mfa(
store: 'storage.Storage' = storage.Storage('mfs') store: 'storage.Storage' = storage.Storage('mfs')
mfa_provider = typing.cast('None|models.MFA', request.user.manager.mfa) mfa_provider = request.user.manager.mfa # typing.cast('None|models.MFA',
if not mfa_provider: if not mfa_provider:
logger.warning('MFA: No MFA provider for user') logger.warning('MFA: No MFA provider for user')
return HttpResponseRedirect(reverse('page.index')) return HttpResponseRedirect(reverse('page.index'))

View File

@ -30,6 +30,7 @@
""" """
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import stat
import typing import typing
import datetime import datetime
import collections.abc import collections.abc
@ -39,17 +40,34 @@ from tests.web import user
from uds import models from uds import models
from uds.core import types, ui, environment from uds.core import types, ui, environment
from uds.services.Proxmox.deployment_linked import ProxmoxUserserviceLinked from uds.models import service
from uds.services.Proxmox.deployment_linked import ProxmoxUserserviceLinked, Operation
from . import fixtures from . import fixtures
from ...utils.test import UDSTransactionTestCase from ...utils.test import UDSTransactionTestCase
def limit_iter(check: typing.Callable[[], bool], limit: int = 128) -> typing.Generator[int, None, None]:
"""
Limit an iterator to a number of elements
"""
current = 0
while current < limit and check():
yield current
current += 1
if current < limit:
return
# Limit reached, raise an exception
raise Exception(f'Limit reached: {current}/{limit}: {check()}')
# We use transactions on some related methods (storage access, etc...) # We use transactions on some related methods (storage access, etc...)
class TestProxmovLinkedService(UDSTransactionTestCase): class TestProxmovLinkedService(UDSTransactionTestCase):
def test_userservice_fixed_cache_l1(self) -> None: def test_userservice_linked_cache_l1(self) -> None:
""" """
Test the user service Test the user service
""" """
@ -63,16 +81,17 @@ class TestProxmovLinkedService(UDSTransactionTestCase):
self.assertEqual(state, types.states.TaskState.RUNNING) self.assertEqual(state, types.states.TaskState.RUNNING)
while state == types.states.TaskState.RUNNING: # Ensure that in the event of failure, we don't loop forever
for _ in limit_iter(lambda: state == types.states.TaskState.RUNNING, limit=128):
state = userservice.check_state() state = userservice.check_state()
self.assertEqual(state, types.states.TaskState.FINISHED) self.assertEqual(state, types.states.TaskState.FINISHED)
self.assertEqual(userservice._name[: len(service.get_basename())], service.get_basename()) self.assertEqual(userservice._name[: len(service.get_basename())], service.get_basename())
self.assertEqual(len(userservice._name), len(service.get_basename()) + service.get_lenname()) self.assertEqual(len(userservice._name), len(service.get_basename()) + service.get_lenname())
vmid = int(userservice._vmid) vmid = int(userservice._vmid)
api.clone_machine.assert_called_with( api.clone_machine.assert_called_with(
publication.machine(), publication.machine(),
mock.ANY, mock.ANY,
@ -84,18 +103,17 @@ class TestProxmovLinkedService(UDSTransactionTestCase):
service.pool.value, service.pool.value,
None, None,
) )
# api.get_task should have been invoked at least once # api.get_task should have been invoked at least once
self.assertTrue(api.get_task.called) self.assertTrue(api.get_task.called)
api.enable_machine_ha.assert_called() api.enable_machine_ha.assert_called()
api.set_machine_mac.assert_called_with(vmid, userservice._mac) api.set_machine_mac.assert_called_with(vmid, userservice._mac)
api.get_machine_pool_info.assert_called_with(vmid, service.pool.value, force=True) api.get_machine_pool_info.assert_called_with(vmid, service.pool.value, force=True)
api.start_machine.assert_called_with(vmid) api.start_machine.assert_called_with(vmid)
def test_userservice_linked_cache_l2_no_ha(self) -> None:
def test_userservice_fixed_cache_l2_no_ha(self) -> None:
""" """
Test the user service Test the user service
""" """
@ -103,10 +121,12 @@ class TestProxmovLinkedService(UDSTransactionTestCase):
userservice = fixtures.create_userservice_linked() userservice = fixtures.create_userservice_linked()
service = userservice.service() service = userservice.service()
service.ha.value = '__' # Disabled service.ha.value = '__' # Disabled
# Set machine state for fixture to started # Set machine state for fixture to started
fixtures.VMS_INFO = [fixtures.VMS_INFO[i]._replace(status = 'running') for i in range(len(fixtures.VMS_INFO))] fixtures.VMS_INFO = [
fixtures.VMS_INFO[i]._replace(status='running') for i in range(len(fixtures.VMS_INFO))
]
publication = userservice.publication() publication = userservice.publication()
publication._vmid = '1' publication._vmid = '1'
@ -114,16 +134,16 @@ class TestProxmovLinkedService(UDSTransactionTestCase):
self.assertEqual(state, types.states.TaskState.RUNNING) self.assertEqual(state, types.states.TaskState.RUNNING)
while state == types.states.TaskState.RUNNING: for _ in limit_iter(lambda: state == types.states.TaskState.RUNNING, limit=128):
state = userservice.check_state() state = userservice.check_state()
self.assertEqual(state, types.states.TaskState.FINISHED) self.assertEqual(state, types.states.TaskState.FINISHED)
self.assertEqual(userservice._name[: len(service.get_basename())], service.get_basename()) self.assertEqual(userservice._name[: len(service.get_basename())], service.get_basename())
self.assertEqual(len(userservice._name), len(service.get_basename()) + service.get_lenname()) self.assertEqual(len(userservice._name), len(service.get_basename()) + service.get_lenname())
vmid = int(userservice._vmid) vmid = int(userservice._vmid)
api.clone_machine.assert_called_with( api.clone_machine.assert_called_with(
publication.machine(), publication.machine(),
mock.ANY, mock.ANY,
@ -135,17 +155,113 @@ class TestProxmovLinkedService(UDSTransactionTestCase):
service.pool.value, service.pool.value,
None, None,
) )
# api.get_task should have been invoked at least once # api.get_task should have been invoked at least once
self.assertTrue(api.get_task.called) self.assertTrue(api.get_task.called)
# Shoud not have been called since HA is disabled # Shoud not have been called since HA is disabled
api.enable_machine_ha.assert_not_called() api.enable_machine_ha.assert_not_called()
api.set_machine_mac.assert_called_with(vmid, userservice._mac) api.set_machine_mac.assert_called_with(vmid, userservice._mac)
api.get_machine_pool_info.assert_called_with(vmid, service.pool.value, force=True) api.get_machine_pool_info.assert_called_with(vmid, service.pool.value, force=True)
# Now, called should not have been called because machine is running # Now, called should not have been called because machine is running
# api.start_machine.assert_called_with(vmid) # api.start_machine.assert_called_with(vmid)
# Stop machine should have been called # Stop machine should have been called
api.shutdown_machine.assert_called_with(vmid) api.shutdown_machine.assert_called_with(vmid)
def test_userservice_linked_user(self) -> None:
"""
Test the user service
"""
with fixtures.patch_provider_api() as api:
userservice = fixtures.create_userservice_linked()
service = userservice.service()
publication = userservice.publication()
publication._vmid = '1'
state = userservice.deploy_for_user(models.User())
self.assertEqual(state, types.states.TaskState.RUNNING)
for _ in limit_iter(lambda: state == types.states.TaskState.RUNNING, limit=128):
state = userservice.check_state()
self.assertEqual(state, types.states.TaskState.FINISHED)
self.assertEqual(userservice._name[: len(service.get_basename())], service.get_basename())
self.assertEqual(len(userservice._name), len(service.get_basename()) + service.get_lenname())
vmid = int(userservice._vmid)
api.clone_machine.assert_called_with(
publication.machine(),
mock.ANY,
userservice._name,
mock.ANY,
True,
None,
service.datastore.value,
service.pool.value,
None,
)
# api.get_task should have been invoked at least once
self.assertTrue(api.get_task.called)
api.enable_machine_ha.assert_called()
api.set_machine_mac.assert_called_with(vmid, userservice._mac)
api.get_machine_pool_info.assert_called_with(vmid, service.pool.value, force=True)
api.start_machine.assert_called_with(vmid)
def test_userservice_cancel(self) -> None:
"""
Test the user service
"""
with fixtures.patch_provider_api() as _api:
for graceful in [True, False]:
userservice = fixtures.create_userservice_linked()
service = userservice.service()
service.soft_shutdown_field.value = graceful
publication = userservice.publication()
publication._vmid = '1'
# Set machine state for fixture to started
fixtures.VMS_INFO = [
fixtures.VMS_INFO[i]._replace(status='running') for i in range(len(fixtures.VMS_INFO))
]
state = userservice.deploy_for_user(models.User())
self.assertEqual(state, types.states.TaskState.RUNNING)
current_op = userservice._get_current_op()
# Invoke cancel
state = userservice.cancel()
self.assertEqual(state, types.states.TaskState.RUNNING)
self.assertEqual(
userservice._queue,
[current_op]
+ ([Operation.GRACEFUL_STOP] if graceful else [])
+ [Operation.STOP, Operation.REMOVE, Operation.FINISH],
)
for counter in limit_iter(lambda: state == types.states.TaskState.RUNNING, limit=128):
state = userservice.check_state()
if counter > 5:
# Set machine state for fixture to stopped
fixtures.VMS_INFO = [
fixtures.VMS_INFO[i]._replace(status='stopped')
for i in range(len(fixtures.VMS_INFO))
]
self.assertEqual(state, types.states.TaskState.FINISHED)
if graceful:
_api.shutdown_machine.assert_called()
else:
_api.stop_machine.assert_called()