mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-12 09:17:56 +03:00
Some improvements to basic types and minor cosmetic fixes
This commit is contained in:
parent
499c5f8ec4
commit
150d8c4197
2
actor
2
actor
@ -1 +1 @@
|
||||
Subproject commit 46088958fb5e883bda4c516244e755e02f83b862
|
||||
Subproject commit e8a3b41cade6001bedb8ecdb8ef34654528f4f70
|
@ -159,7 +159,7 @@ class Services(DetailHandler): # pylint: disable=too-many-public-methods
|
||||
)
|
||||
# Fix max_services_count_type to ServicesCountingType enum or ServicesCountingType.STANDARD if not found
|
||||
try:
|
||||
fields['max_services_count_type'] = types.services.ServicesCountingType.fromInt(int(fields['max_services_count_type']))
|
||||
fields['max_services_count_type'] = types.services.ServicesCountingType.from_int(int(fields['max_services_count_type']))
|
||||
except Exception:
|
||||
fields['max_services_count_type'] = types.services.ServicesCountingType.STANDARD
|
||||
tags = fields['tags']
|
||||
|
@ -35,6 +35,7 @@ import hashlib
|
||||
import secrets
|
||||
import string
|
||||
import typing
|
||||
import dataclasses
|
||||
import collections.abc
|
||||
import datetime
|
||||
import urllib.parse
|
||||
@ -43,7 +44,6 @@ from base64 import b64decode
|
||||
import defusedxml.ElementTree as etree
|
||||
import jwt
|
||||
import requests
|
||||
from cryptography.x509 import load_pem_x509_certificate
|
||||
from django.utils.translation import gettext
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
@ -64,8 +64,8 @@ PKCE_ALPHABET: typing.Final[str] = string.ascii_letters + string.digits + '-._~'
|
||||
# Length of the State parameter
|
||||
STATE_LENGTH: typing.Final[int] = 16
|
||||
|
||||
|
||||
class TokenInfo(typing.NamedTuple):
|
||||
@dataclasses.dataclass
|
||||
class TokenInfo:
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires: datetime.datetime
|
||||
@ -75,16 +75,16 @@ class TokenInfo(typing.NamedTuple):
|
||||
id_token: typing.Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def fromJson(json: dict[str, typing.Any]) -> 'TokenInfo':
|
||||
def from_dict(dct: dict[str, typing.Any]) -> 'TokenInfo':
|
||||
# expires is -10 to avoid problems with clock sync
|
||||
return TokenInfo(
|
||||
access_token=json['access_token'],
|
||||
token_type=json['token_type'],
|
||||
expires=model.getSqlDatetime() + datetime.timedelta(seconds=json['expires_in'] - 10),
|
||||
refresh_token=json['refresh_token'],
|
||||
scope=json['scope'],
|
||||
info=json.get('info', {}),
|
||||
id_token=json.get('id_token', None),
|
||||
access_token=dct['access_token'],
|
||||
token_type=dct['token_type'],
|
||||
expires=model.getSqlDatetime() + datetime.timedelta(seconds=dct['expires_in'] - 10),
|
||||
refresh_token=dct['refresh_token'],
|
||||
scope=dct['scope'],
|
||||
info=dct.get('info', {}),
|
||||
id_token=dct.get('id_token', None),
|
||||
)
|
||||
|
||||
|
||||
@ -340,7 +340,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
if not req.ok:
|
||||
raise Exception('Error requesting token: {}'.format(req.text))
|
||||
|
||||
return TokenInfo.fromJson(req.json())
|
||||
return TokenInfo.from_dict(req.json())
|
||||
|
||||
def _requestInfo(self, token: 'TokenInfo') -> dict[str, typing.Any]:
|
||||
"""Request user info from the info endpoint using the token received from the token endpoint
|
||||
|
@ -1,3 +1,4 @@
|
||||
import dataclasses
|
||||
import io
|
||||
import logging
|
||||
import enum
|
||||
@ -69,8 +70,8 @@ class RadiusStates(enum.IntEnum):
|
||||
NOT_NEEDED = INCORRECT
|
||||
NEEDED = CORRECT
|
||||
|
||||
|
||||
class RadiusResult(typing.NamedTuple):
|
||||
@dataclasses.dataclass
|
||||
class RadiusResult:
|
||||
"""
|
||||
Result of an AccessChallenge request.
|
||||
"""
|
||||
|
@ -278,7 +278,7 @@ class ServerManager(metaclass=singleton.Singleton):
|
||||
except exceptions.UDSException: # No more servers
|
||||
return None
|
||||
elif lockTime: # If lockTime is set, update it
|
||||
models.Server.objects.filter(uuid=info[0]).update(locked_until=now + lockTime)
|
||||
models.Server.objects.filter(uuid=info.server_uuid).update(locked_until=now + lockTime)
|
||||
|
||||
# Notify to server
|
||||
# Update counter
|
||||
@ -339,7 +339,7 @@ class ServerManager(metaclass=singleton.Singleton):
|
||||
else: # Not last one, just decrement counter
|
||||
props[prop_name] = (serverCounter.server_uuid, serverCounter.counter - 1)
|
||||
|
||||
server = models.Server.objects.get(uuid=serverCounter[0])
|
||||
server = models.Server.objects.get(uuid=serverCounter.server_uuid)
|
||||
|
||||
if unlock or serverCounter.counter == 1:
|
||||
server.locked_until = None # Ensure server is unlocked if no more users are assigned to it
|
||||
|
@ -75,7 +75,7 @@ class AuthenticationResult:
|
||||
FAILED_AUTH = AuthenticationResult(success=AuthenticationState.FAIL)
|
||||
SUCCESS_AUTH = AuthenticationResult(success=AuthenticationState.SUCCESS)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class AuthCallbackParams:
|
||||
'''Parameters passed to auth callback stage2
|
||||
|
||||
@ -91,7 +91,7 @@ class AuthCallbackParams:
|
||||
query_string: str
|
||||
|
||||
@staticmethod
|
||||
def fromRequest(request: 'HttpRequest') -> 'AuthCallbackParams':
|
||||
def from_request(request: 'HttpRequest') -> 'AuthCallbackParams':
|
||||
return AuthCallbackParams(
|
||||
https=request.is_secure(),
|
||||
host=request.META['HTTP_HOST'],
|
||||
|
@ -37,7 +37,7 @@ from .services import ServiceType
|
||||
|
||||
|
||||
# For requests to actors/servers
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class PreconnectRequest:
|
||||
"""Information sent on a preconnect request"""
|
||||
|
||||
@ -56,7 +56,7 @@ class PreconnectRequest:
|
||||
|
||||
|
||||
# For requests to actors/servers
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class AssignRequest:
|
||||
"""Information sent on a assign request"""
|
||||
|
||||
@ -71,7 +71,7 @@ class AssignRequest:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class ReleaseRequest:
|
||||
"""Information sent on a release request"""
|
||||
|
||||
@ -81,7 +81,7 @@ class ReleaseRequest:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class ConnectionData:
|
||||
"""
|
||||
Connection data provided by transports, and contains all the "transformable" information needed to connect to a service
|
||||
@ -102,7 +102,7 @@ class ConnectionData:
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class ConnectionSource:
|
||||
"""
|
||||
Connection source from where the connection is being done
|
||||
|
@ -35,14 +35,11 @@ import enum
|
||||
import typing
|
||||
import collections.abc
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class DetectedOsInfo:
|
||||
os: 'KnownOS'
|
||||
browser: 'KnownBrowser'
|
||||
version: str
|
||||
|
||||
def replace(self, **kwargs):
|
||||
return dataclasses.replace(self, **kwargs)
|
||||
version: str
|
||||
|
||||
|
||||
class KnownOS(enum.Enum):
|
||||
|
@ -59,7 +59,7 @@ class TransportSelectionPolicy(enum.IntEnum):
|
||||
COMMON = 1
|
||||
LABEL = 2
|
||||
|
||||
def asStr(self) -> str:
|
||||
def as_str(self) -> str:
|
||||
return self.name.lower()
|
||||
|
||||
@staticmethod
|
||||
@ -75,7 +75,7 @@ class HighAvailabilityPolicy(enum.IntEnum):
|
||||
DISABLED = 0
|
||||
ENABLED = 1
|
||||
|
||||
def asStr(self) -> str:
|
||||
def as_str(self) -> str:
|
||||
return str(self)
|
||||
|
||||
@staticmethod
|
||||
|
@ -53,7 +53,7 @@ class CommonPrefs:
|
||||
BYPASS_PREF = 'bypassPluginDetection'
|
||||
|
||||
@staticmethod
|
||||
def getWidthHeight(size: str) -> tuple[int, int]:
|
||||
def get_wh(size: str) -> tuple[int, int]:
|
||||
"""
|
||||
Get width based on screenSizePref value
|
||||
"""
|
||||
@ -67,7 +67,7 @@ class CommonPrefs:
|
||||
}.get(size, (1024, 768))
|
||||
|
||||
@staticmethod
|
||||
def getDepth(depth: str) -> int:
|
||||
def get_depth(depth: str) -> int:
|
||||
"""
|
||||
Get depth based on depthPref value
|
||||
"""
|
||||
|
@ -33,7 +33,7 @@ import typing
|
||||
import dataclasses
|
||||
import collections.abc
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass
|
||||
class TypeInfo:
|
||||
name: str
|
||||
type: str
|
||||
|
@ -113,6 +113,7 @@ ServerSubtype.manager().register(
|
||||
ServerType.UNMANAGED, 'ip', 'Unmanaged IP Server', consts.images.DEFAULT_IMAGE_BASE64, False
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ServerDiskInfo:
|
||||
mountpoint: str
|
||||
@ -248,7 +249,8 @@ class ServerStats:
|
||||
# Human readable
|
||||
return f'memory: {self.memused//(1024*1024)}/{self.memtotal//(1024*1024)} cpu: {self.cpuused*100} users: {self.current_users}, weight: {self.weight()}, valid: {self.is_valid}'
|
||||
|
||||
|
||||
# ServerCounter must be serializable by json, so
|
||||
# we keep it as a NamedTuple instead of a dataclass
|
||||
class ServerCounter(typing.NamedTuple):
|
||||
server_uuid: str
|
||||
counter: int
|
||||
|
@ -36,7 +36,7 @@ class ServiceType(enum.StrEnum):
|
||||
VDI = 'VDI'
|
||||
VAPP = 'VAPP'
|
||||
|
||||
def fromStr(self, value: str) -> 'ServiceType':
|
||||
def from_str(self, value: str) -> 'ServiceType':
|
||||
"""Returns the service type from a string"""
|
||||
return ServiceType(value.upper())
|
||||
|
||||
@ -48,7 +48,7 @@ class ServicesCountingType(enum.IntEnum):
|
||||
CONSERVATIVE = 1
|
||||
|
||||
@staticmethod
|
||||
def fromInt(value: int) -> 'ServicesCountingType':
|
||||
def from_int(value: int) -> 'ServicesCountingType':
|
||||
"""Returns the MaxServiceCountingMethodType from an int
|
||||
If the int is not a valid value, returns STANDARD
|
||||
"""
|
||||
@ -58,7 +58,7 @@ class ServicesCountingType(enum.IntEnum):
|
||||
return ServicesCountingType.STANDARD
|
||||
|
||||
@staticmethod
|
||||
def fromStr(value: str) -> 'ServicesCountingType':
|
||||
def from_str(value: str) -> 'ServicesCountingType':
|
||||
"""Returns the MaxServiceCountingMethodType from an str
|
||||
If the str is not a valid value, returns STANDARD
|
||||
"""
|
||||
|
@ -42,40 +42,41 @@ from uds import models
|
||||
if typing.TYPE_CHECKING:
|
||||
from django.db.models import Model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ObjTypeInfo:
|
||||
@dataclasses.dataclass
|
||||
class _ObjTypeInfo:
|
||||
obj_type: int
|
||||
model: 'type[Model]'
|
||||
|
||||
@enum.unique
|
||||
class ObjectType(enum.Enum):
|
||||
PROVIDER = ObjTypeInfo(1, models.Provider)
|
||||
SERVICE = ObjTypeInfo(2, models.Service)
|
||||
OSMANAGER = ObjTypeInfo(3, models.OSManager)
|
||||
TRANSPORT = ObjTypeInfo(4, models.Transport)
|
||||
NETWORK = ObjTypeInfo(5, models.Network)
|
||||
POOL = ObjTypeInfo(6, models.ServicePool)
|
||||
USER_SERVICE = ObjTypeInfo(7, models.UserService)
|
||||
AUTHENTICATOR = ObjTypeInfo(8, models.Authenticator)
|
||||
USER = ObjTypeInfo(9, models.User)
|
||||
GROUP = ObjTypeInfo(10, models.Group)
|
||||
STATS_COUNTER = ObjTypeInfo(11, models.StatsCounters)
|
||||
STATS_EVENTS = ObjTypeInfo(12, models.StatsEvents)
|
||||
CALENDAR = ObjTypeInfo(13, models.Calendar)
|
||||
CALENDAR_RULE = ObjTypeInfo(14, models.CalendarRule)
|
||||
METAPOOL = ObjTypeInfo(15, models.MetaPool)
|
||||
ACCOUNT = ObjTypeInfo(16, models.Account)
|
||||
PROVIDER = _ObjTypeInfo(1, models.Provider)
|
||||
SERVICE = _ObjTypeInfo(2, models.Service)
|
||||
OSMANAGER = _ObjTypeInfo(3, models.OSManager)
|
||||
TRANSPORT = _ObjTypeInfo(4, models.Transport)
|
||||
NETWORK = _ObjTypeInfo(5, models.Network)
|
||||
POOL = _ObjTypeInfo(6, models.ServicePool)
|
||||
USER_SERVICE = _ObjTypeInfo(7, models.UserService)
|
||||
AUTHENTICATOR = _ObjTypeInfo(8, models.Authenticator)
|
||||
USER = _ObjTypeInfo(9, models.User)
|
||||
GROUP = _ObjTypeInfo(10, models.Group)
|
||||
STATS_COUNTER = _ObjTypeInfo(11, models.StatsCounters)
|
||||
STATS_EVENTS = _ObjTypeInfo(12, models.StatsEvents)
|
||||
CALENDAR = _ObjTypeInfo(13, models.Calendar)
|
||||
CALENDAR_RULE = _ObjTypeInfo(14, models.CalendarRule)
|
||||
METAPOOL = _ObjTypeInfo(15, models.MetaPool)
|
||||
ACCOUNT = _ObjTypeInfo(16, models.Account)
|
||||
# Actor and Tunnel tokens are now on REGISTERED_SERVER, so removed
|
||||
MFA = ObjTypeInfo(19, models.MFA)
|
||||
REGISTERED_SERVER = ObjTypeInfo(20, models.Server)
|
||||
REGISTERED_SERVER_GROUP = ObjTypeInfo(21, models.ServerGroup)
|
||||
ACCOUNT_USAGE = ObjTypeInfo(22, models.AccountUsage)
|
||||
IMAGE = ObjTypeInfo(23, models.Image)
|
||||
LOG = ObjTypeInfo(24, models.Log)
|
||||
NOTIFICATION = ObjTypeInfo(25, models.Notification)
|
||||
TICKET_STORE = ObjTypeInfo(26, models.TicketStore)
|
||||
MFA = _ObjTypeInfo(19, models.MFA)
|
||||
REGISTERED_SERVER = _ObjTypeInfo(20, models.Server)
|
||||
REGISTERED_SERVER_GROUP = _ObjTypeInfo(21, models.ServerGroup)
|
||||
ACCOUNT_USAGE = _ObjTypeInfo(22, models.AccountUsage)
|
||||
IMAGE = _ObjTypeInfo(23, models.Image)
|
||||
LOG = _ObjTypeInfo(24, models.Log)
|
||||
NOTIFICATION = _ObjTypeInfo(25, models.Notification)
|
||||
TICKET_STORE = _ObjTypeInfo(26, models.TicketStore)
|
||||
|
||||
@property
|
||||
def model(self) -> type['Model']:
|
||||
|
@ -74,14 +74,14 @@ def detect_os(
|
||||
|
||||
# If we found a known OS, store it
|
||||
if found != types.os.KnownOS.UNKNOWN:
|
||||
res = res.replace(os=found)
|
||||
res.os = found
|
||||
|
||||
# Try to detect browser from Sec-Ch-Ua first
|
||||
secChUa = headers.get('Sec-Ch-Ua')
|
||||
if secChUa is not None:
|
||||
for browser in consts.os.knownBrowsers:
|
||||
if browser in secChUa:
|
||||
res = res.replace(browser=browser)
|
||||
res.browser = browser
|
||||
break
|
||||
else:
|
||||
# Try to detect browser from User-Agent
|
||||
@ -109,7 +109,8 @@ def detect_os(
|
||||
break
|
||||
|
||||
if match is not None:
|
||||
res = res.replace(browser=ruleKey, version=match.groups(1)[0])
|
||||
res.browser = ruleKey or types.os.KnownBrowser.OTHER
|
||||
res.version = match.groups(1)[0]
|
||||
|
||||
logger.debug('Detected: %s %s', res.os, res.browser)
|
||||
|
||||
|
@ -163,7 +163,7 @@ class Service(ManagedObjectModel, TaggingMixin): # type: ignore
|
||||
|
||||
@property
|
||||
def maxServicesCountType(self) -> ServicesCountingType:
|
||||
return ServicesCountingType.fromInt(self.max_services_count_type)
|
||||
return ServicesCountingType.from_int(self.max_services_count_type)
|
||||
|
||||
def isInMaintenance(self) -> bool:
|
||||
# orphaned services?
|
||||
|
@ -217,7 +217,7 @@ class BaseX2GOTransport(transports.Transport):
|
||||
return ready == 'Y'
|
||||
|
||||
def getScreenSize(self) -> tuple[int, int]:
|
||||
return CommonPrefs.getWidthHeight(self.screenSize.value)
|
||||
return CommonPrefs.get_wh(self.screenSize.value)
|
||||
|
||||
def processedUser(self, userService: 'models.UserService', user: 'models.User') -> str:
|
||||
v = self.processUserPassword(userService, user, '')
|
||||
|
@ -87,7 +87,7 @@ def authCallback(request: HttpRequest, authName: str) -> HttpResponse:
|
||||
if not authenticator:
|
||||
raise Exception('Authenticator not found')
|
||||
|
||||
params = types.auth.AuthCallbackParams.fromRequest(request)
|
||||
params = types.auth.AuthCallbackParams.from_request(request)
|
||||
|
||||
logger.debug('Auth callback for %s with params %s', authenticator, params)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user