1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-25 23:21:41 +03:00

Improvements:

* Moving some named tuples to dataclasses
* Added unique_id to service token alias (to avoid recreation...)
This commit is contained in:
Adolfo Gómez García 2024-01-01 20:08:50 +01:00
parent 48be614e20
commit 499c5f8ec4
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
21 changed files with 131 additions and 66 deletions

View File

@ -411,8 +411,17 @@ class Initialize(ActorV3Action):
if not service:
service = typing.cast('Service', Service.objects.get(token=token))
# If exists, create and alias for it
alias_token = CryptoManager().randomString(40) # fix alias with new token
service.aliases.create(alias=alias_token)
# Get first mac and, if not exists, get first ip
unique_id = self._params['id'][0].get('mac', self._params['id'][0].get('ip', ''))
if unique_id is None:
raise BlockAccess()
# If exists, do not create a new one (avoid creating for old 3.x actors lots of aliases...)
if not ServiceTokenAlias.objects.filter(service=service, unique_id=unique_id).exists():
alias_token = CryptoManager().randomString(40) # fix alias with new token
service.aliases.create(alias=alias_token, unique_id=unique_id)
else:
# If exists, get existing one
alias_token = ServiceTokenAlias.objects.get(service=service, unique_id=unique_id).alias
# Locate an userService that belongs to this service and which
# Build the possible ids and make initial filter to match service

View File

@ -331,7 +331,7 @@ class ServersGroups(ModelHandler):
for i in types.servers.ServerSubtype.manager().enum():
v = types.rest.TypeInfo(
name=i.description, type=f'{i.type.name}@{i.subtype}', description='', icon=i.icon
).asDict(group=gettext('Managed') if i.managed else gettext('Unmanaged'))
).as_dict(group=gettext('Managed') if i.managed else gettext('Unmanaged'))
yield v
def getGui(self, type_: str) -> list[typing.Any]:

View File

@ -269,7 +269,7 @@ class ServicesPools(ModelHandler):
val['pool_group_id'] = poolGroupId
val['pool_group_name'] = poolGroupName
val['pool_group_thumb'] = poolGroupThumb
val['usage'] = str(item.usage(usage_count)[0]) + '%'
val['usage'] = str(item.usage(usage_count).percent) + '%'
if item.osmanager:
val['osmanager_id'] = item.osmanager.uuid

View File

@ -287,7 +287,7 @@ class BaseModelHandler(Handler):
type=type_.getType(),
description=_(type_.description()),
icon=type_.icon64().replace('\n', ''),
).asDict(**self.typeInfo(type_))
).as_dict(**self.typeInfo(type_))
if hasattr(type_, 'group'):
res['group'] = _(type_.group) # Add group info is it is contained
return res

View File

