1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-11-13 20:24:27 +03:00

34 Commits

Author SHA1 Message Date
Adolfo Gómez García
67a58d57cb Remove redundant exception handling in ModelHandler 2025-11-11 19:09:20 +01:00
Adolfo Gómez García
39a046bb23 Refactor type hinting and clean up whitespace in sorting methods 2025-11-11 18:44:08 +01:00
Adolfo Gómez García
ae16e78a4a Refactor error handling and improve sorting methods in REST handlers 2025-11-11 18:40:07 +01:00
Adolfo Gómez García
5c3d7281fa Upgrading admin to new gen 2025-11-10 23:19:59 +01:00
Adolfo Gómez García
72f0c85f75 Update script and preload links in admin index.html with new version stamps and integrity hashes 2025-11-10 05:12:15 +01:00
Adolfo Gómez García
fc39a96850 Upgrading admin to new gen 2025-11-10 01:03:10 +01:00
Adolfo Gómez García
820ba7790d Refactor REST methods for improved clarity and efficiency
- Simplified `get_items` and `get_item` methods across various handlers to reduce complexity and improve readability.
- Introduced static methods for item conversion to streamline item creation.
- Enhanced error handling and logging for better debugging.
- Added `get_position_in_queryset` utility to determine the position of an item in a queryset.
- Updated `odata_filter` method to provide a default filtering mechanism in the base handler.
- Removed redundant exception handling and streamlined queryset filtering.
- Improved type hinting and annotations for better code clarity and type safety.
2025-11-09 18:38:47 +01:00
Adolfo Gómez García
a1119a6cc7 Add fallback options to translations and update script references in admin index
- Added "Fallback: Allow" and "Fallback: Deny" translations to translations-fakejs.js.
- Updated script and module preload links in index.html to new stamp values for cache busting.
2025-11-08 19:31:13 +01:00
Adolfo Gómez García
027be9b680 fix: update script references and improve tunnel.js functionality
- Updated the script references in the admin index.html to use new integrity hashes and timestamps.
- Modified tunnel.js to use top-level await for starting the tunnel, enhancing readability and performance.
- Fixed regex in tunnel.js to correctly replace the address placeholder in the RDP file content.
- Corrected a typo in the method call for creating the temporary RDP file.
- Improving odata functionality, and with that working, migrating to faster admin interface
2025-11-08 07:13:49 +01:00
Adolfo Gómez García
1e93bb702e Merge branch 'master' of github.com:/VirtualCable/openuds 2025-11-07 17:20:34 +01:00
Adolfo Gómez García
57cfb0d98e Enhance RDP transport scripts: add client version check and improve error handling 2025-11-07 17:20:23 +01:00
Adolfo Gómez
c88510133a Merge pull request #148 from VirtualCable/dev/andres/master-req
Dev/andres/master req
2025-11-07 14:52:22 +01:00
Adolfo Gómez
f5d4640cb1 Merge pull request #150 from VirtualCable/dkmstr-patch-1
Fix max message length to 4000 characters
2025-11-05 23:38:50 +01:00
Adolfo Gómez
dfd5cd4206 Fix max message length to 4000 characters
Adjust max message length to 4000 for SQL Server compatibility.
2025-11-05 23:38:38 +01:00
aschumann-virtualcable
32b5b29ae5 refactor: Improve VM name sanitization in OpenshiftProvider and add new test cases 2025-11-05 13:55:05 +01:00
aschumann-virtualcable
eaff4aeb80 refactor: Remove redundant code in OpenshiftClient.connect method 2025-11-05 13:16:50 +01:00
aschumann-virtualcable
b63e82dbdb Merge remote-tracking branch 'origin/master' into dev/andres/master-req 2025-11-05 13:07:09 +01:00
aschumann-virtualcable
144de7122b refactor: Update token handling in OpenshiftClient.connect method 2025-11-05 13:06:39 +01:00
aschumann-virtualcable
53bd9ed75a Merge remote-tracking branch 'origin/dev/janier/master' into dev/andres/master 2025-11-05 13:00:55 +01:00
Adolfo Gómez García
31325aa194 Updated actor 2025-11-05 12:28:09 +01:00
Adolfo Gómez García
52d34ed303 Merge branch 'master' of github.com:/VirtualCable/openuds 2025-11-05 12:27:41 +01:00
Adolfo Gómez García
790ac8063e Added doc folder to include some documentation related to server 2025-11-05 12:27:32 +01:00
Janier Rodríguez
bce487168b Refactor OpenShift service tests for improved clarity and organization
- Updated serialization tests for OpenshiftProvider to ensure correct method behavior after serialization.
- Enhanced publication serialization tests, adding checks for autoserializable fields and marshaling.
- Reorganized service tests to group by functionality, including configuration, utility methods, availability, VM operations, and exception handling.
- Added detailed tests for VM creation, deletion, and cancellation operations in the deployment context.
- Introduced fixed user service tests to validate lifecycle and operation behaviors.
- Removed outdated user service fixed tests and consolidated relevant functionality into new structured tests.
- Added serialization tests for fixed user service to ensure data integrity during serialization and deserialization.
2025-11-05 11:46:32 +01:00
Adolfo Gómez
295c820c7c Merge pull request #146 from VirtualCable/dkmstr-patch-1
Add DeepWiki badge to README
2025-11-05 04:03:41 +01:00
Adolfo Gómez
a4f6214ed9 Add DeepWiki badge to README
Added a badge for DeepWiki to the README.
2025-11-05 04:03:19 +01:00
Adolfo Gómez García
4ec687567d Refactor RDP client executable search order to prioritize 'udsrdp' 2025-11-05 03:57:02 +01:00
Adolfo Gómez García
7d16ae03e5 Add RDP client support with tunnel creation and error handling 2025-11-05 03:55:23 +01:00
Adolfo Gómez García
046130c77b Clarify lxmlsec installation comment and reorder dependency installation in workflow 2025-11-04 03:27:01 +01:00
Adolfo Gómez García
26c9dd0dec Update Python version in workflow and clarify lxml installation in requirements; fix user_service key in login/logout tests 2025-11-04 03:11:12 +01:00
Adolfo Gómez
27de5e065f Merge pull request #144 from VirtualCable/alert-autofix-390
Potential fix for code scanning alert no. 390: Workflow does not contain permissions
2025-11-04 00:41:41 +01:00
Janier Rodríguez
519436176a Merge branch 'dev/janier/master' of github.com:VirtualCable/openuds into dev/janier/master 2025-10-31 17:48:49 +01:00
Janier Rodríguez
3bbbc9d5dd refactor: Remove cloning functionality from OpenshiftClient and related tests 2025-10-31 17:45:52 +01:00
Janier Rodríguez
e6549c17d1 Add comprehensive tests for OpenShift service and user service functionality
- Implemented unit tests for OpenShift client, provider, publication, and service functionalities.
- Added serialization tests for user services and providers to ensure data integrity during serialization and deserialization.
- Created tests for VM lifecycle operations, including creation, deletion, and state checks.
- Enhanced test coverage for service availability and error handling scenarios.
- Introduced fixed user service tests to validate assignment and operational methods.
- Ensured all tests are structured to handle various edge cases and provide meaningful assertions.
2025-10-31 17:36:41 +01:00
aschumann-virtualcable
1970bb89dd Merge remote-tracking branch 'origin/master' into dev/andres/master 2025-10-28 15:34:04 +01:00
57 changed files with 25429 additions and 585 deletions

View File

@@ -23,7 +23,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.13'
python-version: '3.x'
- name: Install system dependencies
run: |
@@ -41,6 +41,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Install lxmlsec with local libraries to avoid binary wheel issues
pip install --upgrade --no-binary lxml --no-binary xmlsec lxml xmlsec
# Install other requirements
pip install -r requirements.txt
- name: Set PYTHONPATH

View File

@@ -17,3 +17,4 @@ Notes
* From `v4.0` onwards (current master), OpenUDS has been splitted in several repositories and contains submodules. Remember to use "git clone --resursive ..." to fetch it ;-).
* `v4.0` version needs Python 3.11 (may work fine on newer versions). It uses new features only available on 3.10 or later, and is tested against 3.11. It will probably work on 3.10 too.
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/VirtualCable/openuds)

2
actor

Submodule actor updated: 79a7e8bbc2...10b407ced9

22700
server/doc/api/rest.yaml Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,7 @@ cryptography
python3-saml
six
dnspython
lxml
# lxml must be installed source to avoid conflicts
ovirt-engine-sdk-python
pycurl
matplotlib

View File

@@ -38,6 +38,7 @@ import traceback
from django import http
from django.utils.decorators import method_decorator
from django.core.exceptions import ObjectDoesNotExist
from django.views.decorators.csrf import csrf_exempt
from django.views.generic.base import View
@@ -203,6 +204,14 @@ class Dispatcher(View):
except exceptions.rest.HandlerError as e:
log.log_operation(handler, 500, types.log.LogLevel.ERROR)
return http.HttpResponseBadRequest(f'{{"error": "{e}"}}'.encode(), content_type="application/json")
except exceptions.services.generics.Error as e:
log.log_operation(handler, 503, types.log.LogLevel.ERROR)
return http.HttpResponseServerError(
f'{{"error": "{e}"}}'.encode(), content_type="application/json", status=503
)
except ObjectDoesNotExist as e: # All DoesNotExist exceptions are not found
log.log_operation(handler, 404, types.log.LogLevel.ERROR)
return http.HttpResponseNotFound(f'{{"error": "{e}"}}'.encode(), content_type="application/json")
except Exception as e:
log.log_operation(handler, 500, types.log.LogLevel.ERROR)
# Get ecxeption backtrace

View File

@@ -58,6 +58,7 @@ logger = logging.getLogger(__name__)
T = typing.TypeVar('T')
class Handler(abc.ABC):
"""
REST requests handler base class
@@ -390,11 +391,48 @@ class Handler(abc.ABC):
if name in self._params:
return self._params[name]
return ''
def filter_queryset(self, qs: QuerySet[typing.Any]) -> list[typing.Any]:
def get_sort_field_info(self, *args: str) -> tuple[str, bool] | None:
"""
Returns sorting information for the first sorting if it is contained in the odata orderby list.
Args:
args: The possible name of the field name to check for sorting information.
Returns:
A tuple containing the clean field name found and a boolean indicating if the sorting is descending,
Note:
We only use the first in case of table sort translations, so this only returns info for the first field
"""
if self.odata.orderby:
order_field = self.odata.orderby[0]
clean_field = order_field.lstrip('-')
for field_name in args:
if clean_field == field_name:
is_descending = order_field.startswith('-')
return (clean_field, is_descending)
return None
def apply_sort(self, qs: QuerySet[typing.Any]) -> list[typing.Any] | QuerySet[typing.Any]:
"""
Custom sorting function to apply to querysets.
Override this method in subclasses to provide custom sorting logic.
Args:
qs: The queryset to sort.
order_by: The field name to sort by.
Returns:
The sorted queryset.
"""
return qs.order_by(*self.odata.orderby)
@typing.final
def filter_odata_queryset(self, qs: QuerySet[typing.Any]) -> list[typing.Any]:
"""
Filters the queryset based on odata
Note: We return a list, because after applying slicing, querysets may be evaluated
by using _result_cache, so we force evaluation here to avoid issues later.
"""
@@ -405,28 +443,30 @@ class Handler(abc.ABC):
except ValueError as e:
raise exceptions.rest.RequestError(f'Invalid odata filter: {e}') from e
# Store total count before slicing
self.add_header('X-Total-Count', str(qs.count()))
# order_by must be unique and all fields are summited by once
# As after slicing we can have a list, we may use list result from sorting
if self.odata.orderby:
qs = qs.order_by(*self.odata.orderby)
result = self.apply_sort(qs)
else:
result = qs
# If odata start/limit are set, apply them
if self.odata.start is not None:
qs = qs[self.odata.start :]
result = result[self.odata.start :]
# Note that limit is AFTER start because of previous line
if self.odata.limit is not None:
qs = qs[: self.odata.limit]
result = list(qs)
result = result[: self.odata.limit]
# Get total items and set it on X-Total-Count
try:
total_items = len(result)
self.add_header('X-Total-Count', total_items)
except Exception as e:
raise exceptions.rest.RequestError(f'Invalid odata: {e}')
# After slicing, the qs may be a list, so we ensure it's a list
# to avoid issues later
result = list(result)
return result
def filter_data(self, data: collections.abc.Iterable[T]) -> list[T]:
def filter_odata_data(self, data: collections.abc.Iterable[T]) -> list[T]:
"""
Filters the dict base on the currnet odata
"""
@@ -446,8 +486,7 @@ class Handler(abc.ABC):
raise exceptions.rest.RequestError(f'Invalid odata: {e}')
return data
@classmethod
def api_components(cls: type[typing.Self]) -> types.rest.api.Components:
"""
@@ -456,7 +495,9 @@ class Handler(abc.ABC):
return types.rest.api.Components()
@classmethod
def api_paths(cls: type[typing.Self], path: str, tags: list[str], security: str) -> dict[str, types.rest.api.PathItem]:
def api_paths(
cls: type[typing.Self], path: str, tags: list[str], security: str
) -> dict[str, types.rest.api.PathItem]:
"""
Returns the API operations that should be registered
"""

View File

@@ -70,7 +70,7 @@ class AccountsUsage(DetailHandler[AccountItem]): # pylint: disable=too-many-pub
"""
@staticmethod
def usage_to_dict(item: 'AccountUsage', perm: int) -> AccountItem:
def as_dict(item: 'AccountUsage', perm: int) -> AccountItem:
"""
Convert an account usage to a dictionary
:param item: Account usage item (db)
@@ -89,19 +89,23 @@ class AccountsUsage(DetailHandler[AccountItem]): # pylint: disable=too-many-pub
elapsed_timemark=item.elapsed_timemark,
permission=perm,
)
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, Account)
return self.calc_item_position(item_uuid, parent.usages.all())
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[AccountItem]:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[AccountItem]:
parent = ensure.is_instance(parent, Account)
# Check what kind of access do we have to parent provider
perm = permissions.effective_permissions(self._user, parent)
try:
if not item:
return [AccountsUsage.usage_to_dict(k, perm) for k in self.filter_queryset(parent.usages.all())]
k = parent.usages.get(uuid=process_uuid(item))
return AccountsUsage.usage_to_dict(k, perm)
except Exception:
logger.exception('itemId %s', item)
raise exceptions.rest.NotFound(_('Account usage not found: {}').format(item)) from None
return [AccountsUsage.as_dict(k, perm) for k in self.odata_filter(parent.usages.all())]
def get_item(self, parent: 'Model', item: str) -> AccountItem:
parent = ensure.is_instance(parent, Account)
# Check what kind of access do we have to parent provider
return AccountsUsage.as_dict(
parent.usages.get(uuid=process_uuid(item)), permissions.effective_permissions(self._user, parent)
)
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, Account)

View File

@@ -54,6 +54,9 @@ from .users_groups import Groups, Users
from uds.core.module import Module
if typing.TYPE_CHECKING:
from django.db.models.query import QuerySet
logger = logging.getLogger(__name__)
@@ -110,9 +113,9 @@ class Authenticators(ModelHandler[AuthenticatorItem]):
.icon(name='name', title=_('Name'), visible=True)
.text_column(name='type_name', title=_('Type'))
.text_column(name='comments', title=_('Comments'))
.numeric_column(name='priority', title=_('Priority'), width='5rem')
.numeric_column(name='priority', title=_('Priority'), width='8rem')
.text_column(name='small_name', title=_('Label'))
.numeric_column(name='users_count', title=_('Users'), width='1rem')
.numeric_column(name='users_count', title=_('Users'), width='6rem')
.text_column(name='mfa_name', title=_('MFA'))
.text_column(name='tags', title=_('tags'), visible=False)
.row_style(prefix='row-state-', field='state')
@@ -218,6 +221,29 @@ class Authenticators(ModelHandler[AuthenticatorItem]):
type_info=type(self).as_typeinfo(item.get_type()),
)
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
if field_info := self.get_sort_field_info('users_count'):
field_name, is_descending = field_info
order_by_field = f"-{field_name}" if is_descending else field_name
return qs.annotate(users_count=models.Count('users')).order_by(order_by_field)
if field_info := self.get_sort_field_info('type_name'):
_, is_descending = field_info
order_by_field = f'-data_type' if is_descending else 'data_type'
return qs.order_by(order_by_field)
if field_info := self.get_sort_field_info('numeric_id'):
_, is_descending = field_info
order_by_field = f'-pk' if is_descending else 'pk'
return qs.order_by(order_by_field)
if field_info := self.get_sort_field_info('mfa_name'):
_, is_descending = field_info
order_by_field = f'-mfa__name' if is_descending else 'mfa__name'
return qs.order_by(order_by_field)
return super().apply_sort(qs)
def post_save(self, item: 'models.Model') -> None:
item = ensure.is_instance(item, Authenticator)
try:

View File

@@ -82,30 +82,34 @@ class CalendarRules(DetailHandler[CalendarRuleItem]): # pylint: disable=too-man
name=item.name,
comments=item.comments,
start=item.start,
end=timezone.make_aware(datetime.datetime.combine(item.end, datetime.time.max)) if item.end else None,
end=(
timezone.make_aware(datetime.datetime.combine(item.end, datetime.time.max))
if item.end
else None
),
frequency=item.frequency,
interval=item.interval,
duration=item.duration,
duration_unit=item.duration_unit,
permission=perm,
)
def get_item_position(self, parent: 'models.Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, Calendar)
return self.calc_item_position(item_uuid, parent.rules.all())
def get_items(
self, parent: 'models.Model', item: typing.Optional[str]
) -> types.rest.ItemsResult[CalendarRuleItem]:
def get_items(self, parent: 'models.Model') -> types.rest.ItemsResult[CalendarRuleItem]:
parent = ensure.is_instance(parent, Calendar)
# Check what kind of access do we have to parent provider
perm = permissions.effective_permissions(self._user, parent)
try:
if item is None:
return [CalendarRules.rule_as_dict(k, perm) for k in self.filter_queryset(parent.rules.all())]
k = parent.rules.get(uuid=process_uuid(item))
return CalendarRules.rule_as_dict(k, perm)
except CalendarRule.DoesNotExist:
raise exceptions.rest.NotFound(_('Calendar rule not found: {}').format(item)) from None
except Exception as e:
logger.exception('itemId %s', item)
raise exceptions.rest.RequestError(f'Error retrieving calendar rule: {e}') from e
return [CalendarRules.rule_as_dict(k, perm) for k in self.filter_odata_queryset(parent.rules.all())]
def get_item(self, parent: 'models.Model', item: str) -> CalendarRuleItem:
parent = ensure.is_instance(parent, Calendar)
# Check what kind of access do we have to parent provider
return CalendarRules.rule_as_dict(
parent.rules.get(uuid=process_uuid(item)), permissions.effective_permissions(self._user, parent)
)
def get_table(self, parent: 'models.Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, Calendar)

View File

@@ -51,7 +51,7 @@ class Config(Handler):
ROLE = consts.UserRole.ADMIN
def get(self) -> typing.Any:
return self.filter_data(CfgConfig.get_config_values(self.is_admin()))
return CfgConfig.get_config_values(self.is_admin())
def put(self) -> typing.Any:
for section, section_dict in typing.cast(dict[str, dict[str, dict[str, str]]], self._params).items():

View File

@@ -86,7 +86,7 @@ class Connection(Handler):
# Ensure user is present on request, used by web views methods
self._request.user = self._user
return Connection.result(result=self.filter_data(services.get_services_info_dict(self._request)))
return Connection.result(result=self.filter_odata_data(services.get_services_info_dict(self._request)))
def connection(self, id_service: str, id_transport: str, skip: str = '') -> dict[str, typing.Any]:
skip_check = skip in ('doNotCheck', 'do_not_check', 'no_check', 'nocheck', 'skip_check')

View File

@@ -91,19 +91,21 @@ class MetaServicesPool(DetailHandler[MetaItem]):
user_services_count=item.pool.userServices.exclude(state__in=State.INFO_STATES).count(),
user_services_in_preparation=item.pool.userServices.filter(state=State.PREPARING).count(),
)
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, models.MetaPool)
return self.calc_item_position(item_uuid, parent.members.all())
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult['MetaItem']:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['MetaItem']:
parent = ensure.is_instance(parent, models.MetaPool)
return [MetaServicesPool.as_dict(i) for i in self.filter_odata_queryset(parent.members.all())]
def get_item(self, parent: 'Model', item: str) -> 'MetaItem':
parent = ensure.is_instance(parent, models.MetaPool)
try:
if not item:
return [MetaServicesPool.as_dict(i) for i in self.filter_queryset(parent.members.all())]
i = parent.members.get(uuid=process_uuid(item))
return MetaServicesPool.as_dict(i)
return MetaServicesPool.as_dict(parent.members.get(uuid=process_uuid(item)))
except models.MetaPoolMember.DoesNotExist:
raise exceptions.rest.NotFound(_('Meta pool member not found: {}').format(item)) from None
except Exception as e:
logger.exception('err: %s', item)
raise exceptions.rest.RequestError(f'Error retrieving meta pool member: {e}') from e
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, models.MetaPool)
@@ -171,67 +173,62 @@ class MetaAssignedService(DetailHandler[UserServiceItem]):
element.pool_name = item.deployed_service.name
return element
def _get_assigned_userservice(self, metapool: models.MetaPool, userservice_id: str) -> models.UserService:
@staticmethod
def _get_assigned_userservice(metapool: models.MetaPool, userservice_id: str) -> models.UserService:
"""
Gets an assigned service and checks that it belongs to this metapool
If not found, raises InvalidItemException
"""
try:
return models.UserService.objects.filter(
uuid=process_uuid(userservice_id),
cache_level=0,
deployed_service__in=[i.pool for i in metapool.members.all()],
)[0]
except IndexError:
found = models.UserService.objects.filter(
uuid=process_uuid(userservice_id),
cache_level=0,
deployed_service__in=[i.pool for i in metapool.members.all()],
).first()
if found is None:
raise exceptions.rest.NotFound(_('User service not found: {}').format(userservice_id)) from None
except Exception:
logger.error('Error getting assigned userservice %s for metapool %s', userservice_id, metapool.uuid)
raise exceptions.rest.RequestError(
_('Error retrieving assigned service: {}').format(userservice_id)
) from None
return found
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[UserServiceItem]:
def _assigned_userservices_for_pools(
self, parent: 'models.MetaPool'
) -> typing.Generator[tuple[models.UserService, typing.Optional[dict[str, typing.Any]]], None, None]:
for m in self.odata_filter(parent.members.filter(enabled=True)):
properties: dict[str, typing.Any] = {
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice',
owner_id__in=m.pool.assigned_user_services().values_list('uuid', flat=True),
).values_list('key', 'value')
}
for u in (
m.pool.assigned_user_services()
.filter(state__in=State.VALID_STATES)
.prefetch_related('deployed_service', 'publication')
):
yield u, properties.get(u.uuid, {})
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[UserServiceItem]:
parent = ensure.is_instance(parent, models.MetaPool)
def _assigned_userservices_for_pools() -> (
typing.Generator[tuple[models.UserService, typing.Optional[dict[str, typing.Any]]], None, None]
):
for m in self.filter_queryset(parent.members.filter(enabled=True)):
properties: dict[str, typing.Any] = {
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice',
owner_id__in=m.pool.assigned_user_services().values_list('uuid', flat=True),
).values_list('key', 'value')
}
for u in (
m.pool.assigned_user_services()
.filter(state__in=State.VALID_STATES)
.prefetch_related('deployed_service', 'publication')
):
yield u, properties.get(u.uuid, {})
return list(
{
k.uuid: MetaAssignedService.item_as_dict(parent, k, props)
for k, props in self._assigned_userservices_for_pools(parent)
}.values()
)
try:
if not item: # All items
result: dict[str, typing.Any] = {}
def get_item(self, parent: 'Model', item: str) -> UserServiceItem:
parent = ensure.is_instance(parent, models.MetaPool)
for k, props in _assigned_userservices_for_pools():
result[k.uuid] = MetaAssignedService.item_as_dict(parent, k, props)
return list(result.values())
return MetaAssignedService.item_as_dict(
parent,
self._get_assigned_userservice(parent, item),
props={
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice', owner_id=process_uuid(item)
).values_list('key', 'value')
},
)
except Exception as e:
logger.exception('get_items')
raise exceptions.rest.RequestError(f'Error retrieving meta pool member: {e}') from e
return MetaAssignedService.item_as_dict(
parent,
self._get_assigned_userservice(parent, item),
props={
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice', owner_id=process_uuid(item)
).values_list('key', 'value')
},
)
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, models.MetaPool)

View File

@@ -73,22 +73,27 @@ class AccessCalendars(DetailHandler[AccessCalendarItem]):
access=item.access,
priority=item.priority,
)
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
# parent can be a ServicePool or a metaPool
if isinstance(parent, models.ServicePool):
parent = ensure.is_instance(parent, models.ServicePool)
return self.calc_item_position(item_uuid, parent.calendarAccess.all())
parent = ensure.is_instance(parent, models.MetaPool)
return self.calc_item_position(item_uuid, parent.calendarAccess.all())
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.ItemsResult[AccessCalendarItem]:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[AccessCalendarItem]:
# parent can be a ServicePool or a metaPool
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
try:
if not item:
return [AccessCalendars.as_item(i) for i in self.filter_queryset(parent.calendarAccess.all())]
return AccessCalendars.as_item(parent.calendarAccess.get(uuid=process_uuid(item)))
except models.CalendarAccess.DoesNotExist:
raise exceptions.rest.NotFound(_('Access calendar not found: {}').format(item)) from None
except Exception as e:
logger.exception('err: %s', item)
raise exceptions.rest.RequestError(f'Error retrieving access calendar: {e}') from e
return [AccessCalendars.as_item(i) for i in self.filter_odata_queryset(parent.calendarAccess.all())]
def get_item(self, parent: 'Model', item: str) -> AccessCalendarItem:
# parent can be a ServicePool or a metaPool
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
return AccessCalendars.as_item(parent.calendarAccess.get(uuid=process_uuid(item)))
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
return (
@@ -190,21 +195,20 @@ class ActionsCalendars(DetailHandler[ActionCalendarItem]):
next_execution=item.next_execution,
last_execution=item.last_execution,
)
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.ItemsResult[ActionCalendarItem]:
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, models.ServicePool)
try:
if item is None:
return [ActionsCalendars.as_dict(i) for i in self.filter_queryset(parent.calendaraction_set.all())]
i = parent.calendaraction_set.get(uuid=process_uuid(item))
return ActionsCalendars.as_dict(i)
except models.CalendarAction.DoesNotExist:
raise exceptions.rest.NotFound(_('Scheduled action not found: {}').format(item)) from None
except Exception as e:
logger.error('Error retrieving scheduled action %s: %s', item, e)
raise exceptions.rest.RequestError(f'Error retrieving scheduled action: {e}') from e
return self.calc_item_position(item_uuid, parent.calendaraction_set.all())
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ActionCalendarItem]:
parent = ensure.is_instance(parent, models.ServicePool)
return [
ActionsCalendars.as_dict(i) for i in self.filter_odata_queryset(parent.calendaraction_set.all())
]
def get_item(self, parent: 'Model', item: str) -> ActionCalendarItem:
parent = ensure.is_instance(parent, models.ServicePool)
return ActionsCalendars.as_dict(parent.calendaraction_set.get(uuid=process_uuid(item)))
def get_table(self, parent: 'Model') -> TableInfo:
return (

View File

@@ -36,7 +36,7 @@ import logging
import typing
from django.utils.translation import gettext, gettext_lazy as _
from django.db.models import Model
from django.db.models import Model, Count
import uds.core.types.permissions
from uds.core import exceptions, services, types
@@ -51,6 +51,9 @@ from .services_usage import ServicesUsage
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from django.db.models.query import QuerySet
# Helper class for Provider offers
@dataclasses.dataclass
@@ -96,6 +99,8 @@ class Providers(ModelHandler[ProviderItem]):
.numeric_column(name='user_services_count', title=_('User Services'))
.text_column(name='tags', title=_('Tags'), visible=False)
.row_style(prefix='row-maintenance-', field='maintenance_mode')
.with_field_mappings(type_name='data_type')
.with_filter_fields('name', 'data_type', 'comments', 'maintenance_mode')
).build()
# Rest api related information to complete the auto-generated API
@@ -103,6 +108,14 @@ class Providers(ModelHandler[ProviderItem]):
typed=types.rest.api.RestApiInfoGuiType.MULTIPLE_TYPES,
)
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
if field_info := self.get_sort_field_info('services_count'):
field_name, is_descending = field_info
order_by_field = f"-{field_name}" if is_descending else field_name
return qs.annotate(services_count=Count('services')).order_by(order_by_field)
return super().apply_sort(qs)
def get_item(self, item: 'Model') -> ProviderItem:
item = ensure.is_instance(item, Provider)
type_ = item.get_type()

View File

@@ -120,7 +120,7 @@ class Reports(model.BaseModelHandler[ReportItem]):
return match_args(
self._args,
error,
((), lambda: list(self.filter_data(self.get_items()))),
((), lambda: list(self.filter_odata_data(self.get_items()))),
((consts.rest.OVERVIEW,), lambda: list(self.get_items())),
(
(consts.rest.TABLEINFO,),

View File

@@ -148,38 +148,25 @@ class ServersServers(DetailHandler[ServerItem]):
typed=types.rest.api.RestApiInfoGuiType.SINGLE_TYPE,
)
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[ServerItem]:
def as_server_item(self, item: 'models.Server') -> ServerItem:
return ServerItem(
id=item.uuid,
hostname=item.hostname,
ip=item.ip,
listen_port=item.listen_port,
mac=item.mac if item.mac != consts.NULL_MAC else '',
maintenance_mode=item.maintenance_mode,
register_username=item.register_username,
stamp=item.stamp,
)
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServerItem]:
parent = typing.cast('models.ServerGroup', parent) # We will receive for sure
try:
if item is None:
q = self.filter_queryset(parent.servers.all())
else:
q = parent.servers.filter(uuid=process_uuid(item))
res: list[ServerItem] = []
i = None
for i in q:
res.append(
ServerItem(
id=i.uuid,
hostname=i.hostname,
ip=i.ip,
listen_port=i.listen_port,
mac=i.mac if i.mac != consts.NULL_MAC else '',
maintenance_mode=i.maintenance_mode,
register_username=i.register_username,
stamp=i.stamp,
)
)
if item is None:
return res
if not i:
raise exceptions.rest.NotFound(f'Server not found: {item}')
return res[0]
except exceptions.rest.HandlerError:
raise
except Exception:
logger.exception('Error getting server')
raise exceptions.rest.ResponseError(_('Error getting server')) from None
return [self.as_server_item(i) for i in self.filter_odata_queryset(parent.servers.all())]
def get_item(self, parent: 'Model', item: str) -> ServerItem:
parent = typing.cast('models.ServerGroup', parent) # We will receive for sure
return self.as_server_item(parent.servers.get(uuid=process_uuid(item)))
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, models.ServerGroup)

View File

@@ -155,23 +155,25 @@ class Services(DetailHandler[ServiceItem]): # pylint: disable=too-many-public-m
ret_value.info = Services.service_info(item)
return ret_value
def get_item_position(self, parent: Model, item_uuid: str) -> int:
parent = ensure.is_instance(parent, models.Provider)
return self.calc_item_position(item_uuid, parent.services.all())
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[ServiceItem]:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServiceItem]:
parent = ensure.is_instance(parent, models.Provider)
# Check what kind of access do we have to parent provider
perm = permissions.effective_permissions(self._user, parent)
try:
if item is None:
return [Services.service_item(k, perm) for k in self.filter_queryset(parent.services.all())]
k = parent.services.get(uuid=process_uuid(item))
val = Services.service_item(k, perm, full=True)
# On detail, ne wee to fill the instance fields by hand
return val
except models.Service.DoesNotExist:
raise exceptions.rest.NotFound(_('Service not found')) from None
except Exception as e:
logger.error('Error getting services for %s: %s', parent, e)
raise exceptions.rest.ResponseError(_('Error getting services')) from None
return [Services.service_item(k, perm) for k in self.odata_filter(parent.services.all())]
def get_item(self, parent: 'Model', item: str) -> ServiceItem:
parent = ensure.is_instance(parent, models.Provider)
# Check what kind of access do we have to parent provider
return Services.service_item(
parent.services.get(uuid=process_uuid(item)),
permissions.effective_permissions(self._user, parent),
full=True,
)
def _delete_incomplete_service(self, service: models.Service) -> None:
"""

View File

@@ -56,6 +56,9 @@ from .user_services import AssignedUserService, CachedService, Changelog, Groups
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from django.db.models import QuerySet
@dataclasses.dataclass
class ServicePoolItem(types.rest.BaseRestItem):
@@ -158,6 +161,7 @@ class ServicesPools(ModelHandler[ServicePoolItem]):
.text_column(name='parent', title=_('Parent service'))
.text_column(name='tags', title=_('tags'), visible=False)
.row_style(prefix='row-state-', field='state')
.with_filter_fields('name', 'state')
.build()
)
@@ -175,13 +179,21 @@ class ServicesPools(ModelHandler[ServicePoolItem]):
typed=types.rest.api.RestApiInfoGuiType.SINGLE_TYPE,
)
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
if field_info := self.get_sort_field_info('state'):
field_name, is_descending = field_info
order_by_field = f"-{field_name}" if is_descending else field_name
return qs.order_by(order_by_field)
return super().apply_sort(qs)
def get_items(
self, *args: typing.Any, **kwargs: typing.Any
) -> typing.Generator[ServicePoolItem, None, None]:
# Optimized query, due that there is a lot of info needed for theee
d = sql_now() - datetime.timedelta(seconds=GlobalConfig.RESTRAINT_TIME.as_int())
return super().get_items(
overview=kwargs.get('overview', True),
sumarize=kwargs.get('overview', True),
query=(
ServicePool.objects.prefetch_related(
'service',

View File

@@ -33,7 +33,6 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import dataclasses
import logging
import typing
import datetime
from django.utils.translation import gettext as _
@@ -110,32 +109,34 @@ class ServicesUsage(DetailHandler[ServicesUsageItem]):
source_ip=item.src_ip,
in_use=item.in_use,
)
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, Provider)
return self.calc_item_position(
item_uuid,
UserService.objects.filter(deployed_service__service__provider=parent).order_by('creation_date'),
)
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.ItemsResult[ServicesUsageItem]:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[ServicesUsageItem]:
parent = ensure.is_instance(parent, Provider)
try:
if item is None:
userservices_query = self.filter_queryset(
UserService.objects.filter(deployed_service__service__provider=parent)
)
else:
userservices_query = UserService.objects.filter(
deployed_service__service_uuid=process_uuid(item)
)
return [
ServicesUsage.item_as_dict(k)
for k in userservices_query.filter(state=State.USABLE)
userservices = self.odata_filter(
UserService.objects.filter(deployed_service__service__provider=parent)
.order_by('creation_date')
.prefetch_related('deployed_service', 'deployed_service__service', 'user', 'user__manager')
]
)
return [ServicesUsage.item_as_dict(k) for k in userservices]
except Exception as e:
logger.error('Error getting services usage for %s: %s', parent.uuid, e)
raise exceptions.rest.ResponseError(_('Error getting services usage')) from None
def get_item(self, parent: 'Model', item: str) -> ServicesUsageItem:
parent = ensure.is_instance(parent, Provider)
return ServicesUsage.item_as_dict(
UserService.objects.filter(deployed_service__service_uuid=process_uuid(item)).get()
)
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, Provider)
return (

View File

@@ -63,39 +63,35 @@ class TunnelServers(DetailHandler[TunnelServerItem]):
REST_API_INFO = types.rest.api.RestApiInfo(
name='TunnelServers', description='Tunnel servers assigned to a tunnel'
)
@staticmethod
def as_tunnel_server_item(item: models.Server) -> TunnelServerItem:
return TunnelServerItem(
id=item.uuid,
hostname=item.hostname,
ip=item.ip,
mac=item.mac if item.mac != consts.NULL_MAC else '',
maintenance=item.maintenance_mode,
)
def get_item_position(self, parent: Model, item_uuid: str) -> int:
parent = ensure.is_instance(parent, models.ServerGroup)
return self.calc_item_position(item_uuid, parent.servers.all())
def get_items(
self, parent: 'Model', item: typing.Optional[str]
self, parent: 'Model'
) -> types.rest.ItemsResult[TunnelServerItem]:
parent = ensure.is_instance(parent, models.ServerGroup)
try:
multi = False
if item is None:
multi = True
q = self.filter_queryset(parent.servers.all())
else:
q = parent.servers.filter(uuid=process_uuid(item))
res: list[TunnelServerItem] = [
TunnelServerItem(
id=i.uuid,
hostname=i.hostname,
ip=i.ip,
mac=i.mac if i.mac != consts.NULL_MAC else '',
maintenance=i.maintenance_mode,
)
for i in q
return [
TunnelServers.as_tunnel_server_item(i)
for i in self.odata_filter(parent.servers.all())
]
if multi:
return res
if not res:
raise exceptions.rest.NotFound(f'Tunnel server {item} not found')
return res[0]
except exceptions.rest.HandlerError:
raise
except Exception as e:
logger.error('Error getting tunnel servers for %s: %s', parent, e)
raise exceptions.rest.ResponseError(_('Error getting tunnel servers')) from e
def get_item(
self, parent: 'Model', item: str
) -> TunnelServerItem:
parent = ensure.is_instance(parent, models.ServerGroup)
return TunnelServers.as_tunnel_server_item(parent.servers.get(uuid=process_uuid(item)))
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, models.ServerGroup)

View File

@@ -37,7 +37,7 @@ import logging
import typing
from django.utils.translation import gettext as _
from django.db.models import Model
from django.db.models import Model, QuerySet, OuterRef, Subquery
import uds.core.types.permissions
from uds import models
@@ -145,42 +145,90 @@ class AssignedUserService(DetailHandler[UserServiceItem]):
return val
def get_items(
self, parent: 'Model', item: typing.Optional[str]
def apply_sort(self, qs: QuerySet[typing.Any]) -> list[typing.Any] | QuerySet[typing.Any]:
def annotated_sort(field: str, descending: bool) -> QuerySet[typing.Any]:
prop_value_subquery = models.Properties.objects.filter(
owner_id=OuterRef('uuid'), owner_type='userservice', key=field
).values('value')[:1]
return qs.annotate(prop_value=Subquery(prop_value_subquery)).order_by(
f'{"-" if descending else ""}prop_value'
)
if sort_info := self.get_sort_field_info('ip', 'actor_version'):
return annotated_sort(*sort_info)
first_order_by_field, is_descending = sort_info
return super().apply_sort(qs)
def get_qs(self, for_cached: bool) -> QuerySet[models.UserService]:
parent = ensure.is_instance(self._parent, models.ServicePool)
if for_cached:
return parent.cached_users_services()
return parent.assigned_user_services()
def do_get_item_position(self, for_cached: bool, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, models.ServicePool)
return self.calc_item_position(item_uuid, self.get_qs(for_cached).all())
def get_item_position(self, parent: Model, item_uuid: str) -> int:
return self.do_get_item_position(for_cached=False, parent=parent, item_uuid=item_uuid)
def do_get_item(
self,
parent: 'Model',
item: str,
for_cached: bool,
) -> 'UserServiceItem':
parent = ensure.is_instance(parent, models.ServicePool)
return AssignedUserService.userservice_item(
self.get_qs(for_cached).get(uuid=process_uuid(item)),
props={
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice', owner_id=process_uuid(item)
).values_list('key', 'value')
},
is_cache=for_cached,
)
def do_get_items(
self,
parent: 'Model',
for_cached: bool,
) -> types.rest.ItemsResult['UserServiceItem']:
parent = ensure.is_instance(parent, models.ServicePool)
try:
if not item:
# First, fetch all properties for all assigned services on this pool
# We can cache them, because they are going to be readed anyway...
properties: dict[str, typing.Any] = collections.defaultdict(dict)
for id, key, value in self.filter_queryset(
models.Properties.objects.filter(
owner_type='userservice',
owner_id__in=parent.assigned_user_services().values_list('uuid', flat=True),
)
).values_list('owner_id', 'key', 'value'):
properties[id][key] = value
def get_qs() -> QuerySet[models.UserService]:
if for_cached:
return parent.cached_users_services()
return parent.assigned_user_services()
return [
AssignedUserService.userservice_item(k, properties.get(k.uuid, {}))
for k in parent.assigned_user_services()
.all()
.prefetch_related('deployed_service', 'publication', 'user')
]
return AssignedUserService.userservice_item(
parent.assigned_user_services().get(process_uuid(uuid=process_uuid(item))),
props={
k: v
for k, v in models.Properties.objects.filter(
owner_type='userservice', owner_id=process_uuid(item)
).values_list('key', 'value')
},
# First, fetch all properties for all assigned services on this pool
# We can cache them, because they are going to be readed anyway...
properties: dict[str, typing.Any] = collections.defaultdict(dict)
for id, key, value in models.Properties.objects.filter(
owner_type='userservice',
owner_id__in=get_qs().values_list('uuid', flat=True),
).values_list('owner_id', 'key', 'value'):
properties[id][key] = value
return [
AssignedUserService.userservice_item(k, properties.get(k.uuid, {}))
for k in self.odata_filter(
get_qs().all().prefetch_related('deployed_service', 'publication', 'user')
)
except Exception as e:
logger.error('Error getting user service %s: %s', item, e)
raise exceptions.rest.ResponseError(_('Error getting user service')) from e
]
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['UserServiceItem']:
return self.do_get_items(parent, for_cached=False)
def get_item(
self,
parent: 'Model',
item: str,
) -> 'UserServiceItem':
return self.do_get_item(parent, item, for_cached=False)
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, models.ServicePool)
@@ -202,6 +250,8 @@ class AssignedUserService(DetailHandler[UserServiceItem]):
.text_column(name='owner', title=_('Owner'))
.text_column(name='actor_version', title=_('Actor version'))
.row_style(prefix='row-state-', field='state')
.with_field_mappings(revision='deployed_service.publications.revision')
.with_filter_fields('creation_date', 'unique_id', 'friendly_name', 'state', 'in_use')
).build()
def get_logs(self, parent: 'Model', item: str) -> list[typing.Any]:
@@ -295,27 +345,19 @@ class CachedService(AssignedUserService):
"""
CUSTOM_METHODS = [] # Remove custom methods from assigned services
def get_item_position(self, parent: Model, item_uuid: str) -> int:
return self.do_get_item_position(for_cached=True, parent=parent, item_uuid=item_uuid)
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.ItemsResult['UserServiceItem']:
parent = ensure.is_instance(parent, models.ServicePool)
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['UserServiceItem']:
return self.do_get_items(parent, for_cached=True)
try:
if not item:
return [
AssignedUserService.userservice_item(k, is_cache=True)
for k in self.filter_queryset(parent.cached_users_services().all()).prefetch_related(
'deployed_service', 'publication'
)
]
cached_userservice: models.UserService = parent.cached_users_services().get(uuid=process_uuid(item))
return AssignedUserService.userservice_item(cached_userservice, is_cache=True)
except models.UserService.DoesNotExist:
raise exceptions.rest.NotFound(_('User service not found')) from None
except Exception as e:
logger.error('Error getting user service %s: %s', item, e)
raise exceptions.rest.ResponseError(_('Error getting user service')) from e
def get_item(
self,
parent: 'Model',
item: str,
) -> 'UserServiceItem':
return self.do_get_item(parent, item, for_cached=True)
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, models.ServicePool)
@@ -327,6 +369,8 @@ class CachedService(AssignedUserService):
.text_column(name='ip', title=_('IP'))
.text_column(name='friendly_name', title=_('Friendly name'))
.dict_column(name='state', title=_('State'), dct=State.literals_dict())
.with_field_mappings(revision='deployed_service.publications.revision')
.with_filter_fields('creation_date', 'unique_id', 'friendly_name', 'state')
)
if parent.state != State.LOCKED:
table_info = table_info.text_column(name='cache_level', title=_('Cache level')).text_column(
@@ -366,7 +410,7 @@ class Groups(DetailHandler[GroupItem]):
Processes the groups detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['GroupItem']:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['GroupItem']:
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
return [
@@ -381,10 +425,13 @@ class Groups(DetailHandler[GroupItem]):
auth_name=group.manager.name,
)
for group in typing.cast(
collections.abc.Iterable[models.Group], self.filter_queryset(parent.assignedGroups.all())
collections.abc.Iterable[models.Group], self.filter_odata_queryset(parent.assignedGroups.all())
)
]
def get_item(self, parent: Model, item: str) -> GroupItem:
raise exceptions.rest.NotSupportedError('Single group retrieval not implemented inside assigned groups')
def get_table(self, parent: 'Model') -> TableInfo:
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
return (
@@ -437,7 +484,7 @@ class Transports(DetailHandler[TransportItem]):
Processes the transports detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['TransportItem']:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['TransportItem']:
parent = ensure.is_instance(parent, models.ServicePool)
return [
@@ -449,9 +496,14 @@ class Transports(DetailHandler[TransportItem]):
priority=trans.priority,
trans_type=trans.get_type().mod_name(),
)
for trans in self.filter_queryset(parent.transports.all())
for trans in self.filter_odata_queryset(parent.transports.all())
]
def get_item(self, parent: 'Model', item: str) -> TransportItem:
raise exceptions.rest.NotSupportedError(
'Single transport retrieval not implemented inside assigned transports'
)
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, models.ServicePool)
return (
@@ -562,7 +614,7 @@ class Publications(DetailHandler[PublicationItem]):
return self.success()
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['PublicationItem']:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['PublicationItem']:
parent = ensure.is_instance(parent, models.ServicePool)
return [
PublicationItem(
@@ -573,9 +625,14 @@ class Publications(DetailHandler[PublicationItem]):
reason=State.from_str(i.state).is_errored() and i.get_instance().error_reason() or '',
state_date=i.state_date,
)
for i in self.filter_queryset(parent.publications.all())
for i in self.filter_odata_queryset(parent.publications.all())
]
def get_item(self, parent: 'Model', item: str) -> PublicationItem:
raise exceptions.rest.NotSupportedError(
'Single publication retrieval not implemented inside assigned publications'
)
def get_table(self, parent: 'Model') -> TableInfo:
parent = ensure.is_instance(parent, models.ServicePool)
return (
@@ -600,7 +657,7 @@ class Changelog(DetailHandler[ChangelogItem]):
Processes the transports detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> list['ChangelogItem']:
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['ChangelogItem']:
parent = ensure.is_instance(parent, models.ServicePool)
return [
ChangelogItem(
@@ -608,9 +665,12 @@ class Changelog(DetailHandler[ChangelogItem]):
stamp=i.stamp,
log=i.log,
)
for i in self.filter_queryset(parent.changelog.all())
for i in self.filter_odata_queryset(parent.changelog.all())
]
def get_item(self, parent: 'Model', item: str) -> ChangelogItem:
raise exceptions.rest.NotSupportedError('Single changelog retrieval not implemented inside changelog')
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, models.ServicePool)
return (

View File

@@ -53,6 +53,8 @@ from uds.REST.model import DetailHandler
from .user_services import AssignedUserService, UserServiceItem
if typing.TYPE_CHECKING:
from django.db.models.query import QuerySet
logger = logging.getLogger(__name__)
@@ -100,41 +102,49 @@ class Users(DetailHandler[UserItem]):
'enable_client_logging',
]
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult[UserItem]:
@staticmethod
def as_user_item(user: 'User') -> UserItem:
return UserItem(
id=user.uuid,
name=user.name,
real_name=user.real_name,
comments=user.comments,
state=user.state,
staff_member=user.staff_member,
is_admin=user.is_admin,
last_access=user.last_access,
mfa_data=user.mfa_data,
parent=user.parent,
groups=[i.uuid for i in user.get_groups()],
role=user.get_role().as_str(),
)
def apply_sort(self, qs: 'QuerySet[typing.Any]') -> 'list[typing.Any] | QuerySet[typing.Any]':
if field_info := self.get_sort_field_info('role'):
descending = '-' if field_info[1] else ''
return qs.order_by(f'{descending}is_admin', f'{descending}staff_member')
return super().apply_sort(qs)
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, Authenticator)
def as_user_item(user: 'User') -> UserItem:
return UserItem(
id=user.uuid,
name=user.name,
real_name=user.real_name,
comments=user.comments,
state=user.state,
staff_member=user.staff_member,
is_admin=user.is_admin,
last_access=user.last_access,
mfa_data=user.mfa_data,
parent=user.parent,
groups=[i.uuid for i in user.get_groups()],
role=user.get_role().as_str(),
)
return self.calc_item_position(item_uuid, parent.users.all())
def get_items(self, parent: 'Model') -> types.rest.ItemsResult[UserItem]:
parent = ensure.is_instance(parent, Authenticator)
# Extract authenticator
try:
if item is None: # All users
return [as_user_item(i) for i in self.filter_queryset(parent.users.all())]
return [self.as_user_item(i) for i in self.odata_filter(parent.users.all())]
u = parent.users.get(uuid__iexact=process_uuid(item))
res = as_user_item(u)
usr = AUser(u)
res.groups = [g.db_obj().uuid for g in usr.groups()]
logger.debug('Item: %s', res)
return res
except User.DoesNotExist:
raise exceptions.rest.NotFound(_('User not found')) from None
except Exception as e:
logger.error('Error getting user %s: %s', item, e)
raise exceptions.rest.ResponseError(_('Error getting user')) from e
def get_item(self, parent: 'Model', item: str) -> UserItem:
parent = ensure.is_instance(parent, Authenticator)
db_usr = parent.users.get(uuid__iexact=process_uuid(item))
user_item = self.as_user_item(db_usr)
auth_usr = AUser(db_usr)
user_item.groups = [g.db_obj().uuid for g in auth_usr.groups()]
return user_item
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, Authenticator)
@@ -350,44 +360,38 @@ class GroupItem(types.rest.BaseRestItem):
class Groups(DetailHandler[GroupItem]):
CUSTOM_METHODS = ['services_pools', 'users']
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ItemsResult['GroupItem']:
@staticmethod
def as_group_item(group: 'Group') -> GroupItem:
val = GroupItem(
id=group.uuid,
name=group.name,
comments=group.comments,
state=group.state,
type=group.is_meta and 'meta' or 'group',
meta_if_any=group.meta_if_any,
skip_mfa=group.skip_mfa,
)
if group.is_meta:
val.groups = list(x.uuid for x in group.groups.all().order_by('name'))
return val
def get_item_position(self, parent: 'Model', item_uuid: str) -> int:
parent = ensure.is_instance(parent, Authenticator)
try:
multi = False
if item is None:
multi = True
q = self.filter_queryset(parent.groups.all())
else:
q = parent.groups.filter(uuid=process_uuid(item))
res: list[GroupItem] = []
i = None
for i in q:
val = GroupItem(
id=i.uuid,
name=i.name,
comments=i.comments,
state=i.state,
type=i.is_meta and 'meta' or 'group',
meta_if_any=i.meta_if_any,
skip_mfa=i.skip_mfa,
)
if i.is_meta:
val.groups = list(x.uuid for x in i.groups.all().order_by('name'))
res.append(val)
return self.calc_item_position(item_uuid, parent.groups.all())
if multi:
return res
def get_items(self, parent: 'Model') -> types.rest.ItemsResult['GroupItem']:
parent = ensure.is_instance(parent, Authenticator)
q = self.odata_filter(parent.groups.all())
return [self.as_group_item(i) for i in q]
if not i:
raise exceptions.rest.NotFound(_('Group not found')) from None
# Add pools field if 1 item only
res[0].pools = [v.uuid for v in get_service_pools_for_groups([i])]
return res[0]
except exceptions.rest.HandlerError:
raise # Re-raise
except Exception as e:
logger.error('Group item not found: %s.%s: %s', parent.name, item, e)
raise exceptions.rest.ResponseError(_('Error getting group')) from e
def get_item(self, parent: 'Model', item: str) -> 'GroupItem':
parent = ensure.is_instance(parent, Authenticator)
db_grp = parent.groups.filter(uuid=process_uuid(item)).first()
if not db_grp:
raise exceptions.rest.NotFound(_('Group not found')) from None
grp = self.as_group_item(db_grp)
grp.pools = [v.uuid for v in get_service_pools_for_groups([db_grp])]
return grp
def get_table(self, parent: 'Model') -> types.rest.TableInfo:
parent = ensure.is_instance(parent, Authenticator)

View File

@@ -51,6 +51,9 @@ from ..handlers import Handler
if typing.TYPE_CHECKING:
pass
T = typing.TypeVar('T', bound=models.Model)
logger = logging.getLogger(__name__)
@@ -132,6 +135,22 @@ class BaseModelHandler(Handler, abc.ABC, typing.Generic[types.rest.T_Item]):
return args
def odata_filter(self, qs: models.QuerySet[T]) -> list[T]:
"""
Invoked to filter the queryset according to parameters received
Default implementation does not filter anything
Args:
qs: Queryset to filter
Returns:
Filtered queryset as a list
Note:
This is not final, so we can override it in subclasses if needed
"""
return self.filter_odata_queryset(qs)
# Success methods
def success(self) -> str:
"""

View File

@@ -34,30 +34,37 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import logging
import typing
import collections.abc
import abc
from django.db import models
from django.utils.translation import gettext as _
from uds.core import consts, exceptions, types, module
from uds.core.util.model import process_uuid
from uds.core.util import api as api_utils
from uds.core.util import api as api_utils, model as model_utils
from uds.REST.utils import rest_result
from uds.REST.model.base import BaseModelHandler
from uds.REST.utils import camel_and_snake_case_from
T = typing.TypeVar('T', bound=models.Model)
# Not imported at runtime, just for type checking
if typing.TYPE_CHECKING:
from django.db.models.query import QuerySet
from uds.models import User
from uds.REST.model.master import ModelHandler
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Details do not have types at all
# so, right now, we only process details petitions for Handling & tables info
# noinspection PyMissingConstructor
class DetailHandler(BaseModelHandler[types.rest.T_Item]):
class DetailHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
"""
Detail handler (for relations such as provider-->services, authenticators-->users,groups, deployed services-->cache,assigned, groups, transports
Urls recognized for GET are:
@@ -138,21 +145,19 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
"""
# Process args
logger.debug('Detail args for GET: %s', self._args)
num_args = len(self._args)
parent: models.Model = self._parent_item
if num_args == 0:
return self.get_items(parent, None)
# if has custom methods, look for if this request matches any of them
r = self._check_is_custom_method(self._args[0], parent)
if r is not consts.rest.NOT_FOUND:
return r
match self._args:
case []: # same as overview
return self.get_items(parent)
case [consts.rest.OVERVIEW]:
return self.get_items(parent, None)
return self.get_items(parent)
case [consts.rest.OVERVIEW, *_fails]:
raise exceptions.rest.RequestError('Invalid overview request') from None
case [consts.rest.TYPES]:
@@ -177,8 +182,10 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
return self.get_logs(parent, item_id)
case [consts.rest.LOG, *_fails]:
raise exceptions.rest.RequestError('Invalid log request') from None
case [consts.rest.POSITION, item_uuid]:
return self.get_item_position(parent, item_uuid)
case [one_arg]:
return self.get_items(parent, process_uuid(one_arg))
return self.get_item(parent, process_uuid(one_arg))
case _:
# Maybe a custom method?
r = self._check_is_custom_method(self._args[1], parent, self._args[0])
@@ -247,9 +254,8 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
# Override this to provide functionality
# Default (as sample) get_items
def get_items(
self, parent: models.Model, item: typing.Optional[str]
) -> types.rest.ItemsResult[types.rest.T_Item]:
@abc.abstractmethod
def get_items(self, parent: models.Model) -> types.rest.ItemsResult[types.rest.T_Item]:
"""
This MUST be overridden by derived classes
Excepts to return a list of dictionaries or a single dictionary, depending on "item" param
@@ -261,6 +267,16 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
# return {} # Returns one item
raise NotImplementedError(f'Must provide an get_items method for {self.__class__} class')
@abc.abstractmethod
def get_item(self, parent: models.Model, item: str) -> types.rest.T_Item:
"""
Utility method to get a single item by uuid
:param parent: Parent model
:param item: Item uuid
:return: Item as dictionary
"""
raise NotImplementedError(f'Must provide an get_item method for {self.__class__} class')
# Default save
def save_item(self, parent: models.Model, item: typing.Optional[str]) -> types.rest.T_Item:
"""
@@ -328,6 +344,51 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
"""
return [] # Default is that details do not have types
def get_logs(self, parent: models.Model, item: str) -> list[typing.Any]:
"""
If the detail has any log associated with it items, provide it overriding this method
Args:
parent: Parent model
item: Item id (uuid)
Returns:
A list of log elements (normally got using "uds.core.util.log.get_logs" method)
"""
raise exceptions.rest.InvalidMethodError('Object does not support logs')
def calc_item_position(self, item_uuid: str, qs: 'QuerySet[T]') -> int:
"""
Helper method to get the position of an item in a queryset
Args:
item_uuid (str): UUID of the item to find
qs (QuerySet[T]): Queryset to search into
Returns:
int: Position of the item in the default ordering, -1 if not found
"""
# Find item in qs, may be none, then return -1
obj = qs.filter(uuid__iexact=item_uuid).first()
if obj:
return model_utils.get_position_in_queryset(obj, qs)
return -1
def get_item_position(self, parent: models.Model, item_uuid: str) -> int:
"""
Tries to get the position of an item in the default ordering of the detail items
Args:
item_uuid (str): UUID of the item to find
Returns:
int: Position of the item in the default ordering, -1 if not found
Note:
Override this method if the detail can provide item position
"""
return -1
@classmethod
def possible_types(cls: type[typing.Self]) -> collections.abc.Iterable[type[module.Module]]:
"""
@@ -337,15 +398,6 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
"""
return []
def get_logs(self, parent: models.Model, item: str) -> list[typing.Any]:
"""
If the detail has any log associated with it items, provide it overriding this method
:param parent:
:param item:
:return: a list of log elements (normally got using "uds.core.util.log.get_logs" method)
"""
raise exceptions.rest.InvalidMethodError('Object does not support logs')
@classmethod
def api_components(cls: type[typing.Self]) -> types.rest.api.Components:
"""
@@ -355,10 +407,12 @@ class DetailHandler(BaseModelHandler[types.rest.T_Item]):
return api_utils.get_component_from_type(cls)
@classmethod
def api_paths(cls: type[typing.Self], path: str, tags: list[str], security: str) -> dict[str, types.rest.api.PathItem]:
def api_paths(
cls: type[typing.Self], path: str, tags: list[str], security: str
) -> dict[str, types.rest.api.PathItem]:
"""
Returns the API operations that should be registered
"""
from .api_helpers import api_paths
return api_paths(cls, path, tags=tags, security=security)

View File

@@ -44,7 +44,7 @@ from uds.core import consts
from uds.core import exceptions
from uds.core import types
from uds.core.module import Module
from uds.core.util import log, permissions, api as api_utils
from uds.core.util import log, permissions, model as model_utils, api as api_utils
from uds.models import ManagedObjectModel, Tag, TaggingMixin
from uds.REST.model.base import BaseModelHandler
@@ -198,15 +198,10 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
method = getattr(detail_handler, self._operation)
return method()
except self.MODEL.DoesNotExist:
raise exceptions.rest.NotFound('Item not found on model {self.MODEL.__name__}')
except (KeyError, AttributeError) as e:
raise exceptions.rest.InvalidMethodError(f'Invalid method {self._operation}') from e
except exceptions.rest.HandlerError:
raise
except Exception as e:
logger.error('Exception processing detail: %s', e)
raise exceptions.rest.RequestError(f'Error processing detail: {e}') from e
# Data related
def get_item(self, item: models.Model) -> types.rest.T_Item:
@@ -222,30 +217,44 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
default behavior is return item_as_dict
"""
return self.get_item(item)
def filter_model_queryset(self, qs: QuerySet[T]|None = None) -> QuerySet[T]:
qs = typing.cast('QuerySet[T]', self.MODEL.objects.all()) if qs is None else qs
if self.FILTER is not None:
qs = qs.filter(**self.FILTER)
if self.EXCLUDE is not None:
qs = qs.exclude(**self.EXCLUDE)
return qs
def get_item_position(self, item_uuid: str, query: QuerySet[T] | None = None) -> int:
qs = self.filter_model_queryset(query)
# Find item in qs, may be none, then return -1
obj = qs.filter(uuid__iexact=item_uuid).first()
if obj:
return model_utils.get_position_in_queryset(obj, qs)
return -1
def get_items(
self, *, overview: bool = False, query: QuerySet[T] | None = None
self, *, sumarize: bool = False, query: QuerySet[T] | None = None
) -> typing.Generator[types.rest.T_Item, None, None]:
"""
Get items from the model.
Args:
overview: If True, return a summary of the items.
sumarize: If True, return a summary of the items.
query: Optional queryset to filter the items. Used to optimize the process for some models
(such as ServicePools)
"""
# Basic model filter
if query:
qs = query
else:
qs = self.MODEL.objects.all()
if self.FILTER is not None:
qs = qs.filter(**self.FILTER)
if self.EXCLUDE is not None:
qs = qs.exclude(**self.EXCLUDE)
qs = self.filter_model_queryset(query)
qs = self.filter_queryset(qs)
# Custom filtering from params (odata, etc)
qs = self.odata_filter(qs)
for item in qs:
try:
@@ -259,7 +268,7 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
is False
):
continue
yield self.get_item_summary(item) if overview else self.get_item(item)
yield self.get_item_summary(item) if sumarize else self.get_item(item)
except Exception as e: # maybe an exception is thrown to skip an item
logger.debug('Got exception processing item from model: %s', e)
# logger.exception('Exception getting item from {0}'.format(self.model))
@@ -268,9 +277,6 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
logger.debug('method GET for %s, %s', self.__class__.__name__, self._args)
number_of_args = len(self._args)
if number_of_args == 0:
return list(self.get_items(overview=False))
# if has custom methods, look for if this request matches any of them
for cm in self.CUSTOM_METHODS:
# Convert to snake case
@@ -309,7 +315,7 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
match self._args:
case []: # Same as overview, but with all data
return [i.as_dict() for i in self.get_items(overview=False)]
return [i.as_dict() for i in self.get_items(sumarize=False)]
case [consts.rest.OVERVIEW]:
return [i.as_dict() for i in self.get_items()]
case [consts.rest.OVERVIEW, *_fails]:
@@ -330,6 +336,8 @@ class ModelHandler(BaseModelHandler[types.rest.T_Item], abc.ABC):
return self.get_processed_gui(for_type)
case [consts.rest.GUI, for_type, *_fails]:
raise exceptions.rest.RequestError('Invalid GUI request') from None
case [consts.rest.POSITION, item_uuid]:
return self.get_item_position(item_uuid)
case _: # Maybe an item or a detail
if number_of_args == 1:
try:

View File

@@ -38,6 +38,7 @@ TYPES: typing.Final[str] = 'types'
TABLEINFO: typing.Final[str] = 'tableinfo'
GUI: typing.Final[str] = 'gui'
LOG: typing.Final[str] = 'log'
POSITION: typing.Final[str] = 'position'
SYSTEM: typing.Final[str] = 'system' # Defined on system class, here for reference

View File

@@ -96,7 +96,7 @@ class NotificationsManager(metaclass=singleton.Singleton):
message = message % args
except Exception:
message = message + ' ' + str(args) + ' (format error)'
message = message[:4096] # Max length of message
message = message[:4000] # Max length of message, fixed to ensure it also supports sqlserver
# Store the notification on local persistent storage
# Will be processed by UDS backend
try:

View File

@@ -315,10 +315,14 @@ class Transport(Module):
osname: str,
type: typing.Literal['tunnel', 'direct'],
params: collections.abc.Mapping[str, typing.Any],
client_version: str | None = None,
) -> types.transports.TransportScript:
"""
Returns a script for the given os and type
"""
if (client_version or '0.0') >= '5.0.0':
return self.get_relative_script(f'scripts/{osname.lower()}/{type}.js', params)
return self.get_relative_script(f'scripts/{osname.lower()}/{type}.py', params)
def get_link(

View File

@@ -162,7 +162,7 @@ class ManagedObjectItem(BaseRestItem, typing.Generic[T_Model]):
# Alias for get_items return type
ItemsResult: typing.TypeAlias = list[T_Item] | BaseRestItem | typing.Iterator[T_Item]
ItemsResult: typing.TypeAlias = list[T_Item] | typing.Iterator[T_Item]
@dataclasses.dataclass
@@ -279,6 +279,8 @@ class TableInfo:
fields: list[TableField] # List of fields in the table
row_style: 'RowStyleInfo'
subtitle: typing.Optional[str] = None
filter_fields: list[str] = dataclasses.field(default_factory=list[str])
field_mappings: dict[str, str] = dataclasses.field(default_factory=dict[str, str])
def as_dict(self) -> dict[str, typing.Any]:
return {
@@ -286,6 +288,8 @@ class TableInfo:
'fields': [field.as_dict() for field in self.fields],
'row_style': self.row_style.as_dict(),
'subtitle': self.subtitle or '',
'filter_fields': self.filter_fields,
'field_mappings': self.field_mappings,
}
@staticmethod

View File

@@ -31,18 +31,20 @@ def _as_dict_without_none(v: typing.Any) -> typing.Any:
# (handler, model, detail, etc.)
# So we can override names or whatever we need
# Types of GUI info that can be provided
class RestApiInfoGuiType(enum.Enum):
SINGLE_TYPE = 0
MULTIPLE_TYPES = 1
UNTYPED = 3
def is_single_type(self) -> bool:
return self == RestApiInfoGuiType.SINGLE_TYPE
def supports_multiple_types(self) -> bool:
return self == RestApiInfoGuiType.MULTIPLE_TYPES
@dataclasses.dataclass
class RestApiInfo:
@@ -429,7 +431,9 @@ class ODataParams:
filter=data.get('$filter'),
start=start,
limit=limit,
orderby=order_by,
orderby=[
o.replace('.', '__') for o in order_by
], # Allow order by related fields with dot or __
select=select,
)
except (ValueError, TypeError):

View File

@@ -181,3 +181,18 @@ def get_my_ip_from_db() -> str:
logger.error('Error getting my IP: %s', e)
return '0.0.0.0'
# Tries to get the position of an object in the default queryset ordering
def get_position_in_queryset(obj: typing.Any, queryset: typing.Any) -> int:
"""
Tries to get the position of an object in the default queryset ordering
:param obj: Object to find
:param queryset: Queryset to search in
:return: Position in the queryset (0 based). -1 if not found
"""
try:
lst = list(queryset.values_list('pk', flat=True))
return lst.index(obj.pk)
except ValueError:
return -1

View File

@@ -176,7 +176,7 @@ class DjangoQueryTransformer(lark.Transformer[typing.Any, Q | AnnotatedField]):
if func_name not in _FUNCTIONS_PARAMS_NUM:
raise ValueError(f"Unknown function: {func_name}")
if func_name in ('substringof', 'startswith', 'endswith'):
if func_name in ('substringof', 'contains', 'startswith', 'endswith'):
if len(func_args) != 2:
raise ValueError(f"{func_name} requires 2 arguments")
field, value = func_args
@@ -186,6 +186,8 @@ class DjangoQueryTransformer(lark.Transformer[typing.Any, Q | AnnotatedField]):
raise ValueError(f"Function '{func_name}' does not support field-to-field comparison")
match func_name:
case 'substringof':
return Q(**{f"{value}__icontains": field}) # Note the order swap
case 'contains':
return Q(**{f"{field}__icontains": value})
case 'startswith':
return Q(**{f"{field}__istartswith": value})

View File

@@ -399,6 +399,8 @@ class TableBuilder:
subtitle: str | None
fields: list[types.rest.TableField]
style_info: types.rest.RowStyleInfo
filter_fields: list[str]
field_mappings: dict[str, str]
def __init__(self, title: str, subtitle: str | None = None) -> None:
# TODO: USe table_name on a later iteration of the code
@@ -406,6 +408,8 @@ class TableBuilder:
self.subtitle = subtitle
self.fields = []
self.style_info = types.rest.RowStyleInfo.null()
self.filter_fields = []
self.field_mappings = {}
def _add_field(
self,
@@ -513,6 +517,20 @@ class TableBuilder:
self.style_info = types.rest.RowStyleInfo(prefix=prefix, field=field)
return self
def with_filter_fields(self, *fields: str) -> typing.Self:
"""
Sets the sorting fields for the table fields.
"""
self.filter_fields = list(fields)
return self
def with_field_mappings(self, **kwargs: str) -> typing.Self:
"""
Sets the filter fields translations for the table fields.
"""
self.field_mappings = kwargs
return self
def build(self) -> types.rest.TableInfo:
"""
Returns the table info for the table fields.
@@ -522,4 +540,6 @@ class TableBuilder:
fields=self.fields,
row_style=self.style_info,
subtitle=self.subtitle,
filter_fields=self.filter_fields,
field_mappings=self.field_mappings,
)

View File

@@ -163,7 +163,6 @@ class UniqueGenerator:
logger.debug('Last: %s', last)
seq = last.seq + 1
except Exception:
# logger.exception('Error here')
seq = 0
with transaction.atomic():
self._range_filter(seq).delete() # Clean ups all unassigned after last assigned in this range

View File

@@ -34,7 +34,6 @@ import urllib.parse
import logging
import requests
import time
import token
from uds.core.util import security
from uds.core.util.cache import Cache
@@ -42,10 +41,8 @@ from uds.core.util.decorators import cached
from . import types, consts, exceptions
logger = logging.getLogger(__name__)
class OpenshiftClient:
cluster_url: str
api_url: str
@@ -104,7 +101,7 @@ class OpenshiftClient:
except Exception as ex:
logging.error(f"Could not obtain token: {ex}")
raise
def connect(self, force: bool = False) -> requests.Session:
# For testing, always use the fixed token
session = self._session = security.secure_requests_session(verify=self._verify_ssl)
@@ -296,33 +293,28 @@ class OpenshiftClient:
"""
Monitor the clone process of a virtual machine.
"""
clone_url = f"{api_url}/apis/clone.kubevirt.io/v1alpha1/namespaces/{namespace}/virtualmachineclones/{clone_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
path = f"/apis/clone.kubevirt.io/v1alpha1/namespaces/{namespace}/virtualmachineclones/{clone_name}"
logging.info("Monitoring clone process for '%s'...", clone_name)
while True:
try:
response = requests.get(clone_url, headers=headers, verify=False)
if response.status_code == 200:
clone_data = response.json()
status = clone_data.get('status', {})
phase = status.get('phase', 'Unknown')
logging.info("Phase: %s", phase)
for condition in status.get('conditions', []):
ctype = condition.get('type', '')
cstatus = condition.get('status', '')
cmsg = condition.get('message', '')
logging.info(" %s: %s - %s", ctype, cstatus, cmsg)
if phase == 'Succeeded':
logging.info("Clone '%s' completed successfully!", clone_name)
break
elif phase == 'Failed':
logging.error("Clone '%s' failed!", clone_name)
break
elif response.status_code == 404:
logging.warning("Clone resource '%s' not found. May have been cleaned up.", clone_name)
response = self.do_request('GET', path)
status = response.get('status', {})
phase = status.get('phase', 'Unknown')
logging.info("Phase: %s", phase)
for condition in status.get('conditions', []):
ctype = condition.get('type', '')
cstatus = condition.get('status', '')
cmsg = condition.get('message', '')
logging.info(" %s: %s - %s", ctype, cstatus, cmsg)
if phase == 'Succeeded':
logging.info("Clone '%s' completed successfully!", clone_name)
break
else:
logging.error("Error monitoring clone: %d", response.status_code)
elif phase == 'Failed':
logging.error("Clone '%s' failed!", clone_name)
break
except exceptions.OpenshiftNotFoundError:
logging.warning("Clone resource '%s' not found. May have been cleaned up.", clone_name)
break
except Exception as e:
logging.error("Monitoring exception: %s", e)
logging.info("Waiting %d seconds before next check...", polling_interval)
@@ -332,19 +324,16 @@ class OpenshiftClient:
"""
Returns the name of the PVC or DataVolume used by the VM.
"""
vm_url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
response = requests.get(vm_url, headers=headers, verify=False)
if response.status_code == 200:
vm_obj = response.json()
volumes = vm_obj.get("spec", {}).get("template", {}).get("spec", {}).get("volumes", [])
for vol in volumes:
pvc = vol.get("persistentVolumeClaim")
if pvc:
return pvc.get("claimName"), "pvc"
dv = vol.get("dataVolume")
if dv:
return dv.get("name"), "dv"
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
response = self.do_request('GET', path)
volumes = response.get("spec", {}).get("template", {}).get("spec", {}).get("volumes", [])
for vol in volumes:
pvc = vol.get("persistentVolumeClaim")
if pvc:
return pvc.get("claimName"), "pvc"
dv = vol.get("dataVolume")
if dv:
return dv.get("name"), "dv"
raise Exception(f"No PVC or DataVolume found in VM {vm_name}")
def get_datavolume_phase(self, datavolume_name: str) -> str:
@@ -352,13 +341,10 @@ class OpenshiftClient:
Get the phase of a DataVolume.
Returns the phase as a string.
"""
url = f"{self.api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{self.namespace}/datavolumes/{datavolume_name}"
headers = {'Authorization': f'Bearer {self.get_token()}', 'Accept': 'application/json'}
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{self.namespace}/datavolumes/{datavolume_name}"
try:
response = requests.get(url, headers=headers, verify=self._verify_ssl, timeout=self._timeout)
if response.status_code == 200:
dv = response.json()
return dv.get('status', {}).get('phase', '')
response = self.do_request('GET', path)
return response.get('status', {}).get('phase', '')
except Exception:
pass
return ''
@@ -368,17 +354,14 @@ class OpenshiftClient:
Get the size of a DataVolume.
Returns the size as a string.
"""
url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{dv_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
response = requests.get(url, headers=headers, verify=False)
if response.status_code == 200:
dv = response.json()
size = dv.get("status", {}).get("amount", None)
if size:
return size
return (
dv.get("spec", {}).get("pvc", {}).get("resources", {}).get("requests", {}).get("storage") or ""
)
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{dv_name}"
response = self.do_request('GET', path)
size = response.get("status", {}).get("amount", None)
if size:
return size
return (
response.get("spec", {}).get("pvc", {}).get("resources", {}).get("requests", {}).get("storage") or ""
)
raise Exception(f"Could not get the size of DataVolume {dv_name}")
def get_pvc_size(self, api_url: str, namespace: str, pvc_name: str) -> str:
@@ -386,14 +369,11 @@ class OpenshiftClient:
Get the size of a PVC.
Returns the size as a string.
"""
url = f"{api_url}/api/v1/namespaces/{namespace}/persistentvolumeclaims/{pvc_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
response = requests.get(url, headers=headers, verify=False)
if response.status_code == 200:
pvc = response.json()
capacity = pvc.get("status", {}).get("capacity", {}).get("storage")
if capacity:
return capacity
path = f"/api/v1/namespaces/{namespace}/persistentvolumeclaims/{pvc_name}"
response = self.do_request('GET', path)
capacity = response.get("status", {}).get("capacity", {}).get("storage")
if capacity:
return capacity
raise Exception(f"Could not get the size of PVC {pvc_name}")
def clone_pvc_with_datavolume(
@@ -409,12 +389,7 @@ class OpenshiftClient:
Clone a PVC using a DataVolume.
Returns True if the DataVolume was created successfully, else False.
"""
dv_url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes"
headers = {
"Authorization": f"Bearer {self.get_token()}",
"Accept": "application/json",
"Content-Type": "application/json",
}
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes"
body: dict[str, typing.Any] = {
"apiVersion": "cdi.kubevirt.io/v1beta1",
"kind": "DataVolume",
@@ -428,12 +403,13 @@ class OpenshiftClient:
},
},
}
response = requests.post(dv_url, headers=headers, json=body, verify=False)
if response.status_code == 201:
try:
self.do_request('POST', path, data=body)
logging.info(f"DataVolume '{cloned_pvc_name}' created successfully")
return True
logging.error(f"Failed to create DataVolume: {response.status_code} {response.text}")
return False
except Exception as e:
logging.error(f"Failed to create DataVolume: {e}")
return False
def create_vm_from_pvc(
self,
@@ -448,16 +424,13 @@ class OpenshiftClient:
Create a new VM from a cloned PVC using DataVolumeTemplates.
Returns True if the VM was created successfully, else False.
"""
original_vm_url = (
f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{source_vm_name}"
)
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
resp = requests.get(original_vm_url, headers=headers, verify=False)
if resp.status_code != 200:
logging.error(f"Could not get source VM: {resp.status_code} {resp.text}")
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{source_vm_name}"
try:
vm_obj = self.do_request('GET', path)
except Exception as e:
logging.error(f"Could not get source VM: {e}")
return False
vm_obj = resp.json()
vm_obj['metadata']['name'] = new_vm_name
for k in ['resourceVersion', 'uid', 'selfLink']:
@@ -503,28 +476,27 @@ class OpenshiftClient:
logger.info(f"Creating VM '{new_vm_name}' from cloned PVC '{new_dv_name}'.")
#logger.info(f"VM Object: {vm_obj}")
create_url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines"
headers["Content-Type"] = "application/json"
resp = requests.post(create_url, headers=headers, json=vm_obj, verify=False)
if resp.status_code == 201:
create_path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines"
try:
self.do_request('POST', create_path, data=vm_obj)
logging.info(f"VM '{new_vm_name}' created successfully with DataVolumeTemplate.")
return True
logging.error(f"Error creating VM: {resp.status_code} {resp.text}")
return False
except Exception as e:
logging.error(f"Error creating VM: {e}")
return False
def delete_vm(self, api_url: str, namespace: str, vm_name: str) -> bool:
"""
Delete a VM by name.
Returns True if the VM was deleted successfully, else False.
"""
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
response = requests.delete(url, headers=headers, verify=False)
if response.status_code in [200, 202]:
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
try:
self.do_request('DELETE', path)
logging.info(f"VM {vm_name} deleted successfully.")
return True
else:
logging.error(f"Error deleting VM {vm_name}: {response.status_code} - {response.text}")
except Exception as e:
logging.error(f"Error deleting VM {vm_name}: {e}")
return False
def wait_for_datavolume_clone_progress(
@@ -534,14 +506,12 @@ class OpenshiftClient:
Wait for a DataVolume clone to complete.
Returns True if the clone completed successfully, else False.
"""
url = f"{api_url}/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{datavolume_name}"
headers = {"Authorization": f"Bearer {self.get_token()}", "Accept": "application/json"}
path = f"/apis/cdi.kubevirt.io/v1beta1/namespaces/{namespace}/datavolumes/{datavolume_name}"
start = time.time()
while time.time() - start < timeout:
response = requests.get(url, headers=headers, verify=False)
if response.status_code == 200:
dv = response.json()
status = dv.get('status', {})
try:
response = self.do_request('GET', path)
status = response.get('status', {})
phase = status.get('phase')
progress = status.get('progress', 'N/A')
logging.info(f"DataVolume {datavolume_name} status: {phase}, progress: {progress}")
@@ -551,8 +521,8 @@ class OpenshiftClient:
elif phase == 'Failed':
logging.error(f"DataVolume {datavolume_name} clone failed")
return False
else:
logging.error(f"Error querying DataVolume {datavolume_name}: {response.status_code}")
except Exception as e:
logging.error(f"Error querying DataVolume {datavolume_name}: {e}")
time.sleep(polling_interval)
logging.error(f"Timeout waiting for DataVolume {datavolume_name} clone")
return False
@@ -562,40 +532,47 @@ class OpenshiftClient:
Start a VM by name.
Returns True if the VM was started successfully, else False.
"""
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
headers = {
"Authorization": f"Bearer {self.get_token()}",
"Content-Type": "application/merge-patch+json",
"Accept": "application/json",
}
body: dict[str, typing.Any] = {"spec": {"runStrategy": "Always"}}
response = requests.patch(url, headers=headers, json=body, verify=False)
if response.status_code in [200, 201]:
logging.info(f"VM {vm_name} started.")
return True
else:
logging.info(f"Error starting VM {vm_name}: {response.status_code} - {response.text}")
# Get Vm info
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
try:
vm_obj = self.do_request('GET', path)
except Exception as e:
logging.error(f"Could not get source VM: {e}")
return False
# Update runStrategy to Always
vm_obj['spec']['runStrategy'] = 'Always'
try:
self.do_request('PUT', path, data=vm_obj)
logging.info(f"VM {vm_name} will be started.")
return True
except Exception as e:
logging.info(f"Error starting VM {vm_name}: {e}")
return False
def stop_vm(self, api_url: str, namespace: str, vm_name: str) -> bool:
"""
Stop a VM by name.
Returns True if the VM was stopped successfully, else False.
"""
url = f"{api_url}/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
headers = {
"Authorization": f"Bearer {self.get_token()}",
"Content-Type": "application/merge-patch+json",
"Accept": "application/json",
}
body: dict[str, typing.Any] = {"spec": {"runStrategy": "Halted"}}
response = requests.patch(url, headers=headers, json=body, verify=False)
if response.status_code in [200, 201]:
# Get Vm info
path = f"/apis/kubevirt.io/v1/namespaces/{namespace}/virtualmachines/{vm_name}"
try:
vm_obj = self.do_request('GET', path)
except Exception as e:
logging.error(f"Could not get source VM: {e}")
return False
# Update runStrategy to Halted
vm_obj['spec']['runStrategy'] = 'Halted'
try:
self.do_request('PUT', path, data=vm_obj)
logging.info(f"VM {vm_name} will be stopped.")
return True
else:
logging.info(f"Error stopping VM {vm_name}: {response.status_code} - {response.text}")
return False
except Exception as e:
logging.info(f"Error starting VM {vm_name}: {e}")
return False
def copy_vm_same_size(
self, api_url: str, namespace: str, source_vm_name: str, new_vm_name: str, storage_class: str
@@ -687,22 +664,3 @@ class OpenshiftClient:
except Exception as e:
logging.error(f"Error deleting VM: {e}")
return False
def clone_vm_instance(self, source_vm_name: str, new_vm_name: str, storage_class: str) -> bool:
"""
Clone a VM by name, creating a new VM with the same size.
Returns True if clone succeeded, False otherwise.
"""
try:
self.copy_vm_same_size(self.api_url, self.namespace, source_vm_name, new_vm_name, storage_class)
return True
except Exception as e:
logging.error(f"Error cloning VM: {e}")
return False
@staticmethod
def validate_vm_id(vm_id: str | int) -> None:
try:
int(vm_id)
except ValueError:
raise exceptions.OpenshiftNotFoundError(f'VM {vm_id} not found')

View File

@@ -126,10 +126,17 @@ class OpenshiftProvider(ServiceProvider):
def sanitized_name(self, name: str) -> str:
"""
Sanitizes the VM name to comply with RFC 1123:
- Lowercase
- Alphanumeric, '-', '.'
- Starts/ends with alphanumeric
- Max length 63 chars
- Converts to lowercase
- Replaces any character not in [a-z0-9.-] with '-'
- Collapses multiple '-' into one
- Removes leading/trailing non-alphanumeric characters
- Limits length to 63 characters
"""
name = re.sub(r'^[^a-z0-9]+|[^a-z0-9.-]|-{2,}|[^a-z0-9]+$', '-', name.lower())
return name[:63]
name = name.lower()
# Replace any character not allowed with '-'
name = re.sub(r'[^a-z0-9.-]', '-', name)
# Collapse multiple '-' into one
name = re.sub(r'-{2,}', '-', name)
# Remove leading/trailing non-alphanumeric characters
name = re.sub(r'^[^a-z0-9]+|[^a-z0-9]+$', '', name)
return name[:63]

View File

@@ -52,7 +52,7 @@ class TestServiceNoCache(services.Service):
Basic testing service without cache and no publication OFC
"""
type_name = _('Testing Service no cache')
type_type = 'TestService1'
type_type = 'TestService2'
type_description = _('Testing (and dummy) service with no cache')
icon_file = 'service.png'
@@ -84,7 +84,7 @@ class TestServiceCache(services.Service):
"""
type_name = _('Testing Service WITH cache')
type_type = 'TestService2'
type_type = 'TestService1'
type_description = _('Testing (and dummy) service with CACHE and PUBLICATION')
icon_file = 'provider.png' # : We reuse provider icon here :-), it's just for testing purpuoses

File diff suppressed because one or more lines are too long

View File

@@ -66,6 +66,8 @@ gettext("Change owner");
gettext("Assign service");
gettext("Cancel");
gettext("Changelog");
gettext("Fallback: Allow");
gettext("Fallback: Deny");
gettext("Delete assigned service");
gettext("Delete cached service");
gettext("Service pool is locked");

View File

@@ -102,6 +102,6 @@
</svg>
</div>
</uds-root>
<link rel="modulepreload" href="/uds/res/admin/chunk-2F3F2YC2.js?stamp=1759963950" integrity="sha384-VVOra5xy5Xg9fYkBmK9MLhX7vif/MexRAaLIDBsQ4ZlkF31s/U6uWWrj+LAnvX/q"><script src="/uds/res/admin/polyfills.js?stamp=1759963950" type="module" crossorigin="anonymous" integrity="sha384-TVRkn44wOGJBeCKWJBHWLvXubZ+Julj/yA0OoEFa3LgJHVHaPeeATX6NcjuNgsIA"></script><script src="/uds/res/admin/main.js?stamp=1759963950" type="module" crossorigin="anonymous" integrity="sha384-U5YvgqT/kogPnBfS+4CA2tP+ICQhuQPuL6EnLHKwmcFpPQBI4Ndzwqiuv4gm5YCr"></script></body>
<link rel="modulepreload" href="/uds/res/admin/chunk-2F3F2YC2.js?stamp=1762813175" integrity="sha384-VVOra5xy5Xg9fYkBmK9MLhX7vif/MexRAaLIDBsQ4ZlkF31s/U6uWWrj+LAnvX/q"><script src="/uds/res/admin/polyfills.js?stamp=1762813175" type="module" crossorigin="anonymous" integrity="sha384-TVRkn44wOGJBeCKWJBHWLvXubZ+Julj/yA0OoEFa3LgJHVHaPeeATX6NcjuNgsIA"></script><script src="/uds/res/admin/main.js?stamp=1762813175" type="module" crossorigin="anonymous" integrity="sha384-4RLNNtXgmAW+M5Uzni5JSekmZJ6MDtfZR2d7Cw0nQECcN+XAYbAUnTZrGx/sEc8N"></script></body>
</html>

View File

@@ -0,0 +1,54 @@
'use strict';
// Info for lintering tools about the variables provided by uds client
var Process, Logger, File, Utils, Tasks;
// We receive data in "data" variable, which is an object from json readonly
var data;
const errorString = `You need to have xfreerdp or Thincast installed and in path for this to work.</p>
<p>Please, install the proper package for your system.</p>
<ul>
<li>xfreerdp: <a href="https://github.com/FreeRDP/FreeRDP">Download</a></li>
<li>Thincast: <a href="https://thincast.com/en/products/client">Download</a></li>
</ul>`;
// Try, in order of preference, to find other RDP clients
const executablePath =
Process.findExecutable('udsrdp') ||
Process.findExecutable('thincast-remote-desktop-client') ||
Process.findExecutable('thincast-client') ||
Process.findExecutable('thincast') ||
Process.findExecutable('xfreerdp3') ||
Process.findExecutable('xfreerdp') ||
Process.findExecutable('xfreerdp');
if (!executablePath) {
Logger.error('No RDP client found on system');
throw new Error(errorString);
}
// using Utils.expandVars, expand variables of data.freerdp_params (that is an array of strings)
let parameters = data.freerdp_params.map((param) => Utils.expandVars(param));
let process = null;
// If has the as_file property, create the temp file on home folder and use it
if (data.as_file) {
Logger.debug('Has as_file property, creating temp RDP file');
// Create and save the temp file
rdpFilePath = File.createTempFile(File.getHomeDirectory(), data.as_file, '.rdp');
Logger.debug(`RDP temp file created at ${rdpFilePath}`);
// Append to removable task to delete the file later
Tasks.addEarlyUnlinkableFile(rdpFilePath);
let password = data.password ? `/p:${data.password}` : '/p:';
// Launch the RDP client with the temp file
process = Process.launch(executablePath, [rdpFilePath, password]); // the addres in INSIDE the file is already set to
} else {
// Launch the RDP client with the parameters
process = Process.launch(executablePath, [...parameters, `/v:${data.address}`]);
}
Tasks.addWaitableApp(process);

View File

@@ -0,0 +1,64 @@
'use strict';
// Info for lintering tools about the variables provided by uds client
var Process, Logger, File, Utils, Tasks;
// We receive data in "data" variable, which is an object from json readonly
var data;
const errorString = `You need to have xfreerdp or Thincast installed and in path for this to work.</p>
<p>Please, install the proper package for your system.</p>
<ul>
<li>xfreerdp: <a href="https://github.com/FreeRDP/FreeRDP">Download</a></li>
<li>Thincast: <a href="https://thincast.com/en/products/client">Download</a></li>
</ul>`;
// Try, in order of preference, to find other RDP clients
const executablePath =
Process.findExecutable('udsrdp') ||
Process.findExecutable('thincast-remote-desktop-client') ||
Process.findExecutable('thincast-client') ||
Process.findExecutable('thincast') ||
Process.findExecutable('xfreerdp3') ||
Process.findExecutable('xfreerdp') ||
Process.findExecutable('xfreerdp');
if (!executablePath) {
Logger.error('No RDP client found on system');
throw new Error('No RDP client found on system');
}
// using Utils.expandVars, expand variables of data.freerdp_params (that is an array of strings)
let parameters = data.freerdp_params.map((param) => Utils.expandVars(param));
let tunnel = null;
try {
tunnel = await Tasks.startTunnel(data.tunHost, data.tunPort, data.ticket, null, data.tunChk);
} catch (error) {
Logger.error(`Failed to start tunnel: ${error.message}`);
throw new Error(`Failed to start tunnel: ${error.message}`);
}
let process = null;
// If has the as_file property, create the temp file on home folder and use it
if (data.as_file) {
Logger.debug('Has as_file property, creating temp RDP file');
// Replace "{address}" with data.address in the as_file content
let content = data.as_file.replace(/\{address\}/g, `127.0.0.1:${tunnel.port}`);
// Create and save the temp file
rdpFilePath = File.createTempFile(File.getHomeDtartirectory(), content, '.rdp');
Logger.debug(`RDP temp file created at ${rdpFilePath}`);
// Append to removable task to delete the file later
Tasks.addEarlyUnlinkableFile(rdpFilePath);
let password = data.password ? `/p:${data.password}` : '';
// Launch the RDP client with the temp file, the addres in INSIDE the file is already set to
process = Process.launch(executablePath, [rdpFilePath, password]);
} else {
// Launch the RDP client with the parameters
process = Process.launch(executablePath, [...parameters, `/v:127.0.0.1:${tunnel.port}`]);
}
Tasks.addWaitableApp(process);

View File

@@ -124,7 +124,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
# logoutData = {
# 'token': 'server token', # Must be present on all events
# 'type': 'login', # MUST BE PRESENT
# 'user_service': 'uuid', # MUST BE PRESENT
# 'userservice_uuid': 'uuid', # MUST BE PRESENT
# 'username': 'username', # Optional
# }
response = self.client.rest_post(
@@ -132,7 +132,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
data={
'token': self.server.token,
'type': 'logout',
'user_service': self.user_service_managed.uuid,
'userservice_uuid': self.user_service_managed.uuid,
'username': 'local_user_name',
'session_id': '',
},
@@ -146,7 +146,7 @@ class ServerEventsLoginLogoutTest(rest.test.RESTTestCase):
data={
'token': self.server.token,
'type': 'login',
'user_service': 'invalid uuid',
'userservice_uuid': 'invalid uuid',
'username': 'local_user_name',
},
)

View File

@@ -0,0 +1,348 @@
# -*- coding: utf-8 -*-
"""
Test fixtures for OpenShift service tests.
Provides reusable functions and mock objects for unit testing OpenShift provider, service, deployment, publication, and user service logic.
All functions are designed to be used across multiple test modules for consistency and maintainability.
"""
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import contextlib
import copy
import functools
import random
import typing
from unittest import mock
import uuid
from uds.core import environment
from uds.core.ui.user_interface import gui
from uds.models.user import User
from uds.services.OpenShift import service, service_fixed, provider, publication, deployment, deployment_fixed
from uds.services.OpenShift.openshift import types as openshift_types, exceptions as openshift_exceptions
DEF_VMS: list[openshift_types.VM] = [
openshift_types.VM(
name=f'vm-{i}',
namespace='default',
uid=f'uid-{i}',
status=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
volume_template=openshift_types.VolumeTemplate(name=f'volume-{i}', storage='10Gi'),
disks=[openshift_types.DeviceDisk(name=f'disk-{i}', boot_order=1)],
volumes=[openshift_types.Volume(name=f'volume-{i}', data_volume=f'dv-{i}')],
)
for i in range(1, 11)
]
DEF_VM_INSTANCES: list[openshift_types.VMInstance] = [
openshift_types.VMInstance(
name=f'vm-{i}',
namespace='default',
uid=f'uid-instance-{i}',
interfaces=[
openshift_types.Interface(
name='eth0',
mac_address=f'00:11:22:33:44:{i:02x}',
ip_address=f'192.168.1.{i}',
)
],
status=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
phase=openshift_types.VMStatus.STOPPED if i % 2 == 0 else openshift_types.VMStatus.RUNNING,
)
for i in range(1, 11)
]
# clone values to avoid modifying the original ones
VMS: list[openshift_types.VM] = copy.deepcopy(DEF_VMS)
VM_INSTANCES: list[openshift_types.VMInstance] = copy.deepcopy(DEF_VM_INSTANCES)
def clear() -> None:
"""
Reset all VM and VMInstance values to their default state.
Use this before each test to ensure a clean environment.
"""
VMS[:] = copy.deepcopy(DEF_VMS)
VM_INSTANCES[:] = copy.deepcopy(DEF_VM_INSTANCES)
def replace_vm_info(vm_name: str, **kwargs: typing.Any) -> None:
"""
Update attributes of a VM in VMS by name.
Raises OpenshiftNotFoundError if VM is not found.
"""
try:
vm = next(vm for vm in VMS if vm.name == vm_name)
for k, v in kwargs.items():
setattr(vm, k, v)
except Exception:
raise openshift_exceptions.OpenshiftNotFoundError(f'VM {vm_name} not found')
def replacer_vm_info(**kwargs: typing.Any) -> typing.Callable[..., None]:
"""
Returns a partial function to update VM info with preset kwargs.
Useful for patching or repeated updates in tests.
"""
return functools.partial(replace_vm_info, **kwargs)
T = typing.TypeVar('T')
def returner(value: T, *args: typing.Any, **kwargs: typing.Any) -> typing.Callable[..., T]:
"""
Returns a function that always returns the given value.
Useful for mocking return values in tests.
"""
def inner(*args: typing.Any, **kwargs: typing.Any) -> T:
return value
return inner
# Provider values
PROVIDER_VALUES_DICT: gui.ValuesDictType = {
'cluster_url': 'https://oauth-openshift.apps-crc.testing',
'api_url': 'https://api.crc.testing:6443',
'username': 'kubeadmin',
'password': 'test-password',
'namespace': 'default',
'verify_ssl': False,
'concurrent_creation_limit': 1,
'concurrent_removal_limit': 1,
'timeout': 10,
}
# Service values
SERVICE_VALUES_DICT: gui.ValuesDictType = {
'template': VMS[0].name,
'basename': 'base',
'lenname': 4,
'publication_timeout': 120,
'prov_uuid': '',
}
# Service fixed values
SERVICE_FIXED_VALUES_DICT: gui.ValuesDictType = {
'token': '',
'machines': [VMS[2].name, VMS[3].name, VMS[4].name],
'on_logout': 'no',
'randomize': False,
'maintain_on_error': False,
'prov_uuid': '',
}
def create_client_mock() -> mock.Mock:
"""
Create a MagicMock for OpenshiftClient with default behaviors and side effects.
Used to simulate API responses in provider/service tests.
"""
client = mock.MagicMock()
# Prepare deep copies of default data
client.test.return_value = True
client.list_vms.return_value = copy.deepcopy(DEF_VMS)
client.start_vm_instance.return_value = True
client.stop_vm_instance.return_value = True
client.delete_vm_instance.return_value = True
client.get_datavolume_phase.return_value = "Succeeded"
client.get_vm_pvc_or_dv_name.return_value = ("test-pvc", "pvc")
client.get_pvc_size.return_value = "10Gi"
client.create_vm_from_pvc.return_value = True
client.wait_for_datavolume_clone_progress.return_value = True
def get_vm_info_side_effect(vm_name: str, **kwargs: typing.Any) -> openshift_types.VM | None:
for vm in VMS:
if vm.name == vm_name:
return vm
return None
def get_vm_instance_info_side_effect(vm_name: str, **kwargs: typing.Any) -> openshift_types.VMInstance | None:
for inst in VM_INSTANCES:
if inst.name == vm_name:
return inst
return None
client.get_vm_info.side_effect = get_vm_info_side_effect
client.get_vm_instance_info.side_effect = get_vm_instance_info_side_effect
return client
@contextlib.contextmanager
def patched_provider(**kwargs: typing.Any) -> typing.Generator[provider.OpenshiftProvider, None, None]:
"""
Context manager that yields a provider with a patched OpenshiftClient mock.
Use this to ensure all API calls are intercepted and controlled in tests.
"""
client = create_client_mock()
prov = create_provider(**kwargs)
prov._cached_api = client
yield prov
def create_provider(**kwargs: typing.Any) -> provider.OpenshiftProvider:
"""
Create an OpenshiftProvider instance with default or overridden values.
Used for provider-level tests and as a dependency for other fixtures.
"""
values = PROVIDER_VALUES_DICT.copy()
values.update(kwargs)
uuid_ = str(uuid.uuid4())
return provider.OpenshiftProvider(
environment=environment.Environment.private_environment(uuid_), values=values, uuid=uuid_
)
def create_service(
provider: typing.Optional[provider.OpenshiftProvider] = None, **kwargs: typing.Any
) -> service.OpenshiftService:
"""
Create an OpenshiftService instance (dynamic service).
Used for service-level tests and as a dependency for user services and publications.
"""
uuid_ = str(uuid.uuid4())
values = SERVICE_VALUES_DICT.copy()
values.update(kwargs)
srvc = service.OpenshiftService(
environment=environment.Environment.private_environment(uuid_),
provider=provider or create_provider(),
values=values,
uuid=uuid_,
)
return srvc
def create_service_fixed(
provider: typing.Optional[provider.OpenshiftProvider] = None, **kwargs: typing.Any
) -> service_fixed.OpenshiftServiceFixed:
"""
Create an OpenshiftServiceFixed instance (fixed service).
Used for fixed service tests and as a dependency for fixed user services.
"""
uuid_ = str(uuid.uuid4())
values = SERVICE_FIXED_VALUES_DICT.copy()
values.update(kwargs)
return service_fixed.OpenshiftServiceFixed(
environment=environment.Environment.private_environment(uuid_),
provider=provider or create_provider(),
values=values,
uuid=uuid_,
)
def create_publication(
service: typing.Optional[service.OpenshiftService] = None,
**kwargs: typing.Any,
) -> publication.OpenshiftTemplatePublication:
"""
Create an OpenshiftTemplatePublication instance.
Used for publication-level tests and as a dependency for user services.
"""
uuid_ = str(uuid.uuid4())
pub = publication.OpenshiftTemplatePublication(
environment=environment.Environment.private_environment(uuid_),
service=service or create_service(**kwargs),
revision=1,
servicepool_name='servicepool_name',
uuid=uuid_,
)
pub._name = f"pub-{random.randint(1000, 9999)}"
return pub
def create_userservice(
service: typing.Optional[service.OpenshiftService] = None,
publication: typing.Optional[publication.OpenshiftTemplatePublication] = None,
) -> deployment.OpenshiftUserService:
"""
Create an OpenshiftUserService instance (dynamic user service).
Used for user service tests that require a publication and service.
"""
uuid_ = str(uuid.uuid4())
return deployment.OpenshiftUserService(
environment=environment.Environment.private_environment(uuid_),
service=service or create_service(),
publication=publication or create_publication(),
uuid=uuid_,
)
def create_userservice_fixed(
service: typing.Optional[service_fixed.OpenshiftServiceFixed] = None,
) -> deployment_fixed.OpenshiftUserServiceFixed:
"""
Create an OpenshiftUserServiceFixed instance (fixed user service).
Used for tests of fixed user service logic and lifecycle.
"""
uuid_ = str(uuid.uuid4().hex)
return deployment_fixed.OpenshiftUserServiceFixed(
environment=environment.Environment.private_environment(uuid_),
service=service or create_service_fixed(),
publication=None,
uuid=uuid_,
)
def create_user(
name: str = "testuser",
real_name: str = "Test User",
is_admin: bool = False,
state: str = 'A',
password: str = 'password',
mfa_data: str = '',
staff_member: bool = False,
last_access: typing.Optional[str] = None,
parent: typing.Optional[User] = None,
created: typing.Optional[str] = None,
comments: str = '',
) -> User:
"""
Create a mock User instance for testing.
All fields can be customized for specific test scenarios.
"""
user = mock.Mock(spec=User)
user.name = name
user.real_name = real_name
user.is_admin = is_admin
user.state = state
user.password = password
user.mfa_data = mfa_data
user.staff_member = staff_member
user.last_access = last_access
user.parent = parent
user.created = created
user.comments = comments
return user

View File

@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import logging
from uds.services.OpenShift.openshift import client as openshift_client
from tests.utils.test import UDSTransactionTestCase
from tests.utils import vars
logger = logging.getLogger(__name__)
class TestOpenshiftClient(UDSTransactionTestCase):
"""Tests for operations with OpenShiftClient."""
os_client: openshift_client.OpenshiftClient
test_vm: str = ''
test_pool: str = ''
test_storage: str = ''
def setUp(self) -> None:
"""
Set up OpenShift client and test variables for each test.
Skips tests if required variables are missing.
"""
v = vars.get_vars('openshift')
if not v:
self.skipTest('No OpenShift test variables found')
self.os_client = openshift_client.OpenshiftClient(
cluster_url=v['cluster_url'],
api_url=v['api_url'],
username=v['username'],
password=v['password'],
namespace=v['namespace'],
timeout=int(v['timeout']),
verify_ssl=v['verify_ssl'] == 'true',
)
self.test_vm = v.get('test_vm', '')
self.test_pool = v.get('test_pool', '')
self.test_storage = v.get('test_storage', '')
# --- Token/API Tests ---
def test_get_token(self) -> None:
"""
Test that get_token returns a valid token string.
"""
token = self.os_client.get_token()
self.assertIsNotNone(token)
def test_get_api_url(self) -> None:
"""
Test that get_api_url constructs a valid URL with path and parameters.
"""
url = self.os_client.get_api_url('/test/path', ('param1', 'value1'))
self.assertIn('/test/path', url)
self.assertIn('param1=value1', url)
def test_get_api_url_invalid(self):
"""
Test that get_api_url works with an invalid path.
"""
url = self.os_client.get_api_url('/invalid/path', ('param', 'value'))
self.assertIn('/invalid/path', url)
# --- VM Listing/Info Tests ---
def test_list_vms(self) -> None:
"""
Test that list_vms returns a list and get_vm_info works for listed VMs.
"""
vms = self.os_client.list_vms()
self.assertIsInstance(vms, list)
if vms:
info = self.os_client.get_vm_info(vms[0].name)
self.assertIsNotNone(info)
def test_list_vms_and_check_fields(self):
"""
Test that all VMs returned by list_vms have required fields.
"""
vms = self.os_client.list_vms()
self.assertIsInstance(vms, list)
for vm in vms:
self.assertTrue(hasattr(vm, 'name'))
self.assertTrue(hasattr(vm, 'namespace'))
def test_get_vm_info(self):
"""
Test that get_vm_info returns info for a valid VM name.
"""
if not self.test_vm:
self.skipTest('No test_vm specified')
info = self.os_client.get_vm_info(self.test_vm)
self.assertIsNotNone(info)
def test_get_vm_info_invalid(self):
"""
Test that get_vm_info returns None for an invalid VM name.
"""
info = self.os_client.get_vm_info('nonexistent-vm')
self.assertIsNone(info)
def test_get_vm_instance_info(self):
"""
Test that get_vm_instance_info returns info or None for a valid VM name.
"""
if not self.test_vm:
self.skipTest('No test_vm specified')
info = self.os_client.get_vm_instance_info(self.test_vm)
self.assertTrue(info is None or hasattr(info, 'name'))
def test_get_vm_instance_info_invalid(self):
"""
Test that get_vm_instance_info returns None for an invalid VM name.
"""
info = self.os_client.get_vm_instance_info('nonexistent-vm')
self.assertIsNone(info)
# --- VM Lifecycle and Actions ---
def test_vm_lifecycle(self) -> None:
"""
Test VM lifecycle actions: start, stop, delete (skipped in shared environments).
"""
self.skipTest('Skip this test to avoid issues in shared environments')
if not self.test_vm:
self.skipTest('No test_vm specified in test-vars.ini')
self.assertTrue(self.os_client.start_vm_instance(self.test_vm))
self.assertTrue(self.os_client.stop_vm_instance(self.test_vm))
self.assertTrue(self.os_client.delete_vm_instance(self.test_vm))
def test_start_stop_suspend_resume_vm(self):
"""
Test stop (and optionally start) VM instance. Suspend/resume skipped if not supported.
"""
if not self.test_vm:
self.skipTest('No test_vm specified')
#self.assertTrue(self.os_client.start_vm_instance(self.test_vm))
self.assertTrue(self.os_client.stop_vm_instance(self.test_vm))
# Suspend/resume skipped if not supported
def test_delete_vm_invalid(self):
"""
Test that delete_vm_instance returns False for an invalid VM name.
"""
self.assertFalse(self.os_client.delete_vm_instance('nonexistent-vm'))
# --- DataVolume Tests ---
# --- DataVolume Tests ---
def test_datavolume_phase(self) -> None:
"""
Test that get_datavolume_phase returns a string for a valid datavolume.
"""
phase = self.os_client.get_datavolume_phase('test-dv')
self.assertIsInstance(phase, str)
def test_datavolume_phase_invalid(self):
"""
Test that get_datavolume_phase returns a string for an invalid datavolume.
"""
phase = self.os_client.get_datavolume_phase('nonexistent-dv')
self.assertIsInstance(phase, str)

View File

@@ -0,0 +1,163 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from unittest import mock
from tests.services.openshift import fixtures
from uds.core.types.states import TaskState
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftDeployment(UDSTransactionTestCase):
def _create_userservice(self):
"""
Helper to create a userservice instance with a preset name for deployment operation tests.
Returns:
userservice: A userservice object with name 'test-vm'.
"""
userservice = fixtures.create_userservice()
userservice._name = 'test-vm'
return userservice
def setUp(self) -> None:
super().setUp()
fixtures.clear()
# --- Create operation tests ---
def test_op_create_success(self) -> None:
"""
Test successful VM creation operation.
Should clear the waiting_name flag after creation.
"""
userservice = self._create_userservice()
userservice._waiting_name = False
api = userservice.service().api
with mock.patch.object(api, 'get_vm_pvc_or_dv_name', return_value=('test-pvc', 'pvc')), \
mock.patch.object(api, 'get_pvc_size', return_value='10Gi'), \
mock.patch.object(api, 'create_vm_from_pvc', return_value=True), \
mock.patch.object(api, 'wait_for_datavolume_clone_progress', return_value=True):
userservice.op_create()
self.assertFalse(userservice._waiting_name)
def test_op_create_failure(self) -> None:
"""
Test failed VM creation operation.
Should set the waiting_name flag if creation fails.
"""
userservice = self._create_userservice()
api = userservice.service().api
userservice._waiting_name = False
with mock.patch.object(api, 'get_vm_pvc_or_dv_name', return_value=('test-pvc', 'pvc')), \
mock.patch.object(api, 'get_pvc_size', return_value='10Gi'), \
mock.patch.object(api, 'create_vm_from_pvc', return_value=False):
userservice.op_create()
self.assertTrue(userservice._waiting_name)
def test_op_create_checker_running(self) -> None:
"""
Test create checker returns RUNNING when datavolume phase is pending.
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_datavolume_phase', return_value='Pending'):
state = userservice.op_create_checker()
self.assertEqual(state, TaskState.RUNNING)
def test_op_create_checker_finished(self) -> None:
"""
Test create checker returns FINISHED when datavolume phase is succeeded and VM info is available.
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_datavolume_phase', return_value='Succeeded'), \
mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]), \
mock.patch.object(api, 'get_vm_instance_info', return_value=fixtures.VM_INSTANCES[0]):
state = userservice.op_create_checker()
self.assertEqual(state, TaskState.FINISHED)
# --- Delete operation tests ---
def test_op_delete_checker_finished(self) -> None:
"""
Test delete checker returns FINISHED when VM info is None (deleted).
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_vm_info', return_value=None):
state = userservice.op_delete_checker()
self.assertEqual(state, TaskState.FINISHED)
def test_op_delete_checker_running(self) -> None:
"""
Test delete checker returns RUNNING when VM info still exists.
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]):
state = userservice.op_delete_checker()
self.assertEqual(state, TaskState.RUNNING)
def test_op_delete_completed_checker(self) -> None:
"""
Test delete completed checker always returns FINISHED.
"""
userservice = self._create_userservice()
state = userservice.op_delete_completed_checker()
self.assertEqual(state, TaskState.FINISHED)
# --- Cancel operation tests ---
def test_op_cancel_checker_finished(self) -> None:
"""
Test cancel checker returns FINISHED when VM info is None (cancelled).
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_vm_info', return_value=None):
state = userservice.op_cancel_checker()
self.assertEqual(state, TaskState.FINISHED)
def test_op_cancel_checker_running(self) -> None:
"""
Test cancel checker returns RUNNING when VM info still exists.
"""
userservice = self._create_userservice()
api = userservice.service().api
with mock.patch.object(api, 'get_vm_info', return_value=fixtures.VMS[0]):
state = userservice.op_cancel_checker()
self.assertEqual(state, TaskState.RUNNING)
def test_op_cancel_completed_checker(self) -> None:
"""
Test cancel completed checker always returns FINISHED.
"""
userservice = self._create_userservice()
state = userservice.op_cancel_completed_checker()
self.assertEqual(state, TaskState.FINISHED)

View File

@@ -0,0 +1,163 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from unittest import mock
from uds.core.types.states import TaskState
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftUserServiceFixed(UDSTransactionTestCase):
def _create_userservice_fixed(self):
"""
Helper to create a fixed userservice instance for deployment_fixed operation tests.
Returns:
userservice: A fixed userservice object with name 'fixed-vm'.
"""
userservice = fixtures.create_userservice_fixed()
userservice._name = 'fixed-vm'
return userservice
def setUp(self) -> None:
super().setUp()
fixtures.clear()
# --- Start operation tests ---
def test_op_start_vm_running(self) -> None:
"""
Test that op_start does not start VM if it is already running.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
vm_mock.status.is_off.return_value = False
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
with mock.patch.object(api, 'start_vm_instance') as start_mock:
userservice.op_start()
start_mock.assert_not_called()
def test_op_start_vm_off(self) -> None:
"""
Test that op_start starts VM if it is off.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
vm_mock.status.is_off.return_value = True
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
with mock.patch.object(api, 'start_vm_instance') as start_mock:
userservice.op_start()
start_mock.assert_called_once_with('fixed-vm')
# --- Stop operation tests ---
def test_op_stop_vm_off(self) -> None:
"""
Test that op_stop does not stop VM if it is already off.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
vm_mock.status.is_off.return_value = True
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
with mock.patch.object(api, 'stop_vm_instance') as stop_mock:
userservice.op_stop()
stop_mock.assert_not_called()
def test_op_stop_vm_running(self) -> None:
"""
Test that op_stop stops VM if it is running.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
vm_mock.status.is_off.return_value = False
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
with mock.patch.object(api, 'stop_vm_instance') as stop_mock:
userservice.op_stop()
stop_mock.assert_called_once_with('fixed-vm')
# --- Start checker tests ---
def test_op_start_checker_running(self) -> None:
"""
Test that op_start_checker returns RUNNING if VM status is not error.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
status_mock = mock.Mock()
status_mock.is_error.return_value = False
vm_mock = mock.Mock()
vm_mock.status = status_mock
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
state = userservice.op_start_checker()
self.assertEqual(state, TaskState.RUNNING)
def test_op_start_checker_finished(self) -> None:
"""
Test that op_start_checker returns FINISHED if VM status is RUNNING.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
from uds.services.OpenShift.openshift import types as opensh_types
vm_mock.status = opensh_types.VMStatus.RUNNING
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
state = userservice.op_start_checker()
self.assertEqual(state, TaskState.FINISHED)
# --- Stop checker tests ---
def test_op_stop_checker_running(self) -> None:
"""
Test that op_stop_checker returns RUNNING if VM status is not error.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
status_mock = mock.Mock()
status_mock.is_error.return_value = False
vm_mock.status = status_mock
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
state = userservice.op_stop_checker()
self.assertEqual(state, TaskState.RUNNING)
def test_op_stop_checker_finished(self) -> None:
"""
Test that op_stop_checker returns FINISHED if VM status is STOPPED.
"""
userservice = self._create_userservice_fixed()
api = userservice.service().provider().api
vm_mock = mock.Mock()
from uds.services.OpenShift.openshift import types as opensh_types
vm_mock.status = opensh_types.VMStatus.STOPPED
with mock.patch.object(api, 'get_vm_info', return_value=vm_mock):
state = userservice.op_stop_checker()
self.assertEqual(state, TaskState.FINISHED)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
from unittest import mock
from uds.core import types, ui, environment
from uds.services.OpenShift.provider import OpenshiftProvider
from . import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftProvider(UDSTransactionTestCase):
def setUp(self) -> None:
"""
Set up test environment and clear fixtures before each test.
"""
super().setUp()
fixtures.clear()
# --- Provider Data Tests ---
def test_provider_data(self) -> None:
"""
Test provider data fields and types for correct initialization.
"""
provider = fixtures.create_provider()
self.assertEqual(provider.cluster_url.value, fixtures.PROVIDER_VALUES_DICT['cluster_url'])
self.assertEqual(provider.api_url.value, fixtures.PROVIDER_VALUES_DICT['api_url'])
self.assertEqual(provider.username.value, fixtures.PROVIDER_VALUES_DICT['username'])
self.assertEqual(provider.password.value, fixtures.PROVIDER_VALUES_DICT['password'])
self.assertEqual(provider.namespace.value, fixtures.PROVIDER_VALUES_DICT['namespace'])
self.assertEqual(provider.verify_ssl.value, fixtures.PROVIDER_VALUES_DICT['verify_ssl'])
if not isinstance(provider.concurrent_creation_limit, ui.gui.NumericField):
self.fail('concurrent_creation_limit is not a NumericField')
self.assertEqual(provider.concurrent_creation_limit.as_int(), fixtures.PROVIDER_VALUES_DICT['concurrent_creation_limit'])
if not isinstance(provider.concurrent_removal_limit, ui.gui.NumericField):
self.fail('concurrent_removal_limit is not a NumericField')
self.assertEqual(provider.concurrent_removal_limit.as_int(), fixtures.PROVIDER_VALUES_DICT['concurrent_removal_limit'])
self.assertEqual(provider.timeout.as_int(), fixtures.PROVIDER_VALUES_DICT['timeout'])
# --- Provider Test Method ---
def test_provider_test(self) -> None:
"""
Test the static provider test method and test_connection logic.
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
for ret_val in [True, False]:
api.test.reset_mock()
api.test.return_value = ret_val
# Patch test_connection to return ret_val for static test
with mock.patch('uds.services.OpenShift.provider.OpenshiftProvider.test_connection', return_value=ret_val):
result = OpenshiftProvider.test(environment.Environment.temporary_environment(), fixtures.PROVIDER_VALUES_DICT)
self.assertIsInstance(result, types.core.TestResult)
self.assertEqual(result.success, ret_val)
self.assertIsInstance(result.error, str)
# Ensure test_connection calls api.test
provider.test_connection()
api.test.assert_called_once_with()
# --- Provider Availability ---
def test_provider_is_available(self) -> None:
"""
Test the provider is_available method and cache behavior.
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
# First, true result
self.assertEqual(provider.is_available(), True)
api.test.assert_called_once_with()
api.test.reset_mock()
# Now, even if set test to false, should return true due to cache
api.test.return_value = False
self.assertEqual(provider.is_available(), True)
api.test.assert_not_called()
# clear cache of method
provider.is_available.cache_clear() # type: ignore # cache_clear() is added by decorator
self.assertEqual(provider.is_available(), False)
api.test.assert_called_once_with()
# --- Provider API Methods ---
def test_provider_api_methods(self) -> None:
"""
Test provider API methods for VM operations and info retrieval.
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
self.assertEqual(provider.test_connection(), True)
api.test.assert_called_once_with()
self.assertEqual(provider.api.list_vms(), fixtures.VMS)
self.assertEqual(provider.api.get_vm_info('vm-1'), fixtures.VMS[0])
self.assertEqual(provider.api.get_vm_instance_info('vm-1'), fixtures.VM_INSTANCES[0])
self.assertTrue(provider.api.start_vm_instance('vm-1'))
self.assertTrue(provider.api.stop_vm_instance('vm-1'))
self.assertTrue(provider.api.delete_vm_instance('vm-1'))
# --- Name Sanitization ---
def test_sanitized_name(self) -> None:
"""
Test name sanitization utility for various input cases.
"""
provider = fixtures.create_provider()
test_cases = [
('Test-VM-1', 'test-vm-1'),
('Test_VM@2', 'test-vm-2'),
('My Test VM!!!', 'my-test-vm'),
('Test !!! this is', 'test-this-is'),
('UDS-Pub-Hello World!!--2025065122-v1', 'uds-pub-hello-world-2025065122-v1'),
('a' * 100, 'a' * 63), # Test truncation
]
for input_name, expected in test_cases:
self.assertEqual(provider.sanitized_name(input_name), expected)

View File

@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
from unittest import mock
from uds.core import types
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftPublication(UDSTransactionTestCase):
def setUp(self) -> None:
super().setUp()
fixtures.clear()
def test_op_create_and_checker(self) -> None:
"""
Test op_create and op_create_checker flow
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
service = fixtures.create_service(provider=provider)
publication = fixtures.create_publication(service=service)
api.get_vm_pvc_or_dv_name.return_value = ('test-pvc', 'pvc')
api.get_pvc_size.return_value = '10Gi'
api.create_vm_from_pvc.return_value = True
api.wait_for_datavolume_clone_progress.return_value = True
api.get_vm_info.return_value = None
publication.op_create()
api.get_vm_info.return_value = None
state = publication.op_create_checker()
self.assertEqual(state, types.states.TaskState.RUNNING)
def get_vm_info_side_effect(name: str) -> mock.Mock | None:
return mock.Mock(status=mock.Mock()) if name == publication._name else None
api.get_vm_info.side_effect = get_vm_info_side_effect
state = publication.op_create_checker()
self.assertEqual(state, types.states.TaskState.FINISHED)
def test_op_create_completed_and_checker(self) -> None:
"""
Test op_create_completed and op_create_completed_checker flow
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
service = fixtures.create_service(provider=provider)
publication = fixtures.create_publication(service=service)
# VM running
running_status = mock.Mock()
running_status.is_running.return_value = True
running_vm = mock.Mock(status=running_status)
def get_vm_info_side_effect(name: str, **kwargs: dict[str, typing.Any]) -> mock.Mock | None:
return running_vm if name == 'test-vm' else None
api.get_vm_info.side_effect = get_vm_info_side_effect
publication._name = 'test-vm'
publication.op_create_completed()
api.stop_vm_instance.assert_called_with('test-vm')
# VM stopped
stopped_status = mock.Mock()
stopped_status.is_running.return_value = False
stopped_vm = mock.Mock(status=stopped_status)
api.get_vm_info.side_effect = None
api.get_vm_info.return_value = stopped_vm
api.stop_vm_instance.reset_mock()
publication.op_create_completed()
api.stop_vm_instance.assert_not_called()
# Checker: VM not found
api.get_vm_info.return_value = None
state = publication.op_create_completed_checker()
self.assertEqual(state, types.states.TaskState.FINISHED)
# Checker: VM stopped
api.get_vm_info.return_value = stopped_vm
state = publication.op_create_completed_checker()
self.assertEqual(state, types.states.TaskState.FINISHED)
# Checker: VM running
api.get_vm_info.return_value = running_vm
state = publication.op_create_completed_checker()
self.assertEqual(state, types.states.TaskState.RUNNING)
def test_publication_create(self) -> None:
"""
Test publication creation (publish)
"""
with fixtures.patched_provider() as provider:
api = typing.cast(mock.MagicMock, provider.api)
service = fixtures.create_service(provider=provider)
publication = fixtures.create_publication(service=service)
api.get_vm_pvc_or_dv_name.return_value = ('test-pvc', 'pvc')
api.get_pvc_size.return_value = '10Gi'
api.create_vm_from_pvc.return_value = True
api.wait_for_datavolume_clone_progress.return_value = True
call_count = {"count": 0}
def vm_info_side_effect(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
if call_count["count"] < 2:
call_count["count"] += 1
return fixtures.VMS[0]
ready_vm = mock.Mock()
ready_vm.status = mock.Mock()
ready_vm.name = publication._name
return ready_vm
api.get_vm_info.side_effect = vm_info_side_effect
state = publication.publish()
self.assertEqual(state, types.states.State.RUNNING)
state = publication.check_state()
api.get_vm_pvc_or_dv_name.assert_called()
api.get_pvc_size.assert_called()
api.create_vm_from_pvc.assert_called()
for _ in range(10):
state = publication.check_state()
if state == types.states.TaskState.FINISHED:
break
self.assertEqual(state, types.states.TaskState.RUNNING)
self.assertEqual(publication.get_template_id(), publication._name)
def test_get_template_id(self) -> None:
"""
Test template ID retrieval (get_template_id)
"""
service = fixtures.create_service()
publication = fixtures.create_publication(service=service)
publication._name = 'test-template'
template_id = publication.get_template_id()
self.assertEqual(template_id, 'test-template')

View File

@@ -0,0 +1,64 @@
import pickle
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftDeploymentSerialization(UDSTransactionTestCase):
def setUp(self) -> None:
"""
Set up test environment and clear fixtures before each test.
"""
super().setUp()
fixtures.clear()
def _make_userservice(self):
"""
Helper to create a userservice with all fields set for serialization tests.
"""
userservice = fixtures.create_userservice()
userservice._name = 'test-vm'
userservice._ip = '192.168.1.100'
userservice._mac = '00:11:22:33:44:55'
userservice._vmid = 'test-vm-id'
userservice._reason = 'test-reason'
userservice._waiting_name = True
return userservice
# --- Serialization Tests ---
def test_userservice_serialization(self) -> None:
"""
Test that userservice object is correctly serialized and deserialized with all fields preserved.
"""
userservice = self._make_userservice()
data = pickle.dumps(userservice)
userservice2 = pickle.loads(data)
self.assertEqual(userservice2._name, 'test-vm')
self.assertEqual(userservice2._ip, '192.168.1.100')
self.assertEqual(userservice2._mac, '00:11:22:33:44:55')
self.assertEqual(userservice2._vmid, 'test-vm-id')
self.assertEqual(userservice2._reason, 'test-reason')
self.assertTrue(userservice2._waiting_name)
def test_userservice_methods_after_serialization(self) -> None:
"""
Test that userservice methods return correct values after serialization and deserialization.
"""
userservice = self._make_userservice()
data = pickle.dumps(userservice)
userservice2 = pickle.loads(data)
self.assertEqual(userservice2.get_name(), 'test-vm')
self.assertEqual(userservice2.get_ip(), '192.168.1.100')
self.assertEqual(userservice2._mac, '00:11:22:33:44:55')
# --- Field Presence Tests ---
def test_autoserializable_fields(self) -> None:
"""
Test that all expected autoserializable fields are present in userservice object.
"""
userservice = self._make_userservice()
expected = ['_name', '_ip', '_mac', '_vmid', '_reason', '_waiting_name']
for field in expected:
self.assertTrue(hasattr(userservice, field), f"Missing field: {field}")

View File

@@ -0,0 +1,24 @@
import pickle
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftUserServiceFixed(UDSTransactionTestCase):
def setUp(self) -> None:
"""
Set up test environment and clear fixtures before each test.
"""
super().setUp()
fixtures.clear()
# --- Serialization Tests ---
def test_userservice_fixed_serialization(self) -> None:
"""
Test that userservice_fixed object is correctly serialized and deserialized with all fields preserved.
"""
userservice = fixtures.create_userservice_fixed()
userservice._name = 'fixed-vm'
userservice._reason = 'fixed-reason'
data = pickle.dumps(userservice)
userservice2 = pickle.loads(data)
self.assertEqual(userservice2._name, 'fixed-vm')
self.assertEqual(userservice2._reason, 'fixed-reason')

View File

@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
from uds.services.OpenShift.provider import OpenshiftProvider
PROVIDER_SERIALIZE_DATA = (
'{'
'"cluster_url": "https://oauth-openshift.apps-crc.testing", '
'"api_url": "https://api.crc.testing:6443", '
'"username": "kubeadmin", '
'"password": "test-password", '
'"namespace": "default", '
'"verify_ssl": false, '
'"concurrent_creation_limit": 1, '
'"concurrent_removal_limit": 1, '
'"timeout": 10'
'}'
)
class TestOpenshiftProviderSerialization(UDSTransactionTestCase):
# --- Serialization Tests ---
def test_provider_methods_after_serialization(self) -> None:
"""
Test that provider methods return correct values after serialization and deserialization.
"""
from uds.core import environment
provider = fixtures.create_provider()
data = provider.serialize()
provider2 = OpenshiftProvider(environment=environment.Environment.testing_environment())
provider2.deserialize(data)
self.assertEqual(str(provider2.type_name), 'Openshift Provider')
self.assertEqual(str(provider2.type_description), 'Openshift based VMs provider')
self.assertEqual(provider2.cluster_url.value, fixtures.PROVIDER_VALUES_DICT['cluster_url'])
self.assertEqual(provider2.api_url.value, fixtures.PROVIDER_VALUES_DICT['api_url'])
def test_provider_serialization(self) -> None:
"""
Test that all provider fields are correctly serialized and deserialized.
"""
from uds.core import environment
provider = fixtures.create_provider()
data = provider.serialize()
provider2 = OpenshiftProvider(environment=environment.Environment.testing_environment())
provider2.deserialize(data)
for field in fixtures.PROVIDER_VALUES_DICT:
self.assertEqual(getattr(provider2, field).value, fixtures.PROVIDER_VALUES_DICT[field])

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import pickle
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftPublicationSerialization(UDSTransactionTestCase):
EXPECTED_FIELDS = {'_name', '_waiting_name', '_reason', '_queue', '_vmid', '_is_flagged_for_destroy'}
def setUp(self) -> None:
"""
Set up test environment and clear fixtures before each test.
"""
super().setUp()
fixtures.clear()
def _make_publication(self):
"""
Helper to create a publication with all fields set for serialization tests.
"""
publication = fixtures.create_publication()
publication._name = 'test-template'
publication._reason = 'test-reason'
publication._waiting_name = True
return publication
# --- Field Check Helper ---
def check_fields(self, instance: 'fixtures.publication.OpenshiftTemplatePublication') -> None:
"""
Helper to check expected field values in a publication instance.
"""
self.assertEqual(instance._name, 'test-template')
self.assertEqual(instance._reason, 'test-reason')
self.assertTrue(instance._waiting_name)
# --- Serialization Tests ---
def test_autoserialization_fields(self) -> None:
"""
Test that autoserializable fields match the expected set.
"""
publication = fixtures.create_publication()
fields = set(f[0] for f in publication._autoserializable_fields())
self.assertSetEqual(fields, self.EXPECTED_FIELDS)
def test_pickle_serialization(self) -> None:
"""
Test that publication object is correctly serialized and deserialized using pickle.
"""
publication = self._make_publication()
data = pickle.dumps(publication)
publication2 = pickle.loads(data)
self.check_fields(publication2)
def test_marshal_unmarshal(self) -> None:
"""
Test that publication object is correctly marshaled and unmarshaled.
"""
publication = self._make_publication()
marshaled = publication.marshal()
publication2 = fixtures.create_publication()
publication2.unmarshal(marshaled)
self.check_fields(publication2)
def test_methods_after_serialization(self) -> None:
"""
Test that publication methods return correct values after serialization and deserialization.
"""
publication = self._make_publication()
data = pickle.dumps(publication)
publication2 = pickle.loads(data)
self.assertEqual(publication2._name, 'test-template')
self.assertEqual(publication2.get_template_id(), 'test-template')

View File

@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
"""
Unit tests for OpenshiftService logic.
All tests use fixtures for setup and mock dependencies.
Tests are grouped by functionality: configuration, utility methods, availability, VM operations, exception handling, and deletion.
"""
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
Reorganizado y corregido por GitHub Copilot
"""
import typing
from unittest import mock
from uds.services.OpenShift.openshift import exceptions as morph_exceptions
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftService(UDSTransactionTestCase):
def setUp(self) -> None:
super().setUp()
fixtures.clear()
def _create_service_with_provider(self):
"""
Helper to create a service with a patched provider.
"""
provider_ctx = fixtures.patched_provider()
provider = provider_ctx.__enter__()
service = fixtures.create_service(provider=provider)
return service, provider, provider_ctx
# --- Configuration and initial data ---
def test_service_data(self) -> None:
"""
Check initial service data values.
"""
service = fixtures.create_service()
self.assertEqual(service.template.value, fixtures.SERVICE_VALUES_DICT['template'])
self.assertEqual(service.basename.value, fixtures.SERVICE_VALUES_DICT['basename'])
self.assertEqual(service.lenname.value, fixtures.SERVICE_VALUES_DICT['lenname'])
self.assertEqual(service.publication_timeout.value, fixtures.SERVICE_VALUES_DICT['publication_timeout'])
def test_initialize_sets_basename(self) -> None:
"""
Check that initialize sets basename and lenname correctly.
"""
service = fixtures.create_service()
service.basename.value = 'testname'
service.lenname.value = 6
service.initialize({'basename': 'testname', 'lenname': 6})
self.assertEqual(service.basename.value, 'testname')
self.assertEqual(service.lenname.value, 6)
def test_init_gui_sets_choices(self) -> None:
"""
Check that init_gui sets template choices.
"""
service, _, provider_ctx = self._create_service_with_provider()
with mock.patch.object(service.template, 'set_choices') as set_choices_mock:
service.init_gui()
set_choices_mock.assert_called()
provider_ctx.__exit__(None, None, None)
# --- Utility and accessor methods ---
def test_provider_returns_correct_type(self) -> None:
"""
Check that provider() returns the correct provider instance.
"""
service, provider, provider_ctx = self._create_service_with_provider()
self.assertEqual(service.provider(), provider)
provider_ctx.__exit__(None, None, None)
def test_api_property_caching(self) -> None:
"""
Check that the api property is cached.
"""
service, _, provider_ctx = self._create_service_with_provider()
api1 = service.api
api2 = service.api
self.assertIs(api1, api2)
provider_ctx.__exit__(None, None, None)
def test_service_methods(self) -> None:
"""
Check utility methods of the service.
"""
service, _, provider_ctx = self._create_service_with_provider()
self.assertEqual(service.get_basename(), service.basename.value)
self.assertEqual(service.get_lenname(), service.lenname.value)
self.assertEqual(service.sanitized_name('Test VM 1'), 'test-vm-1')
duplicates = list(service.find_duplicates('vm-1', '00:11:22:33:44:55'))
self.assertEqual(len(duplicates), 1)
provider_ctx.__exit__(None, None, None)
# --- Availability and cache ---
def test_service_is_available(self) -> None:
"""
Check service availability and cache handling.
"""
service, provider, provider_ctx = self._create_service_with_provider()
api = typing.cast(mock.MagicMock, provider.api)
self.assertTrue(service.is_available())
api.test.assert_called_with()
api.test.return_value = False
self.assertTrue(service.is_available())
service.provider().is_available.cache_clear() # type: ignore
self.assertFalse(service.is_available())
api.test.assert_called_with()
provider_ctx.__exit__(None, None, None)
# --- VM operations ---
def test_vm_operations(self) -> None:
"""
Check VM operations: get_ip, get_mac, is_running, start, stop, shutdown.
"""
service, _, provider_ctx = self._create_service_with_provider()
api = typing.cast(mock.MagicMock, service.api)
ip = service.get_ip(None, 'vm-1')
self.assertEqual(ip, '192.168.1.1')
mac = service.get_mac(None, 'vm-1')
self.assertEqual(mac, '00:11:22:33:44:01')
self.assertTrue(service.is_running(None, 'vm-1'))
service.start(None, 'vm-1')
api.start_vm_instance.assert_called_with('vm-1')
service.stop(None, 'vm-1')
api.stop_vm_instance.assert_called_with('vm-1')
service.shutdown(None, 'vm-1')
api.stop_vm_instance.assert_called_with('vm-1')
provider_ctx.__exit__(None, None, None)
# --- Exception handling ---
def test_get_ip_raises_exception_if_no_interfaces(self) -> None:
"""
Check that get_ip raises an exception if there are no interfaces.
"""
service, _, provider_ctx = self._create_service_with_provider()
def no_interfaces(_vmid: str):
mock_vm = mock.Mock()
mock_vm.interfaces = []
return mock_vm
with mock.patch.object(service.api, 'get_vm_instance_info', side_effect=no_interfaces):
with self.assertRaises(Exception):
service.get_ip(None, 'vm-1')
provider_ctx.__exit__(None, None, None)
def test_get_mac_raises_exception_if_no_interfaces(self) -> None:
"""
Check that get_mac raises an exception if there are no interfaces.
"""
service, _, provider_ctx = self._create_service_with_provider()
def no_interfaces(_vmid: str):
mock_vm = mock.Mock()
mock_vm.interfaces = []
return mock_vm
with mock.patch.object(service.api, 'get_vm_instance_info', side_effect=no_interfaces):
with self.assertRaises(Exception):
service.get_mac(None, 'vm-1')
provider_ctx.__exit__(None, None, None)
# --- VM deletion ---
def test_vm_deletion(self) -> None:
"""
Check VM deletion logic and is_deleted method.
"""
service, provider, provider_ctx = self._create_service_with_provider()
api = typing.cast(mock.MagicMock, provider.api)
# Execute deletion
service.execute_delete('vm-1')
api.delete_vm_instance.assert_called_with('vm-1')
# Check if deleted
api.get_vm_info.side_effect = morph_exceptions.OpenshiftNotFoundError('not found')
self.assertTrue(service.is_deleted('vm-1'))
# Simulate VM exists
api.get_vm_info.side_effect = None
api.get_vm_info.return_value = fixtures.VMS[0]
self.assertFalse(service.is_deleted('vm-1'))
provider_ctx.__exit__(None, None, None)

View File

@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
from unittest import mock
from tests.services.openshift import fixtures
from tests.utils.test import UDSTransactionTestCase
class TestOpenshiftServiceFixed(UDSTransactionTestCase):
def _create_service_fixed_with_provider(self):
"""
Helper to create a fixed service with a patched provider.
"""
provider_ctx = fixtures.patched_provider()
provider = provider_ctx.__enter__()
service = fixtures.create_service_fixed(provider=provider)
return service, provider, provider_ctx
def setUp(self) -> None:
super().setUp()
fixtures.clear()
# --- Availability ---
def test_service_is_available(self) -> None:
"""
Test provider availability and cache logic.
"""
service, provider, provider_ctx = self._create_service_fixed_with_provider()
api = typing.cast(mock.MagicMock, provider.api)
self.assertTrue(service.is_available())
api.test.assert_called_with()
# With cached data, even if test fails, it will return True
api.test.return_value = False
self.assertTrue(service.is_available())
# Clear cache and test again
service.provider().is_available.cache_clear() # type: ignore
self.assertFalse(service.is_available())
api.test.assert_called_with()
provider_ctx.__exit__(None, None, None)
# --- Service methods ---
def test_service_methods(self) -> None:
"""
Test service methods: enumerate_assignables, get_name, get_ip, get_mac, sanitized_name.
"""
service, _, provider_ctx = self._create_service_fixed_with_provider()
# Enumerate assignables
machines = list(service.enumerate_assignables())
self.assertEqual(len(machines), 3)
self.assertEqual(machines[0].id, 'vm-3')
self.assertEqual(machines[1].id, 'vm-4')
self.assertEqual(machines[2].id, 'vm-5')
# Get machine name
machine_name = service.get_name('uid-3')
self.assertEqual(machine_name, 'vm-3')
# Get IP
ip = service.get_ip('uid-3')
self.assertTrue(ip.startswith('192.168.1.'))
# Get MAC
mac = service.get_mac('uid-3')
self.assertTrue(mac.startswith('00:11:22:33:44:'))
# Sanitized name
sanitized = service.sanitized_name('Test VM 1')
self.assertIsInstance(sanitized, str)
provider_ctx.__exit__(None, None, None)
# --- Assignment logic ---
def test_get_and_assign(self) -> None:
"""
Test get_and_assign logic for fixed service.
"""
service, _, provider_ctx = self._create_service_fixed_with_provider()
vmid = service.get_and_assign()
self.assertIn(vmid, ['vm-3', 'vm-4', 'vm-5'])
# Should not assign the same again
with service._assigned_access() as assigned:
self.assertIn(vmid, assigned)
provider_ctx.__exit__(None, None, None)
def test_remove_and_free(self) -> None:
"""
Test remove_and_free logic for fixed service.
"""
service, _, provider_ctx = self._create_service_fixed_with_provider()
vmid = service.get_and_assign()
result = service.remove_and_free(vmid)
self.assertEqual(result.name, 'FINISHED')
with service._assigned_access() as assigned:
self.assertNotIn(vmid, assigned)
provider_ctx.__exit__(None, None, None)