mirror of
https://github.com/dkmstr/openuds.git
synced 2025-11-13 20:24:27 +03:00
Compare commits
34 Commits
alert-auto
...
before-ser
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67a58d57cb | ||
|
|
39a046bb23 | ||
|
|
ae16e78a4a | ||
|
|
5c3d7281fa | ||
|
|
72f0c85f75 | ||
|
|
fc39a96850 | ||
|
|
820ba7790d | ||
|
|
a1119a6cc7 | ||
|
|
027be9b680 | ||
|
|
1e93bb702e | ||
|
|
57cfb0d98e | ||
|
|
c88510133a | ||
|
|
f5d4640cb1 | ||
|
|
dfd5cd4206 | ||
|
|
32b5b29ae5 | ||
|
|
eaff4aeb80 | ||
|
|
b63e82dbdb | ||
|
|
144de7122b | ||
|
|
53bd9ed75a | ||
|
|
31325aa194 | ||
|
|
52d34ed303 | ||
|
|
790ac8063e | ||
|
|
bce487168b | ||
|
|
295c820c7c | ||
|
|
a4f6214ed9 | ||
|
|
4ec687567d | ||
|
|
7d16ae03e5 | ||
|
|
046130c77b | ||
|
|
26c9dd0dec | ||
|
|
27de5e065f | ||
|
|
519436176a | ||
|
|
3bbbc9d5dd | ||
|
|
e6549c17d1 | ||
|
|
1970bb89dd |
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
@@ -41,6 +41,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# Install lxmlsec with local libraries to avoid binary wheel issues
|
||||
pip install --upgrade --no-binary lxml --no-binary xmlsec lxml xmlsec
|
||||
# Install other requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Set PYTHONPATH
|
||||
|
||||
@@ -17,3 +17,4 @@ Notes
|
||||
* From `v4.0` onwards (current master), OpenUDS has been splitted in several repositories and contains submodules. Remember to use "git clone --resursive ..." to fetch it ;-).
|
||||
* `v4.0` version needs Python 3.11 (may work fine on newer versions). It uses new features only available on 3.10 or later, and is tested against 3.11. It will probably work on 3.10 too.
|
||||
|
||||
[](https://deepwiki.com/VirtualCable/openuds)
|
||||
|
||||
2
actor
2
actor
Submodule actor updated: 79a7e8bbc2...10b407ced9
22700
server/doc/api/rest.yaml
Normal file
22700
server/doc/api/rest.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,7 @@ cryptography
|
||||
python3-saml
|
||||
six
|
||||
dnspython
|
||||
lxml
|
||||
# lxml must be installed source to avoid conflicts
|
||||
ovirt-engine-sdk-python
|
||||
pycurl
|
||||
matplotlib
|
||||
|
||||
@@ -38,6 +38,7 @@ import traceback
|
||||
|
||||
from django import http
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.generic.base import View
|
||||
|
||||
@@ -203,6 +204,14 @@ class Dispatcher(View):
|
||||
except exceptions.rest.HandlerError as e:
|
||||
log.log_operation(handler, 500, types.log.LogLevel.ERROR)
|
||||
return http.HttpResponseBadRequest(f'{{"error": "{e}"}}'.encode(), content_type="application/json")
|
||||
except exceptions.services.generics.Error as e:
|
||||
log.log_operation(handler, 503, types.log.LogLevel.ERROR)
|
||||
return http.HttpResponseServerError(
|
||||
f'{{"error": "{e}"}}'.encode(), content_type="application/json", status=503
|
||||
)
|
||||
except ObjectDoesNotExist as e: # All DoesNotExist exceptions are not found
|
||||
log.log_operation(handler, 404, types.log.LogLevel.ERROR)
|
||||
return http.HttpResponseNotFound(f'{{"error": "{e}"}}'.encode(), content_type="application/json")
|
||||
except Exception as e:
|
||||
log.log_operation(handler, 500, types.log.LogLevel.ERROR)
|
||||
# Get ecxeption backtrace
|
||||
|
||||
@@ -58,6 +58,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
|
||||
|
||||
class Handler(abc.ABC):
|
||||
"""
|
||||
REST requests handler base class
|
||||
@@ -390,11 +391,48 @@ class Handler(abc.ABC):
|
||||
if name in self._params:
|
||||
return self._params[name]
|
||||
return ''
|
||||
|
||||
def filter_queryset(self, qs: QuerySet[typing.Any]) -> list[typing.Any]:
|
||||
|
||||
def get_sort_field_info(self, *args: str) -> tuple[str, bool] | None:
|
||||
"""
|
||||
Returns sorting information for the first sorting if it is contained in the odata orderby list.
|
||||
|
||||
Args:
|
||||
args: The possible name of the field name to check for sorting information.
|
||||
|
||||
Returns:
|
||||
A tuple containing the clean field name found and a boolean indicating if the sorting is descending,
|
||||
|
||||
Note:
|
||||
We only use the first in case of table sort translations, so this only returns info for the first field
|
||||
"""
|
||||
if self.odata.orderby:
|
||||
order_field = self.odata.orderby[0]
|
||||
clean_field = order_field.lstrip('-')
|
||||
for field_name in args:
|
||||
if clean_field == field_name:
|
||||
is_descending = order_field.startswith('-')
|
||||
return (clean_field, is_descending)
|
||||
return None
|
||||
|
||||
def apply_sort(self, qs: QuerySet[typing.Any]) -> list[typing.Any] | QuerySet[typing.Any]:
|
||||
"""
|
||||
Custom sorting function to apply to querysets.
|
||||
Override this method in subclasses to provide custom sorting logic.
|
||||
|
||||
Args:
|
||||
qs: The queryset to sort.
|
||||
order_by: The field name to sort by.
|
||||
|
||||
Returns:
|
||||
The sorted queryset.
|
||||
"""
|
||||
return qs.order_by(*self.odata.orderby)
|
||||
|
||||
@typing.final
|
||||
def filter_odata_queryset(self, qs: QuerySet[typing.Any]) -> list[typing.Any]:
|
||||
"""
|
||||
Filters the queryset based on odata
|
||||
|
||||
|
||||
Note: We return a list, because after applying slicing, querysets may be evaluated
|
||||
by using _result_cache, so we force evaluation here to avoid issues later.
|
||||
"""
|
||||
@@ -405,28 +443,30 @@ class Handler(abc.ABC):
|
||||
except ValueError as e:
|
||||
raise exceptions.rest.RequestError(f'Invalid odata filter: {e}') from e
|
||||
|
||||
# Store total count before slicing
|
||||
self.add_header('X-Total-Count', str(qs.count()))
|
||||
|
||||
# order_by must be unique and all fields are summited by once
|
||||
# As after slicing we can have a list, we may use list result from sorting
|
||||
if self.odata.orderby:
|
||||
qs = qs.order_by(*self.odata.orderby)
|
||||
|
||||
result = self.apply_sort(qs)
|
||||
else:
|
||||
result = qs
|
||||
|
||||
# If odata start/limit are set, apply them
|
||||
if self.odata.start is not None:
|
||||
qs = qs[self.odata.start :]
|
||||
result = result[self.odata.start :]
|
||||
# Note that limit is AFTER start because of previous line
|
||||
if self.odata.limit is not None:
|
||||
qs = qs[: self.odata.limit]
|
||||
|
||||
result = list(qs)
|
||||
result = result[: self.odata.limit]
|
||||
|
||||
# Get total items and set it on X-Total-Count
|
||||
try:
|
||||
total_items = len(result)
|
||||
self.add_header('X-Total-Count', total_items)
|
||||
except Exception as e:
|
||||
raise exceptions.rest.RequestError(f'Invalid odata: {e}')
|
||||
# After slicing, the qs may be a list, so we ensure it's a list
|
||||
# to avoid issues later
|
||||
result = list(result)
|
||||
|
||||
return result
|
||||
|
||||
def filter_data(self, data: collections.abc.Iterable[T]) -> list[T]:
|
||||
def filter_odata_data(self, data: collections.abc.Iterable[T]) -> list[T]:
|
||||
"""
|
||||
Filters the dict base on the currnet odata
|
||||
"""
|
||||
@@ -446,8 +486,7 @@ class Handler(abc.ABC):
|
||||
raise exceptions.rest.RequestError(f'Invalid odata: {e}')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def api_components(cls: type[typing.Self]) -> types.rest.api.Components:
|
||||
"""
|
||||
@@ -456,7 +495,9 @@ class Handler(abc.ABC):
|
||||
return types.rest.api.Components()
|
||||
|
||||
@classmethod
|
||||
def api_paths(cls: type[typing.Self], path: str, tags: list[str], security: str) -> dict[str, types.rest.api.PathItem]:
|
||||
def api_paths(
|
||||
cls: type[typing.Self], path: str, tags: list[str], security: str
|
||||
) -> dict[str, types.rest.api.PathItem]:
|
||||
"""
|
||||
Returns the API operations that should be registered
|
||||
"""
|
||||
|
||||
@@ -70,7 +70,7 @@ class AccountsUsage(DetailHandler[AccountItem]): # pylint: disable=too-many-pub
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def usage_to_dict(item: 'AccountUsage', perm: int) -> AccountItem:
|
||||
def as_dict(item: 'AccountUsage', perm: int) -> AccountItem:
|
||||
"""
|
||||
Convert an account usage to a dictionary
|
||||
:param item: Account usage item (db)
|
||||
@@ -89,19 +89,23 @@ class AccountsUsage(DetailHandler[AccountItem]): # pylint: disable=too-many-pub
|
||||
elapsed_timemark=item.elapsed_timemark,
|
||||
permission=perm,
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, Account)
|
||||
return self.calc_item_position(item_uuid, parent.usages.all())
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[AccountItem]:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[AccountItem]:
|
||||
parent = ensure.is_instance(parent, Account)
|
||||
# Check what kind of access do we have to parent provider
|
||||
perm = permissions.effective_permissions(self._user, parent)
|
||||
try:
|
||||
if not item:
|
||||
return [AccountsUsage.usage_to_dict(k, perm) for k in self.filter_queryset(parent.usages.all())]
|
||||
k = parent.usages.get(uuid=process_uuid(item))
|
||||
return AccountsUsage.usage_to_dict(k, perm)
|
||||
except Exception:
|
||||
logger.exception('itemId %s', item)
|
||||
raise exceptions.rest.NotFound(_('Account usage not found: {}').format(item)) from None
|
||||
return [AccountsUsage.as_dict(k, perm) for k in self.odata_filter(parent.usages.all())]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> AccountItem:
|
||||
parent = ensure.is_instance(parent, Account)
|
||||
# Check what kind of access do we have to parent provider
|
||||
return AccountsUsage.as_dict(
|
||||
parent.usages.get(uuid=process_uuid(item)), permissions.effective_permissions(self._user, parent)
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, Account)
|
||||
|
||||
@@ -54,6 +54,9 @@ from .users_groups import Groups, Users
|
||||
|
||||
from uds.core.module import Module
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -110,9 +113,9 @@ class Authenticators(ModelHandler[AuthenticatorItem]):
|
||||
.icon(name='name', title=_('Name'), visible=True)
|
||||
.text_column(name='type_name', title=_('Type'))
|
||||
.text_column(name='comments', title=_('Comments'))
|
||||
.numeric_column(name='priority', title=_('Priority'), width='5rem')
|
||||
.numeric_column(name='priority', title=_('Priority'), width='8rem')
|
||||
.text_column(name='small_name', title=_('Label'))
|
||||
.numeric_column(name='users_count', title=_('Users'), width='1rem')
|
||||
.numeric_column(name='users_count', title=_('Users'), width='6rem')
|
||||
.text_column(name='mfa_name', title=_('MFA'))
|
||||
.text_column(name='tags', title=_('tags'), visible=False)
|
||||
.row_style(prefix='row-state-', field='state')
|
||||
@@ -218,6 +221,29 @@ class Authenticators(ModelHandler[AuthenticatorItem]):
|
||||
type_info=type(self).as_typeinfo(item.get_type()),
|
||||
)
|
||||
|
||||
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
|
||||
if field_info := self.get_sort_field_info('users_count'):
|
||||
field_name, is_descending = field_info
|
||||
order_by_field = f"-{field_name}" if is_descending else field_name
|
||||
return qs.annotate(users_count=models.Count('users')).order_by(order_by_field)
|
||||
|
||||
if field_info := self.get_sort_field_info('type_name'):
|
||||
_, is_descending = field_info
|
||||
order_by_field = f'-data_type' if is_descending else 'data_type'
|
||||
return qs.order_by(order_by_field)
|
||||
|
||||
if field_info := self.get_sort_field_info('numeric_id'):
|
||||
_, is_descending = field_info
|
||||
order_by_field = f'-pk' if is_descending else 'pk'
|
||||
return qs.order_by(order_by_field)
|
||||
|
||||
if field_info := self.get_sort_field_info('mfa_name'):
|
||||
_, is_descending = field_info
|
||||
order_by_field = f'-mfa__name' if is_descending else 'mfa__name'
|
||||
return qs.order_by(order_by_field)
|
||||
|
||||
return super().apply_sort(qs)
|
||||
|
||||
def post_save(self, item: 'models.Model') -> None:
|
||||
item = ensure.is_instance(item, Authenticator)
|
||||
try:
|
||||
|
||||
@@ -82,30 +82,34 @@ class CalendarRules(DetailHandler[CalendarRuleItem]): # pylint: disable=too-man
|
||||
name=item.name,
|
||||
comments=item.comments,
|
||||
start=item.start,
|
||||
end=timezone.make_aware(datetime.datetime.combine(item.end, datetime.time.max)) if item.end else None,
|
||||
end=(
|
||||
timezone.make_aware(datetime.datetime.combine(item.end, datetime.time.max))
|
||||
if item.end
|
||||
else None
|
||||
),
|
||||
frequency=item.frequency,
|
||||
interval=item.interval,
|
||||
duration=item.duration,
|
||||
duration_unit=item.duration_unit,
|
||||
permission=perm,
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: 'models.Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, Calendar)
|
||||
return self.calc_item_position(item_uuid, parent.rules.all())
|
||||
|
||||
def get_items(
|
||||
self, parent: 'models.Model', item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult[CalendarRuleItem]:
|
||||
def get_items(self, parent: 'models.Model') -> types.rest.ItemsResult[CalendarRuleItem]:
|
||||
parent = ensure.is_instance(parent, Calendar)
|
||||
# Check what kind of access do we have to parent provider
|
||||
perm = permissions.effective_permissions(self._user, parent)
|
||||
try:
|
||||
if item is None:
|
||||
return [CalendarRules.rule_as_dict(k, perm) for k in self.filter_queryset(parent.rules.all())]
|
||||
k = parent.rules.get(uuid=process_uuid(item))
|
||||
return CalendarRules.rule_as_dict(k, perm)
|
||||
except CalendarRule.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('Calendar rule not found: {}').format(item)) from None
|
||||
except Exception as e:
|
||||
logger.exception('itemId %s', item)
|
||||
raise exceptions.rest.RequestError(f'Error retrieving calendar rule: {e}') from e
|
||||
return [CalendarRules.rule_as_dict(k, perm) for k in self.filter_odata_queryset(parent.rules.all())]
|
||||
|
||||
def get_item(self, parent: 'models.Model', item: str) -> CalendarRuleItem:
|
||||
parent = ensure.is_instance(parent, Calendar)
|
||||
# Check what kind of access do we have to parent provider
|
||||
return CalendarRules.rule_as_dict(
|
||||
parent.rules.get(uuid=process_uuid(item)), permissions.effective_permissions(self._user, parent)
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'models.Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, Calendar)
|
||||
|
||||
@@ -51,7 +51,7 @@ class Config(Handler):
|
||||
ROLE = consts.UserRole.ADMIN
|
||||
|
||||
def get(self) -> typing.Any:
|
||||
return self.filter_data(CfgConfig.get_config_values(self.is_admin()))
|
||||
return CfgConfig.get_config_values(self.is_admin())
|
||||
|
||||
def put(self) -> typing.Any:
|
||||
for section, section_dict in typing.cast(dict[str, dict[str, dict[str, str]]], self._params).items():
|
||||
|
||||
@@ -86,7 +86,7 @@ class Connection(Handler):
|
||||
# Ensure user is present on request, used by web views methods
|
||||
self._request.user = self._user
|
||||
|
||||
return Connection.result(result=self.filter_data(services.get_services_info_dict(self._request)))
|
||||
return Connection.result(result=self.filter_odata_data(services.get_services_info_dict(self._request)))
|
||||
|
||||
def connection(self, id_service: str, id_transport: str, skip: str = '') -> dict[str, typing.Any]:
|
||||
skip_check = skip in ('doNotCheck', 'do_not_check', 'no_check', 'nocheck', 'skip_check')
|
||||
|
||||
@@ -91,19 +91,21 @@ class MetaServicesPool(DetailHandler[MetaItem]):
|
||||
user_services_count=item.pool.userServices.exclude(state__in=State.INFO_STATES).count(),
|
||||
user_services_in_preparation=item.pool.userServices.filter(state=State.PREPARING).count(),
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
return self.calc_item_position(item_uuid, parent.members.all())
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult['MetaItem']:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['MetaItem']:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
return [MetaServicesPool.as_dict(i) for i in self.filter_odata_queryset(parent.members.all())]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> 'MetaItem':
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
try:
|
||||
if not item:
|
||||
return [MetaServicesPool.as_dict(i) for i in self.filter_queryset(parent.members.all())]
|
||||
i = parent.members.get(uuid=process_uuid(item))
|
||||
return MetaServicesPool.as_dict(i)
|
||||
return MetaServicesPool.as_dict(parent.members.get(uuid=process_uuid(item)))
|
||||
except models.MetaPoolMember.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('Meta pool member not found: {}').format(item)) from None
|
||||
except Exception as e:
|
||||
logger.exception('err: %s', item)
|
||||
raise exceptions.rest.RequestError(f'Error retrieving meta pool member: {e}') from e
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
@@ -171,67 +173,62 @@ class MetaAssignedService(DetailHandler[UserServiceItem]):
|
||||
element.pool_name = item.deployed_service.name
|
||||
return element
|
||||
|
||||
def _get_assigned_userservice(self, metapool: models.MetaPool, userservice_id: str) -> models.UserService:
|
||||
@staticmethod
|
||||
def _get_assigned_userservice(metapool: models.MetaPool, userservice_id: str) -> models.UserService:
|
||||
"""
|
||||
Gets an assigned service and checks that it belongs to this metapool
|
||||
If not found, raises InvalidItemException
|
||||
"""
|
||||
try:
|
||||
return models.UserService.objects.filter(
|
||||
uuid=process_uuid(userservice_id),
|
||||
cache_level=0,
|
||||
deployed_service__in=[i.pool for i in metapool.members.all()],
|
||||
)[0]
|
||||
except IndexError:
|
||||
found = models.UserService.objects.filter(
|
||||
uuid=process_uuid(userservice_id),
|
||||
cache_level=0,
|
||||
deployed_service__in=[i.pool for i in metapool.members.all()],
|
||||
).first()
|
||||
if found is None:
|
||||
raise exceptions.rest.NotFound(_('User service not found: {}').format(userservice_id)) from None
|
||||
except Exception:
|
||||
logger.error('Error getting assigned userservice %s for metapool %s', userservice_id, metapool.uuid)
|
||||
raise exceptions.rest.RequestError(
|
||||
_('Error retrieving assigned service: {}').format(userservice_id)
|
||||
) from None
|
||||
return found
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[UserServiceItem]:
|
||||
def _assigned_userservices_for_pools(
|
||||
self, parent: 'models.MetaPool'
|
||||
) -> typing.Generator[tuple[models.UserService, typing.Optional[dict[str, typing.Any]]], None, None]:
|
||||
for m in self.odata_filter(parent.members.filter(enabled=True)):
|
||||
properties: dict[str, typing.Any] = {
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice',
|
||||
owner_id__in=m.pool.assigned_user_services().values_list('uuid', flat=True),
|
||||
).values_list('key', 'value')
|
||||
}
|
||||
for u in (
|
||||
m.pool.assigned_user_services()
|
||||
.filter(state__in=State.VALID_STATES)
|
||||
.prefetch_related('deployed_service', 'publication')
|
||||
):
|
||||
yield u, properties.get(u.uuid, {})
|
||||
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[UserServiceItem]:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
|
||||
def _assigned_userservices_for_pools() -> (
|
||||
typing.Generator[tuple[models.UserService, typing.Optional[dict[str, typing.Any]]], None, None]
|
||||
):
|
||||
for m in self.filter_queryset(parent.members.filter(enabled=True)):
|
||||
properties: dict[str, typing.Any] = {
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice',
|
||||
owner_id__in=m.pool.assigned_user_services().values_list('uuid', flat=True),
|
||||
).values_list('key', 'value')
|
||||
}
|
||||
for u in (
|
||||
m.pool.assigned_user_services()
|
||||
.filter(state__in=State.VALID_STATES)
|
||||
.prefetch_related('deployed_service', 'publication')
|
||||
):
|
||||
yield u, properties.get(u.uuid, {})
|
||||
return list(
|
||||
{
|
||||
k.uuid: MetaAssignedService.item_as_dict(parent, k, props)
|
||||
for k, props in self._assigned_userservices_for_pools(parent)
|
||||
}.values()
|
||||
)
|
||||
|
||||
try:
|
||||
if not item: # All items
|
||||
result: dict[str, typing.Any] = {}
|
||||
def get_item(self, parent: 'Model', item: str) -> UserServiceItem:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
|
||||
for k, props in _assigned_userservices_for_pools():
|
||||
result[k.uuid] = MetaAssignedService.item_as_dict(parent, k, props)
|
||||
return list(result.values())
|
||||
|
||||
return MetaAssignedService.item_as_dict(
|
||||
parent,
|
||||
self._get_assigned_userservice(parent, item),
|
||||
props={
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice', owner_id=process_uuid(item)
|
||||
).values_list('key', 'value')
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('get_items')
|
||||
raise exceptions.rest.RequestError(f'Error retrieving meta pool member: {e}') from e
|
||||
return MetaAssignedService.item_as_dict(
|
||||
parent,
|
||||
self._get_assigned_userservice(parent, item),
|
||||
props={
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice', owner_id=process_uuid(item)
|
||||
).values_list('key', 'value')
|
||||
},
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
|
||||
@@ -73,22 +73,27 @@ class AccessCalendars(DetailHandler[AccessCalendarItem]):
|
||||
access=item.access,
|
||||
priority=item.priority,
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
# parent can be a ServicePool or a metaPool
|
||||
if isinstance(parent, models.ServicePool):
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return self.calc_item_position(item_uuid, parent.calendarAccess.all())
|
||||
|
||||
parent = ensure.is_instance(parent, models.MetaPool)
|
||||
return self.calc_item_position(item_uuid, parent.calendarAccess.all())
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult[AccessCalendarItem]:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[AccessCalendarItem]:
|
||||
# parent can be a ServicePool or a metaPool
|
||||
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
|
||||
|
||||
try:
|
||||
if not item:
|
||||
return [AccessCalendars.as_item(i) for i in self.filter_queryset(parent.calendarAccess.all())]
|
||||
return AccessCalendars.as_item(parent.calendarAccess.get(uuid=process_uuid(item)))
|
||||
except models.CalendarAccess.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('Access calendar not found: {}').format(item)) from None
|
||||
except Exception as e:
|
||||
logger.exception('err: %s', item)
|
||||
raise exceptions.rest.RequestError(f'Error retrieving access calendar: {e}') from e
|
||||
return [AccessCalendars.as_item(i) for i in self.filter_odata_queryset(parent.calendarAccess.all())]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> AccessCalendarItem:
|
||||
# parent can be a ServicePool or a metaPool
|
||||
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
|
||||
|
||||
return AccessCalendars.as_item(parent.calendarAccess.get(uuid=process_uuid(item)))
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
return (
|
||||
@@ -190,21 +195,20 @@ class ActionsCalendars(DetailHandler[ActionCalendarItem]):
|
||||
next_execution=item.next_execution,
|
||||
last_execution=item.last_execution,
|
||||
)
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult[ActionCalendarItem]:
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
try:
|
||||
if item is None:
|
||||
return [ActionsCalendars.as_dict(i) for i in self.filter_queryset(parent.calendaraction_set.all())]
|
||||
i = parent.calendaraction_set.get(uuid=process_uuid(item))
|
||||
return ActionsCalendars.as_dict(i)
|
||||
except models.CalendarAction.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('Scheduled action not found: {}').format(item)) from None
|
||||
except Exception as e:
|
||||
logger.error('Error retrieving scheduled action %s: %s', item, e)
|
||||
raise exceptions.rest.RequestError(f'Error retrieving scheduled action: {e}') from e
|
||||
return self.calc_item_position(item_uuid, parent.calendaraction_set.all())
|
||||
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ActionCalendarItem]:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return [
|
||||
ActionsCalendars.as_dict(i) for i in self.filter_odata_queryset(parent.calendaraction_set.all())
|
||||
]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> ActionCalendarItem:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return ActionsCalendars.as_dict(parent.calendaraction_set.get(uuid=process_uuid(item)))
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
return (
|
||||
|
||||
@@ -36,7 +36,7 @@ import logging
|
||||
import typing
|
||||
|
||||
from django.utils.translation import gettext, gettext_lazy as _
|
||||
from django.db.models import Model
|
||||
from django.db.models import Model, Count
|
||||
|
||||
import uds.core.types.permissions
|
||||
from uds.core import exceptions, services, types
|
||||
@@ -51,6 +51,9 @@ from .services_usage import ServicesUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
|
||||
# Helper class for Provider offers
|
||||
@dataclasses.dataclass
|
||||
@@ -96,6 +99,8 @@ class Providers(ModelHandler[ProviderItem]):
|
||||
.numeric_column(name='user_services_count', title=_('User Services'))
|
||||
.text_column(name='tags', title=_('Tags'), visible=False)
|
||||
.row_style(prefix='row-maintenance-', field='maintenance_mode')
|
||||
.with_field_mappings(type_name='data_type')
|
||||
.with_filter_fields('name', 'data_type', 'comments', 'maintenance_mode')
|
||||
).build()
|
||||
|
||||
# Rest api related information to complete the auto-generated API
|
||||
@@ -103,6 +108,14 @@ class Providers(ModelHandler[ProviderItem]):
|
||||
typed=types.rest.api.RestApiInfoGuiType.MULTIPLE_TYPES,
|
||||
)
|
||||
|
||||
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
|
||||
if field_info := self.get_sort_field_info('services_count'):
|
||||
field_name, is_descending = field_info
|
||||
order_by_field = f"-{field_name}" if is_descending else field_name
|
||||
return qs.annotate(services_count=Count('services')).order_by(order_by_field)
|
||||
|
||||
return super().apply_sort(qs)
|
||||
|
||||
def get_item(self, item: 'Model') -> ProviderItem:
|
||||
item = ensure.is_instance(item, Provider)
|
||||
type_ = item.get_type()
|
||||
|
||||
@@ -120,7 +120,7 @@ class Reports(model.BaseModelHandler[ReportItem]):
|
||||
return match_args(
|
||||
self._args,
|
||||
error,
|
||||
((), lambda: list(self.filter_data(self.get_items()))),
|
||||
((), lambda: list(self.filter_odata_data(self.get_items()))),
|
||||
((consts.rest.OVERVIEW,), lambda: list(self.get_items())),
|
||||
(
|
||||
(consts.rest.TABLEINFO,),
|
||||
|
||||
@@ -148,38 +148,25 @@ class ServersServers(DetailHandler[ServerItem]):
|
||||
typed=types.rest.api.RestApiInfoGuiType.SINGLE_TYPE,
|
||||
)
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[ServerItem]:
|
||||
def as_server_item(self, item: 'models.Server') -> ServerItem:
|
||||
return ServerItem(
|
||||
id=item.uuid,
|
||||
hostname=item.hostname,
|
||||
ip=item.ip,
|
||||
listen_port=item.listen_port,
|
||||
mac=item.mac if item.mac != consts.NULL_MAC else '',
|
||||
maintenance_mode=item.maintenance_mode,
|
||||
register_username=item.register_username,
|
||||
stamp=item.stamp,
|
||||
)
|
||||
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServerItem]:
|
||||
parent = typing.cast('models.ServerGroup', parent) # We will receive for sure
|
||||
try:
|
||||
if item is None:
|
||||
q = self.filter_queryset(parent.servers.all())
|
||||
else:
|
||||
q = parent.servers.filter(uuid=process_uuid(item))
|
||||
res: list[ServerItem] = []
|
||||
i = None
|
||||
for i in q:
|
||||
res.append(
|
||||
ServerItem(
|
||||
id=i.uuid,
|
||||
hostname=i.hostname,
|
||||
ip=i.ip,
|
||||
listen_port=i.listen_port,
|
||||
mac=i.mac if i.mac != consts.NULL_MAC else '',
|
||||
maintenance_mode=i.maintenance_mode,
|
||||
register_username=i.register_username,
|
||||
stamp=i.stamp,
|
||||
)
|
||||
)
|
||||
if item is None:
|
||||
return res
|
||||
if not i:
|
||||
raise exceptions.rest.NotFound(f'Server not found: {item}')
|
||||
return res[0]
|
||||
except exceptions.rest.HandlerError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception('Error getting server')
|
||||
raise exceptions.rest.ResponseError(_('Error getting server')) from None
|
||||
return [self.as_server_item(i) for i in self.filter_odata_queryset(parent.servers.all())]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> ServerItem:
|
||||
parent = typing.cast('models.ServerGroup', parent) # We will receive for sure
|
||||
return self.as_server_item(parent.servers.get(uuid=process_uuid(item)))
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServerGroup)
|
||||
|
||||
@@ -155,23 +155,25 @@ class Services(DetailHandler[ServiceItem]): # pylint: disable=too-many-public-m
|
||||
ret_value.info = Services.service_info(item)
|
||||
|
||||
return ret_value
|
||||
|
||||
def get_item_position(self, parent: Model, item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, models.Provider)
|
||||
return self.calc_item_position(item_uuid, parent.services.all())
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[ServiceItem]:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServiceItem]:
|
||||
parent = ensure.is_instance(parent, models.Provider)
|
||||
# Check what kind of access do we have to parent provider
|
||||
perm = permissions.effective_permissions(self._user, parent)
|
||||
try:
|
||||
if item is None:
|
||||
return [Services.service_item(k, perm) for k in self.filter_queryset(parent.services.all())]
|
||||
k = parent.services.get(uuid=process_uuid(item))
|
||||
val = Services.service_item(k, perm, full=True)
|
||||
# On detail, ne wee to fill the instance fields by hand
|
||||
return val
|
||||
except models.Service.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('Service not found')) from None
|
||||
except Exception as e:
|
||||
logger.error('Error getting services for %s: %s', parent, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting services')) from None
|
||||
return [Services.service_item(k, perm) for k in self.odata_filter(parent.services.all())]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> ServiceItem:
|
||||
parent = ensure.is_instance(parent, models.Provider)
|
||||
# Check what kind of access do we have to parent provider
|
||||
return Services.service_item(
|
||||
parent.services.get(uuid=process_uuid(item)),
|
||||
permissions.effective_permissions(self._user, parent),
|
||||
full=True,
|
||||
)
|
||||
|
||||
def _delete_incomplete_service(self, service: models.Service) -> None:
|
||||
"""
|
||||
|
||||
@@ -56,6 +56,9 @@ from .user_services import AssignedUserService, CachedService, Changelog, Groups
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models import QuerySet
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServicePoolItem(types.rest.BaseRestItem):
|
||||
@@ -158,6 +161,7 @@ class ServicesPools(ModelHandler[ServicePoolItem]):
|
||||
.text_column(name='parent', title=_('Parent service'))
|
||||
.text_column(name='tags', title=_('tags'), visible=False)
|
||||
.row_style(prefix='row-state-', field='state')
|
||||
.with_filter_fields('name', 'state')
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -175,13 +179,21 @@ class ServicesPools(ModelHandler[ServicePoolItem]):
|
||||
typed=types.rest.api.RestApiInfoGuiType.SINGLE_TYPE,
|
||||
)
|
||||
|
||||
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
|
||||
if field_info := self.get_sort_field_info('state'):
|
||||
field_name, is_descending = field_info
|
||||
order_by_field = f"-{field_name}" if is_descending else field_name
|
||||
return qs.order_by(order_by_field)
|
||||
|
||||
return super().apply_sort(qs)
|
||||
|
||||
def get_items(
|
||||
self, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> typing.Generator[ServicePoolItem, None, None]:
|
||||
# Optimized query, due that there is a lot of info needed for theee
|
||||
d = sql_now() - datetime.timedelta(seconds=GlobalConfig.RESTRAINT_TIME.as_int())
|
||||
return super().get_items(
|
||||
overview=kwargs.get('overview', True),
|
||||
sumarize=kwargs.get('overview', True),
|
||||
query=(
|
||||
ServicePool.objects.prefetch_related(
|
||||
'service',
|
||||
|
||||
@@ -33,7 +33,6 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -110,32 +109,34 @@ class ServicesUsage(DetailHandler[ServicesUsageItem]):
|
||||
source_ip=item.src_ip,
|
||||
in_use=item.in_use,
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, Provider)
|
||||
return self.calc_item_position(
|
||||
item_uuid,
|
||||
UserService.objects.filter(deployed_service__service__provider=parent).order_by('creation_date'),
|
||||
)
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult[ServicesUsageItem]:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServicesUsageItem]:
|
||||
parent = ensure.is_instance(parent, Provider)
|
||||
try:
|
||||
if item is None:
|
||||
userservices_query = self.filter_queryset(
|
||||
UserService.objects.filter(deployed_service__service__provider=parent)
|
||||
)
|
||||
else:
|
||||
userservices_query = UserService.objects.filter(
|
||||
deployed_service__service_uuid=process_uuid(item)
|
||||
)
|
||||
|
||||
return [
|
||||
ServicesUsage.item_as_dict(k)
|
||||
for k in userservices_query.filter(state=State.USABLE)
|
||||
userservices = self.odata_filter(
|
||||
UserService.objects.filter(deployed_service__service__provider=parent)
|
||||
.order_by('creation_date')
|
||||
.prefetch_related('deployed_service', 'deployed_service__service', 'user', 'user__manager')
|
||||
]
|
||||
)
|
||||
return [ServicesUsage.item_as_dict(k) for k in userservices]
|
||||
|
||||
except Exception as e:
|
||||
logger.error('Error getting services usage for %s: %s', parent.uuid, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting services usage')) from None
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> ServicesUsageItem:
|
||||
parent = ensure.is_instance(parent, Provider)
|
||||
return ServicesUsage.item_as_dict(
|
||||
UserService.objects.filter(deployed_service__service_uuid=process_uuid(item)).get()
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, Provider)
|
||||
return (
|
||||
|
||||
@@ -63,39 +63,35 @@ class TunnelServers(DetailHandler[TunnelServerItem]):
|
||||
REST_API_INFO = types.rest.api.RestApiInfo(
|
||||
name='TunnelServers', description='Tunnel servers assigned to a tunnel'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def as_tunnel_server_item(item: models.Server) -> TunnelServerItem:
|
||||
return TunnelServerItem(
|
||||
id=item.uuid,
|
||||
hostname=item.hostname,
|
||||
ip=item.ip,
|
||||
mac=item.mac if item.mac != consts.NULL_MAC else '',
|
||||
maintenance=item.maintenance_mode,
|
||||
)
|
||||
|
||||
def get_item_position(self, parent: Model, item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, models.ServerGroup)
|
||||
return self.calc_item_position(item_uuid, parent.servers.all())
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
self, parent: 'Model'
|
||||
) -> types.rest.ItemsResult[TunnelServerItem]:
|
||||
parent = ensure.is_instance(parent, models.ServerGroup)
|
||||
try:
|
||||
multi = False
|
||||
if item is None:
|
||||
multi = True
|
||||
q = self.filter_queryset(parent.servers.all())
|
||||
else:
|
||||
q = parent.servers.filter(uuid=process_uuid(item))
|
||||
res: list[TunnelServerItem] = [
|
||||
TunnelServerItem(
|
||||
id=i.uuid,
|
||||
hostname=i.hostname,
|
||||
ip=i.ip,
|
||||
mac=i.mac if i.mac != consts.NULL_MAC else '',
|
||||
maintenance=i.maintenance_mode,
|
||||
)
|
||||
for i in q
|
||||
return [
|
||||
TunnelServers.as_tunnel_server_item(i)
|
||||
for i in self.odata_filter(parent.servers.all())
|
||||
]
|
||||
|
||||
if multi:
|
||||
return res
|
||||
if not res:
|
||||
raise exceptions.rest.NotFound(f'Tunnel server {item} not found')
|
||||
return res[0]
|
||||
except exceptions.rest.HandlerError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error('Error getting tunnel servers for %s: %s', parent, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting tunnel servers')) from e
|
||||
def get_item(
|
||||
self, parent: 'Model', item: str
|
||||
) -> TunnelServerItem:
|
||||
parent = ensure.is_instance(parent, models.ServerGroup)
|
||||
return TunnelServers.as_tunnel_server_item(parent.servers.get(uuid=process_uuid(item)))
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServerGroup)
|
||||
|
||||
@@ -37,7 +37,7 @@ import logging
|
||||
import typing
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
from django.db.models import Model
|
||||
from django.db.models import Model, QuerySet, OuterRef, Subquery
|
||||
|
||||
import uds.core.types.permissions
|
||||
from uds import models
|
||||
@@ -145,42 +145,90 @@ class AssignedUserService(DetailHandler[UserServiceItem]):
|
||||
|
||||
return val
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
def apply_sort(self, qs: QuerySet[typing.Any]) -> list[typing.Any] | QuerySet[typing.Any]:
|
||||
def annotated_sort(field: str, descending: bool) -> QuerySet[typing.Any]:
|
||||
prop_value_subquery = models.Properties.objects.filter(
|
||||
owner_id=OuterRef('uuid'), owner_type='userservice', key=field
|
||||
).values('value')[:1]
|
||||
return qs.annotate(prop_value=Subquery(prop_value_subquery)).order_by(
|
||||
f'{"-" if descending else ""}prop_value'
|
||||
)
|
||||
|
||||
if sort_info := self.get_sort_field_info('ip', 'actor_version'):
|
||||
return annotated_sort(*sort_info)
|
||||
first_order_by_field, is_descending = sort_info
|
||||
|
||||
return super().apply_sort(qs)
|
||||
|
||||
def get_qs(self, for_cached: bool) -> QuerySet[models.UserService]:
|
||||
parent = ensure.is_instance(self._parent, models.ServicePool)
|
||||
if for_cached:
|
||||
return parent.cached_users_services()
|
||||
return parent.assigned_user_services()
|
||||
|
||||
def do_get_item_position(self, for_cached: bool, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return self.calc_item_position(item_uuid, self.get_qs(for_cached).all())
|
||||
|
||||
def get_item_position(self, parent: Model, item_uuid: str) -> int:
|
||||
return self.do_get_item_position(for_cached=False, parent=parent, item_uuid=item_uuid)
|
||||
|
||||
def do_get_item(
|
||||
self,
|
||||
parent: 'Model',
|
||||
item: str,
|
||||
for_cached: bool,
|
||||
) -> 'UserServiceItem':
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
|
||||
return AssignedUserService.userservice_item(
|
||||
self.get_qs(for_cached).get(uuid=process_uuid(item)),
|
||||
props={
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice', owner_id=process_uuid(item)
|
||||
).values_list('key', 'value')
|
||||
},
|
||||
is_cache=for_cached,
|
||||
)
|
||||
|
||||
def do_get_items(
|
||||
self,
|
||||
parent: 'Model',
|
||||
for_cached: bool,
|
||||
) -> types.rest.ItemsResult['UserServiceItem']:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
|
||||
try:
|
||||
if not item:
|
||||
# First, fetch all properties for all assigned services on this pool
|
||||
# We can cache them, because they are going to be readed anyway...
|
||||
properties: dict[str, typing.Any] = collections.defaultdict(dict)
|
||||
for id, key, value in self.filter_queryset(
|
||||
models.Properties.objects.filter(
|
||||
owner_type='userservice',
|
||||
owner_id__in=parent.assigned_user_services().values_list('uuid', flat=True),
|
||||
)
|
||||
).values_list('owner_id', 'key', 'value'):
|
||||
properties[id][key] = value
|
||||
def get_qs() -> QuerySet[models.UserService]:
|
||||
if for_cached:
|
||||
return parent.cached_users_services()
|
||||
return parent.assigned_user_services()
|
||||
|
||||
return [
|
||||
AssignedUserService.userservice_item(k, properties.get(k.uuid, {}))
|
||||
for k in parent.assigned_user_services()
|
||||
.all()
|
||||
.prefetch_related('deployed_service', 'publication', 'user')
|
||||
]
|
||||
return AssignedUserService.userservice_item(
|
||||
parent.assigned_user_services().get(process_uuid(uuid=process_uuid(item))),
|
||||
props={
|
||||
k: v
|
||||
for k, v in models.Properties.objects.filter(
|
||||
owner_type='userservice', owner_id=process_uuid(item)
|
||||
).values_list('key', 'value')
|
||||
},
|
||||
# First, fetch all properties for all assigned services on this pool
|
||||
# We can cache them, because they are going to be readed anyway...
|
||||
properties: dict[str, typing.Any] = collections.defaultdict(dict)
|
||||
for id, key, value in models.Properties.objects.filter(
|
||||
owner_type='userservice',
|
||||
owner_id__in=get_qs().values_list('uuid', flat=True),
|
||||
).values_list('owner_id', 'key', 'value'):
|
||||
properties[id][key] = value
|
||||
|
||||
return [
|
||||
AssignedUserService.userservice_item(k, properties.get(k.uuid, {}))
|
||||
for k in self.odata_filter(
|
||||
get_qs().all().prefetch_related('deployed_service', 'publication', 'user')
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('Error getting user service %s: %s', item, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting user service')) from e
|
||||
]
|
||||
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['UserServiceItem']:
|
||||
return self.do_get_items(parent, for_cached=False)
|
||||
|
||||
def get_item(
|
||||
self,
|
||||
parent: 'Model',
|
||||
item: str,
|
||||
) -> 'UserServiceItem':
|
||||
return self.do_get_item(parent, item, for_cached=False)
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
@@ -202,6 +250,8 @@ class AssignedUserService(DetailHandler[UserServiceItem]):
|
||||
.text_column(name='owner', title=_('Owner'))
|
||||
.text_column(name='actor_version', title=_('Actor version'))
|
||||
.row_style(prefix='row-state-', field='state')
|
||||
.with_field_mappings(revision='deployed_service.publications.revision')
|
||||
.with_filter_fields('creation_date', 'unique_id', 'friendly_name', 'state', 'in_use')
|
||||
).build()
|
||||
|
||||
def get_logs(self, parent: 'Model', item: str) -> list[typing.Any]:
|
||||
@@ -295,27 +345,19 @@ class CachedService(AssignedUserService):
|
||||
"""
|
||||
|
||||
CUSTOM_METHODS = [] # Remove custom methods from assigned services
|
||||
|
||||
def get_item_position(self, parent: Model, item_uuid: str) -> int:
|
||||
return self.do_get_item_position(for_cached=True, parent=parent, item_uuid=item_uuid)
|
||||
|
||||
def get_items(
|
||||
self, parent: 'Model', item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult['UserServiceItem']:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['UserServiceItem']:
|
||||
return self.do_get_items(parent, for_cached=True)
|
||||
|
||||
try:
|
||||
if not item:
|
||||
return [
|
||||
AssignedUserService.userservice_item(k, is_cache=True)
|
||||
for k in self.filter_queryset(parent.cached_users_services().all()).prefetch_related(
|
||||
'deployed_service', 'publication'
|
||||
)
|
||||
]
|
||||
cached_userservice: models.UserService = parent.cached_users_services().get(uuid=process_uuid(item))
|
||||
return AssignedUserService.userservice_item(cached_userservice, is_cache=True)
|
||||
except models.UserService.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('User service not found')) from None
|
||||
except Exception as e:
|
||||
logger.error('Error getting user service %s: %s', item, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting user service')) from e
|
||||
def get_item(
|
||||
self,
|
||||
parent: 'Model',
|
||||
item: str,
|
||||
) -> 'UserServiceItem':
|
||||
return self.do_get_item(parent, item, for_cached=True)
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
@@ -327,6 +369,8 @@ class CachedService(AssignedUserService):
|
||||
.text_column(name='ip', title=_('IP'))
|
||||
.text_column(name='friendly_name', title=_('Friendly name'))
|
||||
.dict_column(name='state', title=_('State'), dct=State.literals_dict())
|
||||
.with_field_mappings(revision='deployed_service.publications.revision')
|
||||
.with_filter_fields('creation_date', 'unique_id', 'friendly_name', 'state')
|
||||
)
|
||||
if parent.state != State.LOCKED:
|
||||
table_info = table_info.text_column(name='cache_level', title=_('Cache level')).text_column(
|
||||
@@ -366,7 +410,7 @@ class Groups(DetailHandler[GroupItem]):
|
||||
Processes the groups detail requests of a Service Pool
|
||||
"""
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['GroupItem']:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['GroupItem']:
|
||||
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
|
||||
|
||||
return [
|
||||
@@ -381,10 +425,13 @@ class Groups(DetailHandler[GroupItem]):
|
||||
auth_name=group.manager.name,
|
||||
)
|
||||
for group in typing.cast(
|
||||
collections.abc.Iterable[models.Group], self.filter_queryset(parent.assignedGroups.all())
|
||||
collections.abc.Iterable[models.Group], self.filter_odata_queryset(parent.assignedGroups.all())
|
||||
)
|
||||
]
|
||||
|
||||
def get_item(self, parent: Model, item: str) -> GroupItem:
|
||||
raise exceptions.rest.NotSupportedError('Single group retrieval not implemented inside assigned groups')
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
|
||||
return (
|
||||
@@ -437,7 +484,7 @@ class Transports(DetailHandler[TransportItem]):
|
||||
Processes the transports detail requests of a Service Pool
|
||||
"""
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['TransportItem']:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['TransportItem']:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
|
||||
return [
|
||||
@@ -449,9 +496,14 @@ class Transports(DetailHandler[TransportItem]):
|
||||
priority=trans.priority,
|
||||
trans_type=trans.get_type().mod_name(),
|
||||
)
|
||||
for trans in self.filter_queryset(parent.transports.all())
|
||||
for trans in self.filter_odata_queryset(parent.transports.all())
|
||||
]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> TransportItem:
|
||||
raise exceptions.rest.NotSupportedError(
|
||||
'Single transport retrieval not implemented inside assigned transports'
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return (
|
||||
@@ -562,7 +614,7 @@ class Publications(DetailHandler[PublicationItem]):
|
||||
|
||||
return self.success()
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['PublicationItem']:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['PublicationItem']:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return [
|
||||
PublicationItem(
|
||||
@@ -573,9 +625,14 @@ class Publications(DetailHandler[PublicationItem]):
|
||||
reason=State.from_str(i.state).is_errored() and i.get_instance().error_reason() or '',
|
||||
state_date=i.state_date,
|
||||
)
|
||||
for i in self.filter_queryset(parent.publications.all())
|
||||
for i in self.filter_odata_queryset(parent.publications.all())
|
||||
]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> PublicationItem:
|
||||
raise exceptions.rest.NotSupportedError(
|
||||
'Single publication retrieval not implemented inside assigned publications'
|
||||
)
|
||||
|
||||
def get_table(self, parent: 'Model') -> TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return (
|
||||
@@ -600,7 +657,7 @@ class Changelog(DetailHandler[ChangelogItem]):
|
||||
Processes the transports detail requests of a Service Pool
|
||||
"""
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['ChangelogItem']:
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['ChangelogItem']:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return [
|
||||
ChangelogItem(
|
||||
@@ -608,9 +665,12 @@ class Changelog(DetailHandler[ChangelogItem]):
|
||||
stamp=i.stamp,
|
||||
log=i.log,
|
||||
)
|
||||
for i in self.filter_queryset(parent.changelog.all())
|
||||
for i in self.filter_odata_queryset(parent.changelog.all())
|
||||
]
|
||||
|
||||
def get_item(self, parent: 'Model', item: str) -> ChangelogItem:
|
||||
raise exceptions.rest.NotSupportedError('Single changelog retrieval not implemented inside changelog')
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, models.ServicePool)
|
||||
return (
|
||||
|
||||
@@ -53,6 +53,8 @@ from uds.REST.model import DetailHandler
|
||||
|
||||
from .user_services import AssignedUserService, UserServiceItem
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,41 +102,49 @@ class Users(DetailHandler[UserItem]):
|
||||
'enable_client_logging',
|
||||
]
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[UserItem]:
|
||||
@staticmethod
|
||||
def as_user_item(user: 'User') -> UserItem:
|
||||
return UserItem(
|
||||
id=user.uuid,
|
||||
name=user.name,
|
||||
real_name=user.real_name,
|
||||
comments=user.comments,
|
||||
state=user.state,
|
||||
staff_member=user.staff_member,
|
||||
is_admin=user.is_admin,
|
||||
last_access=user.last_access,
|
||||
mfa_data=user.mfa_data,
|
||||
parent=user.parent,
|
||||
groups=[i.uuid for i in user.get_groups()],
|
||||
role=user.get_role().as_str(),
|
||||
)
|
||||
|
||||
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
|
||||
if field_info := self.get_sort_field_info('role'):
|
||||
descending = '-' if field_info[1] else ''
|
||||
return qs.order_by(f'{descending}is_admin', f'{descending}staff_member')
|
||||
|
||||
return super().apply_sort(qs)
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
|
||||
def as_user_item(user: 'User') -> UserItem:
|
||||
return UserItem(
|
||||
id=user.uuid,
|
||||
name=user.name,
|
||||
real_name=user.real_name,
|
||||
comments=user.comments,
|
||||
state=user.state,
|
||||
staff_member=user.staff_member,
|
||||
is_admin=user.is_admin,
|
||||
last_access=user.last_access,
|
||||
mfa_data=user.mfa_data,
|
||||
parent=user.parent,
|
||||
groups=[i.uuid for i in user.get_groups()],
|
||||
role=user.get_role().as_str(),
|
||||
)
|
||||
return self.calc_item_position(item_uuid, parent.users.all())
|
||||
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[UserItem]:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
|
||||
# Extract authenticator
|
||||
try:
|
||||
if item is None: # All users
|
||||
return [as_user_item(i) for i in self.filter_queryset(parent.users.all())]
|
||||
return [self.as_user_item(i) for i in self.odata_filter(parent.users.all())]
|
||||
|
||||
u = parent.users.get(uuid__iexact=process_uuid(item))
|
||||
res = as_user_item(u)
|
||||
usr = AUser(u)
|
||||
res.groups = [g.db_obj().uuid for g in usr.groups()]
|
||||
logger.debug('Item: %s', res)
|
||||
return res
|
||||
except User.DoesNotExist:
|
||||
raise exceptions.rest.NotFound(_('User not found')) from None
|
||||
except Exception as e:
|
||||
logger.error('Error getting user %s: %s', item, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting user')) from e
|
||||
def get_item(self, parent: 'Model', item: str) -> UserItem:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
|
||||
db_usr = parent.users.get(uuid__iexact=process_uuid(item))
|
||||
user_item = self.as_user_item(db_usr)
|
||||
auth_usr = AUser(db_usr)
|
||||
user_item.groups = [g.db_obj().uuid for g in auth_usr.groups()]
|
||||
return user_item
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
@@ -350,44 +360,38 @@ class GroupItem(types.rest.BaseRestItem):
|
||||
class Groups(DetailHandler[GroupItem]):
|
||||
CUSTOM_METHODS = ['services_pools', 'users']
|
||||
|
||||
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult['GroupItem']:
|
||||
@staticmethod
|
||||
def as_group_item(group: 'Group') -> GroupItem:
|
||||
val = GroupItem(
|
||||
id=group.uuid,
|
||||
name=group.name,
|
||||
comments=group.comments,
|
||||
state=group.state,
|
||||
type=group.is_meta and 'meta' or 'group',
|
||||
meta_if_any=group.meta_if_any,
|
||||
skip_mfa=group.skip_mfa,
|
||||
)
|
||||
if group.is_meta:
|
||||
val.groups = list(x.uuid for x in group.groups.all().order_by('name'))
|
||||
return val
|
||||
|
||||
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
try:
|
||||
multi = False
|
||||
if item is None:
|
||||
multi = True
|
||||
q = self.filter_queryset(parent.groups.all())
|
||||
else:
|
||||
q = parent.groups.filter(uuid=process_uuid(item))
|
||||
res: list[GroupItem] = []
|
||||
i = None
|
||||
for i in q:
|
||||
val = GroupItem(
|
||||
id=i.uuid,
|
||||
name=i.name,
|
||||
comments=i.comments,
|
||||
state=i.state,
|
||||
type=i.is_meta and 'meta' or 'group',
|
||||
meta_if_any=i.meta_if_any,
|
||||
skip_mfa=i.skip_mfa,
|
||||
)
|
||||
if i.is_meta:
|
||||
val.groups = list(x.uuid for x in i.groups.all().order_by('name'))
|
||||
res.append(val)
|
||||
return self.calc_item_position(item_uuid, parent.groups.all())
|
||||
|
||||
if multi:
|
||||
return res
|
||||
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['GroupItem']:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
q = self.odata_filter(parent.groups.all())
|
||||
return [self.as_group_item(i) for i in q]
|
||||
|
||||
if not i:
|
||||
raise exceptions.rest.NotFound(_('Group not found')) from None
|
||||
# Add pools field if 1 item only
|
||||
res[0].pools = [v.uuid for v in get_service_pools_for_groups([i])]
|
||||
return res[0]
|
||||
except exceptions.rest.HandlerError:
|
||||
raise # Re-raise
|
||||
except Exception as e:
|
||||
logger.error('Group item not found: %s.%s: %s', parent.name, item, e)
|
||||
raise exceptions.rest.ResponseError(_('Error getting group')) from e
|
||||
def get_item(self, parent: 'Model', item: str) -> 'GroupItem':
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
db_grp = parent.groups.filter(uuid=process_uuid(item)).first()
|
||||
if not db_grp:
|
||||
raise exceptions.rest.NotFound(_('Group not found')) from None
|
||||
grp = self.as_group_item(db_grp)
|
||||
grp.pools = [v.uuid for v in get_service_pools_for_groups([db_grp])]
|
||||
return grp
|
||||
|
||||
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
|
||||
parent = ensure.is_instance(parent, Authenticator)
|
||||
|
||||
@@ -51,6 +51,9 @@ from ..handlers import Handler
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
T = typing.TypeVar('T', bound=models.Model)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -132,6 +135,22 @@ class BaseModelHandler(Handler, abc.ABC, typing.Generic[types.rest.T_Item]):
|
||||
|
||||
return args
|
||||
|
||||
def odata_filter(self, qs: models.QuerySet[T]) -> list[T]:
|
||||
"""
|
||||
Invoked to filter the queryset according to parameters received
|
||||
Default implementation does not filter anything
|
||||
|
||||
Args:
|
||||
qs: Queryset to filter
|
||||
|
||||
Returns:
|
||||
Filtered queryset as a list
|
||||
|
||||
Note:
|
||||
This is not final, so we can override it in subclasses if needed
|
||||
"""
|
||||
return self.filter_odata_queryset(qs)
|
||||
|
||||
# Success methods
|
||||
def success(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -34,30 +34,37 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
import logging
|
||||
import typing
|
||||
import collections.abc
|
||||
import abc
|
||||
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from uds.core import consts, exceptions, types, module
|
||||
from uds.core.util.model import process_uuid
|
||||
from uds.core.util import api as api_utils
|
||||
from uds.core.util import api as api_utils, model as model_utils
|
||||
from uds.REST.utils import rest_result
|
||||
|
||||
from uds.REST.model.base import BaseModelHandler
|
||||
from uds.REST.utils import camel_and_snake_case_from
|
||||
|
||||
T = typing.TypeVar('T', bound=models.Model)
|
||||
|
||||
# Not imported at runtime, just for type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from uds.models import User
|
||||
from uds.REST.model.master import ModelHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Details do not have types at all
|
||||
# so, right now, we only process details petitions for Handling & tables info
|
||||
# noinspection PyMissingConstructor
|
||||
class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
|
||||
|
||||
class DetailHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
"""
|
||||
Detail handler (for relations such as provider-->services, authenticators-->users,groups, deployed services-->cache,assigned, groups, transports
|
||||
Urls recognized for GET are:
|
||||
@@ -138,21 +145,19 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
"""
|
||||
# Process args
|
||||
logger.debug('Detail args for GET: %s', self._args)
|
||||
num_args = len(self._args)
|
||||
|
||||
parent: models.Model = self._parent_item
|
||||
|
||||
if num_args == 0:
|
||||
return self.get_items(parent, None)
|
||||
|
||||
# if has custom methods, look for if this request matches any of them
|
||||
r = self._check_is_custom_method(self._args[0], parent)
|
||||
if r is not consts.rest.NOT_FOUND:
|
||||
return r
|
||||
|
||||
match self._args:
|
||||
case []: # same as overview
|
||||
return self.get_items(parent)
|
||||
case [consts.rest.OVERVIEW]:
|
||||
return self.get_items(parent, None)
|
||||
return self.get_items(parent)
|
||||
case [consts.rest.OVERVIEW, *_fails]:
|
||||
raise exceptions.rest.RequestError('Invalid overview request') from None
|
||||
case [consts.rest.TYPES]:
|
||||
@@ -177,8 +182,10 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
return self.get_logs(parent, item_id)
|
||||
case [consts.rest.LOG, *_fails]:
|
||||
raise exceptions.rest.RequestError('Invalid log request') from None
|
||||
case [consts.rest.POSITION, item_uuid]:
|
||||
return self.get_item_position(parent, item_uuid)
|
||||
case [one_arg]:
|
||||
return self.get_items(parent, process_uuid(one_arg))
|
||||
return self.get_item(parent, process_uuid(one_arg))
|
||||
case _:
|
||||
# Maybe a custom method?
|
||||
r = self._check_is_custom_method(self._args[1], parent, self._args[0])
|
||||
@@ -247,9 +254,8 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
|
||||
# Override this to provide functionality
|
||||
# Default (as sample) get_items
|
||||
def get_items(
|
||||
self, parent: models.Model, item: typing.Optional[str]
|
||||
) -> types.rest.ItemsResult[types.rest.T_Item]:
|
||||
@abc.abstractmethod
|
||||
def get_items(self, parent: models.Model) -> types.rest.ItemsResult[types.rest.T_Item]:
|
||||
"""
|
||||
This MUST be overridden by derived classes
|
||||
Excepts to return a list of dictionaries or a single dictionary, depending on "item" param
|
||||
@@ -261,6 +267,16 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
# return {} # Returns one item
|
||||
raise NotImplementedError(f'Must provide an get_items method for {self.__class__} class')
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_item(self, parent: models.Model, item: str) -> types.rest.T_Item:
|
||||
"""
|
||||
Utility method to get a single item by uuid
|
||||
:param parent: Parent model
|
||||
:param item: Item uuid
|
||||
:return: Item as dictionary
|
||||
"""
|
||||
raise NotImplementedError(f'Must provide an get_item method for {self.__class__} class')
|
||||
|
||||
# Default save
|
||||
def save_item(self, parent: models.Model, item: typing.Optional[str]) -> types.rest.T_Item:
|
||||
"""
|
||||
@@ -328,6 +344,51 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
"""
|
||||
return [] # Default is that details do not have types
|
||||
|
||||
def get_logs(self, parent: models.Model, item: str) -> list[typing.Any]:
|
||||
"""
|
||||
If the detail has any log associated with it items, provide it overriding this method
|
||||
|
||||
Args:
|
||||
parent: Parent model
|
||||
item: Item id (uuid)
|
||||
|
||||
Returns:
|
||||
A list of log elements (normally got using "uds.core.util.log.get_logs" method)
|
||||
"""
|
||||
raise exceptions.rest.InvalidMethodError('Object does not support logs')
|
||||
|
||||
def calc_item_position(self, item_uuid: str, qs: 'QuerySet[T]') -> int:
|
||||
"""
|
||||
Helper method to get the position of an item in a queryset
|
||||
|
||||
Args:
|
||||
item_uuid (str): UUID of the item to find
|
||||
qs (QuerySet[T]): Queryset to search into
|
||||
|
||||
Returns:
|
||||
int: Position of the item in the default ordering, -1 if not found
|
||||
"""
|
||||
# Find item in qs, may be none, then return -1
|
||||
obj = qs.filter(uuid__iexact=item_uuid).first()
|
||||
if obj:
|
||||
return model_utils.get_position_in_queryset(obj, qs)
|
||||
return -1
|
||||
|
||||
|
||||
def get_item_position(self, parent: models.Model, item_uuid: str) -> int:
|
||||
"""
|
||||
Tries to get the position of an item in the default ordering of the detail items
|
||||
|
||||
Args:
|
||||
item_uuid (str): UUID of the item to find
|
||||
Returns:
|
||||
int: Position of the item in the default ordering, -1 if not found
|
||||
|
||||
Note:
|
||||
Override this method if the detail can provide item position
|
||||
"""
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def possible_types(cls: type[typing.Self]) -> collections.abc.Iterable[type[module.Module]]:
|
||||
"""
|
||||
@@ -337,15 +398,6 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_logs(self, parent: models.Model, item: str) -> list[typing.Any]:
|
||||
"""
|
||||
If the detail has any log associated with it items, provide it overriding this method
|
||||
:param parent:
|
||||
:param item:
|
||||
:return: a list of log elements (normally got using "uds.core.util.log.get_logs" method)
|
||||
"""
|
||||
raise exceptions.rest.InvalidMethodError('Object does not support logs')
|
||||
|
||||
@classmethod
|
||||
def api_components(cls: type[typing.Self]) -> types.rest.api.Components:
|
||||
"""
|
||||
@@ -355,10 +407,12 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
|
||||
return api_utils.get_component_from_type(cls)
|
||||
|
||||
@classmethod
|
||||
def api_paths(cls: type[typing.Self], path: str, tags: list[str], security: str) -> dict[str, types.rest.api.PathItem]:
|
||||
def api_paths(
|
||||
cls: type[typing.Self], path: str, tags: list[str], security: str
|
||||
) -> dict[str, types.rest.api.PathItem]:
|
||||
"""
|
||||
Returns the API operations that should be registered
|
||||
"""
|
||||
from .api_helpers import api_paths
|
||||
|
||||
|
||||
return api_paths(cls, path, tags=tags, security=security)
|
||||
|
||||
@@ -44,7 +44,7 @@ from uds.core import consts
|
||||
from uds.core import exceptions
|
||||
from uds.core import types
|
||||
from uds.core.module import Module
|
||||
from uds.core.util import log, permissions, api as api_utils
|
||||
from uds.core.util import log, permissions, model as model_utils, api as api_utils
|
||||
from uds.models import ManagedObjectModel, Tag, TaggingMixin
|
||||
|
||||
from uds.REST.model.base import BaseModelHandler
|
||||
@@ -198,15 +198,10 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
method = getattr(detail_handler, self._operation)
|
||||
|
||||
return method()
|
||||
except self.MODEL.DoesNotExist:
|
||||
raise exceptions.rest.NotFound('Item not found on model {self.MODEL.__name__}')
|
||||
except (KeyError, AttributeError) as e:
|
||||
raise exceptions.rest.InvalidMethodError(f'Invalid method {self._operation}') from e
|
||||
except exceptions.rest.HandlerError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error('Exception processing detail: %s', e)
|
||||
raise exceptions.rest.RequestError(f'Error processing detail: {e}') from e
|
||||
|
||||
# Data related
|
||||
def get_item(self, item: models.Model) -> types.rest.T_Item:
|
||||
@@ -222,30 +217,44 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
default behavior is return item_as_dict
|
||||
"""
|
||||
return self.get_item(item)
|
||||
|
||||
def filter_model_queryset(self, qs: QuerySet[T]|None = None) -> QuerySet[T]:
|
||||
qs = typing.cast('QuerySet[T]', self.MODEL.objects.all()) if qs is None else qs
|
||||
|
||||
if self.FILTER is not None:
|
||||
qs = qs.filter(**self.FILTER)
|
||||
if self.EXCLUDE is not None:
|
||||
qs = qs.exclude(**self.EXCLUDE)
|
||||
|
||||
return qs
|
||||
|
||||
def get_item_position(self, item_uuid: str, query: QuerySet[T] | None = None) -> int:
|
||||
qs = self.filter_model_queryset(query)
|
||||
|
||||
# Find item in qs, may be none, then return -1
|
||||
obj = qs.filter(uuid__iexact=item_uuid).first()
|
||||
if obj:
|
||||
return model_utils.get_position_in_queryset(obj, qs)
|
||||
return -1
|
||||
|
||||
|
||||
def get_items(
|
||||
self, *, overview: bool = False, query: QuerySet[T] | None = None
|
||||
self, *, sumarize: bool = False, query: QuerySet[T] | None = None
|
||||
) -> typing.Generator[types.rest.T_Item, None, None]:
|
||||
"""
|
||||
Get items from the model.
|
||||
Args:
|
||||
overview: If True, return a summary of the items.
|
||||
sumarize: If True, return a summary of the items.
|
||||
query: Optional queryset to filter the items. Used to optimize the process for some models
|
||||
(such as ServicePools)
|
||||
|
||||
"""
|
||||
|
||||
# Basic model filter
|
||||
if query:
|
||||
qs = query
|
||||
else:
|
||||
qs = self.MODEL.objects.all()
|
||||
if self.FILTER is not None:
|
||||
qs = qs.filter(**self.FILTER)
|
||||
if self.EXCLUDE is not None:
|
||||
qs = qs.exclude(**self.EXCLUDE)
|
||||
qs = self.filter_model_queryset(query)
|
||||
|
||||
qs = self.filter_queryset(qs)
|
||||
# Custom filtering from params (odata, etc)
|
||||
qs = self.odata_filter(qs)
|
||||
|
||||
for item in qs:
|
||||
try:
|
||||
@@ -259,7 +268,7 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
is False
|
||||
):
|
||||
continue
|
||||
yield self.get_item_summary(item) if overview else self.get_item(item)
|
||||
yield self.get_item_summary(item) if sumarize else self.get_item(item)
|
||||
except Exception as e: # maybe an exception is thrown to skip an item
|
||||
logger.debug('Got exception processing item from model: %s', e)
|
||||
# logger.exception('Exception getting item from {0}'.format(self.model))
|
||||
@@ -268,9 +277,6 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
logger.debug('method GET for %s, %s', self.__class__.__name__, self._args)
|
||||
number_of_args = len(self._args)
|
||||
|
||||
if number_of_args == 0:
|
||||
return list(self.get_items(overview=False))
|
||||
|
||||
# if has custom methods, look for if this request matches any of them
|
||||
for cm in self.CUSTOM_METHODS:
|
||||
# Convert to snake case
|
||||
@@ -309,7 +315,7 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
|
||||
match self._args:
|
||||
case []: # Same as overview, but with all data
|
||||
return [i.as_dict() for i in self.get_items(overview=False)]
|
||||
return [i.as_dict() for i in self.get_items(sumarize=False)]
|
||||
case [consts.rest.OVERVIEW]:
|
||||
return [i.as_dict() for i in self.get_items()]
|
||||
case [consts.rest.OVERVIEW, *_fails]:
|
||||
@@ -330,6 +336,8 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
|
||||
return self.get_processed_gui(for_type)
|
||||
case [consts.rest.GUI, for_type, *_fails]:
|
||||
raise exceptions.rest.RequestError('Invalid GUI request') from None
|
||||
case [consts.rest.POSITION, item_uuid]:
|
||||
return self.get_item_position(item_uuid)
|
||||
case _: # Maybe an item or a detail
|
||||
if number_of_args == 1:
|
||||
try:
|
||||
|
||||
@@ -38,6 +38,7 @@ TYPES: typing.Final[str] = 'types'
|
||||
TABLEINFO: typing.Final[str] = 'tableinfo'
|
||||
GUI: typing.Final[str] = 'gui'
|
||||
LOG: typing.Final[str] = 'log'
|
||||
POSITION: typing.Final[str] = 'position'
|
||||
|
||||
SYSTEM: typing.Final[str] = 'system' # Defined on system class, here for reference
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class NotificationsManager(metaclass=singleton.Singleton):
|
||||
message = message % args
|
||||
except Exception:
|
||||
message = message + ' ' + str(args) + ' (format error)'
|
||||
message = message[:4096] # Max length of message
|
||||
message = message[:4000] # Max length of message, fixed to ensure it also supports sqlserver
|
||||
# Store the notification on local persistent storage
|
||||
# Will be processed by UDS backend
|
||||
try:
|
||||
|
||||
@@ -315,10 +315,14 @@ class Transport(Module):
|
||||
osname: str,
|
||||
type: typing.Literal['tunnel', 'direct'],
|
||||
params: collections.abc.Mapping[str, typing.Any],
|
||||
client_version: str | None = None,
|
||||
) -> types.transports.TransportScript:
|
||||
"""
|
||||
Returns a script for the given os and type
|
||||
"""
|
||||
if (client_version or '0.0') >= '5.0.0':
|
||||
return self.get_relative_script(f'scripts/{osname.lower()}/{type}.js', params)
|
||||
|
||||
return self.get_relative_script(f'scripts/{osname.lower()}/{type}.py', params)
|
||||
|
||||
def get_link(
|
||||
|
||||
@@ -162,7 +162,7 @@ class ManagedObjectItem(BaseRestItem, typing.Generic[T_Model]):
|
||||
|
||||
|
||||
# Alias for get_items return type
|
||||
ItemsResult: typing.TypeAlias = list[T_Item] | BaseRestItem | typing.Iterator[T_Item]
|
||||
ItemsResult: typing.TypeAlias = list[T_Item] | typing.Iterator[T_Item]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -279,6 +279,8 @@ class TableInfo:
|
||||
fields: list[TableField] # List of fields in the table
|
||||
row_style: 'RowStyleInfo'
|
||||
subtitle: typing.Optional[str] = None
|
||||
filter_fields: list[str] = dataclasses.field(default_factory=list[str])
|
||||
field_mappings: dict[str, str] = dataclasses.field(default_factory=dict[str, str])
|
||||
|
||||
def as_dict(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
@@ -286,6 +288,8 @@ class TableInfo:
|
||||
'fields': [field.as_dict() for field in self.fields],
|
||||
'row_style': self.row_style.as_dict(),
|
||||
'subtitle': self.subtitle or '',
|
||||
'filter_fields': self.filter_fields,
|
||||
'field_mappings': self.field_mappings,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -31,18 +31,20 @@ def _as_dict_without_none(v: typing.Any) -> typing.Any:
|
||||
# (handler, model, detail, etc.)
|
||||
# So we can override names or whatever we need
|
||||
|
||||
|
||||
# Types of GUI info that can be provided
|
||||
class RestApiInfoGuiType(enum.Enum):
|
||||
SINGLE_TYPE = 0
|
||||
MULTIPLE_TYPES = 1
|
||||
UNTYPED = 3
|
||||
|
||||
|
||||
def is_single_type(self) -> bool:
|
||||
return self == RestApiInfoGuiType.SINGLE_TYPE
|
||||
|
||||
|
||||
def supports_multiple_types(self) -> bool:
|
||||
return self == RestApiInfoGuiType.MULTIPLE_TYPES
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RestApiInfo:
|
||||
|
||||
@@ -429,7 +431,9 @@ class ODataParams:
|
||||
filter=data.get('$filter'),
|
||||
start=start,
|
||||
limit=limit,
|
||||
orderby=order_by,
|
||||
orderby=[
|
||||
o.replace('.', '__') for o in order_by
|
||||
], # Allow order by related fields with dot or __
|
||||
select=select,
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
|
||||
@@ -181,3 +181,18 @@ def get_my_ip_from_db() -> str:
|
||||
logger.error('Error getting my IP: %s', e)
|
||||
|
||||
return '0.0.0.0'
|
||||
|
||||
|
||||
# Tries to get the position of an object in the default queryset ordering
|
||||
def get_position_in_queryset(obj: typing.Any, queryset: typing.Any) -> int:
|
||||
"""
|
||||
Tries to get the position of an object in the default queryset ordering
|
||||
:param obj: Object to find
|
||||
:param queryset: Queryset to search in
|
||||
:return: Position in the queryset (0 based). -1 if not found
|
||||
"""
|
||||
try:
|
||||
lst = list(queryset.values_list('pk', flat=True))
|
||||
return lst.index(obj.pk)
|
||||
except ValueError:
|
||||
return -1
|
||||
@@ -176,7 +176,7 @@ class DjangoQueryTransformer(lark.Transformer[typing.Any, Q | AnnotatedField]):
|
||||
if func_name not in _FUNCTIONS_PARAMS_NUM:
|
||||
raise ValueError(f"Unknown function: {func_name}")
|
||||
|
||||
if func_name in ('substringof', 'startswith', 'endswith'):
|
||||
if func_name in ('substringof', 'contains', 'startswith', 'endswith'):
|
||||
if len(func_args) != 2:
|
||||
raise ValueError(f"{func_name} requires 2 arguments")
|
||||
field, value = func_args
|
||||
@@ -186,6 +186,8 @@ class DjangoQueryTransformer(lark.Transformer[typing.Any, Q | AnnotatedField]):
|
||||
raise ValueError(f"Function '{func_name}' does not support field-to-field comparison")
|
||||
match func_name:
|
||||
case 'substringof':
|
||||
return Q(**{f"{value}__icontains": field}) # Note the order swap
|
||||
case 'contains':
|
||||
return Q(**{f"{field}__icontains": value})
|
||||
case 'startswith':
|
||||
return Q(**{f"{field}__istartswith": value})
|
||||
|
||||
@@ -399,6 +399,8 @@ class TableBuilder:
|
||||
subtitle: str | None
|
||||
fields: list[types.rest.TableField]
|
||||
style_info: types.rest.RowStyleInfo
|
||||
filter_fields: list[str]
|
||||
field_mappings: dict[str, str]
|
||||
|
||||
def __init__(self, title: str, subtitle: str | None = None) -> None:
|
||||
# TODO: USe table_name on a later iteration of the code
|
||||
@@ -406,6 +408,8 @@ class TableBuilder:
|
||||
self.subtitle = subtitle
|
||||
self.fields = []
|
||||
self.style_info = types.rest.RowStyleInfo.null()
|
||||
self.filter_fields = []
|
||||
self.field_mappings = {}
|
||||
|
||||
def _add_field(
|
||||
self,
|
||||
@@ -513,6 +517,20 @@ class TableBuilder:
|
||||
self.style_info = types.rest.RowStyleInfo(prefix=prefix, field=field)
|
||||
return self
|
||||
|
||||
def with_filter_fields(self, *fields: str) -> typing.Self:
|
||||
"""
|
||||
Sets the sorting fields for the table fields.
|
||||
"""
|
||||
self.filter_fields = list(fields)
|
||||
return self
|
||||
|
||||
def with_field_mappings(self, **kwargs: str) -> typing.Self:
|
||||
"""
|
||||
Sets the filter fields translations for the table fields.
|
||||
"""
|
||||
self.field_mappings = kwargs
|
||||
return self
|
||||
|
||||
def build(self) -> types.rest.TableInfo:
|
||||
"""
|
||||
Returns the table info for the table fields.
|
||||
@@ -522,4 +540,6 @@ class TableBuilder:
|
||||
fields=self.fields,
|
||||
row_style=self.style_info,
|
||||
subtitle=self.subtitle,
|
||||
filter_fields=self.filter_fields,
|
||||
field_mappings=self.field_mappings,
|
||||
)
|
||||
|
||||
@@ -163,7 +163,6 @@ class UniqueGenerator:
|
||||
logger.debug('Last: %s', last)
|
||||
seq = last.seq + 1
|
||||
except Exception:
|
||||
# logger.exception('Error here')
|
||||
seq = 0
|
||||
with transaction.atomic():
|
||||
self._range_filter(seq).delete() # Clean ups all unassigned after last assigned in this range
|
||||
|
||||
@@ -34,7 +34,6 @@ import urllib.parse
|
||||
import logging
|
||||
import requests
|
||||
import time
|
||||
import token
|
||||
|
||||
from uds.core.util import security
|
||||
from uds.core.util.cache import Cache
|
||||
@@ -42,10 +41,8 @@ from uds.core.util.decorators import cached
|
||||
|
||||
from . import types, consts, exceptions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenshiftClient:
|
||||
cluster_url: str
|
||||
api_url: str
|
||||
@@ -104,7 +101,7 @@ class OpenshiftClient:
|
||||
except Exception as ex:
|
||||
logging.error(f"Could not obtain token: {ex}")
|
||||
raise
|
||||
|
||||
|
||||
def connect(self, force: bool = False) -> requests.Session:
|
||||
# For testing, always use the fixed token
|
||||
session = self._session = security.secure_requests_session(verify=self._verify_ssl)
|
||||
@@ -296,33 +293,28 @@ class OpenshiftClient:
|
||||
"""
|
||||
Monitor the clone process of a virtual machine.
|
||||
"""
|
||||
clone_url = f"{api_url}/apis/clone.kubevirt.io/v1alpha1/namespaces/{namespace}/virtualmachineclones/{clone_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
path = f"/apis/clone.kubevirt.io/v1alpha1/namespaces/{namespace}/virtualmachineclones/{clone_name}"
|
||||
logging.info("Monitoring clone process for '%s'...", clone_name)
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(clone_url, headers=headers, verify=False)
|
||||
if response.status_code == 200:
|
||||
clone_data = response.json()
|
||||
status = clone_data.get('status', {})
|
||||
phase = status.get('phase', 'Unknown')
|
||||
logging.info("Phase: %s", phase)
|
||||
for condition in status.get('conditions', []):
|
||||
ctype = condition.get('type', '')
|
||||
cstatus = condition.get('status', '')
|
||||
cmsg = condition.get('message', '')
|
||||
logging.info(" %s: %s - %s", ctype, cstatus, cmsg)
|
||||
if phase == 'Succeeded':
|
||||
logging.info("Clone '%s' completed successfully!", clone_name)
|
||||
break
|
||||
elif phase == 'Failed':
|
||||
logging.error("Clone '%s' failed!", clone_name)
|
||||
break
|
||||
elif response.status_code == 404:
|
||||
logging.warning("Clone resource '%s' not found. May have been cleaned up.", clone_name)
|
||||
response = self.do_request('GET', path)
|
||||
status = response.get('status', {})
|
||||
phase = status.get('phase', 'Unknown')
|
||||
logging.info("Phase: %s", phase)
|
||||
for condition in status.get('conditions', []):
|
||||
ctype = condition.get('type', '')
|
||||
cstatus = condition.get('status', '')
|
||||
cmsg = condition.get('message', '')
|
||||
logging.info(" %s: %s - %s", ctype, cstatus, cmsg)
|
||||
if phase == 'Succeeded':
|
||||
logging.info("Clone '%s' completed successfully!", clone_name)
|
||||
break
|
||||
else:
|
||||
logging.error("Error monitoring clone: %d", response.status_code)
|
||||
elif phase == 'Failed':
|
||||
logging.error("Clone '%s' failed!", clone_name)
|
||||
break
|
||||
except exceptions.OpenshiftNotFoundError:
|
||||
logging.warning("Clone resource '%s' not found. May have been cleaned up.", clone_name)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("Monitoring exception: %s", e)
|
||||
logging.info("Waiting %d seconds before next check...", polling_interval)
|
||||
@@ -332,19 +324,16 @@ class OpenshiftClient:
|
||||
"""
|
||||
Returns the name of the PVC or DataVolume used by the VM.
|
||||
"""
|
||||
vm_url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
response = requests.get(vm_url, headers=headers, verify=False)
|
||||
if response.status_code == 200:
|
||||
vm_obj = response.json()
|
||||
volumes = vm_obj.get("spec", {}).get("template", {}).get("spec", {}).get("volumes", [])
|
||||
for vol in volumes:
|
||||
pvc = vol.get("persistentVolumeClaim")
|
||||
if pvc:
|
||||
return pvc.get("claimName"), "pvc"
|
||||
dv = vol.get("dataVolume")
|
||||
if dv:
|
||||
return dv.get("name"), "dv"
|
||||
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
response = self.do_request('GET', path)
|
||||
volumes = response.get("spec", {}).get("template", {}).get("spec", {}).get("volumes", [])
|
||||
for vol in volumes:
|
||||
pvc = vol.get("persistentVolumeClaim")
|
||||
if pvc:
|
||||
return pvc.get("claimName"), "pvc"
|
||||
dv = vol.get("dataVolume")
|
||||
if dv:
|
||||
return dv.get("name"), "dv"
|
||||
raise Exception(f"No PVC or DataVolume found in VM {vm_name}")
|
||||
|
||||
def get_datavolume_phase(self, datavolume_name: str) -> str:
|
||||
@@ -352,13 +341,10 @@ class OpenshiftClient:
|
||||
Get the phase of a DataVolume.
|
||||
Returns the phase as a string.
|
||||
"""
|
||||
url = f"{self.api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{self.namespace}/datavolumes/{datavolume_name}"
|
||||
headers = {'Authorization': f'Bearer {self.get_token()}', 'Accept': 'application/json'}
|
||||
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{self.namespace}/datavolumes/{datavolume_name}"
|
||||
try:
|
||||
response = requests.get(url, headers=headers, verify=self._verify_ssl, timeout=self._timeout)
|
||||
if response.status_code == 200:
|
||||
dv = response.json()
|
||||
return dv.get('status', {}).get('phase', '')
|
||||
response = self.do_request('GET', path)
|
||||
return response.get('status', {}).get('phase', '')
|
||||
except Exception:
|
||||
pass
|
||||
return ''
|
||||
@@ -368,17 +354,14 @@ class OpenshiftClient:
|
||||
Get the size of a DataVolume.
|
||||
Returns the size as a string.
|
||||
"""
|
||||
url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{dv_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
response = requests.get(url, headers=headers, verify=False)
|
||||
if response.status_code == 200:
|
||||
dv = response.json()
|
||||
size = dv.get("status", {}).get("amount", None)
|
||||
if size:
|
||||
return size
|
||||
return (
|
||||
dv.get("spec", {}).get("pvc", {}).get("resources", {}).get("requests", {}).get("storage") or ""
|
||||
)
|
||||
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{dv_name}"
|
||||
response = self.do_request('GET', path)
|
||||
size = response.get("status", {}).get("amount", None)
|
||||
if size:
|
||||
return size
|
||||
return (
|
||||
response.get("spec", {}).get("pvc", {}).get("resources", {}).get("requests", {}).get("storage") or ""
|
||||
)
|
||||
raise Exception(f"Could not get the size of DataVolume {dv_name}")
|
||||
|
||||
def get_pvc_size(self, api_url: str, namespace: str, pvc_name: str) -> str:
|
||||
@@ -386,14 +369,11 @@ class OpenshiftClient:
|
||||
Get the size of a PVC.
|
||||
Returns the size as a string.
|
||||
"""
|
||||
url = f"{api_url}/api/v1/namespaces/{namespace}/persistentvolumeclaims/{pvc_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
response = requests.get(url, headers=headers, verify=False)
|
||||
if response.status_code == 200:
|
||||
pvc = response.json()
|
||||
capacity = pvc.get("status", {}).get("capacity", {}).get("storage")
|
||||
if capacity:
|
||||
return capacity
|
||||
path = f"/api/v1/namespaces/{namespace}/persistentvolumeclaims/{pvc_name}"
|
||||
response = self.do_request('GET', path)
|
||||
capacity = response.get("status", {}).get("capacity", {}).get("storage")
|
||||
if capacity:
|
||||
return capacity
|
||||
raise Exception(f"Could not get the size of PVC {pvc_name}")
|
||||
|
||||
def clone_pvc_with_datavolume(
|
||||
@@ -409,12 +389,7 @@ class OpenshiftClient:
|
||||
Clone a PVC using a DataVolume.
|
||||
Returns True if the DataVolume was created successfully, else False.
|
||||
"""
|
||||
dv_url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_token()}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes"
|
||||
body: dict[str, typing.Any] = {
|
||||
"apiVersion": "cdi.kubevirt.io/v1beta1",
|
||||
"kind": "DataVolume",
|
||||
@@ -428,12 +403,13 @@ class OpenshiftClient:
|
||||
},
|
||||
},
|
||||
}
|
||||
response = requests.post(dv_url, headers=headers, json=body, verify=False)
|
||||
if response.status_code == 201:
|
||||
try:
|
||||
self.do_request('POST', path, data=body)
|
||||
logging.info(f"DataVolume '{cloned_pvc_name}' created successfully")
|
||||
return True
|
||||
logging.error(f"Failed to create DataVolume: {response.status_code} {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create DataVolume: {e}")
|
||||
return False
|
||||
|
||||
def create_vm_from_pvc(
|
||||
self,
|
||||
@@ -448,16 +424,13 @@ class OpenshiftClient:
|
||||
Create a new VM from a cloned PVC using DataVolumeTemplates.
|
||||
Returns True if the VM was created successfully, else False.
|
||||
"""
|
||||
original_vm_url = (
|
||||
f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{source_vm_name}"
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
resp = requests.get(original_vm_url, headers=headers, verify=False)
|
||||
if resp.status_code != 200:
|
||||
logging.error(f"Could not get source VM: {resp.status_code} {resp.text}")
|
||||
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{source_vm_name}"
|
||||
try:
|
||||
vm_obj = self.do_request('GET', path)
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get source VM: {e}")
|
||||
return False
|
||||
|
||||
vm_obj = resp.json()
|
||||
vm_obj['metadata']['name'] = new_vm_name
|
||||
|
||||
for k in ['resourceVersion', 'uid', 'selfLink']:
|
||||
@@ -503,28 +476,27 @@ class OpenshiftClient:
|
||||
logger.info(f"Creating VM '{new_vm_name}' from cloned PVC '{new_dv_name}'.")
|
||||
#logger.info(f"VM Object: {vm_obj}")
|
||||
|
||||
create_url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines"
|
||||
headers["Content-Type"] = "application/json"
|
||||
resp = requests.post(create_url, headers=headers, json=vm_obj, verify=False)
|
||||
if resp.status_code == 201:
|
||||
create_path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines"
|
||||
try:
|
||||
self.do_request('POST', create_path, data=vm_obj)
|
||||
logging.info(f"VM '{new_vm_name}' created successfully with DataVolumeTemplate.")
|
||||
return True
|
||||
logging.error(f"Error creating VM: {resp.status_code} {resp.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating VM: {e}")
|
||||
return False
|
||||
|
||||
def delete_vm(self, api_url: str, namespace: str, vm_name: str) -> bool:
|
||||
"""
|
||||
Delete a VM by name.
|
||||
Returns True if the VM was deleted successfully, else False.
|
||||
"""
|
||||
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
response = requests.delete(url, headers=headers, verify=False)
|
||||
if response.status_code in [200, 202]:
|
||||
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
try:
|
||||
self.do_request('DELETE', path)
|
||||
logging.info(f"VM {vm_name} deleted successfully.")
|
||||
return True
|
||||
else:
|
||||
logging.error(f"Error deleting VM {vm_name}: {response.status_code} - {response.text}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting VM {vm_name}: {e}")
|
||||
return False
|
||||
|
||||
def wait_for_datavolume_clone_progress(
|
||||
@@ -534,14 +506,12 @@ class OpenshiftClient:
|
||||
Wait for a DataVolume clone to complete.
|
||||
Returns True if the clone completed successfully, else False.
|
||||
"""
|
||||
url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{datavolume_name}"
|
||||
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
|
||||
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{datavolume_name}"
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
response = requests.get(url, headers=headers, verify=False)
|
||||
if response.status_code == 200:
|
||||
dv = response.json()
|
||||
status = dv.get('status', {})
|
||||
try:
|
||||
response = self.do_request('GET', path)
|
||||
status = response.get('status', {})
|
||||
phase = status.get('phase')
|
||||
progress = status.get('progress', 'N/A')
|
||||
logging.info(f"DataVolume {datavolume_name} status: {phase}, progress: {progress}")
|
||||
@@ -551,8 +521,8 @@ class OpenshiftClient:
|
||||
elif phase == 'Failed':
|
||||
logging.error(f"DataVolume {datavolume_name} clone failed")
|
||||
return False
|
||||
else:
|
||||
logging.error(f"Error querying DataVolume {datavolume_name}: {response.status_code}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error querying DataVolume {datavolume_name}: {e}")
|
||||
time.sleep(polling_interval)
|
||||
logging.error(f"Timeout waiting for DataVolume {datavolume_name} clone")
|
||||
return False
|
||||
@@ -562,40 +532,47 @@ class OpenshiftClient:
|
||||
Start a VM by name.
|
||||
Returns True if the VM was started successfully, else False.
|
||||
"""
|
||||
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_token()}",
|
||||
"Content-Type": "application/merge-patch+json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
body: dict[str, typing.Any] = {"spec": {"runStrategy": "Always"}}
|
||||
response = requests.patch(url, headers=headers, json=body, verify=False)
|
||||
if response.status_code in [200, 201]:
|
||||
logging.info(f"VM {vm_name} started.")
|
||||
return True
|
||||
else:
|
||||
logging.info(f"Error starting VM {vm_name}: {response.status_code} - {response.text}")
|
||||
|
||||
# Get Vm info
|
||||
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
try:
|
||||
vm_obj = self.do_request('GET', path)
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get source VM: {e}")
|
||||
return False
|
||||
|
||||
# Update runStrategy to Always
|
||||
vm_obj['spec']['runStrategy'] = 'Always'
|
||||
try:
|
||||
self.do_request('PUT', path, data=vm_obj)
|
||||
logging.info(f"VM {vm_name} will be started.")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.info(f"Error starting VM {vm_name}: {e}")
|
||||
return False
|
||||
|
||||
def stop_vm(self, api_url: str, namespace: str, vm_name: str) -> bool:
|
||||
"""
|
||||
Stop a VM by name.
|
||||
Returns True if the VM was stopped successfully, else False.
|
||||
"""
|
||||
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_token()}",
|
||||
"Content-Type": "application/merge-patch+json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
body: dict[str, typing.Any] = {"spec": {"runStrategy": "Halted"}}
|
||||
response = requests.patch(url, headers=headers, json=body, verify=False)
|
||||
if response.status_code in [200, 201]:
|
||||
# Get Vm info
|
||||
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
|
||||
try:
|
||||
vm_obj = self.do_request('GET', path)
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get source VM: {e}")
|
||||
return False
|
||||
|
||||
# Update runStrategy to Halted
|
||||
vm_obj['spec']['runStrategy'] = 'Halted'
|
||||
try:
|
||||
self.do_request('PUT', path, data=vm_obj)
|
||||
logging.info(f"VM {vm_name} will be stopped.")
|
||||
return True
|
||||
else:
|
||||
logging.info(f"Error stopping VM {vm_name}: {response.status_code} - {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.info(f"Error starting VM {vm_name}: {e}")
|
||||
return False
|
||||
|
||||
def copy_vm_same_size(
|
||||
self, api_url: str, namespace: str, source_vm_name: str, new_vm_name: str, storage_class: str
|
||||
@@ -687,22 +664,3 @@ class OpenshiftClient:
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting VM: {e}")
|
||||
return False
|
||||
|
||||
def clone_vm_instance(self, source_vm_name: str, new_vm_name: str, storage_class: str) -> bool:
|
||||
"""
|
||||
Clone a VM by name, creating a new VM with the same size.
|
||||
Returns True if clone succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
self.copy_vm_same_size(self.api_url, self.namespace, source_vm_name, new_vm_name, storage_class)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error cloning VM: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def validate_vm_id(vm_id: str | int) -> None:
|
||||
try:
|
||||
int(vm_id)
|
||||
except ValueError:
|
||||
raise exceptions.OpenshiftNotFoundError(f'VM {vm_id} not found')
|
||||
|
||||
@@ -126,10 +126,17 @@ class OpenshiftProvider(ServiceProvider):
|
||||
def sanitized_name(self, name: str) -> str:
|
||||
"""
|
||||
Sanitizes the VM name to comply with RFC 1123:
|
||||
- Lowercase
|
||||
- Alphanumeric, '-', '.'
|
||||
- Starts/ends with alphanumeric
|
||||
- Max length 63 chars
|
||||
- Converts to lowercase
|
||||
- Replaces any character not in [a-z0-9.-] with '-'
|
||||
- Collapses multiple '-' into one
|
||||
- Removes leading/trailing non-alphanumeric characters
|
||||
- Limits length to 63 characters
|
||||
"""
|
||||
name = re.sub(r'^[^a-z0-9]+|[^a-z0-9.-]|-{2,}|[^a-z0-9]+$', '-', name.lower())
|
||||
return name[:63]
|
||||
name = name.lower()
|
||||
# Replace any character not allowed with '-'
|
||||
name = re.sub(r'[^a-z0-9.-]', '-', name)
|
||||
# Collapse multiple '-' into one
|
||||
name = re.sub(r'-{2,}', '-', name)
|
||||
# Remove leading/trailing non-alphanumeric characters
|
||||
name = re.sub(r'^[^a-z0-9]+|[^a-z0-9]+$', '', name)
|
||||
return name[:63]
|
||||
@@ -52,7 +52,7 @@ class TestServiceNoCache(services.Service):
|
||||
Basic testing service without cache and no publication OFC
|
||||
"""
|
||||
type_name = _('Testing Service no cache')
|
||||
type_type = 'TestService1'
|
||||
type_type = 'TestService2'
|
||||
type_description = _('Testing (and dummy) service with no cache')
|
||||
icon_file = 'service.png'
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestServiceCache(services.Service):
|
||||
"""
|
||||
|
||||
type_name = _('Testing Service WITH cache')
|
||||
type_type = 'TestService2'
|
||||
type_type = 'TestService1'
|
||||
type_description = _('Testing (and dummy) service with CACHE and PUBLICATION')
|
||||
icon_file = 'provider.png' # : We reuse provider icon here :-), it's just for testing purpuoses
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -66,6 +66,8 @@ gettext("Change owner");
|
||||
gettext("Assign service");
|
||||
gettext("Cancel");
|
||||
gettext("Changelog");
|
||||
gettext("Fallback: Allow");
|
||||
gettext("Fallback: Deny");
|
||||
gettext("Delete assigned service");
|
||||
gettext("Delete cached service");
|
||||
gettext("Service pool is locked");
|
||||
|
||||
@@ -102,6 +102,6 @@
|
||||
</svg>
|
||||
</div>
|
||||
</uds-root>
|
||||
<link rel="modulepreload" href="/uds/res/admin/chunk-2F3F2YC2.js?stamp=1759963950" integrity="sha384-VVOra5xy5Xg9fYkBmK9MLhX7vif/MexRAaLIDBsQ4ZlkF31s/U6uWWrj+LAnvX/q"><script src="/uds/res/admin/polyfills.js?stamp=1759963950" type="module" crossorigin="anonymous" integrity="sha384-TVRkn44wOGJBeCKWJBHWLvXubZ+Julj/yA0OoEFa3LgJHVHaPeeATX6NcjuNgsIA"></script><script src="/uds/res/admin/main.js?stamp=1759963950" type="module" crossorigin="anonymous" integrity="sha384-U5YvgqT/kogPnBfS+4CA2tP+ICQhuQPuL6EnLHKwmcFpPQBI4Ndzwqiuv4gm5YCr"></script></body>
|
||||
<link rel="modulepreload" href="/uds/res/admin/chunk-2F3F2YC2.js?stamp=1762813175" integrity="sha384-VVOra5xy5Xg9fYkBmK9MLhX7vif/MexRAaLIDBsQ4ZlkF31s/U6uWWrj+LAnvX/q"><script src="/uds/res/admin/polyfills.js?stamp=1762813175" type="module" crossorigin="anonymous" integrity="sha384-TVRkn44wOGJBeCKWJBHWLvXubZ+Julj/yA0OoEFa3LgJHVHaPeeATX6NcjuNgsIA"></script><script src="/uds/res/admin/main.js?stamp=1762813175" type="module" crossorigin="anonymous" integrity="sha384-4RLNNtXgmAW+M5Uzni5JSekmZJ6MDtfZR2d7Cw0nQECcN+XAYbAUnTZrGx/sEc8N"></script></body>
|
||||
|
||||
</html>
|
||||
|
||||
54
server/src/uds/transports/RDP/scripts/linux/direct.js
Normal file
54
server/src/uds/transports/RDP/scripts/linux/direct.js
Normal file
@@ -0,0 +1,54 @@
|
||||
'use strict';
|
||||
|
||||
// Info for lintering tools about the variables provided by uds client
|
||||
var Process, Logger, File, Utils, Tasks;
|
||||
|
||||
// We receive data in "data" variable, which is an object from json readonly
|
||||
var data;
|
||||
|
||||
|
||||
const errorString = `You need to have xfreerdp or Thincast installed and in path for this to work.</p>
|
||||
<p>Please, install the proper package for your system.</p>
|
||||
<ul>
|
||||
<li>xfreerdp: <a href="https://github.com/FreeRDP/FreeRDP">Download</a></li>
|
||||
<li>Thincast: <a href="https://thincast.com/en/products/client">Download</a></li>
|
||||
</ul>`;
|
||||
|
||||
// Try, in order of preference, to find other RDP clients
|
||||
const executablePath =
|
||||
Process.findExecutable('udsrdp') ||
|
||||
Process.findExecutable('thincast-remote-desktop-client') ||
|
||||
Process.findExecutable('thincast-client') ||
|
||||
Process.findExecutable('thincast') ||
|
||||
Process.findExecutable('xfreerdp3') ||
|
||||
Process.findExecutable('xfreerdp') ||
|
||||
Process.findExecutable('xfreerdp');
|
||||
|
||||
if (!executablePath) {
|
||||
Logger.error('No RDP client found on system');
|
||||
throw new Error(errorString);
|
||||
}
|
||||
|
||||
// using Utils.expandVars, expand variables of data.freerdp_params (that is an array of strings)
|
||||
let parameters = data.freerdp_params.map((param) => Utils.expandVars(param));
|
||||
|
||||
let process = null;
|
||||
|
||||
// If has the as_file property, create the temp file on home folder and use it
|
||||
if (data.as_file) {
|
||||
Logger.debug('Has as_file property, creating temp RDP file');
|
||||
// Create and save the temp file
|
||||
rdpFilePath = File.createTempFile(File.getHomeDirectory(), data.as_file, '.rdp');
|
||||
Logger.debug(`RDP temp file created at ${rdpFilePath}`);
|
||||
|
||||
// Append to removable task to delete the file later
|
||||
Tasks.addEarlyUnlinkableFile(rdpFilePath);
|
||||
let password = data.password ? `/p:${data.password}` : '/p:';
|
||||
// Launch the RDP client with the temp file
|
||||
process = Process.launch(executablePath, [rdpFilePath, password]); // the addres in INSIDE the file is already set to
|
||||
} else {
|
||||
// Launch the RDP client with the parameters
|
||||
process = Process.launch(executablePath, [...parameters, `/v:${data.address}`]);
|
||||
}
|
||||
|
||||
Tasks.addWaitableApp(process);
|
||||
64
server/src/uds/transports/RDP/scripts/linux/tunnel.js
Normal file
64
server/src/uds/transports/RDP/scripts/linux/tunnel.js
Normal file
@@ -0,0 +1,64 @@
|
||||
'use strict';
|
||||
|
||||
// Info for lintering tools about the variables provided by uds client
|
||||
var Process, Logger, File, Utils, Tasks;
|
||||
|
||||
// We receive data in "data" variable, which is an object from json readonly
|
||||
var data;
|
||||
|
||||
|
||||
const errorString = `You need to have xfreerdp or Thincast installed and in path for this to work.</p>
|
||||
<p>Please, install the proper package for your system.</p>
|
||||
<ul>
|
||||
<li>xfreerdp: <a href="https://github.com/FreeRDP/FreeRDP">Download</a></li>
|
||||
<li>Thincast: <a href="https://thincast.com/en/products/client">Download</a></li>
|
||||
</ul>`;
|
||||
|
||||
// Try, in order of preference, to find other RDP clients
|
||||
const executablePath =
|
||||
Process.findExecutable('udsrdp') ||
|
||||
Process.findExecutable('thincast-remote-desktop-client') ||
|
||||
Process.findExecutable('thincast-client') ||
|
||||
Process.findExecutable('thincast') ||
|
||||
Process.findExecutable('xfreerdp3') ||
|
||||
Process.findExecutable('xfreerdp') ||
|
||||
Process.findExecutable('xfreerdp');
|
||||
|
||||
if (!executablePath) {
|
||||
Logger.error('No RDP client found on system');
|
||||
throw new Error('No RDP client found on system');
|
||||
}
|
||||
|
||||
// using Utils.expandVars, expand variables of data.freerdp_params (that is an array of strings)
|
||||
let parameters = data.freerdp_params.map((param) => Utils.expandVars(param));
|
||||
|
||||
let tunnel = null;
|
||||
try {
|
||||
tunnel = await Tasks.startTunnel(data.tunHost, data.tunPort, data.ticket, null, data.tunChk);
|
||||
} catch (error) {
|
||||
Logger.error(`Failed to start tunnel: ${error.message}`);
|
||||
throw new Error(`Failed to start tunnel: ${error.message}`);
|
||||
}
|
||||
|
||||
let process = null;
|
||||
|
||||
// If has the as_file property, create the temp file on home folder and use it
|
||||
if (data.as_file) {
|
||||
Logger.debug('Has as_file property, creating temp RDP file');
|
||||
// Replace "{address}" with data.address in the as_file content
|
||||
let content = data.as_file.replace(/\{address\}/g, `127.0.0.1:${tunnel.port}`);
|
||||
// Create and save the temp file
|
||||
rdpFilePath = File.createTempFile(File.getHomeDtartirectory(), content, '.rdp');
|
||||
Logger.debug(`RDP temp file created at ${rdpFilePath}`);
|
||||
|
||||
// Append to removable task to delete the file later
|
||||
Tasks.addEarlyUnlinkableFile(rdpFilePath);
|
||||
let password = data.password ? `/p:${data.password}` : '';
|
||||
// Launch the RDP client with the temp file, the addres in INSIDE the file is already set to
|
||||
process = Process.launch(executablePath, [rdpFilePath, password]);
|
||||
} else {
|
||||
// Launch the RDP client with the parameters
|
||||
process = Process.launch(executablePath, [...parameters, `/v:127.0.0.1:${tunnel.port}`]);
|
||||
}
|
||||
|
||||
Tasks.addWaitableApp(process);
|
||||
@@ -124,7 +124,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
|
||||
# logoutData = {
|
||||
# 'token': 'server token', # Must be present on all events
|
||||
# 'type': 'login', # MUST BE PRESENT
|
||||
# 'user_service': 'uuid', # MUST BE PRESENT
|
||||
# 'userservice_uuid': 'uuid', # MUST BE PRESENT
|
||||
# 'username': 'username', # Optional
|
||||
# }
|
||||
response = self.client.rest_post(
|
||||
@@ -132,7 +132,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
|
||||
data={
|
||||
'token': self.server.token,
|
||||
'type': 'logout',
|
||||
'user_service': self.user_service_managed.uuid,
|
||||
'userservice_uuid': self.user_service_managed.uuid,
|
||||
'username': 'local_user_name',
|
||||
'session_id': '',
|
||||
},
|
||||
@@ -146,7 +146,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
|
||||
data={
|
||||
'token': self.server.token,
|
||||
'type': 'login',
|
||||
'user_service': 'invalid uuid',
|
||||
'userservice_uuid': 'invalid uuid',
|
||||
'username': 'local_user_name',
|
||||
},
|
||||
)
|
||||
|
||||
0
server/tests/services/openshift/__init__.py
Normal file
0
server/tests/services/openshift/__init__.py
Normal file
348
server/tests/services/openshift/fixtures.py
Normal file
348
server/tests/services/openshift/fixtures.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Test fixtures for OpenShift service tests.
|
||||
Provides reusable functions and mock objects for unit testing OpenShift provider, service, deployment, publication, and user service logic.
|
||||
All functions are designed to be used across multiple test modules for consistency and maintainability.
|
||||
"""
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import random
|
||||
import typing
|
||||
|
||||
from unittest import mock
|
||||
import uuid
|
||||
|
||||
|
||||
from uds.core import environment
|
||||
from uds.core.ui.user_interface import gui
|
||||
from uds.models.user import User
|
||||
|
||||
from uds.services.OpenShift import service, service_fixed, provider, publication, deployment, deployment_fixed
|
||||
from uds.services.OpenShift.openshift import types as openshift_types, exceptions as openshift_exceptions
|
||||
|
||||
DEF_VMS: list[openshift_types.VM] = [
|
||||
openshift_types.VM(
|
||||
name=f'vm-{i}',
|
||||
namespace='default',
|
||||
uid=f'uid-{i}',
|
||||
status=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
|
||||
volume_template=openshift_types.VolumeTemplate(name=f'volume-{i}', storage='10Gi'),
|
||||
disks=[openshift_types.DeviceDisk(name=f'disk-{i}', boot_order=1)],
|
||||
volumes=[openshift_types.Volume(name=f'volume-{i}', data_volume=f'dv-{i}')],
|
||||
)
|
||||
for i in range(1, 11)
|
||||
]
|
||||
DEF_VM_INSTANCES: list[openshift_types.VMInstance] = [
|
||||
openshift_types.VMInstance(
|
||||
name=f'vm-{i}',
|
||||
namespace='default',
|
||||
uid=f'uid-instance-{i}',
|
||||
interfaces=[
|
||||
openshift_types.Interface(
|
||||
name='eth0',
|
||||
mac_address=f'00:11:22:33:44:{i:02x}',
|
||||
ip_address=f'192.168.1.{i}',
|
||||
)
|
||||
],
|
||||
status=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
|
||||
phase=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
|
||||
)
|
||||
for i in range(1, 11)
|
||||
]
|
||||
|
||||
# clone values to avoid modifying the original ones
|
||||
VMS: list[openshift_types.VM] = copy.deepcopy(DEF_VMS)
|
||||
VM_INSTANCES: list[openshift_types.VMInstance] = copy.deepcopy(DEF_VM_INSTANCES)
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""
|
||||
Reset all VM and VMInstance values to their default state.
|
||||
Use this before each test to ensure a clean environment.
|
||||
"""
|
||||
VMS[:] = copy.deepcopy(DEF_VMS)
|
||||
VM_INSTANCES[:] = copy.deepcopy(DEF_VM_INSTANCES)
|
||||
|
||||
|
||||
def replace_vm_info(vm_name: str, **kwargs: typing.Any) -> None:
|
||||
"""
|
||||
Update attributes of a VM in VMS by name.
|
||||
Raises OpenshiftNotFoundError if VM is not found.
|
||||
"""
|
||||
try:
|
||||
vm = next(vm for vm in VMS if vm.name == vm_name)
|
||||
for k, v in kwargs.items():
|
||||
setattr(vm, k, v)
|
||||
except Exception:
|
||||
raise openshift_exceptions.OpenshiftNotFoundError(f'VM {vm_name} not found')
|
||||
|
||||
|
||||
def replacer_vm_info(**kwargs: typing.Any) -> typing.Callable[..., None]:
|
||||
"""
|
||||
Returns a partial function to update VM info with preset kwargs.
|
||||
Useful for patching or repeated updates in tests.
|
||||
"""
|
||||
return functools.partial(replace_vm_info, **kwargs)
|
||||
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
|
||||
|
||||
def returner(value: T, *args: typing.Any, **kwargs: typing.Any) -> typing.Callable[..., T]:
|
||||
"""
|
||||
Returns a function that always returns the given value.
|
||||
Useful for mocking return values in tests.
|
||||
"""
|
||||
def inner(*args: typing.Any, **kwargs: typing.Any) -> T:
|
||||
return value
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# Provider values
|
||||
PROVIDER_VALUES_DICT: gui.ValuesDictType = {
|
||||
'cluster_url': 'https://oauth-openshift.apps-crc.testing',
|
||||
'api_url': 'https://api.crc.testing:6443',
|
||||
'username': 'kubeadmin',
|
||||
'password': 'test-password',
|
||||
'namespace': 'default',
|
||||
'verify_ssl': False,
|
||||
'concurrent_creation_limit': 1,
|
||||
'concurrent_removal_limit': 1,
|
||||
'timeout': 10,
|
||||
}
|
||||
|
||||
# Service values
|
||||
SERVICE_VALUES_DICT: gui.ValuesDictType = {
|
||||
'template': VMS[0].name,
|
||||
'basename': 'base',
|
||||
'lenname': 4,
|
||||
'publication_timeout': 120,
|
||||
'prov_uuid': '',
|
||||
}
|
||||
|
||||
# Service fixed values
|
||||
SERVICE_FIXED_VALUES_DICT: gui.ValuesDictType = {
|
||||
'token': '',
|
||||
'machines': [VMS[2].name, VMS[3].name, VMS[4].name],
|
||||
'on_logout': 'no',
|
||||
'randomize': False,
|
||||
'maintain_on_error': False,
|
||||
'prov_uuid': '',
|
||||
}
|
||||
|
||||
|
||||
def create_client_mock() -> mock.Mock:
|
||||
"""
|
||||
Create a MagicMock for OpenshiftClient with default behaviors and side effects.
|
||||
Used to simulate API responses in provider/service tests.
|
||||
"""
|
||||
client = mock.MagicMock()
|
||||
|
||||
# Prepare deep copies of default data
|
||||
client.test.return_value = True
|
||||
client.list_vms.return_value = copy.deepcopy(DEF_VMS)
|
||||
client.start_vm_instance.return_value = True
|
||||
client.stop_vm_instance.return_value = True
|
||||
client.delete_vm_instance.return_value = True
|
||||
client.get_datavolume_phase.return_value = "Succeeded"
|
||||
client.get_vm_pvc_or_dv_name.return_value = ("test-pvc", "pvc")
|
||||
client.get_pvc_size.return_value = "10Gi"
|
||||
client.create_vm_from_pvc.return_value = True
|
||||
client.wait_for_datavolume_clone_progress.return_value = True
|
||||
|
||||
def get_vm_info_side_effect(vm_name: str, **kwargs: typing.Any) -> openshift_types.VM | None:
|
||||
for vm in VMS:
|
||||
if vm.name == vm_name:
|
||||
return vm
|
||||
return None
|
||||
|
||||
def get_vm_instance_info_side_effect(vm_name: str, **kwargs: typing.Any) -> openshift_types.VMInstance | None:
|
||||
for inst in VM_INSTANCES:
|
||||
if inst.name == vm_name:
|
||||
return inst
|
||||
return None
|
||||
|
||||
client.get_vm_info.side_effect = get_vm_info_side_effect
|
||||
client.get_vm_instance_info.side_effect = get_vm_instance_info_side_effect
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patched_provider(**kwargs: typing.Any) -> typing.Generator[provider.OpenshiftProvider, None, None]:
|
||||
"""
|
||||
Context manager that yields a provider with a patched OpenshiftClient mock.
|
||||
Use this to ensure all API calls are intercepted and controlled in tests.
|
||||
"""
|
||||
client = create_client_mock()
|
||||
prov = create_provider(**kwargs)
|
||||
prov._cached_api = client
|
||||
yield prov
|
||||
|
||||
|
||||
def create_provider(**kwargs: typing.Any) -> provider.OpenshiftProvider:
|
||||
"""
|
||||
Create an OpenshiftProvider instance with default or overridden values.
|
||||
Used for provider-level tests and as a dependency for other fixtures.
|
||||
"""
|
||||
values = PROVIDER_VALUES_DICT.copy()
|
||||
values.update(kwargs)
|
||||
|
||||
uuid_ = str(uuid.uuid4())
|
||||
return provider.OpenshiftProvider(
|
||||
environment=environment.Environment.private_environment(uuid_), values=values, uuid=uuid_
|
||||
)
|
||||
|
||||
|
||||
def create_service(
|
||||
provider: typing.Optional[provider.OpenshiftProvider] = None, **kwargs: typing.Any
|
||||
) -> service.OpenshiftService:
|
||||
"""
|
||||
Create an OpenshiftService instance (dynamic service).
|
||||
Used for service-level tests and as a dependency for user services and publications.
|
||||
"""
|
||||
uuid_ = str(uuid.uuid4())
|
||||
values = SERVICE_VALUES_DICT.copy()
|
||||
values.update(kwargs)
|
||||
srvc = service.OpenshiftService(
|
||||
environment=environment.Environment.private_environment(uuid_),
|
||||
provider=provider or create_provider(),
|
||||
values=values,
|
||||
uuid=uuid_,
|
||||
)
|
||||
return srvc
|
||||
|
||||
|
||||
def create_service_fixed(
|
||||
provider: typing.Optional[provider.OpenshiftProvider] = None, **kwargs: typing.Any
|
||||
) -> service_fixed.OpenshiftServiceFixed:
|
||||
"""
|
||||
Create an OpenshiftServiceFixed instance (fixed service).
|
||||
Used for fixed service tests and as a dependency for fixed user services.
|
||||
"""
|
||||
uuid_ = str(uuid.uuid4())
|
||||
values = SERVICE_FIXED_VALUES_DICT.copy()
|
||||
values.update(kwargs)
|
||||
return service_fixed.OpenshiftServiceFixed(
|
||||
environment=environment.Environment.private_environment(uuid_),
|
||||
provider=provider or create_provider(),
|
||||
values=values,
|
||||
uuid=uuid_,
|
||||
)
|
||||
|
||||
|
||||
def create_publication(
|
||||
service: typing.Optional[service.OpenshiftService] = None,
|
||||
**kwargs: typing.Any,
|
||||
) -> publication.OpenshiftTemplatePublication:
|
||||
"""
|
||||
Create an OpenshiftTemplatePublication instance.
|
||||
Used for publication-level tests and as a dependency for user services.
|
||||
"""
|
||||
uuid_ = str(uuid.uuid4())
|
||||
pub = publication.OpenshiftTemplatePublication(
|
||||
environment=environment.Environment.private_environment(uuid_),
|
||||
service=service or create_service(**kwargs),
|
||||
revision=1,
|
||||
servicepool_name='servicepool_name',
|
||||
uuid=uuid_,
|
||||
)
|
||||
pub._name = f"pub-{random.randint(1000, 9999)}"
|
||||
return pub
|
||||
|
||||
|
||||
def create_userservice(
|
||||
service: typing.Optional[service.OpenshiftService] = None,
|
||||
publication: typing.Optional[publication.OpenshiftTemplatePublication] = None,
|
||||
) -> deployment.OpenshiftUserService:
|
||||
"""
|
||||
Create an OpenshiftUserService instance (dynamic user service).
|
||||
Used for user service tests that require a publication and service.
|
||||
"""
|
||||
uuid_ = str(uuid.uuid4())
|
||||
return deployment.OpenshiftUserService(
|
||||
environment=environment.Environment.private_environment(uuid_),
|
||||
service=service or create_service(),
|
||||
publication=publication or create_publication(),
|
||||
uuid=uuid_,
|
||||
)
|
||||
|
||||
|
||||
def create_userservice_fixed(
|
||||
service: typing.Optional[service_fixed.OpenshiftServiceFixed] = None,
|
||||
) -> deployment_fixed.OpenshiftUserServiceFixed:
|
||||
"""
|
||||
Create an OpenshiftUserServiceFixed instance (fixed user service).
|
||||
Used for tests of fixed user service logic and lifecycle.
|
||||
"""
|
||||
uuid_ = str(uuid.uuid4().hex)
|
||||
return deployment_fixed.OpenshiftUserServiceFixed(
|
||||
environment=environment.Environment.private_environment(uuid_),
|
||||
service=service or create_service_fixed(),
|
||||
publication=None,
|
||||
uuid=uuid_,
|
||||
)
|
||||
|
||||
|
||||
def create_user(
|
||||
name: str = "testuser",
|
||||
real_name: str = "Test User",
|
||||
is_admin: bool = False,
|
||||
state: str = 'A',
|
||||
password: str = 'password',
|
||||
mfa_data: str = '',
|
||||
staff_member: bool = False,
|
||||
last_access: typing.Optional[str] = None,
|
||||
parent: typing.Optional[User] = None,
|
||||
created: typing.Optional[str] = None,
|
||||
comments: str = '',
|
||||
) -> User:
|
||||
"""
|
||||
Create a mock User instance for testing.
|
||||
All fields can be customized for specific test scenarios.
|
||||
"""
|
||||
user = mock.Mock(spec=User)
|
||||
user.name = name
|
||||
user.real_name = real_name
|
||||
user.is_admin = is_admin
|
||||
user.state = state
|
||||
user.password = password
|
||||
user.mfa_data = mfa_data
|
||||
user.staff_member = staff_member
|
||||
user.last_access = last_access
|
||||
user.parent = parent
|
||||
user.created = created
|
||||
user.comments = comments
|
||||
return user
|
||||
189
server/tests/services/openshift/test_client.py
Normal file
189
server/tests/services/openshift/test_client.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from uds.services.OpenShift.openshift import client as openshift_client
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
from tests.utils import vars
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestOpenshiftClient(UDSTransactionTestCase):
|
||||
"""Tests for operations with OpenShiftClient."""
|
||||
|
||||
os_client: openshift_client.OpenshiftClient
|
||||
test_vm: str = ''
|
||||
test_pool: str = ''
|
||||
test_storage: str = ''
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up OpenShift client and test variables for each test.
|
||||
Skips tests if required variables are missing.
|
||||
"""
|
||||
v = vars.get_vars('openshift')
|
||||
if not v:
|
||||
self.skipTest('No OpenShift test variables found')
|
||||
self.os_client = openshift_client.OpenshiftClient(
|
||||
cluster_url=v['cluster_url'],
|
||||
api_url=v['api_url'],
|
||||
username=v['username'],
|
||||
password=v['password'],
|
||||
namespace=v['namespace'],
|
||||
timeout=int(v['timeout']),
|
||||
verify_ssl=v['verify_ssl'] == 'true',
|
||||
)
|
||||
self.test_vm = v.get('test_vm', '')
|
||||
self.test_pool = v.get('test_pool', '')
|
||||
self.test_storage = v.get('test_storage', '')
|
||||
|
||||
# --- Token/API Tests ---
|
||||
def test_get_token(self) -> None:
|
||||
"""
|
||||
Test that get_token returns a valid token string.
|
||||
"""
|
||||
token = self.os_client.get_token()
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
def test_get_api_url(self) -> None:
|
||||
"""
|
||||
Test that get_api_url constructs a valid URL with path and parameters.
|
||||
"""
|
||||
url = self.os_client.get_api_url('/test/path', ('param1', 'value1'))
|
||||
self.assertIn('/test/path', url)
|
||||
self.assertIn('param1=value1', url)
|
||||
|
||||
def test_get_api_url_invalid(self):
|
||||
"""
|
||||
Test that get_api_url works with an invalid path.
|
||||
"""
|
||||
url = self.os_client.get_api_url('/invalid/path', ('param', 'value'))
|
||||
self.assertIn('/invalid/path', url)
|
||||
|
||||
# --- VM Listing/Info Tests ---
|
||||
def test_list_vms(self) -> None:
|
||||
"""
|
||||
Test that list_vms returns a list and get_vm_info works for listed VMs.
|
||||
"""
|
||||
vms = self.os_client.list_vms()
|
||||
self.assertIsInstance(vms, list)
|
||||
if vms:
|
||||
info = self.os_client.get_vm_info(vms[0].name)
|
||||
self.assertIsNotNone(info)
|
||||
|
||||
def test_list_vms_and_check_fields(self):
|
||||
"""
|
||||
Test that all VMs returned by list_vms have required fields.
|
||||
"""
|
||||
vms = self.os_client.list_vms()
|
||||
self.assertIsInstance(vms, list)
|
||||
for vm in vms:
|
||||
self.assertTrue(hasattr(vm, 'name'))
|
||||
self.assertTrue(hasattr(vm, 'namespace'))
|
||||
|
||||
def test_get_vm_info(self):
|
||||
"""
|
||||
Test that get_vm_info returns info for a valid VM name.
|
||||
"""
|
||||
if not self.test_vm:
|
||||
self.skipTest('No test_vm specified')
|
||||
info = self.os_client.get_vm_info(self.test_vm)
|
||||
self.assertIsNotNone(info)
|
||||
|
||||
def test_get_vm_info_invalid(self):
|
||||
"""
|
||||
Test that get_vm_info returns None for an invalid VM name.
|
||||
"""
|
||||
info = self.os_client.get_vm_info('nonexistent-vm')
|
||||
self.assertIsNone(info)
|
||||
|
||||
def test_get_vm_instance_info(self):
|
||||
"""
|
||||
Test that get_vm_instance_info returns info or None for a valid VM name.
|
||||
"""
|
||||
if not self.test_vm:
|
||||
self.skipTest('No test_vm specified')
|
||||
info = self.os_client.get_vm_instance_info(self.test_vm)
|
||||
self.assertTrue(info is None or hasattr(info, 'name'))
|
||||
|
||||
def test_get_vm_instance_info_invalid(self):
|
||||
"""
|
||||
Test that get_vm_instance_info returns None for an invalid VM name.
|
||||
"""
|
||||
info = self.os_client.get_vm_instance_info('nonexistent-vm')
|
||||
self.assertIsNone(info)
|
||||
|
||||
# --- VM Lifecycle and Actions ---
|
||||
def test_vm_lifecycle(self) -> None:
|
||||
"""
|
||||
Test VM lifecycle actions: start, stop, delete (skipped in shared environments).
|
||||
"""
|
||||
self.skipTest('Skip this test to avoid issues in shared environments')
|
||||
if not self.test_vm:
|
||||
self.skipTest('No test_vm specified in test-vars.ini')
|
||||
self.assertTrue(self.os_client.start_vm_instance(self.test_vm))
|
||||
self.assertTrue(self.os_client.stop_vm_instance(self.test_vm))
|
||||
self.assertTrue(self.os_client.delete_vm_instance(self.test_vm))
|
||||
|
||||
def test_start_stop_suspend_resume_vm(self):
|
||||
"""
|
||||
Test stop (and optionally start) VM instance. Suspend/resume skipped if not supported.
|
||||
"""
|
||||
if not self.test_vm:
|
||||
self.skipTest('No test_vm specified')
|
||||
#self.assertTrue(self.os_client.start_vm_instance(self.test_vm))
|
||||
self.assertTrue(self.os_client.stop_vm_instance(self.test_vm))
|
||||
# Suspend/resume skipped if not supported
|
||||
|
||||
def test_delete_vm_invalid(self):
|
||||
"""
|
||||
Test that delete_vm_instance returns False for an invalid VM name.
|
||||
"""
|
||||
self.assertFalse(self.os_client.delete_vm_instance('nonexistent-vm'))
|
||||
|
||||
# --- DataVolume Tests ---
|
||||
# --- DataVolume Tests ---
|
||||
def test_datavolume_phase(self) -> None:
|
||||
"""
|
||||
Test that get_datavolume_phase returns a string for a valid datavolume.
|
||||
"""
|
||||
phase = self.os_client.get_datavolume_phase('test-dv')
|
||||
self.assertIsInstance(phase, str)
|
||||
|
||||
def test_datavolume_phase_invalid(self):
|
||||
"""
|
||||
Test that get_datavolume_phase returns a string for an invalid datavolume.
|
||||
"""
|
||||
phase = self.os_client.get_datavolume_phase('nonexistent-dv')
|
||||
self.assertIsInstance(phase, str)
|
||||
163
server/tests/services/openshift/test_deployment.py
Normal file
163
server/tests/services/openshift/test_deployment.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from tests.services.openshift import fixtures
|
||||
from uds.core.types.states import TaskState
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftDeployment(UDSTransactionTestCase):
|
||||
def _create_userservice(self):
|
||||
"""
|
||||
Helper to create a userservice instance with a preset name for deployment operation tests.
|
||||
Returns:
|
||||
userservice: A userservice object with name 'test-vm'.
|
||||
"""
|
||||
userservice = fixtures.create_userservice()
|
||||
userservice._name = 'test-vm'
|
||||
return userservice
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
# --- Create operation tests ---
|
||||
def test_op_create_success(self) -> None:
|
||||
"""
|
||||
Test successful VM creation operation.
|
||||
Should clear the waiting_name flag after creation.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
userservice._waiting_name = False
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_vm_pvc_or_dv_name', return_value=('test-pvc', 'pvc')), \
|
||||
mock.patch.object(api, 'get_pvc_size', return_value='10Gi'), \
|
||||
mock.patch.object(api, 'create_vm_from_pvc', return_value=True), \
|
||||
mock.patch.object(api, 'wait_for_datavolume_clone_progress', return_value=True):
|
||||
userservice.op_create()
|
||||
self.assertFalse(userservice._waiting_name)
|
||||
|
||||
def test_op_create_failure(self) -> None:
|
||||
"""
|
||||
Test failed VM creation operation.
|
||||
Should set the waiting_name flag if creation fails.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
userservice._waiting_name = False
|
||||
with mock.patch.object(api, 'get_vm_pvc_or_dv_name', return_value=('test-pvc', 'pvc')), \
|
||||
mock.patch.object(api, 'get_pvc_size', return_value='10Gi'), \
|
||||
mock.patch.object(api, 'create_vm_from_pvc', return_value=False):
|
||||
userservice.op_create()
|
||||
self.assertTrue(userservice._waiting_name)
|
||||
|
||||
def test_op_create_checker_running(self) -> None:
|
||||
"""
|
||||
Test create checker returns RUNNING when datavolume phase is pending.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_datavolume_phase', return_value='Pending'):
|
||||
state = userservice.op_create_checker()
|
||||
self.assertEqual(state, TaskState.RUNNING)
|
||||
|
||||
def test_op_create_checker_finished(self) -> None:
|
||||
"""
|
||||
Test create checker returns FINISHED when datavolume phase is succeeded and VM info is available.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_datavolume_phase', return_value='Succeeded'), \
|
||||
mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]), \
|
||||
mock.patch.object(api, 'get_vm_instance_info', return_value=fixtures.VM_INSTANCES[0]):
|
||||
state = userservice.op_create_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
# --- Delete operation tests ---
|
||||
def test_op_delete_checker_finished(self) -> None:
|
||||
"""
|
||||
Test delete checker returns FINISHED when VM info is None (deleted).
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=None):
|
||||
state = userservice.op_delete_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
def test_op_delete_checker_running(self) -> None:
|
||||
"""
|
||||
Test delete checker returns RUNNING when VM info still exists.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]):
|
||||
state = userservice.op_delete_checker()
|
||||
self.assertEqual(state, TaskState.RUNNING)
|
||||
|
||||
def test_op_delete_completed_checker(self) -> None:
|
||||
"""
|
||||
Test delete completed checker always returns FINISHED.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
state = userservice.op_delete_completed_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
# --- Cancel operation tests ---
|
||||
def test_op_cancel_checker_finished(self) -> None:
|
||||
"""
|
||||
Test cancel checker returns FINISHED when VM info is None (cancelled).
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=None):
|
||||
state = userservice.op_cancel_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
def test_op_cancel_checker_running(self) -> None:
|
||||
"""
|
||||
Test cancel checker returns RUNNING when VM info still exists.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
api = userservice.service().api
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]):
|
||||
state = userservice.op_cancel_checker()
|
||||
self.assertEqual(state, TaskState.RUNNING)
|
||||
|
||||
def test_op_cancel_completed_checker(self) -> None:
|
||||
"""
|
||||
Test cancel completed checker always returns FINISHED.
|
||||
"""
|
||||
userservice = self._create_userservice()
|
||||
state = userservice.op_cancel_completed_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
163
server/tests/services/openshift/test_deployment_fixed.py
Normal file
163
server/tests/services/openshift/test_deployment_fixed.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
from uds.core.types.states import TaskState
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
class TestOpenshiftUserServiceFixed(UDSTransactionTestCase):
|
||||
def _create_userservice_fixed(self):
|
||||
"""
|
||||
Helper to create a fixed userservice instance for deployment_fixed operation tests.
|
||||
Returns:
|
||||
userservice: A fixed userservice object with name 'fixed-vm'.
|
||||
"""
|
||||
userservice = fixtures.create_userservice_fixed()
|
||||
userservice._name = 'fixed-vm'
|
||||
return userservice
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
# --- Start operation tests ---
|
||||
def test_op_start_vm_running(self) -> None:
|
||||
"""
|
||||
Test that op_start does not start VM if it is already running.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
vm_mock.status.is_off.return_value = False
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
with mock.patch.object(api, 'start_vm_instance') as start_mock:
|
||||
userservice.op_start()
|
||||
start_mock.assert_not_called()
|
||||
|
||||
def test_op_start_vm_off(self) -> None:
|
||||
"""
|
||||
Test that op_start starts VM if it is off.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
vm_mock.status.is_off.return_value = True
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
with mock.patch.object(api, 'start_vm_instance') as start_mock:
|
||||
userservice.op_start()
|
||||
start_mock.assert_called_once_with('fixed-vm')
|
||||
|
||||
# --- Stop operation tests ---
|
||||
def test_op_stop_vm_off(self) -> None:
|
||||
"""
|
||||
Test that op_stop does not stop VM if it is already off.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
vm_mock.status.is_off.return_value = True
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
with mock.patch.object(api, 'stop_vm_instance') as stop_mock:
|
||||
userservice.op_stop()
|
||||
stop_mock.assert_not_called()
|
||||
|
||||
def test_op_stop_vm_running(self) -> None:
|
||||
"""
|
||||
Test that op_stop stops VM if it is running.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
vm_mock.status.is_off.return_value = False
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
with mock.patch.object(api, 'stop_vm_instance') as stop_mock:
|
||||
userservice.op_stop()
|
||||
stop_mock.assert_called_once_with('fixed-vm')
|
||||
|
||||
# --- Start checker tests ---
|
||||
def test_op_start_checker_running(self) -> None:
|
||||
"""
|
||||
Test that op_start_checker returns RUNNING if VM status is not error.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
status_mock = mock.Mock()
|
||||
status_mock.is_error.return_value = False
|
||||
vm_mock = mock.Mock()
|
||||
vm_mock.status = status_mock
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
state = userservice.op_start_checker()
|
||||
self.assertEqual(state, TaskState.RUNNING)
|
||||
|
||||
def test_op_start_checker_finished(self) -> None:
|
||||
"""
|
||||
Test that op_start_checker returns FINISHED if VM status is RUNNING.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
from uds.services.OpenShift.openshift import types as opensh_types
|
||||
vm_mock.status = opensh_types.VMStatus.RUNNING
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
state = userservice.op_start_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
# --- Stop checker tests ---
|
||||
def test_op_stop_checker_running(self) -> None:
|
||||
"""
|
||||
Test that op_stop_checker returns RUNNING if VM status is not error.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
status_mock = mock.Mock()
|
||||
status_mock.is_error.return_value = False
|
||||
vm_mock.status = status_mock
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
state = userservice.op_stop_checker()
|
||||
self.assertEqual(state, TaskState.RUNNING)
|
||||
|
||||
def test_op_stop_checker_finished(self) -> None:
|
||||
"""
|
||||
Test that op_stop_checker returns FINISHED if VM status is STOPPED.
|
||||
"""
|
||||
userservice = self._create_userservice_fixed()
|
||||
api = userservice.service().provider().api
|
||||
vm_mock = mock.Mock()
|
||||
from uds.services.OpenShift.openshift import types as opensh_types
|
||||
vm_mock.status = opensh_types.VMStatus.STOPPED
|
||||
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
|
||||
state = userservice.op_stop_checker()
|
||||
self.assertEqual(state, TaskState.FINISHED)
|
||||
|
||||
|
||||
143
server/tests/services/openshift/test_provider.py
Normal file
143
server/tests/services/openshift/test_provider.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import typing
|
||||
from unittest import mock
|
||||
|
||||
from uds.core import types, ui, environment
|
||||
from uds.services.OpenShift.provider import OpenshiftProvider
|
||||
|
||||
from . import fixtures
|
||||
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftProvider(UDSTransactionTestCase):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up test environment and clear fixtures before each test.
|
||||
"""
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
# --- Provider Data Tests ---
|
||||
def test_provider_data(self) -> None:
|
||||
"""
|
||||
Test provider data fields and types for correct initialization.
|
||||
"""
|
||||
provider = fixtures.create_provider()
|
||||
self.assertEqual(provider.cluster_url.value, fixtures.PROVIDER_VALUES_DICT['cluster_url'])
|
||||
self.assertEqual(provider.api_url.value, fixtures.PROVIDER_VALUES_DICT['api_url'])
|
||||
self.assertEqual(provider.username.value, fixtures.PROVIDER_VALUES_DICT['username'])
|
||||
self.assertEqual(provider.password.value, fixtures.PROVIDER_VALUES_DICT['password'])
|
||||
self.assertEqual(provider.namespace.value, fixtures.PROVIDER_VALUES_DICT['namespace'])
|
||||
self.assertEqual(provider.verify_ssl.value, fixtures.PROVIDER_VALUES_DICT['verify_ssl'])
|
||||
if not isinstance(provider.concurrent_creation_limit, ui.gui.NumericField):
|
||||
self.fail('concurrent_creation_limit is not a NumericField')
|
||||
self.assertEqual(provider.concurrent_creation_limit.as_int(), fixtures.PROVIDER_VALUES_DICT['concurrent_creation_limit'])
|
||||
if not isinstance(provider.concurrent_removal_limit, ui.gui.NumericField):
|
||||
self.fail('concurrent_removal_limit is not a NumericField')
|
||||
self.assertEqual(provider.concurrent_removal_limit.as_int(), fixtures.PROVIDER_VALUES_DICT['concurrent_removal_limit'])
|
||||
self.assertEqual(provider.timeout.as_int(), fixtures.PROVIDER_VALUES_DICT['timeout'])
|
||||
|
||||
# --- Provider Test Method ---
|
||||
def test_provider_test(self) -> None:
|
||||
"""
|
||||
Test the static provider test method and test_connection logic.
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
for ret_val in [True, False]:
|
||||
api.test.reset_mock()
|
||||
api.test.return_value = ret_val
|
||||
# Patch test_connection to return ret_val for static test
|
||||
with mock.patch('uds.services.OpenShift.provider.OpenshiftProvider.test_connection', return_value=ret_val):
|
||||
result = OpenshiftProvider.test(environment.Environment.temporary_environment(), fixtures.PROVIDER_VALUES_DICT)
|
||||
self.assertIsInstance(result, types.core.TestResult)
|
||||
self.assertEqual(result.success, ret_val)
|
||||
self.assertIsInstance(result.error, str)
|
||||
# Ensure test_connection calls api.test
|
||||
provider.test_connection()
|
||||
api.test.assert_called_once_with()
|
||||
|
||||
# --- Provider Availability ---
|
||||
def test_provider_is_available(self) -> None:
|
||||
"""
|
||||
Test the provider is_available method and cache behavior.
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
# First, true result
|
||||
self.assertEqual(provider.is_available(), True)
|
||||
api.test.assert_called_once_with()
|
||||
api.test.reset_mock()
|
||||
# Now, even if set test to false, should return true due to cache
|
||||
api.test.return_value = False
|
||||
self.assertEqual(provider.is_available(), True)
|
||||
api.test.assert_not_called()
|
||||
# clear cache of method
|
||||
provider.is_available.cache_clear() # type: ignore # cache_clear() is added by decorator
|
||||
self.assertEqual(provider.is_available(), False)
|
||||
api.test.assert_called_once_with()
|
||||
|
||||
# --- Provider API Methods ---
|
||||
def test_provider_api_methods(self) -> None:
|
||||
"""
|
||||
Test provider API methods for VM operations and info retrieval.
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
self.assertEqual(provider.test_connection(), True)
|
||||
api.test.assert_called_once_with()
|
||||
self.assertEqual(provider.api.list_vms(), fixtures.VMS)
|
||||
self.assertEqual(provider.api.get_vm_info('vm-1'), fixtures.VMS[0])
|
||||
self.assertEqual(provider.api.get_vm_instance_info('vm-1'), fixtures.VM_INSTANCES[0])
|
||||
self.assertTrue(provider.api.start_vm_instance('vm-1'))
|
||||
self.assertTrue(provider.api.stop_vm_instance('vm-1'))
|
||||
self.assertTrue(provider.api.delete_vm_instance('vm-1'))
|
||||
|
||||
# --- Name Sanitization ---
|
||||
def test_sanitized_name(self) -> None:
|
||||
"""
|
||||
Test name sanitization utility for various input cases.
|
||||
"""
|
||||
provider = fixtures.create_provider()
|
||||
test_cases = [
|
||||
('Test-VM-1', 'test-vm-1'),
|
||||
('Test_VM@2', 'test-vm-2'),
|
||||
('My Test VM!!!', 'my-test-vm'),
|
||||
('Test !!! this is', 'test-this-is'),
|
||||
('UDS-Pub-Hello World!!--2025065122-v1', 'uds-pub-hello-world-2025065122-v1'),
|
||||
('a' * 100, 'a' * 63), # Test truncation
|
||||
]
|
||||
for input_name, expected in test_cases:
|
||||
self.assertEqual(provider.sanitized_name(input_name), expected)
|
||||
171
server/tests/services/openshift/test_publication.py
Normal file
171
server/tests/services/openshift/test_publication.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
import typing
|
||||
from unittest import mock
|
||||
|
||||
from uds.core import types
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftPublication(UDSTransactionTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
def test_op_create_and_checker(self) -> None:
|
||||
"""
|
||||
Test op_create and op_create_checker flow
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
service = fixtures.create_service(provider=provider)
|
||||
publication = fixtures.create_publication(service=service)
|
||||
|
||||
api.get_vm_pvc_or_dv_name.return_value = ('test-pvc', 'pvc')
|
||||
api.get_pvc_size.return_value = '10Gi'
|
||||
api.create_vm_from_pvc.return_value = True
|
||||
api.wait_for_datavolume_clone_progress.return_value = True
|
||||
api.get_vm_info.return_value = None
|
||||
|
||||
publication.op_create()
|
||||
api.get_vm_info.return_value = None
|
||||
state = publication.op_create_checker()
|
||||
self.assertEqual(state, types.states.TaskState.RUNNING)
|
||||
|
||||
def get_vm_info_side_effect(name: str) -> mock.Mock | None:
|
||||
return mock.Mock(status=mock.Mock()) if name == publication._name else None
|
||||
|
||||
api.get_vm_info.side_effect = get_vm_info_side_effect
|
||||
state = publication.op_create_checker()
|
||||
self.assertEqual(state, types.states.TaskState.FINISHED)
|
||||
|
||||
def test_op_create_completed_and_checker(self) -> None:
|
||||
"""
|
||||
Test op_create_completed and op_create_completed_checker flow
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
service = fixtures.create_service(provider=provider)
|
||||
publication = fixtures.create_publication(service=service)
|
||||
|
||||
# VM running
|
||||
running_status = mock.Mock()
|
||||
running_status.is_running.return_value = True
|
||||
running_vm = mock.Mock(status=running_status)
|
||||
|
||||
def get_vm_info_side_effect(name: str, **kwargs: dict[str, typing.Any]) -> mock.Mock | None:
|
||||
return running_vm if name == 'test-vm' else None
|
||||
|
||||
api.get_vm_info.side_effect = get_vm_info_side_effect
|
||||
publication._name = 'test-vm'
|
||||
publication.op_create_completed()
|
||||
api.stop_vm_instance.assert_called_with('test-vm')
|
||||
|
||||
# VM stopped
|
||||
stopped_status = mock.Mock()
|
||||
stopped_status.is_running.return_value = False
|
||||
stopped_vm = mock.Mock(status=stopped_status)
|
||||
|
||||
api.get_vm_info.side_effect = None
|
||||
api.get_vm_info.return_value = stopped_vm
|
||||
api.stop_vm_instance.reset_mock()
|
||||
publication.op_create_completed()
|
||||
api.stop_vm_instance.assert_not_called()
|
||||
|
||||
# Checker: VM not found
|
||||
api.get_vm_info.return_value = None
|
||||
state = publication.op_create_completed_checker()
|
||||
self.assertEqual(state, types.states.TaskState.FINISHED)
|
||||
|
||||
# Checker: VM stopped
|
||||
api.get_vm_info.return_value = stopped_vm
|
||||
state = publication.op_create_completed_checker()
|
||||
self.assertEqual(state, types.states.TaskState.FINISHED)
|
||||
|
||||
# Checker: VM running
|
||||
api.get_vm_info.return_value = running_vm
|
||||
state = publication.op_create_completed_checker()
|
||||
self.assertEqual(state, types.states.TaskState.RUNNING)
|
||||
|
||||
def test_publication_create(self) -> None:
|
||||
"""
|
||||
Test publication creation (publish)
|
||||
"""
|
||||
with fixtures.patched_provider() as provider:
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
service = fixtures.create_service(provider=provider)
|
||||
publication = fixtures.create_publication(service=service)
|
||||
|
||||
api.get_vm_pvc_or_dv_name.return_value = ('test-pvc', 'pvc')
|
||||
api.get_pvc_size.return_value = '10Gi'
|
||||
api.create_vm_from_pvc.return_value = True
|
||||
api.wait_for_datavolume_clone_progress.return_value = True
|
||||
|
||||
call_count = {"count": 0}
|
||||
def vm_info_side_effect(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
||||
if call_count["count"] < 2:
|
||||
call_count["count"] += 1
|
||||
return fixtures.VMS[0]
|
||||
|
||||
ready_vm = mock.Mock()
|
||||
ready_vm.status = mock.Mock()
|
||||
ready_vm.name = publication._name
|
||||
return ready_vm
|
||||
api.get_vm_info.side_effect = vm_info_side_effect
|
||||
|
||||
state = publication.publish()
|
||||
self.assertEqual(state, types.states.State.RUNNING)
|
||||
|
||||
state = publication.check_state()
|
||||
api.get_vm_pvc_or_dv_name.assert_called()
|
||||
api.get_pvc_size.assert_called()
|
||||
api.create_vm_from_pvc.assert_called()
|
||||
|
||||
for _ in range(10):
|
||||
state = publication.check_state()
|
||||
if state == types.states.TaskState.FINISHED:
|
||||
break
|
||||
self.assertEqual(state, types.states.TaskState.RUNNING)
|
||||
self.assertEqual(publication.get_template_id(), publication._name)
|
||||
|
||||
def test_get_template_id(self) -> None:
|
||||
"""
|
||||
Test template ID retrieval (get_template_id)
|
||||
"""
|
||||
service = fixtures.create_service()
|
||||
publication = fixtures.create_publication(service=service)
|
||||
publication._name = 'test-template'
|
||||
template_id = publication.get_template_id()
|
||||
self.assertEqual(template_id, 'test-template')
|
||||
@@ -0,0 +1,64 @@
|
||||
import pickle
|
||||
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftDeploymentSerialization(UDSTransactionTestCase):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up test environment and clear fixtures before each test.
|
||||
"""
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
def _make_userservice(self):
|
||||
"""
|
||||
Helper to create a userservice with all fields set for serialization tests.
|
||||
"""
|
||||
userservice = fixtures.create_userservice()
|
||||
userservice._name = 'test-vm'
|
||||
userservice._ip = '192.168.1.100'
|
||||
userservice._mac = '00:11:22:33:44:55'
|
||||
userservice._vmid = 'test-vm-id'
|
||||
userservice._reason = 'test-reason'
|
||||
userservice._waiting_name = True
|
||||
return userservice
|
||||
|
||||
# --- Serialization Tests ---
|
||||
def test_userservice_serialization(self) -> None:
|
||||
"""
|
||||
Test that userservice object is correctly serialized and deserialized with all fields preserved.
|
||||
"""
|
||||
userservice = self._make_userservice()
|
||||
data = pickle.dumps(userservice)
|
||||
userservice2 = pickle.loads(data)
|
||||
|
||||
self.assertEqual(userservice2._name, 'test-vm')
|
||||
self.assertEqual(userservice2._ip, '192.168.1.100')
|
||||
self.assertEqual(userservice2._mac, '00:11:22:33:44:55')
|
||||
self.assertEqual(userservice2._vmid, 'test-vm-id')
|
||||
self.assertEqual(userservice2._reason, 'test-reason')
|
||||
self.assertTrue(userservice2._waiting_name)
|
||||
|
||||
def test_userservice_methods_after_serialization(self) -> None:
|
||||
"""
|
||||
Test that userservice methods return correct values after serialization and deserialization.
|
||||
"""
|
||||
userservice = self._make_userservice()
|
||||
data = pickle.dumps(userservice)
|
||||
userservice2 = pickle.loads(data)
|
||||
|
||||
self.assertEqual(userservice2.get_name(), 'test-vm')
|
||||
self.assertEqual(userservice2.get_ip(), '192.168.1.100')
|
||||
self.assertEqual(userservice2._mac, '00:11:22:33:44:55')
|
||||
|
||||
# --- Field Presence Tests ---
|
||||
def test_autoserializable_fields(self) -> None:
|
||||
"""
|
||||
Test that all expected autoserializable fields are present in userservice object.
|
||||
"""
|
||||
userservice = self._make_userservice()
|
||||
expected = ['_name', '_ip', '_mac', '_vmid', '_reason', '_waiting_name']
|
||||
for field in expected:
|
||||
self.assertTrue(hasattr(userservice, field), f"Missing field: {field}")
|
||||
@@ -0,0 +1,24 @@
|
||||
import pickle
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
class TestOpenshiftUserServiceFixed(UDSTransactionTestCase):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up test environment and clear fixtures before each test.
|
||||
"""
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
# --- Serialization Tests ---
|
||||
def test_userservice_fixed_serialization(self) -> None:
|
||||
"""
|
||||
Test that userservice_fixed object is correctly serialized and deserialized with all fields preserved.
|
||||
"""
|
||||
userservice = fixtures.create_userservice_fixed()
|
||||
userservice._name = 'fixed-vm'
|
||||
userservice._reason = 'fixed-reason'
|
||||
data = pickle.dumps(userservice)
|
||||
userservice2 = pickle.loads(data)
|
||||
self.assertEqual(userservice2._name, 'fixed-vm')
|
||||
self.assertEqual(userservice2._reason, 'fixed-reason')
|
||||
@@ -0,0 +1,83 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
from uds.services.OpenShift.provider import OpenshiftProvider
|
||||
|
||||
PROVIDER_SERIALIZE_DATA = (
|
||||
'{'
|
||||
'"cluster_url": "https://oauth-openshift.apps-crc.testing", '
|
||||
'"api_url": "https://api.crc.testing:6443", '
|
||||
'"username": "kubeadmin", '
|
||||
'"password": "test-password", '
|
||||
'"namespace": "default", '
|
||||
'"verify_ssl": false, '
|
||||
'"concurrent_creation_limit": 1, '
|
||||
'"concurrent_removal_limit": 1, '
|
||||
'"timeout": 10'
|
||||
'}'
|
||||
)
|
||||
|
||||
class TestOpenshiftProviderSerialization(UDSTransactionTestCase):
|
||||
# --- Serialization Tests ---
|
||||
def test_provider_methods_after_serialization(self) -> None:
|
||||
"""
|
||||
Test that provider methods return correct values after serialization and deserialization.
|
||||
"""
|
||||
from uds.core import environment
|
||||
|
||||
provider = fixtures.create_provider()
|
||||
data = provider.serialize()
|
||||
|
||||
provider2 = OpenshiftProvider(environment=environment.Environment.testing_environment())
|
||||
provider2.deserialize(data)
|
||||
|
||||
self.assertEqual(str(provider2.type_name), 'Openshift Provider')
|
||||
self.assertEqual(str(provider2.type_description), 'Openshift based VMs provider')
|
||||
self.assertEqual(provider2.cluster_url.value, fixtures.PROVIDER_VALUES_DICT['cluster_url'])
|
||||
self.assertEqual(provider2.api_url.value, fixtures.PROVIDER_VALUES_DICT['api_url'])
|
||||
|
||||
def test_provider_serialization(self) -> None:
|
||||
"""
|
||||
Test that all provider fields are correctly serialized and deserialized.
|
||||
"""
|
||||
from uds.core import environment
|
||||
|
||||
provider = fixtures.create_provider()
|
||||
data = provider.serialize()
|
||||
|
||||
provider2 = OpenshiftProvider(environment=environment.Environment.testing_environment())
|
||||
provider2.deserialize(data)
|
||||
|
||||
for field in fixtures.PROVIDER_VALUES_DICT:
|
||||
self.assertEqual(getattr(provider2, field).value, fixtures.PROVIDER_VALUES_DICT[field])
|
||||
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import pickle
|
||||
|
||||
from tests.services.openshift import fixtures
|
||||
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftPublicationSerialization(UDSTransactionTestCase):
|
||||
EXPECTED_FIELDS = {'_name', '_waiting_name', '_reason', '_queue', '_vmid', '_is_flagged_for_destroy'}
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up test environment and clear fixtures before each test.
|
||||
"""
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
def _make_publication(self):
|
||||
"""
|
||||
Helper to create a publication with all fields set for serialization tests.
|
||||
"""
|
||||
publication = fixtures.create_publication()
|
||||
publication._name = 'test-template'
|
||||
publication._reason = 'test-reason'
|
||||
publication._waiting_name = True
|
||||
return publication
|
||||
|
||||
# --- Field Check Helper ---
|
||||
def check_fields(self, instance: 'fixtures.publication.OpenshiftTemplatePublication') -> None:
|
||||
"""
|
||||
Helper to check expected field values in a publication instance.
|
||||
"""
|
||||
self.assertEqual(instance._name, 'test-template')
|
||||
self.assertEqual(instance._reason, 'test-reason')
|
||||
self.assertTrue(instance._waiting_name)
|
||||
|
||||
# --- Serialization Tests ---
|
||||
def test_autoserialization_fields(self) -> None:
|
||||
"""
|
||||
Test that autoserializable fields match the expected set.
|
||||
"""
|
||||
publication = fixtures.create_publication()
|
||||
fields = set(f[0] for f in publication._autoserializable_fields())
|
||||
self.assertSetEqual(fields, self.EXPECTED_FIELDS)
|
||||
|
||||
def test_pickle_serialization(self) -> None:
|
||||
"""
|
||||
Test that publication object is correctly serialized and deserialized using pickle.
|
||||
"""
|
||||
publication = self._make_publication()
|
||||
data = pickle.dumps(publication)
|
||||
publication2 = pickle.loads(data)
|
||||
self.check_fields(publication2)
|
||||
|
||||
def test_marshal_unmarshal(self) -> None:
|
||||
"""
|
||||
Test that publication object is correctly marshaled and unmarshaled.
|
||||
"""
|
||||
publication = self._make_publication()
|
||||
marshaled = publication.marshal()
|
||||
publication2 = fixtures.create_publication()
|
||||
publication2.unmarshal(marshaled)
|
||||
self.check_fields(publication2)
|
||||
|
||||
def test_methods_after_serialization(self) -> None:
|
||||
"""
|
||||
Test that publication methods return correct values after serialization and deserialization.
|
||||
"""
|
||||
publication = self._make_publication()
|
||||
data = pickle.dumps(publication)
|
||||
publication2 = pickle.loads(data)
|
||||
self.assertEqual(publication2._name, 'test-template')
|
||||
self.assertEqual(publication2.get_template_id(), 'test-template')
|
||||
190
server/tests/services/openshift/test_service.py
Normal file
190
server/tests/services/openshift/test_service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Unit tests for OpenshiftService logic.
|
||||
All tests use fixtures for setup and mock dependencies.
|
||||
Tests are grouped by functionality: configuration, utility methods, availability, VM operations, exception handling, and deletion.
|
||||
"""
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
Reorganizado y corregido por GitHub Copilot
|
||||
"""
|
||||
|
||||
import typing
|
||||
from unittest import mock
|
||||
from uds.services.OpenShift.openshift import exceptions as morph_exceptions
|
||||
|
||||
from tests.services.openshift import fixtures
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
class TestOpenshiftService(UDSTransactionTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
def _create_service_with_provider(self):
|
||||
"""
|
||||
Helper to create a service with a patched provider.
|
||||
"""
|
||||
provider_ctx = fixtures.patched_provider()
|
||||
provider = provider_ctx.__enter__()
|
||||
service = fixtures.create_service(provider=provider)
|
||||
return service, provider, provider_ctx
|
||||
|
||||
# --- Configuration and initial data ---
|
||||
def test_service_data(self) -> None:
|
||||
"""
|
||||
Check initial service data values.
|
||||
"""
|
||||
service = fixtures.create_service()
|
||||
self.assertEqual(service.template.value, fixtures.SERVICE_VALUES_DICT['template'])
|
||||
self.assertEqual(service.basename.value, fixtures.SERVICE_VALUES_DICT['basename'])
|
||||
self.assertEqual(service.lenname.value, fixtures.SERVICE_VALUES_DICT['lenname'])
|
||||
self.assertEqual(service.publication_timeout.value, fixtures.SERVICE_VALUES_DICT['publication_timeout'])
|
||||
|
||||
def test_initialize_sets_basename(self) -> None:
|
||||
"""
|
||||
Check that initialize sets basename and lenname correctly.
|
||||
"""
|
||||
service = fixtures.create_service()
|
||||
service.basename.value = 'testname'
|
||||
service.lenname.value = 6
|
||||
service.initialize({'basename': 'testname', 'lenname': 6})
|
||||
self.assertEqual(service.basename.value, 'testname')
|
||||
self.assertEqual(service.lenname.value, 6)
|
||||
|
||||
def test_init_gui_sets_choices(self) -> None:
|
||||
"""
|
||||
Check that init_gui sets template choices.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
with mock.patch.object(service.template, 'set_choices') as set_choices_mock:
|
||||
service.init_gui()
|
||||
set_choices_mock.assert_called()
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- Utility and accessor methods ---
|
||||
def test_provider_returns_correct_type(self) -> None:
|
||||
"""
|
||||
Check that provider() returns the correct provider instance.
|
||||
"""
|
||||
service, provider, provider_ctx = self._create_service_with_provider()
|
||||
self.assertEqual(service.provider(), provider)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
def test_api_property_caching(self) -> None:
|
||||
"""
|
||||
Check that the api property is cached.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
api1 = service.api
|
||||
api2 = service.api
|
||||
self.assertIs(api1, api2)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
def test_service_methods(self) -> None:
|
||||
"""
|
||||
Check utility methods of the service.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
self.assertEqual(service.get_basename(), service.basename.value)
|
||||
self.assertEqual(service.get_lenname(), service.lenname.value)
|
||||
self.assertEqual(service.sanitized_name('Test VM 1'), 'test-vm-1')
|
||||
duplicates = list(service.find_duplicates('vm-1', '00:11:22:33:44:55'))
|
||||
self.assertEqual(len(duplicates), 1)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- Availability and cache ---
|
||||
def test_service_is_available(self) -> None:
|
||||
"""
|
||||
Check service availability and cache handling.
|
||||
"""
|
||||
service, provider, provider_ctx = self._create_service_with_provider()
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
self.assertTrue(service.is_available())
|
||||
api.test.assert_called_with()
|
||||
api.test.return_value = False
|
||||
self.assertTrue(service.is_available())
|
||||
service.provider().is_available.cache_clear() # type: ignore
|
||||
self.assertFalse(service.is_available())
|
||||
api.test.assert_called_with()
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- VM operations ---
|
||||
def test_vm_operations(self) -> None:
|
||||
"""
|
||||
Check VM operations: get_ip, get_mac, is_running, start, stop, shutdown.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
api = typing.cast(mock.MagicMock, service.api)
|
||||
ip = service.get_ip(None, 'vm-1')
|
||||
self.assertEqual(ip, '192.168.1.1')
|
||||
mac = service.get_mac(None, 'vm-1')
|
||||
self.assertEqual(mac, '00:11:22:33:44:01')
|
||||
self.assertTrue(service.is_running(None, 'vm-1'))
|
||||
service.start(None, 'vm-1')
|
||||
api.start_vm_instance.assert_called_with('vm-1')
|
||||
service.stop(None, 'vm-1')
|
||||
api.stop_vm_instance.assert_called_with('vm-1')
|
||||
service.shutdown(None, 'vm-1')
|
||||
api.stop_vm_instance.assert_called_with('vm-1')
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- Exception handling ---
|
||||
def test_get_ip_raises_exception_if_no_interfaces(self) -> None:
|
||||
"""
|
||||
Check that get_ip raises an exception if there are no interfaces.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
def no_interfaces(_vmid: str):
|
||||
mock_vm = mock.Mock()
|
||||
mock_vm.interfaces = []
|
||||
return mock_vm
|
||||
with mock.patch.object(service.api, 'get_vm_instance_info', side_effect=no_interfaces):
|
||||
with self.assertRaises(Exception):
|
||||
service.get_ip(None, 'vm-1')
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
def test_get_mac_raises_exception_if_no_interfaces(self) -> None:
|
||||
"""
|
||||
Check that get_mac raises an exception if there are no interfaces.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_with_provider()
|
||||
def no_interfaces(_vmid: str):
|
||||
mock_vm = mock.Mock()
|
||||
mock_vm.interfaces = []
|
||||
return mock_vm
|
||||
with mock.patch.object(service.api, 'get_vm_instance_info', side_effect=no_interfaces):
|
||||
with self.assertRaises(Exception):
|
||||
service.get_mac(None, 'vm-1')
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- VM deletion ---
|
||||
def test_vm_deletion(self) -> None:
|
||||
"""
|
||||
Check VM deletion logic and is_deleted method.
|
||||
"""
|
||||
service, provider, provider_ctx = self._create_service_with_provider()
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
|
||||
# Execute deletion
|
||||
service.execute_delete('vm-1')
|
||||
api.delete_vm_instance.assert_called_with('vm-1')
|
||||
|
||||
# Check if deleted
|
||||
api.get_vm_info.side_effect = morph_exceptions.OpenshiftNotFoundError('not found')
|
||||
self.assertTrue(service.is_deleted('vm-1'))
|
||||
|
||||
# Simulate VM exists
|
||||
api.get_vm_info.side_effect = None
|
||||
api.get_vm_info.return_value = fixtures.VMS[0]
|
||||
self.assertFalse(service.is_deleted('vm-1'))
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
122
server/tests/services/openshift/test_service_fixed.py
Normal file
122
server/tests/services/openshift/test_service_fixed.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import typing
|
||||
from unittest import mock
|
||||
|
||||
from tests.services.openshift import fixtures
|
||||
|
||||
from tests.utils.test import UDSTransactionTestCase
|
||||
|
||||
|
||||
|
||||
class TestOpenshiftServiceFixed(UDSTransactionTestCase):
|
||||
def _create_service_fixed_with_provider(self):
|
||||
"""
|
||||
Helper to create a fixed service with a patched provider.
|
||||
"""
|
||||
provider_ctx = fixtures.patched_provider()
|
||||
provider = provider_ctx.__enter__()
|
||||
service = fixtures.create_service_fixed(provider=provider)
|
||||
return service, provider, provider_ctx
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
fixtures.clear()
|
||||
|
||||
# --- Availability ---
|
||||
def test_service_is_available(self) -> None:
|
||||
"""
|
||||
Test provider availability and cache logic.
|
||||
"""
|
||||
service, provider, provider_ctx = self._create_service_fixed_with_provider()
|
||||
api = typing.cast(mock.MagicMock, provider.api)
|
||||
self.assertTrue(service.is_available())
|
||||
api.test.assert_called_with()
|
||||
# With cached data, even if test fails, it will return True
|
||||
api.test.return_value = False
|
||||
self.assertTrue(service.is_available())
|
||||
# Clear cache and test again
|
||||
service.provider().is_available.cache_clear() # type: ignore
|
||||
self.assertFalse(service.is_available())
|
||||
api.test.assert_called_with()
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- Service methods ---
|
||||
def test_service_methods(self) -> None:
|
||||
"""
|
||||
Test service methods: enumerate_assignables, get_name, get_ip, get_mac, sanitized_name.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_fixed_with_provider()
|
||||
# Enumerate assignables
|
||||
machines = list(service.enumerate_assignables())
|
||||
self.assertEqual(len(machines), 3)
|
||||
self.assertEqual(machines[0].id, 'vm-3')
|
||||
self.assertEqual(machines[1].id, 'vm-4')
|
||||
self.assertEqual(machines[2].id, 'vm-5')
|
||||
# Get machine name
|
||||
machine_name = service.get_name('uid-3')
|
||||
self.assertEqual(machine_name, 'vm-3')
|
||||
# Get IP
|
||||
ip = service.get_ip('uid-3')
|
||||
self.assertTrue(ip.startswith('192.168.1.'))
|
||||
# Get MAC
|
||||
mac = service.get_mac('uid-3')
|
||||
self.assertTrue(mac.startswith('00:11:22:33:44:'))
|
||||
# Sanitized name
|
||||
sanitized = service.sanitized_name('Test VM 1')
|
||||
self.assertIsInstance(sanitized, str)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
# --- Assignment logic ---
|
||||
def test_get_and_assign(self) -> None:
|
||||
"""
|
||||
Test get_and_assign logic for fixed service.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_fixed_with_provider()
|
||||
vmid = service.get_and_assign()
|
||||
self.assertIn(vmid, ['vm-3', 'vm-4', 'vm-5'])
|
||||
# Should not assign the same again
|
||||
with service._assigned_access() as assigned:
|
||||
self.assertIn(vmid, assigned)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
|
||||
def test_remove_and_free(self) -> None:
|
||||
"""
|
||||
Test remove_and_free logic for fixed service.
|
||||
"""
|
||||
service, _, provider_ctx = self._create_service_fixed_with_provider()
|
||||
vmid = service.get_and_assign()
|
||||
result = service.remove_and_free(vmid)
|
||||
self.assertEqual(result.name, 'FINISHED')
|
||||
with service._assigned_access() as assigned:
|
||||
self.assertNotIn(vmid, assigned)
|
||||
provider_ctx.__exit__(None, None, None)
|
||||
Reference in New Issue
Block a user