1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-12 04:58:34 +03:00

Added type to spice console connection, and fixed discovered errors :)

This commit is contained in:
Adolfo Gómez García 2024-02-18 23:46:27 +01:00
parent b58d3e210c
commit a3fa9f604a
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
21 changed files with 491 additions and 255 deletions

View File

@ -32,6 +32,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import typing import typing
import collections.abc import collections.abc
from uds.core import types
from uds.core.environment import Environmentable from uds.core.environment import Environmentable
from uds.core.serializable import Serializable from uds.core.serializable import Serializable
@ -599,9 +600,7 @@ class UserService(Environmentable, Serializable):
""" """
return None return None
def get_console_connection( def get_console_connection(self) -> typing.Optional[types.services.ConsoleConnectionInfo]:
self,
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]:
""" """
This method is invoked by any connection that needs special connection data This method is invoked by any connection that needs special connection data
to connenct to it using, for example, SPICE protocol. (that currently is the only one) to connenct to it using, for example, SPICE protocol. (that currently is the only one)

View File

@ -29,8 +29,13 @@
""" """
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import dataclasses
import typing
import collections.abc
import enum import enum
from attr import field
class ServiceType(enum.StrEnum): class ServiceType(enum.StrEnum):
VDI = 'VDI' VDI = 'VDI'
@ -66,3 +71,21 @@ class ServicesCountingType(enum.IntEnum):
return ServicesCountingType[value] return ServicesCountingType[value]
except KeyError: except KeyError:
return ServicesCountingType.STANDARD return ServicesCountingType.STANDARD
@dataclasses.dataclass
class ConsoleConnectionTicket:
value: str = ''
expires: str = ''
@dataclasses.dataclass
class ConsoleConnectionInfo:
type: str
address: str
port: int = -1
secure_port: int = -1
cert_subject: str = ''
ticket: ConsoleConnectionTicket = dataclasses.field(default_factory=ConsoleConnectionTicket)
ca: str = ''
proxy: str = ''
monitors: int = 0

View File