@ -224,7 +224,7 @@ class ServerManager(metaclass=singleton.Singleton):
excludeServersUUids = excludeServersUUids or set()
with serverGroup.properties as props:
info: typing.Optional[types.servers.ServerCounter] = types.servers.ServerCounter.fromIterable(
info: typing.Optional[types.servers.ServerCounter] = types.servers.ServerCounter.from_iterable(
props.get(prop_name)
)
# If server is forced, and server is part of the group, use it
@ -315,7 +315,7 @@ class ServerManager(metaclass=singleton.Singleton):
userUuid = userUuid if userUuid else userService.user.uuid if userService.user else None
if userUuid is None:
return types.servers.ServerCounter.empty() # No user is assigned to this service, nothing to do
return types.servers.ServerCounter.null() # No user is assigned to this service, nothing to do
prop_name = self.propertyName(userService.user)
with serverGroup.properties as props:
@ -325,10 +325,10 @@ class ServerManager(metaclass=singleton.Singleton):
serverCounter: typing.Optional[
types.servers.ServerCounter
] = types.servers.ServerCounter.fromIterable(props.get(prop_name))
] = types.servers.ServerCounter.from_iterable(props.get(prop_name))
# If no cached value, get server assignation
if serverCounter is None:
return types.servers.ServerCounter.empty()
return types.servers.ServerCounter.null()
# Ensure counter is at least 1
serverCounter = types.servers.ServerCounter(
serverCounter.server_uuid, max(1, serverCounter.counter)
@ -439,7 +439,7 @@ class ServerManager(metaclass=singleton.Singleton):
prop_name = self.propertyName(userService.user)
with serverGroup.properties as props:
info: typing.Optional[types.servers.ServerCounter] = types.servers.ServerCounter.fromIterable(
info: typing.Optional[types.servers.ServerCounter] = types.servers.ServerCounter.from_iterable(
props.get(prop_name)
)
if info is None:

View File

@ -148,7 +148,7 @@ def process_logout(server: 'models.Server', data: dict[str, typing.Any]) -> typi
def process_ping(server: 'models.Server', data: dict[str, typing.Any]) -> typing.Any:
if 'stats' in data:
server.stats = types.servers.ServerStats.fromDict(data['stats'])
server.stats = types.servers.ServerStats.from_dict(data['stats'])
# Set stats on server
server.last_ping = getSqlDatetime()

View File

@ -189,7 +189,7 @@ class ServerApiRequester:
userservice_uuid=userService.uuid,
service_type=service_type,
assignations=count,
).asDict(),
).as_dict(),
)
return True
@ -223,7 +223,7 @@ class ServerApiRequester:
udsuser_uuid=userService.user.uuid if userService.user else '',
userservice_uuid=userService.uuid,
service_type=info.service_type,
).asDict(),
).as_dict(),
)
return True
@ -233,7 +233,7 @@ class ServerApiRequester:
Notifies removal of user service to server
"""
logger.debug('Notifying release of service %s to server %s', userService.uuid, self.server.host)
self.post('release', types.connections.ReleaseRequest(userservice_uuid=userService.uuid).asDict())
self.post('release', types.connections.ReleaseRequest(userservice_uuid=userService.uuid).as_dict())
return True
@ -258,5 +258,5 @@ class ServerApiRequester:
return None
# Will store stats on property, so no save is needed
self.server.stats = types.servers.ServerStats.fromDict(data)
self.server.stats = types.servers.ServerStats.from_dict(data)
return self.server.stats

View File

@ -139,7 +139,7 @@ def notifyPreconnect(userService: 'UserService', info: types.connections.Connect
udsuser_uuid=userService.user.uuid if userService.user else '',
userservice_uuid=userService.uuid,
service_type=info.service_type,
).asDict(),
).as_dict(),
)

View File

@ -30,13 +30,15 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
import dataclasses
import collections.abc
from .services import ServiceType
# For requests to actors/servers
class PreconnectRequest(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class PreconnectRequest:
"""Information sent on a preconnect request"""
udsuser: str # UDS user name
@ -49,12 +51,13 @@ class PreconnectRequest(typing.NamedTuple):
ip: str # IP of the client
hostname: str # Hostname of the client
def asDict(self) -> dict[str, str]:
return self._asdict()
def as_dict(self) -> dict[str, str]:
return dataclasses.asdict(self)
# For requests to actors/servers
class AssignRequest(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class AssignRequest:
"""Information sent on a assign request"""
udsuser: str
@ -64,17 +67,22 @@ class AssignRequest(typing.NamedTuple):
assignations: int # Number of times this service has been assigned
def asDict(self) -> dict[str, 'str|int']:
return self._asdict()
def as_dict(self) -> dict[str, 'str|int']:
return dataclasses.asdict(self)
class ReleaseRequest(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class ReleaseRequest:
"""Information sent on a release request"""
userservice_uuid: str # UUID of userservice
def asDict(self) -> dict[str, str]:
return self._asdict()
def as_dict(self) -> dict[str, str]:
return dataclasses.asdict(self)
class ConnectionData(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class ConnectionData:
"""
Connection data provided by transports, and contains all the "transformable" information needed to connect to a service
(such as username, password, domain, etc..)
@ -91,11 +99,11 @@ class ConnectionData(typing.NamedTuple):
# sso: bool = False # For future sso implementation
def asDict(self) -> dict[str, str]:
return self._asdict()
def as_dict(self) -> dict[str, str]:
return dataclasses.asdict(self)
class ConnectionSource(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class ConnectionSource:
"""
Connection source from where the connection is being done
"""
@ -103,5 +111,5 @@ class ConnectionSource(typing.NamedTuple):
ip: str # IP of the client
hostname: str # Hostname of the client
def asDict(self) -> dict[str, str]:
return self._asdict()
def as_dict(self) -> dict[str, str]:
return dataclasses.asdict(self)

View File

@ -31,6 +31,7 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import enum
import typing
import dataclasses
import collections.abc
from django.utils.translation import gettext as _
@ -84,8 +85,8 @@ class HighAvailabilityPolicy(enum.IntEnum):
(HighAvailabilityPolicy.ENABLED, _('Enabled')),
]
class UsageInfo(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class UsageInfo:
used: int
total: int

View File

@ -30,15 +30,17 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import typing
import dataclasses
import collections.abc
class TypeInfo(typing.NamedTuple):
@dataclasses.dataclass(frozen=True)
class TypeInfo:
name: str
type: str
description: str
icon: str
def asDict(self, **extra) -> dict[str, typing.Any]:
def as_dict(self, **extra) -> dict[str, typing.Any]:
return {
'name': self.name,
'type': self.type,
@ -46,7 +48,9 @@ class TypeInfo(typing.NamedTuple):
'icon': self.icon,
**extra
}
# This is a named tuple for convenience, and must be
# compatible with tuple[str, bool] (name, needs_parent)
class ModelCustomMethod(typing.NamedTuple):
name: str
needs_parent: bool = True

View File

@ -113,6 +113,23 @@ ServerSubtype.manager().register(
ServerType.UNMANAGED, 'ip', 'Unmanaged IP Server', consts.images.DEFAULT_IMAGE_BASE64, False
)
@dataclasses.dataclass(frozen=True)
class ServerDiskInfo:
mountpoint: str
used: int
total: int
@staticmethod
def from_dict(data: dict[str, typing.Any]) -> 'ServerDiskInfo':
return ServerDiskInfo(data['mountpoint'], data['used'], data['total'])
def as_dict(self) -> dict[str, typing.Any]:
return {
'mountpoint': self.mountpoint,
'used': self.used,
'total': self.total,
}
@dataclasses.dataclass
class ServerStats:
@ -120,7 +137,7 @@ class ServerStats:
memtotal: int = 0 # In bytes
cpuused: float = 0 # 0-1 (cpu usage)
uptime: int = 0 # In seconds
disks: list[tuple[str, int, int]] = dataclasses.field(
disks: list[ServerDiskInfo] = dataclasses.field(
default_factory=list
) # List of tuples (mountpoint, used, total)
connections: int = 0 # Number of connections
@ -192,14 +209,14 @@ class ServerStats:
)
@staticmethod
def fromDict(data: collections.abc.Mapping[str, typing.Any], **kwargs: typing.Any) -> 'ServerStats':
def from_dict(data: collections.abc.Mapping[str, typing.Any], **kwargs: typing.Any) -> 'ServerStats':
from uds.core.util.model import getSqlStamp # Avoid circular import
dct = {k: v for k, v in data.items()} # Make a copy
dct.update(kwargs) # and update with kwargs
disks: list[tuple[str, int, int]] = []
disks: list[ServerDiskInfo] = []
for disk in dct.get('disks', []):
disks.append((disk['mountpoint'], disk['used'], disk['total']))
disks.append(ServerDiskInfo.from_dict(disk))
return ServerStats(
memused=dct.get('memused', 1),
memtotal=dct.get('memtotal') or 1, # Avoid division by zero
@ -211,11 +228,17 @@ class ServerStats:
stamp=dct.get('stamp', getSqlStamp()),
)
def asDict(self) -> dict[str, typing.Any]:
data = self._asdict()
# Replace disk as dicts
data['disks'] = [{'mountpoint': d[0], 'used': d[1], 'total': d[2]} for d in self.disks]
return data
def as_dict(self) -> dict[str, typing.Any]:
return {
'memused': self.memused,
'memtotal': self.memtotal,
'cpuused': self.cpuused,
'uptime': self.uptime,
'disks': [d.as_dict() for d in self.disks],
'connections': self.connections,
'current_users': self.current_users,
'stamp': self.stamp,
}
@staticmethod
def empty() -> 'ServerStats':
@ -231,11 +254,11 @@ class ServerCounter(typing.NamedTuple):
counter: int
@staticmethod
def fromIterable(data: typing.Optional[collections.abc.Iterable]) -> typing.Optional['ServerCounter']:
def from_iterable(data: typing.Optional[collections.abc.Iterable]) -> typing.Optional['ServerCounter']:
if data is None:
return None
return ServerCounter(*data)
@staticmethod
def empty() -> 'ServerCounter':
def null() -> 'ServerCounter':
return ServerCounter('', 0)

View File

@ -138,6 +138,6 @@ class FieldInfo:
fills: typing.Optional[Filler] = None
rows: typing.Optional[int] = None
def asDict(self) -> dict[str, typing.Any]:
def as_dict(self) -> dict[str, typing.Any]:
"""Returns a dict with all fields that are not None"""
return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}

View File

@ -375,7 +375,7 @@ class gui:
and don't want to
alter original values.
"""
data = typing.cast(dict, self._fieldsInfo.asDict())
data = typing.cast(dict, self._fieldsInfo.as_dict())
if 'value' in data:
del data['value'] # We don't want to send value on guiDescription
data['label'] = _(data['label']) if data['label'] else ''

View File

@ -32,6 +32,7 @@
"""
import logging
import typing
import dataclasses
import collections.abc
import enum
@ -43,9 +44,10 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ObjTypeInfo(typing.NamedTuple):
type: int
model: type['Model']
@dataclasses.dataclass(frozen=True)
class ObjTypeInfo:
obj_type: int
model: 'type[Model]'
@enum.unique
class ObjectType(enum.Enum):
@ -82,7 +84,7 @@ class ObjectType(enum.Enum):
@property
def type(self) -> int:
"""Returns the integer value of this object type. (The "type" id)"""
return self.value.type
return self.value.obj_type
@staticmethod
def from_model(model: 'Model') -> 'ObjectType':
@ -110,4 +112,4 @@ class ObjectType(enum.Enum):
>>> ObjectType.PROVIDER == 2
False
"""
return super().__eq__(__o) or self.value.type == __o
return super().__eq__(__o) or self.value.obj_type == __o

View File

@ -341,4 +341,9 @@ class Migration(migrations.Migration):
max_length=32,
),
),
migrations.AddField(
model_name="servicetokenalias",
name="unique_id",
field=models.CharField(db_index=True, default="", max_length=128),
),
]

View File

@ -161,6 +161,8 @@ class MetaPool(UUIDModel, TaggingMixin): # type: ignore
Returns the % used services, then count and the max related to "maximum" user services
If no "maximum" number of services, will return 0% ofc
cachedValue is used to optimize (if known the number of assigned services, we can avoid to query the db)
Note:
No metapoools, cachedValue is ignored, but keep for consistency with servicePool
"""
# If no pools, return 0%
if self.members.count() == 0:

View File

@ -244,7 +244,7 @@ class Server(UUIDModel, TaggingMixin, properties.PropertiesMixin):
"""Returns the current stats of this server, or None if not available"""
statsDct = self.properties.get('stats', None)
if statsDct:
return types.servers.ServerStats.fromDict(statsDct)
return types.servers.ServerStats.from_dict(statsDct)
return None
@stats.setter
@ -254,7 +254,7 @@ class Server(UUIDModel, TaggingMixin, properties.PropertiesMixin):
del self.properties['stats']
else:
# Set stamp to current time and save it, overwriting existing stamp if any
statsDict = value.asDict()
statsDict = value.as_dict()
statsDict['stamp'] = getSqlStamp()
self.properties['stats'] = statsDict
@ -263,14 +263,14 @@ class Server(UUIDModel, TaggingMixin, properties.PropertiesMixin):
stats = self.stats
if stats and stats.is_valid: # If rae invalid, do not waste time recalculating
# Avoid replacing current "stamp" value, this is just a "simulation"
self.properties['stats'] = stats.adjust(users_increment=1).asDict()
self.properties['stats'] = stats.adjust(users_increment=1).as_dict()
def newRelease(self) -> None:
"""Simulates, with current stats, the release of a user"""
stats = self.stats
if stats and stats.is_valid:
# Avoid replacing current "stamp" value, this is just a "simulation"
self.properties['stats'] = stats.adjust(users_increment=-1).asDict()
self.properties['stats'] = stats.adjust(users_increment=-1).as_dict()
def isRestrained(self) -> bool:
"""Returns if this server is restrained or not

View File

@ -63,6 +63,7 @@ class ServiceTokenAlias(models.Model):
service = models.ForeignKey('Service', on_delete=models.CASCADE, related_name='aliases')
alias = models.CharField(max_length=64, unique=True)
unique_id = models.CharField(max_length=128, default='', db_index=True) # Used to locate an already created alias for a userService and service
def __str__(self) -> str:
return str(self.alias) # pylint complains about CharField

View File

@ -147,11 +147,13 @@ class ActorInitializeTest(rest.test.RESTActorTestCase):
success = functools.partial(self.invoke_success, 'unmanaged')
failure = functools.partial(self.invoke_failure, 'unmanaged')
TEST_MAC: typing.Final[str] = '00:00:00:00:00:00'
# This will succeed, but only alias token is returned because MAC is not registered by UDS
result = success(
actor_token,
mac='00:00:00:00:00:00',
mac=TEST_MAC,
)
# Unmanaged host is the response for initialization of unmanaged actor ALWAYS
@ -159,9 +161,17 @@ class ActorInitializeTest(rest.test.RESTActorTestCase):
self.assertEqual(result['token'], result['own_token'])
self.assertIsNone(result['unique_id'])
self.assertIsNone(result['os'])
# Store alias token for later tests
alias_token = result['token']
# If repeated, same token is returned
result = success(
actor_token,
mac=TEST_MAC,
)
self.assertEqual(result['token'], alias_token)
# Now, invoke a "nice" initialize
result = success(

View File

@ -81,12 +81,12 @@ class ServerEventsPingTest(rest.test.RESTTestCase):
cpuused=random.random(), # nosec: test data
uptime=random.randint(0, 1000000), # nosec: test data
disks=[
(
types.servers.ServerDiskInfo(
'c:\\',
random.randint(0, 100000000), # nosec: test data
random.randint(100000000, 1000000000), # nosec: test data
),
(
types.servers.ServerDiskInfo(
'd:\\',
random.randint(0, 100000000), # nosec: test data
random.randint(100000000, 1000000000), # nosec: test data
@ -101,7 +101,7 @@ class ServerEventsPingTest(rest.test.RESTTestCase):
data={
'token': self.server.token,
'type': 'ping',
'stats': stats.asDict(),
'stats': stats.as_dict(),
},
)
@ -110,15 +110,15 @@ class ServerEventsPingTest(rest.test.RESTTestCase):
server_stats = self.server.properties.get('stats', None)
self.assertIsNotNone(server_stats)
# Get stats, but clear stamp
statsResponse = types.servers.ServerStats.fromDict(server_stats, stamp=0)
statsResponse = types.servers.ServerStats.from_dict(server_stats, stamp=0)
self.assertEqual(statsResponse, stats)
# Ensure that stamp is not 0 on server_stats dict
self.assertNotEqual(server_stats['stamp'], 0)
# Ensure stat is valid right now
statsResponse = types.servers.ServerStats.fromDict(server_stats)
statsResponse = types.servers.ServerStats.from_dict(server_stats)
self.assertTrue(statsResponse.is_valid)
statsResponse = types.servers.ServerStats.fromDict(server_stats, stamp=getSqlStamp() - consts.system.DEFAULT_CACHE_TIMEOUT - 1)
statsResponse = types.servers.ServerStats.from_dict(server_stats, stamp=getSqlStamp() - consts.system.DEFAULT_CACHE_TIMEOUT - 1)
self.assertFalse(statsResponse.is_valid)
def test_event_ping_without_stats(self) -> None: