1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-12 09:17:56 +03:00

Some improvements to basic types and minor cosmetic fixes

This commit is contained in:
Adolfo Gómez García 2024-01-02 03:28:12 +01:00
parent 499c5f8ec4
commit 150d8c4197
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
18 changed files with 73 additions and 71 deletions

2
actor

@ -1 +1 @@
Subproject commit 46088958fb5e883bda4c516244e755e02f83b862
Subproject commit e8a3b41cade6001bedb8ecdb8ef34654528f4f70

View File

@ -159,7 +159,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
)
# Fix max_services_count_type to ServicesCountingType enum or ServicesCountingType.STANDARD if not found
try:
fields['max_services_count_type'] = types.services.ServicesCountingType.fromInt(int(fields['max_services_count_type']))
fields['max_services_count_type'] = types.services.ServicesCountingType.from_int(int(fields['max_services_count_type']))
except Exception:
fields['max_services_count_type'] = types.services.ServicesCountingType.STANDARD
tags = fields['tags']

View File

@ -35,6 +35,7 @@ import hashlib
import secrets
import string
import typing
import dataclasses
import collections.abc
import datetime
import urllib.parse
@ -43,7 +44,6 @@ from base64 import b64decode
import defusedxml.ElementTree as etree
import jwt
import requests
from cryptography.x509 import load_pem_x509_certificate
from django.utils.translation import gettext
from django.utils.translation import gettext_noop as _
@ -64,8 +64,8 @@ PKCE_ALPHABET: typing.Final[str] = string.ascii_letters + string.digits + '-._~'
# Length of the State parameter
STATE_LENGTH: typing.Final[int] = 16
class TokenInfo(typing.NamedTuple):
@dataclasses.dataclass
class TokenInfo:
access_token: str
token_type: str
expires: datetime.datetime
@ -75,16 +75,16 @@ class TokenInfo(typing.NamedTuple):
id_token: typing.Optional[str]
@staticmethod
def fromJson(json: dict[str, typing.Any]) -> 'TokenInfo':
def from_dict(dct: dict[str, typing.Any]) -> 'TokenInfo':
# expires is -10 to avoid problems with clock sync
return TokenInfo(
access_token=json['access_token'],
token_type=json['token_type'],
expires=model.getSqlDatetime() + datetime.timedelta(seconds=json['expires_in'] - 10),
refresh_token=json['refresh_token'],
scope=json['scope'],
info=json.get('info', {}),
id_token=json.get('id_token', None),
access_token=dct['access_token'],
token_type=dct['token_type'],
expires=model.getSqlDatetime() + datetime.timedelta(seconds=dct['expires_in'] - 10),
refresh_token=dct['refresh_token'],
scope=dct['scope'],
info=dct.get('info', {}),
id_token=dct.get('id_token', None),
)
@ -340,7 +340,7 @@ class OAuth2Authenticator(auths.Authenticator):
if not req.ok:
raise Exception('Error requesting token: {}'.format(req.text))
return TokenInfo.fromJson(req.json())
return TokenInfo.from_dict(req.json())
def _requestInfo(self, token: 'TokenInfo') -> dict[str, typing.Any]:
"""Request user info from the info endpoint using the token received from the token endpoint

View File

@ -1,3 +1,4 @@
import dataclasses
import io
import logging
import enum
@ -69,8 +70,8 @@ class RadiusStates(enum.IntEnum):
NOT_NEEDED = INCORRECT
NEEDED = CORRECT
class RadiusResult(typing.NamedTuple):
@dataclasses.dataclass
class RadiusResult:
"""
Result of an AccessChallenge request.
"""

View File

@ -278,7 +278,7 @@ class ServerManager(metaclass=singleton.Singleton):
except exceptions.UDSException: # No more servers
return None
elif lockTime: # If lockTime is set, update it
models.Server.objects.filter(uuid=info[0]).update(locked_until=now + lockTime)
models.Server.objects.filter(uuid=info.server_uuid).update(locked_until=now + lockTime)
# Notify to server
# Update counter
@ -339,7 +339,7 @@ class ServerManager(metaclass=singleton.Singleton):
else: # Not last one, just decrement counter
props[prop_name] = (serverCounter.server_uuid, serverCounter.counter - 1)
server = models.Server.objects.get(uuid=serverCounter[0])
server = models.Server.objects.get(uuid=serverCounter.server_uuid)
if unlock or serverCounter.counter == 1:
server.locked_until = None # Ensure server is unlocked if no more users are assigned to it

View File

@ -75,7 +75,7 @@ class AuthenticationResult:
FAILED_AUTH = AuthenticationResult(success=AuthenticationState.FAIL)
SUCCESS_AUTH = AuthenticationResult(success=AuthenticationState.SUCCESS)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class AuthCallbackParams:
'''Parameters passed to auth callback stage2
@ -91,7 +91,7 @@ class AuthCallbackParams:
query_string: str
@staticmethod
def fromRequest(request: 'HttpRequest') -> 'AuthCallbackParams':
def from_request(request: 'HttpRequest') -> 'AuthCallbackParams':
return AuthCallbackParams(
https=request.is_secure(),
host=request.META['HTTP_HOST'],

View File

@ -37,7 +37,7 @@ from .services import ServiceType
# For requests to actors/servers
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class PreconnectRequest:
"""Information sent on a preconnect request"""
@ -56,7 +56,7 @@ class PreconnectRequest:
# For requests to actors/servers
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class AssignRequest:
"""Information sent on a assign request"""
@ -71,7 +71,7 @@ class AssignRequest:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class ReleaseRequest:
"""Information sent on a release request"""
@ -81,7 +81,7 @@ class ReleaseRequest:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class ConnectionData:
"""
Connection data provided by transports, and contains all the "transformable" information needed to connect to a service
@ -102,7 +102,7 @@ class ConnectionData:
def as_dict(self) -> dict[str, str]:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class ConnectionSource:
"""
Connection source from where the connection is being done

View File

@ -35,14 +35,11 @@ import enum
import typing
import collections.abc
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class DetectedOsInfo:
os: 'KnownOS'
browser: 'KnownBrowser'
version: str
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
version: str
class KnownOS(enum.Enum):

View File

@ -59,7 +59,7 @@ class TransportSelectionPolicy(enum.IntEnum):
COMMON = 1
LABEL = 2
def asStr(self) -> str:
def as_str(self) -> str:
return self.name.lower()
@staticmethod
@ -75,7 +75,7 @@ class HighAvailabilityPolicy(enum.IntEnum):
DISABLED = 0
ENABLED = 1
def asStr(self) -> str:
def as_str(self) -> str:
return str(self)
@staticmethod

View File

@ -53,7 +53,7 @@ class CommonPrefs:
BYPASS_PREF = 'bypassPluginDetection'
@staticmethod
def getWidthHeight(size: str) -> tuple[int, int]:
def get_wh(size: str) -> tuple[int, int]:
"""
Get width based on screenSizePref value
"""
@ -67,7 +67,7 @@ class CommonPrefs:
}.get(size, (1024, 768))
@staticmethod
def getDepth(depth: str) -> int:
def get_depth(depth: str) -> int:
"""
Get depth based on depthPref value
"""

View File

@ -33,7 +33,7 @@ import typing
import dataclasses
import collections.abc
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass
class TypeInfo:
name: str
type: str

View File

@ -113,6 +113,7 @@ ServerSubtype.manager().register(
ServerType.UNMANAGED, 'ip', 'Unmanaged IP Server', consts.images.DEFAULT_IMAGE_BASE64, False
)
@dataclasses.dataclass(frozen=True)
class ServerDiskInfo:
mountpoint: str
@ -248,7 +249,8 @@ class ServerStats:
# Human readable
return f'memory: {self.memused//(1024*1024)}/{self.memtotal//(1024*1024)} cpu: {self.cpuused*100} users: {self.current_users}, weight: {self.weight()}, valid: {self.is_valid}'
# ServerCounter must be serializable by json, so
# we keep it as a NamedTuple instead of a dataclass
class ServerCounter(typing.NamedTuple):
server_uuid: str
counter: int

View File

@ -36,7 +36,7 @@ class ServiceType(enum.StrEnum):
VDI = 'VDI'
VAPP = 'VAPP'
def fromStr(self, value: str) -> 'ServiceType':
def from_str(self, value: str) -> 'ServiceType':
"""Returns the service type from a string"""
return ServiceType(value.upper())
@ -48,7 +48,7 @@ class ServicesCountingType(enum.IntEnum):
CONSERVATIVE = 1
@staticmethod
def fromInt(value: int) -> 'ServicesCountingType':
def from_int(value: int) -> 'ServicesCountingType':
"""Returns the MaxServiceCountingMethodType from an int
If the int is not a valid value, returns STANDARD
"""
@ -58,7 +58,7 @@ class ServicesCountingType(enum.IntEnum):
return ServicesCountingType.STANDARD
@staticmethod
def fromStr(value: str) -> 'ServicesCountingType':
def from_str(value: str) -> 'ServicesCountingType':
"""Returns the MaxServiceCountingMethodType from an str
If the str is not a valid value, returns STANDARD
"""

View File

@ -42,40 +42,41 @@ from uds import models
if typing.TYPE_CHECKING:
from django.db.models import Model
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class ObjTypeInfo:
@dataclasses.dataclass
class _ObjTypeInfo:
obj_type: int
model: 'type[Model]'
@enum.unique
class ObjectType(enum.Enum):
PROVIDER = ObjTypeInfo(1, models.Provider)
SERVICE = ObjTypeInfo(2, models.Service)
OSMANAGER = ObjTypeInfo(3, models.OSManager)
TRANSPORT = ObjTypeInfo(4, models.Transport)
NETWORK = ObjTypeInfo(5, models.Network)
POOL = ObjTypeInfo(6, models.ServicePool)
USER_SERVICE = ObjTypeInfo(7, models.UserService)
AUTHENTICATOR = ObjTypeInfo(8, models.Authenticator)
USER = ObjTypeInfo(9, models.User)
GROUP = ObjTypeInfo(10, models.Group)
STATS_COUNTER = ObjTypeInfo(11, models.StatsCounters)
STATS_EVENTS = ObjTypeInfo(12, models.StatsEvents)
CALENDAR = ObjTypeInfo(13, models.Calendar)
CALENDAR_RULE = ObjTypeInfo(14, models.CalendarRule)
METAPOOL = ObjTypeInfo(15, models.MetaPool)
ACCOUNT = ObjTypeInfo(16, models.Account)
PROVIDER = _ObjTypeInfo(1, models.Provider)
SERVICE = _ObjTypeInfo(2, models.Service)
OSMANAGER = _ObjTypeInfo(3, models.OSManager)
TRANSPORT = _ObjTypeInfo(4, models.Transport)
NETWORK = _ObjTypeInfo(5, models.Network)
POOL = _ObjTypeInfo(6, models.ServicePool)
USER_SERVICE = _ObjTypeInfo(7, models.UserService)
AUTHENTICATOR = _ObjTypeInfo(8, models.Authenticator)
USER = _ObjTypeInfo(9, models.User)
GROUP = _ObjTypeInfo(10, models.Group)
STATS_COUNTER = _ObjTypeInfo(11, models.StatsCounters)
STATS_EVENTS = _ObjTypeInfo(12, models.StatsEvents)
CALENDAR = _ObjTypeInfo(13, models.Calendar)
CALENDAR_RULE = _ObjTypeInfo(14, models.CalendarRule)
METAPOOL = _ObjTypeInfo(15, models.MetaPool)
ACCOUNT = _ObjTypeInfo(16, models.Account)
# Actor and Tunnel tokens are now on REGISTERED_SERVER, so removed
MFA = ObjTypeInfo(19, models.MFA)
REGISTERED_SERVER = ObjTypeInfo(20, models.Server)
REGISTERED_SERVER_GROUP = ObjTypeInfo(21, models.ServerGroup)
ACCOUNT_USAGE = ObjTypeInfo(22, models.AccountUsage)
IMAGE = ObjTypeInfo(23, models.Image)
LOG = ObjTypeInfo(24, models.Log)
NOTIFICATION = ObjTypeInfo(25, models.Notification)
TICKET_STORE = ObjTypeInfo(26, models.TicketStore)
MFA = _ObjTypeInfo(19, models.MFA)
REGISTERED_SERVER = _ObjTypeInfo(20, models.Server)
REGISTERED_SERVER_GROUP = _ObjTypeInfo(21, models.ServerGroup)
ACCOUNT_USAGE = _ObjTypeInfo(22, models.AccountUsage)
IMAGE = _ObjTypeInfo(23, models.Image)
LOG = _ObjTypeInfo(24, models.Log)
NOTIFICATION = _ObjTypeInfo(25, models.Notification)
TICKET_STORE = _ObjTypeInfo(26, models.TicketStore)
@property
def model(self) -> type['Model']:

View File

@ -74,14 +74,14 @@ def detect_os(
# If we found a known OS, store it
if found != types.os.KnownOS.UNKNOWN:
res = res.replace(os=found)
res.os = found
# Try to detect browser from Sec-Ch-Ua first
secChUa = headers.get('Sec-Ch-Ua')
if secChUa is not None:
for browser in consts.os.knownBrowsers:
if browser in secChUa:
res = res.replace(browser=browser)
res.browser = browser
break
else:
# Try to detect browser from User-Agent
@ -109,7 +109,8 @@ def detect_os(
break
if match is not None:
res = res.replace(browser=ruleKey, version=match.groups(1)[0])
res.browser = ruleKey or types.os.KnownBrowser.OTHER
res.version = match.groups(1)[0]
logger.debug('Detected: %s %s', res.os, res.browser)

View File

@ -163,7 +163,7 @@ class Service(ManagedObjectModel, TaggingMixin): # type: ignore
@property
def maxServicesCountType(self) -> ServicesCountingType:
return ServicesCountingType.fromInt(self.max_services_count_type)
return ServicesCountingType.from_int(self.max_services_count_type)
def isInMaintenance(self) -> bool:
# orphaned services?

View File

@ -217,7 +217,7 @@ class BaseX2GOTransport(transports.Transport):
return ready == 'Y'
def getScreenSize(self) -> tuple[int, int]:
return CommonPrefs.getWidthHeight(self.screenSize.value)
return CommonPrefs.get_wh(self.screenSize.value)
def processedUser(self, userService: 'models.UserService', user: 'models.User') -> str:
v = self.processUserPassword(userService, user, '')

View File

@ -87,7 +87,7 @@ def authCallback(request: HttpRequest, authName: str) -> HttpResponse:
if not authenticator:
raise Exception('Authenticator not found')
params = types.auth.AuthCallbackParams.fromRequest(request)
params = types.auth.AuthCallbackParams.from_request(request)
logger.debug('Auth callback for %s with params %s', authenticator, params)