diff --git a/server/src/uds/auths/SAML/saml.py b/server/src/uds/auths/SAML/saml.py index 7df6e164e..5b0cd48d3 100644 --- a/server/src/uds/auths/SAML/saml.py +++ b/server/src/uds/auths/SAML/saml.py @@ -580,7 +580,7 @@ class SAMLAuthenticator(auths.Authenticator): @decorators.cached( prefix='spm', - key_fnc=CACHING_KEY_FNC, + key_helper=CACHING_KEY_FNC, timeout=3600, # 1 hour ) def get_sp_metadata(self) -> str: diff --git a/server/src/uds/core/util/decorators.py b/server/src/uds/core/util/decorators.py index 862bac1fb..ff6fb8857 100644 --- a/server/src/uds/core/util/decorators.py +++ b/server/src/uds/core/util/decorators.py @@ -206,7 +206,7 @@ def cached( timeout: typing.Union[collections.abc.Callable[[], int], int] = -1, args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None, kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None, - key_fnc: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None, + key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None, ) -> collections.abc.Callable[[FT], FT]: """Decorator that give us a "quick& clean" caching feature on db. The "cached" element must provide a "cache" variable, which is a cache object @@ -257,7 +257,7 @@ def cached( # Not inspectable, no caching possible, return original function return fnc - lkey_fnc = key_fnc or (lambda x: fnc.__name__) + lkey_fnc = key_helper or (lambda x: fnc.__name__) @functools.wraps(fnc) def wrapper(*args, **kwargs) -> typing.Any: diff --git a/server/src/uds/services/Proxmox/client/__init__.py b/server/src/uds/services/Proxmox/client/__init__.py index 3f1cbaaff..625056912 100644 --- a/server/src/uds/services/Proxmox/client/__init__.py +++ b/server/src/uds/services/Proxmox/client/__init__.py @@ -255,7 +255,7 @@ class ProxmoxClient: return True @ensure_connected - @cached('cluster', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('cluster', CACHE_DURATION, key_helper=caching_key_helper) def get_cluster_info(self, **kwargs) -> types.ClusterInfo: return types.ClusterInfo.from_dict(self._get('cluster/status')) @@ -272,13 +272,13 @@ class ProxmoxClient: return True @ensure_connected - @cached('nodeNets', CACHE_DURATION, args=1, kwargs=['node'], key_fnc=caching_key_helper) + @cached('nodeNets', CACHE_DURATION, args=1, kwargs=['node'], key_helper=caching_key_helper) def get_node_networks(self, node: str, **kwargs) -> typing.Any: return self._get('nodes/{}/network'.format(node))['data'] # pylint: disable=unused-argument @ensure_connected - @cached('nodeGpuDevices', CACHE_DURATION_LONG, key_fnc=caching_key_helper) + @cached('nodeGpuDevices', CACHE_DURATION_LONG, key_helper=caching_key_helper) def list_node_gpu_devices(self, node: str, **kwargs) -> list[str]: return [ device['id'] for device in self._get(f'nodes/{node}/hardware/pci')['data'] if device.get('mdev') @@ -409,7 +409,7 @@ class ProxmoxClient: ) @ensure_connected - @cached('hagrps', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('hagrps', CACHE_DURATION, key_helper=caching_key_helper) def list_ha_groups(self, **kwargs) -> list[str]: return [g['group'] for g in self._get('cluster/ha/groups')['data']] @@ -486,7 +486,7 @@ class ProxmoxClient: return [] # If we can't get snapshots, just return empty list @ensure_connected - @cached('snapshots', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('snapshots', CACHE_DURATION, key_helper=caching_key_helper) def supports_snapshot(self, vmid: int, node: typing.Optional[str] = None) -> bool: # If machine uses tpm, snapshots are not supported return not self.get_machine_configuration(vmid, node).tpmstate0 @@ -541,7 +541,7 @@ class ProxmoxClient: ) @ensure_connected - @cached('vms', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('vms', CACHE_DURATION, key_helper=caching_key_helper) def list_machines( self, node: typing.Union[None, str, collections.abc.Iterable[str]] = None, **kwargs ) -> list[types.VMInfo]: @@ -562,7 +562,7 @@ class ProxmoxClient: return sorted(result, key=lambda x: '{}{}'.format(x.node, x.name)) @ensure_connected - @cached('vmip', CACHE_INFO_DURATION, key_fnc=caching_key_helper) + @cached('vmip', CACHE_INFO_DURATION, key_helper=caching_key_helper) def get_machine_pool_info(self, vmid: int, poolid: typing.Optional[str], **kwargs) -> types.VMInfo: # try to locate machine in pool node = None @@ -581,7 +581,7 @@ class ProxmoxClient: return self.get_machine_info(vmid, node, **kwargs) @ensure_connected - @cached('vmin', CACHE_INFO_DURATION, key_fnc=caching_key_helper) + @cached('vmin', CACHE_INFO_DURATION, key_helper=caching_key_helper) def get_machine_info(self, vmid: int, node: typing.Optional[str] = None, **kwargs) -> types.VMInfo: nodes = [types.Node(node, False, False, 0, '', '', '')] if node else self.get_cluster_info().nodes any_node_is_down = False @@ -675,14 +675,14 @@ class ProxmoxClient: resume_machine = start_machine @ensure_connected - @cached('storage', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('storage', CACHE_DURATION, key_helper=caching_key_helper) def get_storage(self, storage: str, node: str, **kwargs) -> types.StorageInfo: return types.StorageInfo.from_dict( self._get('nodes/{}/storage/{}/status'.format(node, urllib.parse.quote(storage)))['data'] ) @ensure_connected - @cached('storages', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('storages', CACHE_DURATION, key_helper=caching_key_helper) def list_storages( self, node: typing.Union[None, str, collections.abc.Iterable[str]] = None, @@ -709,19 +709,19 @@ class ProxmoxClient: return result @ensure_connected - @cached('nodeStats', CACHE_INFO_DURATION, key_fnc=caching_key_helper) + @cached('nodeStats', CACHE_INFO_DURATION, key_helper=caching_key_helper) def get_node_stats(self, **kwargs) -> list[types.NodeStats]: return [ types.NodeStats.from_dict(nodeStat) for nodeStat in self._get('cluster/resources?type=node')['data'] ] @ensure_connected - @cached('pools', CACHE_DURATION // 6, key_fnc=caching_key_helper) + @cached('pools', CACHE_DURATION // 6, key_helper=caching_key_helper) def list_pools(self, **kwargs) -> list[types.PoolInfo]: return [types.PoolInfo.from_dict(poolInfo) for poolInfo in self._get('pools')['data']] @ensure_connected - @cached('pool', CACHE_DURATION, key_fnc=caching_key_helper) + @cached('pool', CACHE_DURATION, key_helper=caching_key_helper) def get_pool_info(self, pool_id: str, retrieve_vm_names: bool = False, **kwargs) -> types.PoolInfo: pool_info = types.PoolInfo.from_dict(self._get(f'pools/{pool_id}')['data']) if retrieve_vm_names: diff --git a/server/src/uds/services/Proxmox/provider.py b/server/src/uds/services/Proxmox/provider.py index fc456545f..5534d02a9 100644 --- a/server/src/uds/services/Proxmox/provider.py +++ b/server/src/uds/services/Proxmox/provider.py @@ -54,6 +54,13 @@ logger = logging.getLogger(__name__) MAX_VMID: typing.Final[int] = 999999999 +def cache_key_helper(self: 'ProxmoxProvider') -> str: + """ + Helper function to generate cache keys for the ProxmoxProvider class + """ + return f'{self.host.value}-{self.port.as_int()}' + + class ProxmoxProvider(services.ServiceProvider): type_name = _('Proxmox Platform Provider') type_type = 'ProxmoxPlatform' @@ -174,13 +181,17 @@ class ProxmoxProvider(services.ServiceProvider): def get_storage_info(self, storageid: str, node: str, force: bool = False) -> client.types.StorageInfo: return self._api().get_storage(storageid, node, force=force) - def list_storages(self, node: typing.Optional[str] = None, force: bool = False) -> list[client.types.StorageInfo]: + def list_storages( + self, node: typing.Optional[str] = None, force: bool = False + ) -> list[client.types.StorageInfo]: return self._api().list_storages(node=node, content='images', force=force) def list_pools(self, force: bool = False) -> list[client.types.PoolInfo]: return self._api().list_pools(force=force) - def get_pool_info(self, pool_id: str, retrieve_vm_names: bool = False, force: bool = False) -> client.types.PoolInfo: + def get_pool_info( + self, pool_id: str, retrieve_vm_names: bool = False, force: bool = False + ) -> client.types.PoolInfo: return self._api().get_pool_info(pool_id, retrieve_vm_names=retrieve_vm_names, force=force) def create_template(self, vmid: int) -> None: @@ -299,7 +310,7 @@ class ProxmoxProvider(services.ServiceProvider): """ return self._api().restore_snapshot(vmid, node, name) - @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT) + @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def is_available(self) -> bool: return self._api().test() diff --git a/server/src/uds/services/Xen/provider.py b/server/src/uds/services/Xen/provider.py index a1f423363..522dca76a 100644 --- a/server/src/uds/services/Xen/provider.py +++ b/server/src/uds/services/Xen/provider.py @@ -448,7 +448,7 @@ class XenProvider(ServiceProvider): # pylint: disable=too-many-public-methods def get_macs_range(self) -> str: return self.macs_range.value - @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=lambda x: x.host.as_str()) + @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT, key_helper=lambda x: x.host.as_str()) def is_available(self) -> bool: try: self.test_connection() diff --git a/server/src/uds/services/Xen/xen_client/__init__.py b/server/src/uds/services/Xen/xen_client/__init__.py index 32c097fdf..1b3834b37 100644 --- a/server/src/uds/services/Xen/xen_client/__init__.py +++ b/server/src/uds/services/Xen/xen_client/__init__.py @@ -212,7 +212,7 @@ class XenServer: # pylint: disable=too-many-public-methods def has_pool(self) -> bool: return self.check_login() and bool(self._pool_name) - @cached(prefix='xen_pool', timeout=consts.cache.LONG_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_pool', timeout=consts.cache.LONG_CACHE_TIMEOUT, key_helper=cache_key_helper) def get_pool_name(self) -> str: pool = self.pool.get_all()[0] return self.pool.get_name_label(pool) @@ -325,7 +325,7 @@ class XenServer: # pylint: disable=too-many-public-methods return {'result': result, 'progress': progress, 'status': str(status), 'connection_error': True} - @cached(prefix='xen_srs', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_srs', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_helper=cache_key_helper) def list_srs(self) -> list[dict[str, typing.Any]]: return_list: list[dict[str, typing.Any]] = [] for srId in self.SR.get_all(): @@ -352,7 +352,7 @@ class XenServer: # pylint: disable=too-many-public-methods ) return return_list - @cached(prefix='xen_sr', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_sr', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def get_sr_info(self, srid: str) -> dict[str, typing.Any]: return { 'id': srid, @@ -361,7 +361,7 @@ class XenServer: # pylint: disable=too-many-public-methods 'used': XenServer.to_mb(self.SR.get_physical_utilisation(srid)), } - @cached(prefix='xen_nets', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_nets', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_helper=cache_key_helper) def list_networks(self, **kwargs) -> list[dict[str, typing.Any]]: return_list: list[dict[str, typing.Any]] = [] for netId in self.network.get_all(): @@ -375,11 +375,11 @@ class XenServer: # pylint: disable=too-many-public-methods return return_list - @cached(prefix='xen_net', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_net', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def get_network_info(self, net_id: str) -> dict[str, typing.Any]: return {'id': net_id, 'name': self.network.get_name_label(net_id)} - @cached(prefix='xen_vms', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_vms', timeout=consts.cache.DEFAULT_CACHE_TIMEOUT, key_helper=cache_key_helper) def list_machines(self) -> list[dict[str, typing.Any]]: return_list: list[dict[str, typing.Any]] = [] try: @@ -410,14 +410,14 @@ class XenServer: # pylint: disable=too-many-public-methods except XenAPI.Failure as e: raise XenFailure(e.details) - @cached(prefix='xen_vm', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_vm', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def get_machine_info(self, vmid: str, **kwargs) -> dict[str, typing.Any]: try: return self.VM.get_record(vmid) except XenAPI.Failure as e: raise XenFailure(e.details) - @cached(prefix='xen_vm_f', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_vm_f', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def get_machine_folder(self, vmid: str, **kwargs) -> str: try: other_config = self.VM.get_other_config(vmid) @@ -633,7 +633,7 @@ class XenServer: # pylint: disable=too-many-public-methods except XenAPI.Failure as e: raise XenFailure(e.details) - @cached(prefix='xen_snapshots', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_snapshots', timeout=consts.cache.SHORT_CACHE_TIMEOUT, key_helper=cache_key_helper) def list_snapshots(self, vmid: str, full_info: bool = False, **kwargs) -> list[dict[str, typing.Any]]: """Returns a list of snapshots for the specified VM, sorted by snapshot_time in descending order. (That is, the most recent snapshot is first in the list.) @@ -665,7 +665,7 @@ class XenServer: # pylint: disable=too-many-public-methods except XenAPI.Failure as e: raise XenFailure(e.details) - @cached(prefix='xen_folders', timeout=consts.cache.LONG_CACHE_TIMEOUT, key_fnc=cache_key_helper) + @cached(prefix='xen_folders', timeout=consts.cache.LONG_CACHE_TIMEOUT, key_helper=cache_key_helper) def list_folders(self, **kwargs) -> list[str]: """list "Folders" from the "Organizations View" of the XenServer diff --git a/server/tests/core/util/test_decorators.py b/server/tests/core/util/test_decorators.py index 8c43b1209..08fc341d1 100644 --- a/server/tests/core/util/test_decorators.py +++ b/server/tests/core/util/test_decorators.py @@ -99,7 +99,7 @@ class CacheTest(UDSTransactionTestCase): def __init__(self, value: str): self.value = [value] * 8 - @cached(prefix='test', timeout=1, key_fnc=cache_key) + @cached(prefix='test', timeout=1, key_helper=cache_key) def cached_test(self, **kwargs) -> list[str]: self.call_count += 1 return self.value diff --git a/server/tests/services/proxmox/fixtures.py b/server/tests/services/proxmox/fixtures.py index 6ba599d7f..172a6c058 100644 --- a/server/tests/services/proxmox/fixtures.py +++ b/server/tests/services/proxmox/fixtures.py @@ -260,7 +260,7 @@ CONSOLE_CONNECTION: typing.Final[types.services.ConsoleConnectionInfo] = types.s CLIENT_METHODS_INFO: typing.Final[list[AutoSpecMethodInfo]] = [ # connect returns None # Test method - AutoSpecMethodInfo('test', method=mock.Mock(return_value=True)), + AutoSpecMethodInfo('test', return_value=True), # get_cluster_info AutoSpecMethodInfo('get_cluster_info', return_value=CLUSTER_INFO), # get_next_vmid diff --git a/server/tests/services/proxmox/test_provider.py b/server/tests/services/proxmox/test_provider.py index b4bba921b..f0fe326cf 100644 --- a/server/tests/services/proxmox/test_provider.py +++ b/server/tests/services/proxmox/test_provider.py @@ -125,7 +125,7 @@ class TestProxmovProvider(UDSTestCase): self.assertEqual(provider.is_available(), False) api.test.assert_called_once_with() - def test_provider_methods(self) -> None: + def test_provider_methods_1(self) -> None: """ Test the provider methods """ @@ -154,7 +154,13 @@ class TestProxmovProvider(UDSTestCase): api.get_storage.assert_called_once_with( fixtures.STORAGES[2].storage, fixtures.STORAGES[2].node, force=True ) - api.get_storage.reset_mock() + + def test_provider_methods_2(self) -> None: + """ + Test the provider methods + """ + with fixtures.patch_provider_api() as api: + provider = fixtures.create_provider() self.assertEqual( provider.get_storage_info(fixtures.STORAGES[2].storage, fixtures.STORAGES[2].node), fixtures.STORAGES[2], @@ -180,6 +186,12 @@ class TestProxmovProvider(UDSTestCase): self.assertEqual(provider.list_pools(), fixtures.POOLS) api.list_pools.assert_called_once_with(force=False) + def test_provider_methods3(self) -> None: + """ + Test the provider methods + """ + with fixtures.patch_provider_api() as api: + provider = fixtures.create_provider() self.assertEqual( provider.get_pool_info(fixtures.POOLS[2].poolid, retrieve_vm_names=True, force=True), fixtures.POOLS[2], @@ -216,6 +228,12 @@ class TestProxmovProvider(UDSTestCase): self.assertEqual(provider.suspend_machine(1), fixtures.UPID) api.suspend_machine.assert_called_once_with(1) + def test_provider_methods_4(self) -> None: + """ + Test the provider methods + """ + with fixtures.patch_provider_api() as api: + provider = fixtures.create_provider() self.assertEqual(provider.shutdown_machine(1), fixtures.UPID) api.shutdown_machine.assert_called_once_with(1) @@ -240,6 +258,13 @@ class TestProxmovProvider(UDSTestCase): self.assertEqual(provider.list_ha_groups(), fixtures.HA_GROUPS) api.list_ha_groups.assert_called_once_with() + def test_provider_methods_5(self) -> None: + """ + Test the provider methods + """ + with fixtures.patch_provider_api() as api: + provider = fixtures.create_provider() + self.assertEqual(provider.get_console_connection('1'), fixtures.CONSOLE_CONNECTION) api.get_console_connection.assert_called_once_with(1, None)