@ -35,6 +35,8 @@ import collections.abc
import ovirtsdk4 as ovirt import ovirtsdk4 as ovirt
from uds.core import types
# Sometimes, we import ovirtsdk4 but "types" does not get imported... event can't be found???? # Sometimes, we import ovirtsdk4 but "types" does not get imported... event can't be found????
# With this seems to work propertly # With this seems to work propertly
try: try:
@ -84,9 +86,7 @@ class Client:
Returns: Returns:
The cache key, taking into consideration the prefix The cache key, taking into consideration the prefix
""" """
return "{}{}{}{}{}".format( return "{}{}{}{}{}".format(prefix, self._host, self._username, self._password, self._timeout)
prefix, self._host, self._username, self._password, self._timeout
)
def _api(self) -> ovirt.Connection: def _api(self) -> ovirt.Connection:
""" """
@ -155,9 +155,7 @@ class Client:
""" """
return True, 'Test successfully passed' return True, 'Test successfully passed'
def list_machines( def list_machines(self, force: bool = False) -> list[collections.abc.MutableMapping[str, typing.Any]]:
self, force: bool = False
) -> list[collections.abc.MutableMapping[str, typing.Any]]:
""" """
Obtains the list of machines inside ovirt that do aren't part of uds Obtains the list of machines inside ovirt that do aren't part of uds
@ -210,9 +208,7 @@ class Client:
finally: finally:
lock.release() lock.release()
def list_clusters( def list_clusters(self, force: bool = False) -> list[collections.abc.MutableMapping[str, typing.Any]]:
self, force: bool = False
) -> list[collections.abc.MutableMapping[str, typing.Any]]:
""" """
Obtains the list of clusters inside ovirt Obtains the list of clusters inside ovirt
@ -347,9 +343,7 @@ class Client:
api = self._api() api = self._api()
datacenter_service = ( datacenter_service = api.system_service().data_centers_service().service(datacenterId)
api.system_service().data_centers_service().service(datacenterId)
)
d: typing.Any = datacenter_service.get() # type: ignore d: typing.Any = datacenter_service.get() # type: ignore
storage = [] storage = []
@ -496,9 +490,7 @@ class Client:
tvm = ovirt.types.Vm(id=vm.id) tvm = ovirt.types.Vm(id=vm.id)
tcluster = ovirt.types.Cluster(id=cluster.id) tcluster = ovirt.types.Cluster(id=cluster.id)
template = ovirt.types.Template( template = ovirt.types.Template(name=name, vm=tvm, cluster=tcluster, description=comments)
name=name, vm=tvm, cluster=tcluster, description=comments
)
# display=display) # display=display)
@ -591,9 +583,7 @@ class Client:
else: else:
usb = ovirt.types.Usb(enabled=False) usb = ovirt.types.Usb(enabled=False)
memoryPolicy = ovirt.types.MemoryPolicy( memoryPolicy = ovirt.types.MemoryPolicy(guaranteed=guaranteed_mb * 1024 * 1024)
guaranteed=guaranteed_mb * 1024 * 1024
)
par = ovirt.types.Vm( par = ovirt.types.Vm(
name=name, name=name,
cluster=cluster, cluster=cluster,
@ -676,9 +666,7 @@ class Client:
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machineId)
api.system_service().vms_service().service(machineId)
)
if vmService.get() is None: if vmService.get() is None:
raise Exception('Machine not found') raise Exception('Machine not found')
@ -702,9 +690,7 @@ class Client:
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machineId)
api.system_service().vms_service().service(machineId)
)
if vmService.get() is None: if vmService.get() is None:
raise Exception('Machine not found') raise Exception('Machine not found')
@ -728,9 +714,7 @@ class Client:
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machineId)
api.system_service().vms_service().service(machineId)
)
if vmService.get() is None: if vmService.get() is None:
raise Exception('Machine not found') raise Exception('Machine not found')
@ -754,9 +738,7 @@ class Client:
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machineId)
api.system_service().vms_service().service(machineId)
)
if vmService.get() is None: if vmService.get() is None:
raise Exception('Machine not found') raise Exception('Machine not found')
@ -775,16 +757,12 @@ class Client:
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machineId)
api.system_service().vms_service().service(machineId)
)
if vmService.get() is None: if vmService.get() is None:
raise Exception('Machine not found') raise Exception('Machine not found')
nic = vmService.nics_service().list()[ nic = vmService.nics_service().list()[0] # If has no nic, will raise an exception (IndexError)
0
] # If has no nic, will raise an exception (IndexError)
nic.mac.address = macAddres nic.mac.address = macAddres
nicService = vmService.nics_service().service(nic.id) nicService = vmService.nics_service().service(nic.id)
nicService.update(nic) nicService.update(nic)
@ -808,9 +786,7 @@ class Client:
finally: finally:
lock.release() lock.release()
def get_console_connection( def get_console_connection(self, machine_id: str) -> typing.Optional[types.services.ConsoleConnectionInfo]:
self, machineId: str
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]:
""" """
Gets the connetion info for the specified machine Gets the connetion info for the specified machine
""" """
@ -818,9 +794,7 @@ class Client:
lock.acquire(True) lock.acquire(True)
api = self._api() api = self._api()
vmService: typing.Any = ( vmService: typing.Any = api.system_service().vms_service().service(machine_id)
api.system_service().vms_service().service(machineId)
)
vm = vmService.get() vm = vmService.get()
if vm is None: if vm is None:
@ -834,9 +808,7 @@ class Client:
if display.certificate is not None: if display.certificate is not None:
cert_subject = display.certificate.subject cert_subject = display.certificate.subject
else: else:
for i in typing.cast( for i in typing.cast(collections.abc.Iterable, api.system_service().hosts_service().list()):
collections.abc.Iterable, api.system_service().hosts_service().list()
):
for k in typing.cast( for k in typing.cast(
collections.abc.Iterable, collections.abc.Iterable,
api.system_service() api.system_service()
@ -852,15 +824,14 @@ class Client:
if cert_subject != '': if cert_subject != '':
break break
return { return types.services.ConsoleConnectionInfo(
'type': display.type.value, type=display.type.value,
'address': display.address, address=display.address,
'port': display.port, port=display.port,
'secure_port': display.secure_port, secure_port=display.secure_port,
'monitors': display.monitors, cert_subject=cert_subject,
'cert_subject': cert_subject, ticket=types.services.ConsoleConnectionTicket(value=ticket.value),
'ticket': {'value': ticket.value, 'expiry': ticket.expiry}, )
}
except Exception: except Exception:
return None return None

View File

@ -36,7 +36,7 @@ import logging
import pickle # nosec: not insecure, we are loading our own data import pickle # nosec: not insecure, we are loading our own data
import typing import typing
from uds.core import consts, services from uds.core import consts, services, types
from uds.core.managers.userservice import UserServiceManager from uds.core.managers.userservice import UserServiceManager
from uds.core.types.states import State from uds.core.types.states import State
from uds.core.util import autoserializable, log from uds.core.util import autoserializable, log
@ -227,8 +227,8 @@ class OVirtLinkedDeployment(services.UserService, autoserializable.AutoSerializa
def get_console_connection( def get_console_connection(
self, self,
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.service().getConsoleConnection(self._vmid) return self.service().get_console_connection(self._vmid)
def desktop_login( def desktop_login(
self, self,

View File

@ -484,10 +484,10 @@ class OVirtProvider(
def getMacRange(self) -> str: def getMacRange(self) -> str:
return self.macsRange.value return self.macsRange.value
def getConsoleConnection( def get_console_connection(
self, machineId: str self, machine_id: str
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.__getApi().get_console_connection(machineId) return self.__getApi().get_console_connection(machine_id)
@cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT) @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT)
def is_available(self) -> bool: def is_available(self) -> bool:

View File

@ -460,10 +460,10 @@ class OVirtLinkedService(services.Service): # pylint: disable=too-many-public-m
""" """
return self.display.value return self.display.value
def getConsoleConnection( def get_console_connection(
self, machineId: str self, machineId: str
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.provider().getConsoleConnection(machineId) return self.provider().get_console_connection(machineId)
def is_avaliable(self) -> bool: def is_avaliable(self) -> bool:
return self.provider().is_available() return self.provider().is_available()

View File

@ -37,7 +37,7 @@ import logging
import typing import typing
import collections.abc import collections.abc
from uds.core import services, consts from uds.core import services, consts, types
from uds.core.types.states import State from uds.core.types.states import State
from uds.core.util import log, autoserializable from uds.core.util import log, autoserializable
@ -159,8 +159,8 @@ class OpenNebulaLiveDeployment(services.UserService, autoserializable.AutoSerial
if self._vmid != '': if self._vmid != '':
self.service().resetMachine(self._vmid) self.service().resetMachine(self._vmid)
def get_console_connection(self) -> dict[str, typing.Any]: def get_console_connection(self) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.service().getConsoleConnection(self._vmid) return self.service().get_console_connection(self._vmid)
def desktop_login(self, username: str, password: str, domain: str = ''): def desktop_login(self, username: str, password: str, domain: str = ''):
return self.service().desktop_login(self._vmid, username, password, domain) return self.service().desktop_login(self._vmid, username, password, domain)

View File

@ -358,25 +358,25 @@ class OpenNebulaClient: # pylint: disable=too-many-public-methods
) )
@ensureConnected @ensureConnected
def deleteVM(self, vmId: str) -> str: def remove_machine(self, vmId: str) -> str:
""" """
Deletes an vm Deletes an vm
""" """
if self.version[0] == '4': # type: ignore if self.version[0] == '4': # type: ignore
return self.VMAction(vmId, 'delete') return self.set_machine_state(vmId, 'delete')
# Version 5 # Version 5
return self.VMAction(vmId, 'terminate-hard') return self.set_machine_state(vmId, 'terminate-hard')
@ensureConnected @ensureConnected
def getVMState(self, vmId: str) -> types.VmState: def get_machine_state(self, vmId: str) -> types.VmState:
""" """
Returns the VM State Returns the VM State
""" """
return self.VMInfo(vmId).state return self.VMInfo(vmId).state
@ensureConnected @ensureConnected
def getVMSubstate(self, vmId: str) -> int: def get_machine_substate(self, vmId: str) -> int:
""" """
Returns the VM State Returns the VM State
""" """
@ -392,6 +392,6 @@ class OpenNebulaClient: # pylint: disable=too-many-public-methods
return -1 return -1
@ensureConnected @ensureConnected
def VMAction(self, vmId: str, action: str) -> str: def set_machine_state(self, vmId: str, action: str) -> str:
result = self.connection.one.vm.action(self.sessionString, action, int(vmId)) result = self.connection.one.vm.action(self.sessionString, action, int(vmId))
return checkResultRaw(result) return checkResultRaw(result)

View File

@ -36,6 +36,8 @@ import collections.abc
from defusedxml import minidom from defusedxml import minidom
from uds.core import types as core_types
from . import types from . import types
# Not imported at runtime, just for type checking # Not imported at runtime, just for type checking
@ -45,7 +47,7 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def getMachineState(api: 'client.OpenNebulaClient', machineId: str) -> types.VmState: def get_machine_state(api: 'client.OpenNebulaClient', machine_id: str) -> types.VmState:
''' '''
Returns the state of the machine Returns the state of the machine
This method do not uses cache at all (it always tries to get machine state from OpenNebula server) This method do not uses cache at all (it always tries to get machine state from OpenNebula server)
@ -57,30 +59,26 @@ def getMachineState(api: 'client.OpenNebulaClient', machineId: str) -> types.VmS
one of the on.VmState Values one of the on.VmState Values
''' '''
try: try:
return api.getVMState(machineId) return api.get_machine_state(machine_id)
except Exception as e: except Exception as e:
logger.error( logger.error('Error obtaining machine state for %s on OpenNebula: %s', machine_id, e)
'Error obtaining machine state for %s on OpenNebula: %s', machineId, e
)
return types.VmState.UNKNOWN return types.VmState.UNKNOWN
def getMachineSubstate(api: 'client.OpenNebulaClient', machineId: str) -> int: def get_machine_substate(api: 'client.OpenNebulaClient', machineId: str) -> int:
''' '''
Returns the lcm_state Returns the lcm_state
''' '''
try: try:
return api.getVMSubstate(machineId) return api.get_machine_substate(machineId)
except Exception as e: except Exception as e:
logger.error( logger.error('Error obtaining machine substate for %s on OpenNebula: %s', machineId, e)
'Error obtaining machine substate for %s on OpenNebula: %s', machineId, e
)
return types.VmState.UNKNOWN.value return types.VmState.UNKNOWN.value
def startMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def start_machine(api: 'client.OpenNebulaClient', machine_id: str) -> None:
''' '''
Tries to start a machine. No check is done, it is simply requested to OpenNebula. Tries to start a machine. No check is done, it is simply requested to OpenNebula.
@ -92,13 +90,13 @@ def startMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
Returns: Returns:
''' '''
try: try:
api.VMAction(machineId, 'resume') api.set_machine_state(machine_id, 'resume')
except Exception: except Exception:
# MAybe the machine is already running. If we get error here, simply ignore it for now... # MAybe the machine is already running. If we get error here, simply ignore it for now...
pass pass
def stopMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def stop_machine(api: 'client.OpenNebulaClient', machine_id: str) -> None:
''' '''
Tries to start a machine. No check is done, it is simply requested to OpenNebula Tries to start a machine. No check is done, it is simply requested to OpenNebula
@ -108,12 +106,12 @@ def stopMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
Returns: Returns:
''' '''
try: try:
api.VMAction(machineId, 'poweroff-hard') api.set_machine_state(machine_id, 'poweroff-hard')
except Exception as e: except Exception as e:
logger.error('Error powering off %s on OpenNebula: %s', machineId, e) logger.error('Error powering off %s on OpenNebula: %s', machine_id, e)
def suspendMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def suspend_machine(api: 'client.OpenNebulaClient', machine_id: str) -> None:
''' '''
Tries to suspend a machine. No check is done, it is simply requested to OpenNebula Tries to suspend a machine. No check is done, it is simply requested to OpenNebula
@ -123,12 +121,12 @@ def suspendMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
Returns: Returns:
''' '''
try: try:
api.VMAction(machineId, 'suspend') api.set_machine_state(machine_id, 'suspend')
except Exception as e: except Exception as e:
logger.error('Error suspending %s on OpenNebula: %s', machineId, e) logger.error('Error suspending %s on OpenNebula: %s', machine_id, e)
def shutdownMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def shutdown_machine(api: 'client.OpenNebulaClient', machine_id: str) -> None:
''' '''
Tries to "gracefully" shutdown a machine. No check is done, it is simply requested to OpenNebula Tries to "gracefully" shutdown a machine. No check is done, it is simply requested to OpenNebula
@ -138,12 +136,12 @@ def shutdownMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
Returns: Returns:
''' '''
try: try:
api.VMAction(machineId, 'poweroff') api.set_machine_state(machine_id, 'poweroff')
except Exception as e: except Exception as e:
logger.error('Error shutting down %s on OpenNebula: %s', machineId, e) logger.error('Error shutting down %s on OpenNebula: %s', machine_id, e)
def resetMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def reset_machine(api: 'client.OpenNebulaClient', machineId: str) -> None:
''' '''
Tries to suspend a machine. No check is done, it is simply requested to OpenNebula Tries to suspend a machine. No check is done, it is simply requested to OpenNebula
@ -153,12 +151,12 @@ def resetMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
Returns: Returns:
''' '''
try: try:
api.VMAction(machineId, 'reboot-hard') api.set_machine_state(machineId, 'reboot-hard')
except Exception as e: except Exception as e:
logger.error('Error reseting %s on OpenNebula: %s', machineId, e) logger.error('Error reseting %s on OpenNebula: %s', machineId, e)
def removeMachine(api: 'client.OpenNebulaClient', machineId: str) -> None: def remove_machine(api: 'client.OpenNebulaClient', machineId: str) -> None:
''' '''
Tries to delete a machine. No check is done, it is simply requested to OpenNebula Tries to delete a machine. No check is done, it is simply requested to OpenNebula
@ -170,14 +168,14 @@ def removeMachine(api: 'client.OpenNebulaClient', machineId: str) -> None:
try: try:
# vm = oca.VirtualMachine.new_with_id(api, int(machineId)) # vm = oca.VirtualMachine.new_with_id(api, int(machineId))
# vm.delete() # vm.delete()
api.deleteVM(machineId) api.remove_machine(machineId)
except Exception as e: except Exception as e:
err = 'Error removing machine {} on OpenNebula: {}'.format(machineId, e) err = 'Error removing machine {} on OpenNebula: {}'.format(machineId, e)
logger.exception(err) logger.exception(err)
raise Exception(err) raise Exception(err)
def enumerateMachines( def enumerate_machines(
api: 'client.OpenNebulaClient', api: 'client.OpenNebulaClient',
) -> collections.abc.Iterable[types.VirtualMachineType]: ) -> collections.abc.Iterable[types.VirtualMachineType]:
''' '''
@ -197,7 +195,7 @@ def enumerateMachines(
yield from api.enumVMs() yield from api.enumVMs()
def getNetInfo( def get_network_info(
api: 'client.OpenNebulaClient', api: 'client.OpenNebulaClient',
machineId: str, machineId: str,
networkId: typing.Optional[str] = None, networkId: typing.Optional[str] = None,
@ -216,9 +214,7 @@ def getNetInfo(
node = nic node = nic
break break
except Exception: except Exception:
raise Exception( raise Exception('No network interface found on template. Please, add a network and republish.')
'No network interface found on template. Please, add a network and republish.'
)
logger.debug(node.toxml()) logger.debug(node.toxml())
@ -231,14 +227,12 @@ def getNetInfo(
return (node.getElementsByTagName('MAC')[0].childNodes[0].data, ip) return (node.getElementsByTagName('MAC')[0].childNodes[0].data, ip)
except Exception: except Exception:
raise Exception( raise Exception('No network interface found on template. Please, add a network and republish.')
'No network interface found on template. Please, add a network and republish.'
)
def getDisplayConnection( def get_console_connection(
api: 'client.OpenNebulaClient', machineId: str api: 'client.OpenNebulaClient', machineId: str
) -> typing.Optional[dict[str, typing.Any]]: ) -> typing.Optional[core_types.services.ConsoleConnectionInfo]:
''' '''
If machine is not running or there is not a display, will return NONE If machine is not running or there is not a display, will return NONE
SPICE connections should check that 'type' is 'SPICE' SPICE connections should check that 'type' is 'SPICE'
@ -255,13 +249,16 @@ def getDisplayConnection(
passwd = '' passwd = ''
lastChild: typing.Any = md.getElementsByTagName('HISTORY_RECORDS')[0].lastChild lastChild: typing.Any = md.getElementsByTagName('HISTORY_RECORDS')[0].lastChild
host = ( address = lastChild.getElementsByTagName('HOSTNAME')[0].childNodes[0].data if lastChild else ''
lastChild.getElementsByTagName('HOSTNAME')[0]
.childNodes[0] return core_types.services.ConsoleConnectionInfo(
.data type=type_,
if lastChild else '' address=address,
port=int(port),
secure_port=-1,
cert_subject='',
ticket=core_types.services.ConsoleConnectionTicket(value=passwd),
) )
return {'type': type_, 'host': host, 'port': int(port), 'passwd': passwd}
except Exception: except Exception:
return None # No SPICE connection return None # No SPICE connection

View File

@ -31,6 +31,7 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
''' '''
import collections.abc import collections.abc
import dis
import logging import logging
import typing import typing
@ -78,9 +79,7 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
# but used for sample purposes # but used for sample purposes
# If we don't indicate an order, the output order of fields will be # If we don't indicate an order, the output order of fields will be
# "random" # "random"
host = gui.TextField( host = gui.TextField(length=64, label=_('Host'), order=1, tooltip=_('OpenNebula Host'), required=True)
length=64, label=_('Host'), order=1, tooltip=_('OpenNebula Host'), required=True
)
port = gui.NumericField( port = gui.NumericField(
length=5, length=5,
label=_('Port'), label=_('Port'),
@ -132,16 +131,12 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
return 'http{}://{}:{}/RPC2'.format( return 'http{}://{}:{}/RPC2'.format('s' if self.ssl.as_bool() else '', self.host.value, self.port.value)
's' if self.ssl.as_bool() else '', self.host.value, self.port.value
)
@property @property
def api(self) -> on.client.OpenNebulaClient: def api(self) -> on.client.OpenNebulaClient:
if self._api is None: if self._api is None:
self._api = on.client.OpenNebulaClient( self._api = on.client.OpenNebulaClient(self.username.value, self.password.value, self.endpoint)
self.username.value, self.password.value, self.endpoint
)
return self._api return self._api
@ -171,14 +166,10 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
return [True, _('Opennebula test connection passed')] return [True, _('Opennebula test connection passed')]
def getDatastores( def getDatastores(self, datastoreType: int = 0) -> collections.abc.Iterable[on.types.StorageType]:
self, datastoreType: int = 0
) -> collections.abc.Iterable[on.types.StorageType]:
yield from on.storage.enumerateDatastores(self.api, datastoreType) yield from on.storage.enumerateDatastores(self.api, datastoreType)
def getTemplates( def getTemplates(self, force: bool = False) -> collections.abc.Iterable[on.types.TemplateType]:
self, force: bool = False
) -> collections.abc.Iterable[on.types.TemplateType]:
yield from on.template.getTemplates(self.api, force) yield from on.template.getTemplates(self.api, force)
def make_template(self, fromTemplateId: str, name, toDataStore: str) -> str: def make_template(self, fromTemplateId: str, name, toDataStore: str) -> str:
@ -204,13 +195,13 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
one of the on.VmState Values one of the on.VmState Values
''' '''
return on.vm.getMachineState(self.api, machineId) return on.vm.get_machine_state(self.api, machineId)
def getMachineSubstate(self, machineId: str) -> int: def getMachineSubstate(self, machineId: str) -> int:
''' '''
Returns the LCM_STATE of a machine (STATE must be ready or this will return -1) Returns the LCM_STATE of a machine (STATE must be ready or this will return -1)
''' '''
return on.vm.getMachineSubstate(self.api, machineId) return on.vm.get_machine_substate(self.api, machineId)
def startMachine(self, machineId: str) -> None: def startMachine(self, machineId: str) -> None:
''' '''
@ -223,7 +214,7 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
''' '''
on.vm.startMachine(self.api, machineId) on.vm.start_machine(self.api, machineId)
def stopMachine(self, machineId: str) -> None: def stopMachine(self, machineId: str) -> None:
''' '''
@ -234,7 +225,7 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
''' '''
on.vm.stopMachine(self.api, machineId) on.vm.stop_machine(self.api, machineId)
def suspendMachine(self, machineId: str) -> None: def suspendMachine(self, machineId: str) -> None:
''' '''
@ -245,7 +236,7 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
''' '''
on.vm.suspendMachine(self.api, machineId) on.vm.suspend_machine(self.api, machineId)
def shutdownMachine(self, machineId: str) -> None: def shutdownMachine(self, machineId: str) -> None:
''' '''
@ -256,13 +247,13 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
''' '''
on.vm.shutdownMachine(self.api, machineId) on.vm.shutdown_machine(self.api, machineId)
def resetMachine(self, machineId: str) -> None: def resetMachine(self, machineId: str) -> None:
''' '''
Resets a machine (hard-reboot) Resets a machine (hard-reboot)
''' '''
on.vm.resetMachine(self.api, machineId) on.vm.reset_machine(self.api, machineId)
def removeMachine(self, machineId: str) -> None: def removeMachine(self, machineId: str) -> None:
''' '''
@ -273,31 +264,21 @@ class OpenNebulaProvider(ServiceProvider): # pylint: disable=too-many-public-me
Returns: Returns:
''' '''
on.vm.removeMachine(self.api, machineId) on.vm.remove_machine(self.api, machineId)
def getNetInfo( def getNetInfo(self, machineId: str, networkId: typing.Optional[str] = None) -> tuple[str, str]:
self, machineId: str, networkId: typing.Optional[str] = None
) -> tuple[str, str]:
''' '''
Changes the mac address of first nic of the machine to the one specified Changes the mac address of first nic of the machine to the one specified
''' '''
return on.vm.getNetInfo(self.api, machineId, networkId) return on.vm.get_network_info(self.api, machineId, networkId)
def getConsoleConnection(self, machineId: str) -> dict[str, typing.Any]: def get_console_connection(self, machine_id: str) -> typing.Optional[types.services.ConsoleConnectionInfo]:
display = on.vm.getDisplayConnection(self.api, machineId) console_connection_info = on.vm.get_console_connection(self.api, machine_id)
if display is None: if console_connection_info is None:
raise Exception('Invalid console connection on OpenNebula!!!') raise Exception('Invalid console connection on OpenNebula!!!')
return { return console_connection_info
'type': display['type'],
'address': display['host'],
'port': display['port'],
'secure_port': -1,
'monitors': 1,
'cert_subject': '',
'ticket': {'value': display['passwd'], 'expiry': ''},
}
def desktop_login(self, machineId: str, username: str, password: str, domain: str) -> dict[str, typing.Any]: def desktop_login(self, machineId: str, username: str, password: str, domain: str) -> dict[str, typing.Any]:
''' '''

View File

@ -309,8 +309,8 @@ class OpenNebulaLiveService(services.Service):
""" """
return self.lenName.as_int() return self.lenName.as_int()
def getConsoleConnection(self, machineId: str) -> dict[str, typing.Any]: def get_console_connection(self, machineId: str) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.provider().getConsoleConnection(machineId) return self.provider().get_console_connection(machineId)
def desktop_login( def desktop_login(
self, machineId: str, username: str, password: str, domain: str self, machineId: str, username: str, password: str, domain: str

View File

@ -39,7 +39,7 @@ import urllib.parse
import requests import requests
from uds.core import consts from uds.core import consts, types as core_types
from uds.core.util import security from uds.core.util import security
from uds.core.util.decorators import cached, ensure_connected from uds.core.util.decorators import cached, ensure_connected
@ -256,8 +256,8 @@ class ProxmoxClient:
@ensure_connected @ensure_connected
@cached('cluster', CACHE_DURATION, key_fnc=caching_key_helper) @cached('cluster', CACHE_DURATION, key_fnc=caching_key_helper)
def get_cluster_info(self, **kwargs) -> types.ClusterStatus: def get_cluster_info(self, **kwargs) -> types.ClusterInfo:
return types.ClusterStatus.from_dict(self._get('cluster/status')) return types.ClusterInfo.from_dict(self._get('cluster/status'))
@ensure_connected @ensure_connected
def get_next_vmid(self) -> int: def get_next_vmid(self) -> int:
@ -273,31 +273,21 @@ class ProxmoxClient:
@ensure_connected @ensure_connected
@cached('nodeNets', CACHE_DURATION, args=1, kwargs=['node'], key_fnc=caching_key_helper) @cached('nodeNets', CACHE_DURATION, args=1, kwargs=['node'], key_fnc=caching_key_helper)
def get_node_netoworks(self, node: str, **kwargs): def get_node_networks(self, node: str, **kwargs) -> typing.Any:
return self._get('nodes/{}/network'.format(node))['data'] return self._get('nodes/{}/network'.format(node))['data']
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ensure_connected @ensure_connected
@cached( @cached('nodeGpuDevices', CACHE_DURATION_LONG, key_fnc=caching_key_helper)
'nodeGpuDevices',
CACHE_DURATION_LONG,
key_fnc=caching_key_helper
)
def list_node_gpu_devices(self, node: str, **kwargs) -> list[str]: def list_node_gpu_devices(self, node: str, **kwargs) -> list[str]:
return [ return [
device['id'] for device in self._get(f'nodes/{node}/hardware/pci')['data'] if device.get('mdev') device['id'] for device in self._get(f'nodes/{node}/hardware/pci')['data'] if device.get('mdev')
] ]
@ensure_connected @ensure_connected
def list_node_vgpus(self, node: str, **kwargs) -> list[typing.Any]: def list_node_vgpus(self, node: str, **kwargs) -> list[types.VGPUInfo]:
return [ return [
{ types.VGPUInfo.from_dict(gpu)
'name': gpu['name'],
'description': gpu['description'],
'device': device,
'available': gpu['available'],
'type': gpu['type'],
}
for device in self.list_node_gpu_devices(node) for device in self.list_node_gpu_devices(node)
for gpu in self._get(f'nodes/{node}/hardware/pci/{device}/mdev')['data'] for gpu in self._get(f'nodes/{node}/hardware/pci/{device}/mdev')['data']
] ]
@ -305,7 +295,7 @@ class ProxmoxClient:
@ensure_connected @ensure_connected
def node_has_vgpus_available(self, node: str, vgpu_type: typing.Optional[str], **kwargs) -> bool: def node_has_vgpus_available(self, node: str, vgpu_type: typing.Optional[str], **kwargs) -> bool:
return any( return any(
gpu['available'] and vgpu_type and gpu['type'] == vgpu_type for gpu in self.list_node_vgpus(node) gpu.available and (vgpu_type is None or gpu.type == vgpu_type) for gpu in self.list_node_vgpus(node)
) )
@ensure_connected @ensure_connected
@ -591,11 +581,7 @@ class ProxmoxClient:
return self.get_machine_info(vmid, node, **kwargs) return self.get_machine_info(vmid, node, **kwargs)
@ensure_connected @ensure_connected
@cached( @cached('vmin', CACHE_INFO_DURATION, key_fnc=caching_key_helper)
'vmin',
CACHE_INFO_DURATION,
key_fnc=caching_key_helper
)
def get_machine_info(self, vmid: int, node: typing.Optional[str] = None, **kwargs) -> types.VMInfo: def get_machine_info(self, vmid: int, node: typing.Optional[str] = None, **kwargs) -> types.VMInfo:
nodes = [types.Node(node, False, False, 0, '', '', '')] if node else self.get_cluster_info().nodes nodes = [types.Node(node, False, False, 0, '', '', '')] if node else self.get_cluster_info().nodes
any_node_is_down = False any_node_is_down = False
@ -617,7 +603,7 @@ class ProxmoxClient:
raise ProxmoxNotFound() raise ProxmoxNotFound()
@ensure_connected @ensure_connected
def get_machine_configuration(self, vmid: int, node: typing.Optional[str] = None, **kwargs): def get_machine_configuration(self, vmid: int, node: typing.Optional[str] = None, **kwargs) -> types.VMConfiguration:
node = node or self.get_machine_info(vmid).node node = node or self.get_machine_info(vmid).node
return types.VMConfiguration.from_dict(self._get('nodes/{}/qemu/{}/config'.format(node, vmid))['data']) return types.VMConfiguration.from_dict(self._get('nodes/{}/qemu/{}/config'.format(node, vmid))['data'])
@ -686,7 +672,7 @@ class ProxmoxClient:
self.get_machine_info(vmid, force=True) self.get_machine_info(vmid, force=True)
# proxmox has a "resume", but start works for suspended vm so we use it # proxmox has a "resume", but start works for suspended vm so we use it
resumeVm = start_machine resume_machine = start_machine
@ensure_connected @ensure_connected
@cached('storage', CACHE_DURATION, key_fnc=caching_key_helper) @cached('storage', CACHE_DURATION, key_fnc=caching_key_helper)
@ -753,26 +739,24 @@ class ProxmoxClient:
@ensure_connected @ensure_connected
def get_console_connection( def get_console_connection(
self, vmId: int, node: typing.Optional[str] = None self, vmId: int, node: typing.Optional[str] = None
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[core_types.services.ConsoleConnectionInfo]:
""" """
Gets the connetion info for the specified machine Gets the connetion info for the specified machine
""" """
node = node or self.get_machine_info(vmId).node node = node or self.get_machine_info(vmId).node
res = self._post(f'nodes/{node}/qemu/{vmId}/spiceproxy')['data'] res: dict = self._post(f'nodes/{node}/qemu/{vmId}/spiceproxy')['data']
return core_types.services.ConsoleConnectionInfo(
return { type=res['type'],
'type': res['type'], proxy=res['proxy'],
'proxy': res['proxy'], address=res['host'],
'address': res['host'], port=res.get('port', None),
'port': res.get('port', None), secure_port=res['tls-port'],
'secure_port': res['tls-port'], cert_subject=res['host-subject'],
'cert_subject': res['host-subject'], ticket=core_types.services.ConsoleConnectionTicket(
'ticket': { value=res['password']
'value': res['password'], ),
'expiry': '', ca=res.get('ca', None),
}, )
'ca': res.get('ca', None),
}
# Sample data: # Sample data:
# 'data': {'proxy': 'http://pvealone.dkmon.com:3128', # 'data': {'proxy': 'http://pvealone.dkmon.com:3128',
# 'release-cursor': 'Ctrl+Alt+R', # 'release-cursor': 'Ctrl+Alt+R',

View File

@ -17,9 +17,9 @@ CONVERSORS: typing.Final[collections.abc.MutableMapping[typing.Any, collections.
float: lambda x: float(x or '0'), float: lambda x: float(x or '0'),
typing.Optional[float]: lambda x: float(x or '0') if x is not None else None, typing.Optional[float]: lambda x: float(x or '0') if x is not None else None,
datetime.datetime: lambda x: datetime.datetime.fromtimestamp(int(x)), datetime.datetime: lambda x: datetime.datetime.fromtimestamp(int(x)),
typing.Optional[datetime.datetime]: lambda x: datetime.datetime.fromtimestamp(int(x)) typing.Optional[datetime.datetime]: lambda x: (
if x is not None datetime.datetime.fromtimestamp(int(x)) if x is not None else None
else None, ),
} }
@ -101,12 +101,12 @@ class NodeStats(typing.NamedTuple):
) )
class ClusterStatus(typing.NamedTuple): class ClusterInfo(typing.NamedTuple):
cluster: typing.Optional[Cluster] cluster: typing.Optional[Cluster]
nodes: list[Node] nodes: list[Node]
@staticmethod @staticmethod
def from_dict(dictionary: collections.abc.MutableMapping[str, typing.Any]) -> 'ClusterStatus': def from_dict(dictionary: collections.abc.MutableMapping[str, typing.Any]) -> 'ClusterInfo':
nodes: list[Node] = [] nodes: list[Node] = []
cluster: typing.Optional[Cluster] = None cluster: typing.Optional[Cluster] = None
@ -116,7 +116,7 @@ class ClusterStatus(typing.NamedTuple):
else: else:
nodes.append(Node.from_dict(i)) nodes.append(Node.from_dict(i))
return ClusterStatus(cluster=cluster, nodes=nodes) return ClusterInfo(cluster=cluster, nodes=nodes)
class UPID(typing.NamedTuple): class UPID(typing.NamedTuple):
@ -311,8 +311,8 @@ class PoolInfo(typing.NamedTuple):
class SnapshotInfo(typing.NamedTuple): class SnapshotInfo(typing.NamedTuple):
description: str
name: str name: str
description: str
parent: typing.Optional[str] parent: typing.Optional[str]
snaptime: typing.Optional[int] snaptime: typing.Optional[int]
@ -321,3 +321,15 @@ class SnapshotInfo(typing.NamedTuple):
@staticmethod @staticmethod
def from_dict(dictionary: collections.abc.MutableMapping[str, typing.Any]) -> 'SnapshotInfo': def from_dict(dictionary: collections.abc.MutableMapping[str, typing.Any]) -> 'SnapshotInfo':
return _from_dict(SnapshotInfo, dictionary) return _from_dict(SnapshotInfo, dictionary)
class VGPUInfo(typing.NamedTuple):
name: str
description: str
device: str
available: bool
type: str
@staticmethod
def from_dict(dictionary: collections.abc.MutableMapping[str, typing.Any]) -> 'VGPUInfo':
return _from_dict(VGPUInfo, dictionary)

View File

@ -36,7 +36,7 @@ import logging
import typing import typing
import collections.abc import collections.abc
from uds.core import services, consts from uds.core import services, consts, types
from uds.core.managers.userservice import UserServiceManager from uds.core.managers.userservice import UserServiceManager
from uds.core.types.states import State from uds.core.types.states import State
from uds.core.util import log, autoserializable from uds.core.util import log, autoserializable
@ -213,7 +213,7 @@ class ProxmoxDeployment(services.UserService, autoserializable.AutoSerializable)
def get_console_connection( def get_console_connection(
self, self,
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.service().get_console_connection(self._vmid) return self.service().get_console_connection(self._vmid)
def desktop_login( def desktop_login(

View File

@ -247,10 +247,10 @@ class ProxmoxProvider(services.ServiceProvider):
def get_console_connection( def get_console_connection(
self, self,
vmid: str, machine_id: str,
node: typing.Optional[str] = None, node: typing.Optional[str] = None,
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self._api().get_console_connection(int(vmid), node) return self._api().get_console_connection(int(machine_id), node)
def get_new_vmid(self) -> int: def get_new_vmid(self) -> int:
while True: # look for an unused VmId while True: # look for an unused VmId

View File

@ -302,9 +302,9 @@ class ProxmoxLinkedService(services.Service): # pylint: disable=too-many-public
return self.soft_shutdown_field.as_bool() return self.soft_shutdown_field.as_bool()
def get_console_connection( def get_console_connection(
self, machineId: str self, machine_id: str
) -> typing.Optional[collections.abc.MutableMapping[str, typing.Any]]: ) -> typing.Optional[types.services.ConsoleConnectionInfo]:
return self.provider().get_console_connection(machineId) return self.provider().get_console_connection(machine_id)
@cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT) @cached('reachable', consts.cache.SHORT_CACHE_TIMEOUT)
def is_avaliable(self) -> bool: def is_avaliable(self) -> bool:

View File

@ -55,6 +55,7 @@ class SPICETransport(BaseSpiceTransport):
Provides access via SPICE to service. Provides access via SPICE to service.
This transport can use an domain. If username processed by authenticator contains '@', it will split it and left-@-part will be username, and right password This transport can use an domain. If username processed by authenticator contains '@', it will split it and left-@-part will be username, and right password
""" """
is_base = False is_base = False
type_name = _('SPICE') type_name = _('SPICE')
@ -83,29 +84,31 @@ class SPICETransport(BaseSpiceTransport):
request: 'ExtendedHttpRequestWithUser', request: 'ExtendedHttpRequestWithUser',
) -> 'types.transports.TransportScript': ) -> 'types.transports.TransportScript':
try: try:
userServiceInstance = userService.get_instance() userservice_instance = userService.get_instance()
con: typing.Optional[collections.abc.MutableMapping[str, typing.Any]] = userServiceInstance.get_console_connection() con: typing.Optional[types.services.ConsoleConnectionInfo] = (
userservice_instance.get_console_connection()
)
except Exception: except Exception:
logger.exception('Error getting console connection data') logger.exception('Error getting console connection data')
raise raise
logger.debug('Connection data: %s', con) logger.debug('Connection data: %s', con)
if not con: if not con:
raise exceptions.service.TransportError('No console connection data') raise exceptions.service.TransportError('No console connection data')
port: str = con['port'] or '-1' port: str = str(con.port) or '-1'
secure_port: str = con['secure_port'] or '-1' secure_port: str = str(con.secure_port) or '-1'
r = RemoteViewerFile( r = RemoteViewerFile(
con['address'], con.address,
port, port,
secure_port, secure_port,
con['ticket']['value'], con.ticket.value,
con.get('ca', self.server_certificate.value.strip()), con.ca or self.server_certificate.value.strip(),
con['cert_subject'], con.cert_subject,
fullscreen=self.fullscreen.as_bool(), fullscreen=self.fullscreen.as_bool(),
) )
r.proxy = self.overrided_proxy.value.strip() or con.get('proxy', None) r.proxy = self.overrided_proxy.value.strip() or con.proxy or ''
r.usb_auto_share = self.allow_usb_redirection.as_bool() r.usb_auto_share = self.allow_usb_redirection.as_bool()
r.new_usb_auto_share = self.allow_usb_redirection_new_plugs.as_bool() r.new_usb_auto_share = self.allow_usb_redirection_new_plugs.as_bool()
@ -122,7 +125,4 @@ class SPICETransport(BaseSpiceTransport):
try: try:
return self.get_script(os.os.os_name(), 'direct', sp) return self.get_script(os.os.os_name(), 'direct', sp)
except Exception: except Exception:
return super().get_transport_script( return super().get_transport_script(userService, transport, ip, os, user, password, request)
userService, transport, ip, os, user, password, request
)

View File

@ -142,13 +142,13 @@ class BaseSpiceTransport(transports.Transport):
old_field_name='overridedProxy', old_field_name='overridedProxy',
) )
def is_ip_allowed(self, userService: 'models.UserService', ip: str) -> bool: def is_ip_allowed(self, userservice: 'models.UserService', ip: str) -> bool:
""" """
Checks if the transport is available for the requested destination ip Checks if the transport is available for the requested destination ip
""" """
ready = self.cache.get(ip) ready = self.cache.get(ip)
if ready is None: if ready is None:
userServiceInstance = userService.get_instance() userServiceInstance = userservice.get_instance()
con = userServiceInstance.get_console_connection() con = userServiceInstance.get_console_connection()
logger.debug('Connection data: %s', con) logger.debug('Connection data: %s', con)
@ -156,36 +156,34 @@ class BaseSpiceTransport(transports.Transport):
if con is None: if con is None:
return False return False
if 'proxy' in con: # If we have a proxy, we can't check if it is available, return True if con.proxy is not None:
return True return True
port, secure_port = con['port'] or -1, con['secure_port'] or -1
# test ANY of the ports # test ANY of the ports
port_to_test = port if port != -1 else secure_port port_to_test = con.port if con.port != -1 else con.secure_port
if port_to_test == -1: if port_to_test == -1:
self.cache.put( self.cache.put(
'cachedMsg', 'Could not find the PORT for connection', 120 'cached_message', 'Could not find the PORT for connection', 120
) # Write a message, that will be used from getCustom ) # Write a message, that will be used from getCustom
logger.info('SPICE didn\'t find has any port: %s', con) logger.info('SPICE didn\'t find has any port: %s', con)
return False return False
self.cache.put( self.cache.put(
'cachedMsg', 'cached_message',
'Could not reach server "{}" on port "{}" from broker (prob. causes are name resolution & firewall rules)'.format( 'Could not reach server "{}" on port "{}" from broker (prob. causes are name resolution & firewall rules)'.format(
con['address'], port_to_test con.address, port_to_test
), ),
120, 120,
) )
if self.test_connectivity(userService, con['address'], port_to_test) is True: if self.test_connectivity(userservice, con.address, port_to_test) is True:
self.cache.put(ip, 'Y', READY_CACHE_TIMEOUT) self.cache.put(ip, 'Y', READY_CACHE_TIMEOUT)
ready = 'Y' ready = 'Y'
return ready == 'Y' return ready == 'Y'
def get_available_error_msg(self, userService: 'models.UserService', ip: str) -> str: def get_available_error_msg(self, userService: 'models.UserService', ip: str) -> str:
msg = self.cache.get('cachedMsg') msg = self.cache.get('cached_message')
if msg is None: if msg is None:
return transports.Transport.get_available_error_msg(self, userService, ip) return transports.Transport.get_available_error_msg(self, userService, ip)
return msg return msg

View File

@ -119,24 +119,24 @@ class TSPICETransport(BaseSpiceTransport):
ticket = '' ticket = ''
ticket_secure = '' ticket_secure = ''
if 'proxy' in con: if con.proxy:
logger.exception('Proxied SPICE tunnels are not suppoorted') logger.exception('Proxied SPICE tunnels are not suppoorted')
return super().get_transport_script( return super().get_transport_script(
userService, transport, ip, os, user, password, request userService, transport, ip, os, user, password, request
) )
if con['port']: if con.port:
ticket = TicketStore.create_for_tunnel( ticket = TicketStore.create_for_tunnel(
userService=userService, userService=userService,
port=int(con['port']), port=int(con.port),
validity=self.tunnel_wait.as_int() + 60, # Ticket overtime validity=self.tunnel_wait.as_int() + 60, # Ticket overtime
) )
if con['secure_port']: if con.secure_port:
ticket_secure = TicketStore.create_for_tunnel( ticket_secure = TicketStore.create_for_tunnel(
userService=userService, userService=userService,
port=int(con['secure_port']), port=int(con.secure_port),
host=con['address'], host=con.address,
validity=self.tunnel_wait.as_int() + 60, # Ticket overtime validity=self.tunnel_wait.as_int() + 60, # Ticket overtime
) )
@ -144,9 +144,9 @@ class TSPICETransport(BaseSpiceTransport):
'127.0.0.1', '127.0.0.1',
'{port}', '{port}',
'{secure_port}', '{secure_port}',
con['ticket']['value'], # This is secure ticket from kvm, not UDS ticket con.ticket.value, # This is secure ticket from kvm, not UDS ticket
con.get('ca', self.server_certificate.value.strip()), con.ca or self.server_certificate.value.strip(),
con['cert_subject'], con.cert_subject,
fullscreen=self.fullscreen.as_bool(), fullscreen=self.fullscreen.as_bool(),
) )

View File

@ -31,22 +31,289 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com Author: Adolfo Gómez, dkmaster at dkmon dot com
""" """
import typing import typing
import datetime
import collections.abc import collections.abc
import itertools
from unittest import mock from unittest import mock
from ...utils.test import UDSTestCase
from ...utils.autospec import autospec, AutoSpecMethodInfo from ...utils.autospec import autospec, AutoSpecMethodInfo
from uds.services.Proxmox import provider, client from uds.services.Proxmox import provider, client as pc
METHODS_INFO: typing.Final[list[AutoSpecMethodInfo]] = [ NODES: typing.Final[list[pc.types.Node]] = [
AutoSpecMethodInfo('test', method=mock.Mock(return_value=True)), pc.types.Node(name='node0', online=True, local=True, nodeid=1, ip='0.0.0.1', level='level', id='id'),
pc.types.Node(name='node1', online=True, local=True, nodeid=2, ip='0.0.0.2', level='level', id='id'),
]
NODE_STATS: typing.Final[list[pc.types.NodeStats]] = [
pc.types.NodeStats(
name='name',
status='status',
uptime=1,
disk=1,
maxdisk=1,
level='level',
id='id',
mem=1,
maxmem=1,
cpu=1.0,
maxcpu=1,
),
pc.types.NodeStats(
name='name',
status='status',
uptime=1,
disk=1,
maxdisk=1,
level='level',
id='id',
mem=1,
maxmem=1,
cpu=1.0,
maxcpu=1,
),
] ]
class TestProxmovProvider: CLUSTER_INFO: typing.Final[pc.types.ClusterInfo] = pc.types.ClusterInfo(
cluster=pc.types.Cluster(name='name', version='version', id='id', nodes=2, quorate=1),
nodes=NODES,
)
STORAGES: typing.Final[list[pc.types.StorageInfo]] = [
pc.types.StorageInfo(
node=NODES[i%len(NODES)].name,
storage=f'storage_{i}',
content=(f'content{i}',) * (i % 3),
type='type',
shared=(i < 8), # First 8 are shared
active=(i % 5) != 0, # Every 5th is not active
used=1024*1024*1024*i*4,
avail=1024*1024*1024*i*8,
total=1024*1024*1024*i*12,
used_fraction=1.0,
) for i in range(10)
]
VGPUS: typing.Final[list[pc.types.VGPUInfo]] = [
pc.types.VGPUInfo(
name='name_1',
description='description_1',
device='device_1',
available=True,
type='gpu_type_1',
),
pc.types.VGPUInfo(
name='name_2',
description='description_2',
device='device_2',
available=False,
type='gpu_type_2',
),
pc.types.VGPUInfo(
name='name_3',
description='description_3',
device='device_3',
available=True,
type='gpu_type_3',
),
]
VMS_INFO: typing.Final[list[pc.types.VMInfo]] = [
pc.types.VMInfo(
status='status',
vmid=i,
node=NODES[i % len(NODES)].name,
template=True,
agent='agent',
cpus=1,
lock='lock',
disk=1,
maxdisk=1,
mem=1024*1024*1024*i,
maxmem=1024*1024*1024*i*2,
name='name',
pid=1000+i,
qmpstatus='qmpstatus',
tags='tags',
uptime=60*60*24*i,
netin=1,
netout=1,
diskread=1,
diskwrite=1,
vgpu_type=VGPUS[i % len(VGPUS)].type,
)
for i in range(10)
]
VMS_CONFIGURATION: typing.Final[list[pc.types.VMConfiguration]] = [
pc.types.VMConfiguration(
name=f'vm_name_{i}',
vga='cirrus',
sockets=1,
cores=1,
vmgenid='vmgenid',
digest='digest',
networks=[pc.types.NetworkConfiguration(net='net', type='type', mac='mac')],
tpmstate0='tpmstate0',
template=bool(i > 8), # Last two are templates
)
for i in range(10)
]
UPID: typing.Final[pc.types.UPID] = pc.types.UPID(
node=NODES[0].name,
pid=1,
pstart=1,
starttime=datetime.datetime.now(),
type='type',
vmid=VMS_INFO[0].vmid,
user='user',
upid='upid',
)
VM_CREATION_RESULT: typing.Final[pc.types.VmCreationResult] = pc.types.VmCreationResult(
node=NODES[0].name,
vmid=VMS_INFO[0].vmid,
upid=UPID,
)
SNAPSHOTS_INFO: typing.Final[list[pc.types.SnapshotInfo]] = [
pc.types.SnapshotInfo(
name=f'snap_name_{i}',
description=f'snap desription{i}',
parent=f'snap_parent_{i}',
snaptime=int(datetime.datetime.now().timestamp()),
vmstate=bool(i % 2),
)
for i in range(10)
]
TASK_STATUS = pc.types.TaskStatus(
node=NODES[0].name,
pid=1,
pstart=1,
starttime=datetime.datetime.now(),
type='type',
status='status',
exitstatus='exitstatus',
user='user',
upid='upid',
id='id',
)
POOL_MEMBERS: typing.Final[list[pc.types.PoolMemberInfo]] = [
pc.types.PoolMemberInfo(
id=f'id_{i}',
node=NODES[i % len(NODES)].name,
storage=STORAGES[i % len(STORAGES)].storage,
type='type',
vmid=VMS_INFO[i%len(VMS_INFO)].vmid,
vmname=VMS_INFO[i%len(VMS_INFO)].name or '',
) for i in range(10)
]
POOLS: typing.Final[list[pc.types.PoolInfo]] = [
pc.types.PoolInfo(
poolid=f'pool_{i}',
comments=f'comments_{i}',
members=POOL_MEMBERS,
) for i in range(10)
]
# Methods that returns None or "internal" methods are not tested
CLIENT_METHODS_INFO: typing.Final[list[AutoSpecMethodInfo]] = [
# connect returns None
# Test method
AutoSpecMethodInfo('test', method=mock.Mock(return_value=True)),
# get_cluster_info
AutoSpecMethodInfo('get_cluster_info', return_value=CLUSTER_INFO),
# get_next_vmid
AutoSpecMethodInfo('get_next_vmid', return_value=1),
# is_vmid_available
AutoSpecMethodInfo('is_vmid_available', return_value=True),
# get_node_networks, not called never (ensure it's not called by mistake)
# list_node_gpu_devices
AutoSpecMethodInfo('list_node_gpu_devices', return_value=['gpu_dev_1', 'gpu_dev_2']),
# list_node_vgpus
AutoSpecMethodInfo('list_node_vgpus', return_value=VGPUS),
# node_has_vgpus_available
AutoSpecMethodInfo('node_has_vgpus_available', return_value=True),
# get_best_node_for_machine
AutoSpecMethodInfo('get_best_node_for_machine', return_value=NODE_STATS[0]),
# clone_machine
AutoSpecMethodInfo('clone_machine', return_value=VM_CREATION_RESULT),
# list_ha_groups
AutoSpecMethodInfo('list_ha_groups', return_value=['ha_group_1', 'ha_group_2']),
# enable_machine_ha return None
# disable_machine_ha return None
# set_protection return None
# get_guest_ip_address
AutoSpecMethodInfo('get_guest_ip_address', return_value='1.0.0.1'),
# remove_machine
AutoSpecMethodInfo('remove_machine', return_value=UPID),
# list_snapshots
AutoSpecMethodInfo('list_snapshots', return_value=SNAPSHOTS_INFO),
# supports_snapshot
AutoSpecMethodInfo('supports_snapshot', return_value=True),
# create_snapshot
AutoSpecMethodInfo('create_snapshot', return_value=UPID),
# remove_snapshot
AutoSpecMethodInfo('remove_snapshot', return_value=UPID),
# get_current_snapshot
AutoSpecMethodInfo('get_current_snapshot', return_value=SNAPSHOTS_INFO[0].name),
# restore_snapshot
AutoSpecMethodInfo('restore_snapshot', return_value=UPID),
# get_task
AutoSpecMethodInfo('get_task', return_value=TASK_STATUS),
# list_machines
AutoSpecMethodInfo('list_machines', return_value=VMS_INFO),
# get_machines_pool_info
AutoSpecMethodInfo('get_machines_pool_info', return_value=VMS_INFO[0]),
# get_machine_info
AutoSpecMethodInfo('get_machine_info', return_value=VMS_INFO[0]),
# get_machine_configuration
AutoSpecMethodInfo('get_machine_configuration', method=lambda vmid: VMS_CONFIGURATION[vmid - 1]),
# set_machine_ha return None
# start_machine
AutoSpecMethodInfo('start_machine', return_value=UPID),
# stop_machine
AutoSpecMethodInfo('stop_machine', return_value=UPID),
# reset_machine
AutoSpecMethodInfo('reset_machine', return_value=UPID),
# suspend_machine
AutoSpecMethodInfo('suspend_machine', return_value=UPID),
# resume_machine
AutoSpecMethodInfo('resume_machine', return_value=UPID),
# shutdown_machine
AutoSpecMethodInfo('shutdown_machine', return_value=UPID),
# convert_to_template
AutoSpecMethodInfo('convert_to_template', return_value=UPID),
# get_storage
AutoSpecMethodInfo('get_storage', method=lambda storage, node: next(filter(lambda s: s.storage == storage, STORAGES))),
# list_storages
AutoSpecMethodInfo('list_storages', return_value=STORAGES),
# get_node_stats
AutoSpecMethodInfo('get_node_stats', method=lambda node: next(filter(lambda n: n.name == node, NODE_STATS))),
# list_pools
AutoSpecMethodInfo('list_pools', return_value=POOLS),
]
class TestProxmovProvider(UDSTestCase):
def test_provider(self) -> None: def test_provider(self) -> None:
""" """
Test the provider Test the provider
""" """
client = autospec(provider.ProxmoxProvider, METHODS_INFO) client = autospec(pc.ProxmoxClient, CLIENT_METHODS_INFO)
assert client.test() is True assert client.test() is True
assert client.get_cluster_info() == CLUSTER_INFO
assert client.get_next_vmid() == 1
assert client.is_vmid_available(1) is True
assert client.get_machine_configuration(1) == VMS_CONFIGURATION[0]

View File

@ -33,6 +33,8 @@ import typing
import dataclasses import dataclasses
from unittest import mock from unittest import mock
T = typing.TypeVar('T')
@dataclasses.dataclass @dataclasses.dataclass
class AutoSpecMethodInfo: class AutoSpecMethodInfo:
name: str name: str
@ -40,10 +42,12 @@ class AutoSpecMethodInfo:
method: 'typing.Callable|None' = None method: 'typing.Callable|None' = None
def autospec(cls: typing.Type, metods_info: collections.abc.Iterable, **kwargs: typing.Any) -> typing.Any: def autospec(cls: typing.Type[T], metods_info: collections.abc.Iterable, **kwargs: typing.Any) -> T:
""" """
This is a helper function that will create a mock object with the same methods as the class passed as parameter. This is a helper function that will create a mock object with the same methods as the class passed as parameter.
This is useful for testing purposes, where you want to mock a class and still have the same methods available. This is useful for testing purposes, where you want to mock a class and still have the same methods available.
The returned value is in fact a mock object, but with the same methods as the class passed as parameter.
""" """
obj = mock.create_autospec(cls, **kwargs) obj = mock.create_autospec(cls, **kwargs)
for method_info in metods_info: for method_info in metods_info: