1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-10-07 15:33:51 +03:00

Refactor REST methods to use GetItemsResult type and improve type safety

- Updated return types of get_items methods in various REST handlers to use types.rest.GetItemsResult instead of types.rest.ManyItemsDictType for better clarity and type safety.
- Introduced UserServiceItem, GroupItem, TransportItem, PublicationItem, and ChangelogItem TypedDicts to standardize item representations across different handlers.
- Refactored AssignedService to AssignedUserService and updated related methods to reflect the new naming convention.
- Enhanced as_typed_dict utility function to convert models to TypedDicts, improving type safety in REST API responses.
- Cleaned up imports and ensured consistent use of typing annotations throughout the codebase.
This commit is contained in:
Adolfo Gómez García
2025-07-25 21:33:28 +02:00
parent 3d9bc55b1d
commit 6922d28537
22 changed files with 324 additions and 211 deletions

View File

@@ -1,7 +1,7 @@
[mypy]
#plugins =
# mypy_django_plugin.main
python_version = 3.11
python_version = 3.12
# Exclude all .*/transports/.*/scripts/.* directories and all tests
exclude = (.*/transports/.*/scripts/.*|.*/tests/.*)
@@ -17,4 +17,4 @@ django_settings_module = "server.settings"
# Disable some anoying reports, because pyright needs the redundant cast on some cases
# [mypy-tests.*]
# disable_error_code =
# disable_error_code =

View File

@@ -74,7 +74,7 @@ class AccountsUsage(DetailHandler): # pylint: disable=too-many-public-methods
'permission': perm,
}
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, Account)
# Check what kind of access do we have to parent provider
perm = permissions.effective_permissions(self._user, parent)

View File

@@ -70,7 +70,7 @@ class Authenticators(ModelHandler):
tags: list[str]
comments: str
net_filtering: str
networks: list[dict[str, str]]
networks: list[str]
state: str
mfa_id: str
small_name: str
@@ -189,6 +189,8 @@ class Authenticators(ModelHandler):
'priority': item.priority,
}
type_ = item.get_type()
return {
'numeric_id': item.id,
'id': item.uuid,

View File

@@ -46,21 +46,35 @@ from uds.core.types.states import State
from uds.core.util.model import process_uuid
from uds.core.util import log, ensure
from uds.REST.model import DetailHandler
from .user_services import AssignedService
from .user_services import AssignedUserService, UserServiceItem
if typing.TYPE_CHECKING:
from django.db.models import Model
logger = logging.getLogger(__name__)
class MetaItem(types.rest.ItemDictType):
"""
Item type for a Meta Pool Member
"""
id: str
pool_id: str
pool_name: typing.NotRequired[str] # Optional, as it can be not present
name: str
comments: str
priority: int
enabled: bool
user_services_count: int
user_services_in_preparation: int
class MetaServicesPool(DetailHandler):
class MetaServicesPool(DetailHandler[MetaItem]):
"""
Processes the transports detail requests of a Service Pool
"""
@staticmethod
def as_dict(item: models.MetaPoolMember) -> dict[str, typing.Any]:
def as_dict(item: models.MetaPoolMember) -> 'MetaItem':
return {
'id': item.uuid,
'pool_id': item.pool.uuid,
@@ -72,7 +86,7 @@ class MetaServicesPool(DetailHandler):
'user_services_in_preparation': item.pool.userServices.filter(state=State.PREPARING).count(),
}
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult['MetaItem']:
parent = ensure.is_instance(parent, models.MetaPool)
try:
if not item:
@@ -133,7 +147,7 @@ class MetaServicesPool(DetailHandler):
log.log(parent, types.log.LogLevel.INFO, log_str, types.log.LogSource.ADMIN)
class MetaAssignedService(DetailHandler):
class MetaAssignedService(DetailHandler[UserServiceItem]):
"""
Rest handler for Assigned Services, wich parent is Service
"""
@@ -143,8 +157,8 @@ class MetaAssignedService(DetailHandler):
meta_pool: 'models.MetaPool',
item: 'models.UserService',
props: typing.Optional[dict[str, typing.Any]],
) -> dict[str, typing.Any]:
element = AssignedService.item_as_dict(item, props, False)
) -> 'UserServiceItem':
element = AssignedUserService.item_as_dict(item, props, False)
element['pool_id'] = item.deployed_service.uuid
element['pool_name'] = item.deployed_service.name
return element
@@ -163,7 +177,7 @@ class MetaAssignedService(DetailHandler):
except Exception:
raise self.invalid_item_response()
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult[UserServiceItem]:
parent = ensure.is_instance(parent, models.MetaPool)
def _assigned_userservices_for_pools() -> (
typing.Generator[

View File

@@ -63,7 +63,7 @@ class AccessCalendars(DetailHandler):
'priority': item.priority,
}
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
# parent can be a ServicePool or a metaPool
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
@@ -154,7 +154,7 @@ class ActionsCalendars(DetailHandler):
'last_execution': item.last_execution,
}
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, models.ServicePool)
try:
if item is None:

View File

@@ -134,7 +134,7 @@ class ServersServers(DetailHandler):
custom_methods = ['maintenance', 'importcsv']
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = typing.cast('models.ServerGroup', parent) # We will receive for sure
try:
if item is None:
@@ -157,10 +157,10 @@ class ServersServers(DetailHandler):
}
)
if item is None:
return typing.cast(types.rest.ManyItemsDictType, res)
return typing.cast(types.rest.GetItemsResult, res)
if not i:
raise Exception('Item not found')
return typing.cast(types.rest.ManyItemsDictType, res[0])
return typing.cast(types.rest.GetItemsResult, res[0])
except Exception as e:
logger.exception('REST servers')
raise self.invalid_item_response() from e

View File

@@ -114,7 +114,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
return ret_value
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, models.Provider)
# Check what kind of access do we have to parent provider
perm = permissions.effective_permissions(self._user, parent)
@@ -335,7 +335,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
except Exception:
raise self.invalid_item_response() from None
def servicepools(self, parent: 'Model', item: str) -> types.rest.ManyItemsDictType:
def servicepools(self, parent: 'Model', item: str) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, models.Provider)
service = parent.services.get(uuid=process_uuid(item))
logger.debug('Got parameters for servicepools: %s, %s', parent, item)

View File

@@ -51,7 +51,7 @@ from uds.REST.model import ModelHandler
from .op_calendars import AccessCalendars, ActionsCalendars
from .services import Services
from .user_services import AssignedService, CachedService, Changelog, Groups, Publications, Transports
from .user_services import AssignedUserService, CachedService, Changelog, Groups, Publications, Transports
if typing.TYPE_CHECKING:
from django.db.models import Model
@@ -66,7 +66,7 @@ class ServicesPools(ModelHandler):
model = ServicePool
detail = {
'services': AssignedService,
'services': AssignedUserService,
'cache': CachedService,
'servers': CachedService, # Alias for cache, but will change in a future release
'groups': Groups,

View File

@@ -90,7 +90,7 @@ class ServicesUsage(DetailHandler):
'in_use': item.in_use,
}
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, Provider)
try:
if item is None:

View File

@@ -60,7 +60,7 @@ class TunnelServers(DetailHandler):
mac: str
maintenance: bool
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult:
parent = ensure.is_instance(parent, models.ServerGroup)
try:
multi = False
@@ -82,10 +82,10 @@ class TunnelServers(DetailHandler):
}
)
if multi:
return typing.cast(types.rest.ManyItemsDictType, res)
return typing.cast(types.rest.GetItemsResult, res)
if not i:
raise Exception('Item not found')
return typing.cast(types.rest.ManyItemsDictType, res[0])
return typing.cast(types.rest.GetItemsResult, res[0])
except Exception as e:
logger.exception('REST groups')
raise self.invalid_item_response() from e

View File

@@ -31,6 +31,7 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import collections.abc
import datetime
import logging
import typing
@@ -50,20 +51,48 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class UserServiceItem(types.rest.ItemDictType):
id: str
id_deployed_service: str
unique_id: str
friendly_name: str
state: str
os_state: str
state_date: datetime.datetime
creation_date: datetime.datetime
revision: str
ip: str
actor_version: str
class AssignedService(DetailHandler):
pool: typing.NotRequired[str]
pool_id: typing.NotRequired[str]
pool_name: typing.NotRequired[str]
# For cache
cache_level: typing.NotRequired[int]
# For assigned
owner: typing.NotRequired[str]
owner_info: typing.NotRequired[dict[str, str]]
in_use: typing.NotRequired[bool]
in_use_date: typing.NotRequired[datetime.datetime]
source_host: typing.NotRequired[str]
source_ip: typing.NotRequired[str]
class AssignedUserService(DetailHandler[UserServiceItem]):
"""
Rest handler for Assigned Services, wich parent is Service
"""
custom_methods = ['reset']
@staticmethod
def item_as_dict(
item: models.UserService,
props: typing.Optional[dict[str, typing.Any]] = None,
is_cache: bool = False,
) -> types.rest.ItemDictType:
) -> 'UserServiceItem':
"""
Converts an assigned/cached service db item to a dictionary for REST response
:param item: item to convert
@@ -72,7 +101,9 @@ class AssignedService(DetailHandler):
if props is None:
props = dict(item.properties)
val = {
val: (
UserServiceItem
) = {
'id': item.uuid,
'id_deployed_service': item.deployed_service.uuid,
'unique_id': item.unique_id,
@@ -85,7 +116,7 @@ class AssignedService(DetailHandler):
'os_state': item.os_state,
'state_date': item.state_date,
'creation_date': item.creation_date,
'revision': item.publication and item.publication.revision or '',
'revision': item.publication and str(item.publication.revision) or '',
'ip': props.get('ip', _('unknown')),
'actor_version': props.get('actor_version', _('unknown')),
}
@@ -113,10 +144,12 @@ class AssignedService(DetailHandler):
'source_ip': item.src_ip,
}
)
# ItemDictType is a TypedDict, but no members, so this is valid
return typing.cast(types.rest.ItemDictType, val)
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
return val
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.GetItemsResult['UserServiceItem']:
parent = ensure.is_instance(parent, models.ServicePool)
try:
@@ -131,12 +164,12 @@ class AssignedService(DetailHandler):
properties[id][key] = value
return [
AssignedService.item_as_dict(k, properties.get(k.uuid, {}))
AssignedUserService.item_as_dict(k, properties.get(k.uuid, {}))
for k in parent.assigned_user_services()
.all()
.prefetch_related('deployed_service', 'publication', 'user')
]
return AssignedService.item_as_dict(
return AssignedUserService.item_as_dict(
parent.assigned_user_services().get(process_uuid(uuid=process_uuid(item))),
props={
k: v
@@ -271,26 +304,28 @@ class AssignedService(DetailHandler):
UserServiceManager.manager().reset(userservice)
class CachedService(AssignedService):
class CachedService(AssignedUserService):
"""
Rest handler for Cached Services, which parent is ServicePool
"""
custom_methods = [] # Remove custom methods from assigned services
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> types.rest.GetItemsResult['UserServiceItem']:
parent = ensure.is_instance(parent, models.ServicePool)
try:
if not item:
return [
AssignedService.item_as_dict(k, is_cache=True)
AssignedUserService.item_as_dict(k, is_cache=True)
for k in parent.cached_users_services()
.all()
.prefetch_related('deployed_service', 'publication')
]
cached_userservice: models.UserService = parent.cached_users_services().get(uuid=process_uuid(item))
return AssignedService.item_as_dict(cached_userservice, is_cache=True)
return AssignedUserService.item_as_dict(cached_userservice, is_cache=True)
except Exception as e:
logger.exception('get_items')
raise self.invalid_item_response() from e
@@ -334,29 +369,38 @@ class CachedService(AssignedService):
except Exception:
raise self.invalid_item_response() from None
class GroupItem(types.rest.ItemDictType):
id: str
auth_id: str
name: str
group_name: str
comments: str
state: str
type: str
auth_name: str
class Groups(DetailHandler):
class Groups(DetailHandler[GroupItem]):
"""
Processes the groups detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> list['GroupItem']:
parent = typing.cast(typing.Union['models.ServicePool', 'models.MetaPool'], parent)
return [
typing.cast(
types.rest.ItemDictType,
{
'id': group.uuid,
'auth_id': group.manager.uuid,
'name': group.name,
'group_name': group.pretty_name,
'comments': group.comments,
'state': group.state,
'type': 'meta' if group.is_meta else 'group',
'auth_name': group.manager.name,
},
)
{
'id': group.uuid,
'auth_id': group.manager.uuid,
'name': group.name,
'group_name': group.pretty_name,
'comments': group.comments,
'state': group.state,
'type': 'meta' if group.is_meta else 'group',
'auth_name': group.manager.name,
}
for group in typing.cast(collections.abc.Iterable[models.Group], parent.assignedGroups.all())
]
@@ -417,13 +461,24 @@ class Groups(DetailHandler):
types.log.LogSource.ADMIN,
)
class TransportItem(types.rest.ItemDictType):
id: str
name: str
type: types.rest.TypeInfoDict
comments: str
priority: int
trans_type: str
class Transports(DetailHandler):
class Transports(DetailHandler[TransportItem]):
"""
Processes the transports detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> list['TransportItem']:
parent = ensure.is_instance(parent, models.ServicePool)
def get_type(trans: 'models.Transport') -> types.rest.TypeInfoDict:
@@ -432,22 +487,21 @@ class Transports(DetailHandler):
except Exception: # No type found
raise self.invalid_item_response()
return [
typing.cast(
types.rest.ItemDictType,
{
'id': i.uuid,
'name': i.name,
'type': get_type(i),
'comments': i.comments,
'priority': i.priority,
'trans_type': _(i.get_type().mod_name()),
},
)
items: list[TransportItem] = [
{
'id': i.uuid,
'name': i.name,
'type': get_type(i),
'comments': i.comments,
'priority': i.priority,
'trans_type': i.get_type().mod_name(),
}
for i in parent.transports.all()
if get_type(i)
]
return items
def get_title(self, parent: 'Model') -> str:
parent = ensure.is_instance(parent, models.ServicePool)
return _('Assigned transports')
@@ -484,8 +538,16 @@ class Transports(DetailHandler):
types.log.LogSource.ADMIN,
)
class PublicationItem(types.rest.ItemDictType):
id: str
revision: int
publish_date: datetime.datetime
state: str
reason: str
state_date: datetime.datetime
class Publications(DetailHandler):
class Publications(DetailHandler[PublicationItem]):
"""
Processes the publications detail requests of a Service Pool
"""
@@ -549,20 +611,19 @@ class Publications(DetailHandler):
return self.success()
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> list['PublicationItem']:
parent = ensure.is_instance(parent, models.ServicePool)
return [
typing.cast(
types.rest.ItemDictType,
{
'id': i.uuid,
'revision': i.revision,
'publish_date': i.publish_date,
'state': i.state,
'reason': State.from_str(i.state).is_errored() and i.get_instance().error_reason() or '',
'state_date': i.state_date,
},
)
{
'id': i.uuid,
'revision': i.revision,
'publish_date': i.publish_date,
'state': i.state,
'reason': State.from_str(i.state).is_errored() and i.get_instance().error_reason() or '',
'state_date': i.state_date,
}
for i in parent.publications.all()
]
@@ -587,23 +648,28 @@ class Publications(DetailHandler):
def get_row_style(self, parent: 'Model') -> types.ui.RowStyleInfo:
return types.ui.RowStyleInfo(prefix='row-state-', field='state')
class ChangelogItem(types.rest.ItemDictType):
revision: int
stamp: datetime.datetime
log: str
class Changelog(DetailHandler):
class Changelog(DetailHandler['ChangelogItem']):
"""
Processes the transports detail requests of a Service Pool
"""
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(
self, parent: 'Model', item: typing.Optional[str]
) -> list['ChangelogItem']:
parent = ensure.is_instance(parent, models.ServicePool)
return [
typing.cast(
types.rest.ItemDictType,
{
'revision': i.revision,
'stamp': i.stamp,
'log': i.log,
},
)
{
'revision': i.revision,
'stamp': i.stamp,
'log': i.log,
}
for i in parent.changelog.all()
]

View File

@@ -29,12 +29,12 @@
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import datetime
import logging
import typing
import collections.abc
from django.utils.translation import gettext as _
from django.forms.models import model_to_dict
from django.db import IntegrityError, transaction
from django.core.exceptions import ValidationError
@@ -42,6 +42,7 @@ from uds.core.types.states import State
from uds.core.auths.user import User as AUser
from uds.core.util import log, ensure
from uds.core.util.rest.tools import as_typed_dict
from uds.core.util.model import process_uuid, sql_stamp_seconds
from uds.models import Authenticator, User, Group, ServicePool
from uds.core.managers.crypto import CryptoManager
@@ -49,7 +50,7 @@ from uds.core import consts, exceptions, types
from uds.REST.model import DetailHandler
from .user_services import AssignedService
from .user_services import AssignedUserService, UserServiceItem
if typing.TYPE_CHECKING:
from django.db.models import Model
@@ -77,7 +78,22 @@ def get_service_pools_for_groups(
yield servicepool
class Users(DetailHandler):
class UserItem(types.rest.ItemDictType):
id: str
name: str
real_name: str
comments: str
state: str
staff_member: bool
is_admin: bool
last_access: datetime.datetime
parent: typing.NotRequired[str]
mfa_data: str
role: str
groups: typing.NotRequired[list[str]]
class Users(DetailHandler[UserItem]):
custom_methods = [
'services_pools',
'user_services',
@@ -89,64 +105,25 @@ class Users(DetailHandler):
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> typing.Any:
parent = ensure.is_instance(parent, Authenticator)
# processes item to change uuid key for id
def uuid_to_id(
iterable: collections.abc.Iterable[typing.Any],
) -> collections.abc.Generator[typing.Any, None, None]:
for v in iterable:
v['id'] = v['uuid']
del v['uuid']
yield v
def as_user_item(model: 'User') -> UserItem:
base = as_typed_dict(
model,
UserItem,
)
# Convert uuid to id, that is what the frontend expects (instaad of the numeric id)
base['id'] = model.uuid
base['role'] = (
model.staff_member and (model.is_admin and _('Admin') or _('Staff member')) or _('User')
)
return base
logger.debug(item)
# Extract authenticator
try:
if item is None:
values = list(
uuid_to_id(
(
i
for i in parent.users.all().values(
'uuid',
'name',
'real_name',
'comments',
'state',
'staff_member',
'is_admin',
'last_access',
'parent',
'mfa_data',
)
)
)
)
for res in values:
res['role'] = (
res['staff_member']
and (res['is_admin'] and _('Admin') or _('Staff member'))
or _('User')
)
return values
return [as_user_item(i) for i in parent.users.all()]
u = parent.users.get(uuid__iexact=process_uuid(item))
res = model_to_dict(
u,
fields=(
'name',
'real_name',
'comments',
'state',
'staff_member',
'is_admin',
'last_access',
'parent',
'mfa_data',
),
)
res['id'] = u.uuid
res['role'] = (
res['staff_member'] and (res['is_admin'] and _('Admin') or _('Staff member')) or _('User')
)
res = as_user_item(u)
usr = AUser(u)
res['groups'] = [g.db_obj().uuid for g in usr.groups()]
logger.debug('Item: %s', res)
@@ -247,7 +224,7 @@ class Users(DetailHandler):
groups = self.fields_from_params(['groups'])['groups']
# Save but skip meta groups, they are not real groups, but just a way to group users based on rules
user.groups.set(g for g in parent.groups.filter(uuid__in=groups) if g.is_meta is False)
return {'id': user.uuid}
except User.DoesNotExist:
raise self.invalid_item_response() from None
@@ -319,19 +296,21 @@ class Users(DetailHandler):
return res
def user_services(self, parent: 'Authenticator', item: str) -> list[dict[str, typing.Any]]:
def user_services(self, parent: 'Authenticator', item: str) -> list[UserServiceItem]:
parent = ensure.is_instance(parent, Authenticator)
uuid = process_uuid(item)
user = parent.users.get(uuid=process_uuid(uuid))
res: list[dict[str, typing.Any]] = []
for i in user.userServices.all():
if i.state == State.USABLE:
v = AssignedService.item_as_dict(i)
v['pool'] = i.deployed_service.name
v['pool_id'] = i.deployed_service.uuid
res.append(v)
return res
def item_as_dict(assigned_user_service: 'UserService') -> UserServiceItem:
base = AssignedUserService.item_as_dict(assigned_user_service)
base['pool'] = assigned_user_service.deployed_service.name
base['pool_id'] = assigned_user_service.deployed_service.uuid
return base
return [
item_as_dict(i)
for i in user.userServices.all().prefetch_related('deployed_service').filter(state=State.USABLE)
]
def clean_related(self, parent: 'Authenticator', item: str) -> dict[str, str]:
uuid = process_uuid(item)
@@ -365,10 +344,22 @@ class Users(DetailHandler):
return {'status': 'ok'}
class Groups(DetailHandler):
class GroupItem(typing.TypedDict):
id: str
name: str
comments: str
state: str
type: str
meta_if_any: bool
skip_mfa: str
groups: typing.NotRequired[list[str]] # Only for meta groups
pools: typing.NotRequired[list[str]] # Only for single group items
class Groups(DetailHandler[GroupItem]):
custom_methods = ['services_pools', 'users']
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: 'Model', item: typing.Optional[str]) -> types.rest.GetItemsResult['GroupItem']:
parent = ensure.is_instance(parent, Authenticator)
try:
multi = False
@@ -377,10 +368,10 @@ class Groups(DetailHandler):
q = parent.groups.all().order_by('name')
else:
q = parent.groups.filter(uuid=process_uuid(item))
res: list[dict[str, typing.Any]] = []
res: list[GroupItem] = []
i = None
for i in q:
val: dict[str, typing.Any] = {
val: GroupItem = {
'id': i.uuid,
'name': i.name,
'comments': i.comments,

View File

@@ -54,10 +54,12 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
T = typing.TypeVar('T', bound=types.rest.ItemDictType)
# Details do not have types at all
# so, right now, we only process details petitions for Handling & tables info
# noinspection PyMissingConstructor
class DetailHandler(BaseModelHandler):
class DetailHandler(BaseModelHandler, typing.Generic[T]):
"""
Detail handler (for relations such as provider-->services, authenticators-->users,groups, deployed services-->cache,assigned, groups, transports
Urls recognized for GET are:
@@ -192,6 +194,7 @@ class DetailHandler(BaseModelHandler):
# Not understood, fallback, maybe the derived class can understand it
return self.fallback_get()
# For reference, this is the old code to be removed
if num_args == 1:
match self._args[0]:
case consts.rest.OVERVIEW:
@@ -290,7 +293,7 @@ class DetailHandler(BaseModelHandler):
# Override this to provide functionality
# Default (as sample) get_items
def get_items(self, parent: models.Model, item: typing.Optional[str]) -> types.rest.ManyItemsDictType:
def get_items(self, parent: models.Model, item: typing.Optional[str]) -> types.rest.GetItemsResult[T]:
"""
This MUST be overridden by derived classes
Excepts to return a list of dictionaries or a single dictionary, depending on "item" param

View File

@@ -91,7 +91,7 @@ class ModelHandler(BaseModelHandler):
[]
) # If this model respond to "custom" methods, we will declare them here
# If this model has details, which ones
detail: typing.ClassVar[typing.Optional[dict[str, type['DetailHandler']]]] = (
detail: typing.ClassVar[typing.Optional[dict[str, type['DetailHandler[typing.Any]']]]] = (
None # Dictionary containing detail routing
)
# Fields that are going to be saved directly

View File

@@ -112,11 +112,10 @@ class ItemDictType(typing.TypedDict):
# Alias for item type
# ItemDictType = dict[str, typing.Any]
ItemListType = list[ItemDictType]
ItemGeneratorType = typing.Generator[ItemDictType, None, None]
T_Item = typing.TypeVar("T_Item", bound=ItemDictType)
# Alias for get_items return type
ManyItemsDictType: typing.TypeAlias = ItemListType|ItemDictType|ItemGeneratorType
GetItemsResult: typing.TypeAlias = list[T_Item]|ItemDictType|typing.Iterator[T_Item]
#
FieldType = collections.abc.Mapping[str, typing.Any]
@@ -157,8 +156,8 @@ class HandlerNode:
ret += f'{" " * level} |- {method}\n'
# Add detail methods
if self.handler.detail:
for method in self.handler.detail.keys():
ret += f'{" " * level} |- {method}\n'
for method_name in self.handler.detail.keys():
ret += f'{" " * level} |- {method_name}\n'
return ret + ''.join(child.tree(level + 1) for child in self.children.values())

View File

@@ -71,7 +71,8 @@ def extract_doc(response: type[TypedResponse]) -> dict[str, typing.Any]:
def is_typed_response(t: type[TypedResponse]) -> bool:
return hasattr(t, '__orig_bases__') and TypedResponse in t.__orig_bases__
orig_bases = getattr(t, '__orig_bases__', None)
return orig_bases is not None and TypedResponse in orig_bases
# Regular expression to match the API: part of the docstring

View File

@@ -45,8 +45,7 @@ logger = logging.getLogger(__name__)
# FT = typing.TypeVar('FT', bound=collections.abc.Callable[..., typing.Any])
P = typing.ParamSpec('P')
R = typing.TypeVar('R')
R = typing.TypeVar('R', bound=typing.Any)
@dataclasses.dataclass
class CacheInfo:
@@ -147,16 +146,16 @@ class _HasConnect(typing.Protocol):
# Keep this, but mypy does not likes it... it's perfect with pyright
# We use pyright for type checking, so we will use this
HasConnect = typing.TypeVar('HasConnect', bound=_HasConnect)
HAS_CONNECT = typing.TypeVar('HAS_CONNECT', bound=_HasConnect)
def ensure_connected(
func: collections.abc.Callable[typing.Concatenate[HasConnect, P], R],
) -> collections.abc.Callable[typing.Concatenate[HasConnect, P], R]:
func: collections.abc.Callable[typing.Concatenate[HAS_CONNECT, P], R],
) -> collections.abc.Callable[typing.Concatenate[HAS_CONNECT, P], R]:
"""This decorator calls "connect" method of the class of the wrapped object"""
@functools.wraps(func)
def new_func(obj: HasConnect, /, *args: P.args, **kwargs: P.kwargs) -> R:
def new_func(obj: HAS_CONNECT, /, *args: P.args, **kwargs: P.kwargs) -> R:
# self = typing.cast(_HasConnect, args[0])
obj.connect()
return func(obj, *args, **kwargs)
@@ -181,11 +180,11 @@ def ensure_connected(
# Decorator for caching
# This decorator will cache the result of the function for a given time, and given parameters
def cached(
prefix: typing.Optional[str] = None,
timeout: typing.Union[collections.abc.Callable[[], int], int] = -1,
args: typing.Optional[typing.Union[collections.abc.Iterable[int], int]] = None,
kwargs: typing.Optional[typing.Union[collections.abc.Iterable[str], str]] = None,
key_helper: typing.Optional[collections.abc.Callable[[typing.Any], str]] = None,
prefix: str | None = None,
timeout: collections.abc.Callable[[], int] | int = -1,
args: collections.abc.Iterable[int] | int | None = None,
kwargs: collections.abc.Iterable[str] | str | None = None,
key_helper: collections.abc.Callable[[typing.Any], str] | None = None,
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
Decorator that gives us a "quick & clean" caching feature on the database.
@@ -340,8 +339,8 @@ def threaded(func: collections.abc.Callable[P, None]) -> collections.abc.Callabl
def blocker(
request_attr: typing.Optional[str] = None,
max_failures: typing.Optional[int] = None,
request_attr: str | None = None,
max_failures: int | None = None,
ignore_block_config: bool = False,
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
@@ -375,7 +374,7 @@ def blocker(
except uds.core.exceptions.rest.BlockAccess:
raise exceptions.rest.AccessDenied
request: typing.Optional[typing.Any] = getattr(args[0], request_attr or '_request', None)
request: typing.Any | None = getattr(args[0], request_attr or '_request', None)
# No request object, so we can't block
if request is None or not isinstance(request, types.requests.ExtendedHttpRequest):
@@ -410,7 +409,7 @@ def blocker(
def profiler(
log_file: typing.Optional[str] = None,
log_file: str | None = None,
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
"""
Decorator that will profile the wrapped function and log the results to the provided file
@@ -450,7 +449,7 @@ def retry_on_exception(
retries: int,
*,
wait_seconds: float = 2,
retryable_exceptions: typing.Optional[typing.List[typing.Type[Exception]]] = None,
retryable_exceptions: list[type[Exception]] | None = None,
do_log: bool = False,
) -> collections.abc.Callable[[collections.abc.Callable[P, R]], collections.abc.Callable[P, R]]:
to_retry = retryable_exceptions or [Exception]
@@ -473,7 +472,7 @@ def retry_on_exception(
raise e
time.sleep(wait_seconds * (2 ** min(i, 4))) # Exponential backoff until 16x
# retries == 0 allowed, but only use it for testing purposes
# because it's a nonsensical decorator otherwise
return fnc(*args, **kwargs)

View File

@@ -495,11 +495,11 @@ else:
class FuseContext(ctypes.Structure):
_fields_ = [
('fuse', ctypes.c_voidp), # type: ignore
('fuse', ctypes.c_voidp),
('uid', c_uid_t),
('gid', c_gid_t),
('pid', c_pid_t),
('private_data', ctypes.c_voidp), # type: ignore
('private_data', ctypes.c_voidp),
]
@@ -521,7 +521,7 @@ class FuseOperations(ctypes.Structure):
ctypes.c_size_t,
),
),
('getdir', ctypes.c_voidp), # type: ignore # Deprecated, use readdir
('getdir', ctypes.c_voidp), # Deprecated, use readdir
('mknod', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, c_mode_t, c_dev_t)),
('mkdir', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, c_mode_t)),
('unlink', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p)),
@@ -532,7 +532,7 @@ class FuseOperations(ctypes.Structure):
('chmod', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, c_mode_t)),
('chown', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, c_uid_t, c_gid_t)),
('truncate', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, c_off_t)),
('utime', ctypes.c_voidp), # type: ignore # Deprecated, use utimens
('utime', ctypes.c_voidp), # Deprecated, use utimens
(
'open',
ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.POINTER(fuse_file_info)),
@@ -604,10 +604,10 @@ class FuseOperations(ctypes.Structure):
ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_char_p,
ctypes.c_voidp, # type: ignore
ctypes.c_voidp,
ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_voidp, # type: ignore
ctypes.c_voidp,
ctypes.c_char_p,
ctypes.POINTER(c_stat),
c_off_t,
@@ -629,8 +629,8 @@ class FuseOperations(ctypes.Structure):
ctypes.POINTER(fuse_file_info),
),
),
('init', ctypes.CFUNCTYPE(ctypes.c_voidp, ctypes.c_voidp)), # type: ignore
('destroy', ctypes.CFUNCTYPE(ctypes.c_voidp, ctypes.c_voidp)), # type: ignore
('init', ctypes.CFUNCTYPE(ctypes.c_voidp, ctypes.c_voidp)),
('destroy', ctypes.CFUNCTYPE(ctypes.c_voidp, ctypes.c_voidp)),
('access', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_char_p, ctypes.c_int)),
(
'create',
@@ -656,7 +656,7 @@ class FuseOperations(ctypes.Structure):
ctypes.c_char_p,
ctypes.POINTER(fuse_file_info),
ctypes.c_int,
ctypes.c_voidp, # type: ignore
ctypes.c_voidp,
),
),
(
@@ -798,7 +798,7 @@ class FUSE:
continue
if hasattr(typing.cast(typing.Any, prototype), 'argtypes'):
val = prototype(partial(FUSE._wrapper, getattr(self, name))) # type: ignore
val = prototype(partial(FUSE._wrapper, getattr(self, name)))
setattr(fuse_ops, name, val)
@@ -846,14 +846,14 @@ class FUSE:
return func(*args, **kwargs) or 0
except OSError as e:
if e.errno > 0: # pyright: ignore
if e.errno and e.errno > 0:
logger.debug(
"FUSE operation %s raised a %s, returning errno %s.",
func.__name__,
type(e),
e.errno,
)
return -e.errno # pyright: ignore
return -e.errno
logger.error(
"FUSE operation %s raised an OSError with negative " "errno %s, returning errno.EINVAL.",
func.__name__,

View File

@@ -32,28 +32,28 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
import logging
from threading import Lock
import threading
import datetime
from time import mktime
import time
from django.db import connection
from uds.core import consts
from uds.core.managers.crypto import CryptoManager
logger = logging.getLogger(__name__)
CACHE_TIME_TIMEOUT = 60 # Every 60 second, refresh the time from database (to avoid drifts)
CACHE_TIME_TIMEOUT: typing.Final[int] = 60 # Every 60 second, refresh the time from database (to avoid drifts)
# pylint: disable=too-few-public-methods
class TimeTrack:
"""
Reduces the queries to database to get the current time
keeping it cached for CACHE_TIME_TIMEOUT seconds (and adjusting it based on local time)
"""
lock: typing.ClassVar[Lock] = Lock()
lock: typing.ClassVar[threading.Lock] = threading.Lock()
last_check: typing.ClassVar[datetime.datetime] = consts.NEVER
cached_time: typing.ClassVar[datetime.datetime] = consts.NEVER
hits: typing.ClassVar[int] = 0
@@ -120,7 +120,7 @@ def sql_stamp_seconds() -> int:
Returns:
int: Unix timestamp
"""
return int(mktime(sql_now().timetuple()))
return int(time.mktime(sql_now().timetuple()))
def sql_stamp() -> float:
@@ -129,7 +129,7 @@ def sql_stamp() -> float:
Returns:
float: Unix timestamp
"""
return float(mktime(sql_now().timetuple())) + sql_now().microsecond / 1000000.0
return float(time.mktime(sql_now().timetuple())) + sql_now().microsecond / 1000000.0
def generate_uuid(obj: typing.Any = None) -> str:
@@ -167,9 +167,9 @@ def get_my_ip_from_db() -> str:
with connection.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchone()
if result:
result = result[0] if isinstance(result[0], str) else result[0].decode('utf8')
result_row = cursor.fetchone()
if result_row:
result = result_row[0] if isinstance(result_row[0], str) else result_row[0].decode('utf8')
return result.split(':')[0]
except Exception as e:

View File

@@ -34,6 +34,10 @@ import typing
import collections.abc
import logging
if typing.TYPE_CHECKING:
from django.db import models
from uds.core import types
logger = logging.getLogger(__name__)
T = typing.TypeVar('T', bound=typing.Any)
@@ -96,3 +100,37 @@ def match_args(
# Invoke error callback
error()
return None # In fact, error is expected to raise an exception, so this is never reached
def as_typed_dict(
model: 'models.Model',
t: type['types.rest.T_Item'],
) -> 'types.rest.T_Item':
"""
Converts a model to a TypedDict of type T.
This is useful to convert models to TypedDicts for use in REST APIs.
"""
annotations = t.__annotations__
dct: dict[str, typing.Any] = {}
NOT_FOUND = object() # Sentinel for not found values
for field, field_type in annotations.items():
# Skip "typing.NotRequired" fields
if typing.get_origin(field_type) is typing.NotRequired:
continue
if hasattr(model, field):
value: typing.Any = getattr(model, field, NOT_FOUND)
if value is NOT_FOUND:
logger.warning(
'Field %s not found in model %s, using default value',
field,
model.__class__.__name__,
)
continue
# Note, currently we do no convert the types, and do not support complex types
dct[field] = value
# Ensure that the dictionary is compatible with the TypedDict
return typing.cast('types.rest.T_Item', dct)

View File

@@ -38,7 +38,6 @@ import ssl
import typing
import datetime
import certifi
import requests
import requests.adapters
import urllib3
@@ -210,7 +209,7 @@ def secure_requests_session(*, verify: 'str|bool' = True, proxies: 'dict[str, st
# See urllib3.poolmanager.SSL_KEYWORDS for all available keys.
self._ssl_context = kwargs['ssl_context'] = create_client_sslcontext(verify=verify is True)
return super().init_poolmanager(*args, **kwargs) # type: ignore
return super().init_poolmanager(*args, **kwargs) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
def cert_verify(self, conn: typing.Any, url: typing.Any, verify: 'str|bool', cert: typing.Any) -> None:
"""Verify a SSL certificate. This method should not be called from user
@@ -231,7 +230,7 @@ def secure_requests_session(*, verify: 'str|bool' = True, proxies: 'dict[str, st
# conn_kw = conn.__dict__['conn_kw']
# conn_kw['ssl_context'] = self.ssl_context
super().cert_verify(conn, url, verify, cert) # type: ignore
super().cert_verify(conn, url, verify, cert) # pyright: ignore[reportUnknownMemberType]
session = requests.Session()
session.mount("https://", UDSHTTPAdapter())

View File

@@ -25,6 +25,7 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# mypy: disable-error-code="attr-defined"
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""