small fixes (typing)

This commit is contained in:
Adolfo Gómez García 2021-07-19 12:42:26 +02:00
parent 51407b54ee
commit 6fd307e86e
8 changed files with 107 additions and 42 deletions
server/src/uds

View File

@ -144,13 +144,17 @@ class Client(Handler):
logger.debug('Res: %s %s %s %s %s', ip, userService, userServiceInstance, transport, transportInstance) logger.debug('Res: %s %s %s %s %s', ip, userService, userServiceInstance, transport, transportInstance)
password = cryptoManager().symDecrpyt(data['password'], scrambler) password = cryptoManager().symDecrpyt(data['password'], scrambler)
# userService.setConnectionSource(srcIp, hostname) # Store where we are accessing from so we can notify Service
if not ip:
raise ServiceNotReadyError
# Set "accesedByClient" # Set "accesedByClient"
userService.setProperty('accessedByClient', '1') userService.setProperty('accessedByClient', '1')
# userService.setConnectionSource(srcIp, hostname) # Store where we are accessing from so we can notify Service
if not ip:
raise ServiceNotReadyError()
# This should never happen, but it's here just in case
if not transportInstance:
raise Exception('No transport instance!!!')
transportScript, signature, params = transportInstance.getEncodedTransportScript(userService, transport, ip, self._request.os, self._request.user, password, self._request) transportScript, signature, params = transportInstance.getEncodedTransportScript(userService, transport, ip, self._request.os, self._request.user, password, self._request)
logger.debug('Signature: %s', signature) logger.debug('Signature: %s', signature)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (c) 2014-2019 Virtual Cable S.L. # Copyright (c) 2014-2021 Virtual Cable S.L.U.
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without modification, # Redistribution and use in source and binary forms, with or without modification,
@ -45,9 +45,18 @@ from uds.core.util import tools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VALID_PARAMS = ( VALID_PARAMS = (
'authId', 'authTag', 'authSmallName', 'auth', 'username', 'authId',
'realname', 'password', 'groups', 'servicePool', 'transport', 'authTag',
'force', 'userIp' 'authSmallName',
'auth',
'username',
'realname',
'password',
'groups',
'servicePool',
'transport',
'force',
'userIp',
) )
@ -72,10 +81,13 @@ class Tickets(Handler):
- Groups exists on authenticator - Groups exists on authenticator
- servicePool has these groups in it's allowed list - servicePool has these groups in it's allowed list
""" """
needs_admin = True # By default, staff is lower level needed needs_admin = True # By default, staff is lower level needed
@staticmethod @staticmethod
def result(result: str = '', error: typing.Optional[str] = None) -> typing.Dict[str, typing.Any]: def result(
result: str = '', error: typing.Optional[str] = None
) -> typing.Dict[str, typing.Any]:
""" """
Returns a result for a Ticket request Returns a result for a Ticket request
""" """
@ -112,7 +124,9 @@ class Tickets(Handler):
raise RequestError('Invalid parameters (no auth)') raise RequestError('Invalid parameters (no auth)')
# Must be invoked as '/rest/ticket/create, with "username", ("authId" or ("authSmallName" or "authTag"), "groups" (array) and optionally "time" (in seconds) as paramteres # Must be invoked as '/rest/ticket/create, with "username", ("authId" or ("authSmallName" or "authTag"), "groups" (array) and optionally "time" (in seconds) as paramteres
def put(self): # pylint: disable=too-many-locals,too-many-branches,too-many-statements def put(
self,
): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
""" """
Processes put requests, currently only under "create" Processes put requests, currently only under "create"
""" """
@ -134,30 +148,48 @@ class Tickets(Handler):
authId = self._params.get('authId', None) authId = self._params.get('authId', None)
authName = self._params.get('auth', None) authName = self._params.get('auth', None)
authTag = self._params.get('authTag', self._params.get('authSmallName', None)) authTag = self._params.get(
'authTag', self._params.get('authSmallName', None)
)
# Will raise an exception if no auth found # Will raise an exception if no auth found
if authId: if authId:
auth = models.Authenticator.objects.get(uuid=processUuid(authId.lower())) auth = models.Authenticator.objects.get(
uuid=processUuid(authId.lower())
)
elif authName: elif authName:
auth = models.Authenticator.objects.get(name=authName) auth = models.Authenticator.objects.get(name=authName)
else: else:
auth = models.Authenticator.objects.get(small_name=authTag) auth = models.Authenticator.objects.get(small_name=authTag)
username: str = self._params['username'] username: str = self._params['username']
password: str = self._params.get('password', '') # Some machines needs password, depending on configuration password: str = self._params.get(
'password', ''
) # Some machines needs password, depending on configuration
groupIds: typing.List[str] = [] groupIds: typing.List[str] = []
for groupName in tools.asList(self._params['groups']): for groupName in tools.asList(self._params['groups']):
try: try:
groupIds.append(auth.groups.get(name=groupName).uuid) groupIds.append(auth.groups.get(name=groupName).uuid)
except Exception: except Exception:
logger.info('Group %s from ticket does not exists on auth %s, forced creation: %s', groupName, auth, force) logger.info(
'Group %s from ticket does not exists on auth %s, forced creation: %s',
groupName,
auth,
force,
)
if force: # Force creation by call if force: # Force creation by call
groupIds.append(auth.groups.create(name=groupName, comments='Autocreated form ticket by using force paratemeter').uuid) groupIds.append(
auth.groups.create(
name=groupName,
comments='Autocreated form ticket by using force paratemeter',
).uuid
)
if not groupIds: # No valid group in groups names if not groupIds: # No valid group in groups names
raise RequestError('Authenticator does not contain ANY of the requested groups and force is not used') raise RequestError(
'Authenticator does not contain ANY of the requested groups and force is not used'
)
time = int(self._params.get('time', 60)) time = int(self._params.get('time', 60))
time = 60 if time < 1 else time time = 60 if time < 1 else time
@ -166,19 +198,29 @@ class Tickets(Handler):
if 'servicePool' in self._params: if 'servicePool' in self._params:
# Check if is pool or metapool # Check if is pool or metapool
poolUuid = processUuid(self._params['servicePool']) poolUuid = processUuid(self._params['servicePool'])
pool : typing.Union[models.ServicePool, models.MetaPool] pool: typing.Union[models.ServicePool, models.MetaPool]
try: try:
pool = typing.cast(models.MetaPool, models.MetaPool.objects.get(uuid=poolUuid)) # If not an metapool uuid, will process it as a servicePool pool = typing.cast(
models.MetaPool, models.MetaPool.objects.get(uuid=poolUuid)
) # If not an metapool uuid, will process it as a servicePool
if force: if force:
# First, add groups to metapool # First, add groups to metapool
for addGrp in set(groupIds) - set(pool.assignedGroups.values_list('uuid', flat=True)): for addGrp in set(groupIds) - set(
pool.assignedGroups.values_list('uuid', flat=True)
):
pool.assignedGroups.add(auth.groups.get(uuid=addGrp)) pool.assignedGroups.add(auth.groups.get(uuid=addGrp))
# And now, to ALL metapool members # And now, to ALL metapool members
for metaMember in pool.members.all(): for metaMember in pool.members.all():
# First, add groups to metapool # Now add groups to pools
for addGrp in set(groupIds) - set(metaMember.pool.assignedGroups.values_list('uuid', flat=True)): for addGrp in set(groupIds) - set(
metaMember.assignedGroups.add(auth.groups.get(uuid=addGrp)) metaMember.pool.assignedGroups.values_list(
'uuid', flat=True
)
):
metaMember.pool.assignedGroups.add(
auth.groups.get(uuid=addGrp)
)
# For metapool, transport is ignored.. # For metapool, transport is ignored..
@ -186,19 +228,30 @@ class Tickets(Handler):
transportId = 'meta' transportId = 'meta'
except models.MetaPool.DoesNotExist: except models.MetaPool.DoesNotExist:
pool = typing.cast(models.ServicePool, models.ServicePool.objects.get(uuid=poolUuid)) pool = typing.cast(
models.ServicePool,
models.ServicePool.objects.get(uuid=poolUuid),
)
# If forced that servicePool must honor groups # If forced that servicePool must honor groups
if force: if force:
for addGrp in set(groupIds) - set(pool.assignedGroups.values_list('uuid', flat=True)): for addGrp in set(groupIds) - set(
pool.assignedGroups.values_list('uuid', flat=True)
):
pool.assignedGroups.add(auth.groups.get(uuid=addGrp)) pool.assignedGroups.add(auth.groups.get(uuid=addGrp))
if 'transport' in self._params: if 'transport' in self._params:
transport: models.Transport = models.Transport.objects.get(uuid=processUuid(self._params['transport'])) transport: models.Transport = models.Transport.objects.get(
uuid=processUuid(self._params['transport'])
)
try: try:
pool.validateTransport(transport) pool.validateTransport(transport)
except Exception: except Exception:
logger.error('Transport %s is not valid for Service Pool %s', transport.name, pool.name) logger.error(
'Transport %s is not valid for Service Pool %s',
transport.name,
pool.name,
)
raise Exception('Invalid transport for Service Pool') raise Exception('Invalid transport for Service Pool')
else: else:
transport = models.Transport(uuid=None) transport = models.Transport(uuid=None)
@ -209,8 +262,16 @@ class Tickets(Handler):
break break
if transport.uuid is None: if transport.uuid is None:
logger.error('Service pool %s does not has valid transports for ip %s', pool.name, userIp) logger.error(
raise Exception('Service pool does not has any valid transports for ip {}'.format(userIp)) 'Service pool %s does not has valid transports for ip %s',
pool.name,
userIp,
)
raise Exception(
'Service pool does not has any valid transports for ip {}'.format(
userIp
)
)
servicePoolId = 'F' + pool.uuid servicePoolId = 'F' + pool.uuid
transportId = transport.uuid transportId = transport.uuid

View File

@ -134,8 +134,8 @@ class TunnelTicket(Handler):
port=port, port=port,
host=host, host=host,
extra={ extra={
't': self._args[0], # ticket 't': self._args[0], # ticket
'b': models.getSqlDatetimeAsUnix(), # Begin time stamp 'b': models.getSqlDatetimeAsUnix(), # Begin time stamp
}, },
validity=MAX_SESSION_LENGTH, validity=MAX_SESSION_LENGTH,
) )

View File

@ -42,9 +42,6 @@ from uds.core.util import permissions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Enclosed methods under /osm path
class TunnelTokens(ModelHandler): class TunnelTokens(ModelHandler):
model = TunnelToken model = TunnelToken

View File

@ -377,7 +377,7 @@ class GlobalConfig:
'maxInitTime', '3601', type=Config.NUMERIC_FIELD 'maxInitTime', '3601', type=Config.NUMERIC_FIELD
) )
MAX_REMOVAL_TIME: Config.Value = Config.section(GLOBAL_SECTION).value( MAX_REMOVAL_TIME: Config.Value = Config.section(GLOBAL_SECTION).value(
'maxRemovalTime', '86400', type=Config.NUMERIC_FIELD 'maxRemovalTime', '14400', type=Config.NUMERIC_FIELD
) )
# Maximum logs per user service # Maximum logs per user service
MAX_LOGS_PER_ELEMENT: Config.Value = Config.section(GLOBAL_SECTION).value( MAX_LOGS_PER_ELEMENT: Config.Value = Config.section(GLOBAL_SECTION).value(

View File

@ -34,9 +34,12 @@ import typing
from collections import defaultdict from collections import defaultdict
from xml.etree import cElementTree from xml.etree import cElementTree
if typing.TYPE_CHECKING:
from xml.etree.cElementTree import Element
def etree_to_dict(t):
d = {} def etree_to_dict(t: 'Element') -> typing.Mapping[str, typing.Any]:
d: typing.MutableMapping[str, typing.Any] = {}
if t.attrib: if t.attrib:
d.update({t.tag: {}}) d.update({t.tag: {}})
@ -59,5 +62,5 @@ def etree_to_dict(t):
return d return d
def parse(xml_string: str) -> typing.Dict: def parse(xml_string: str) -> typing.Mapping[str, typing.Any]:
return etree_to_dict(cElementTree.XML(xml_string)) return etree_to_dict(cElementTree.XML(xml_string))

View File

@ -66,7 +66,7 @@ def checkResultRaw(lst: typing.Any) -> str:
return str(lst[1]) return str(lst[1])
def checkResult(lst: typing.Any) -> typing.Tuple[typing.Dict, str]: def checkResult(lst: typing.Any) -> typing.Tuple[typing.Mapping[str, typing.Any], str]:
return xml2dict.parse(checkResultRaw(lst)), lst[1] return xml2dict.parse(checkResultRaw(lst)), lst[1]

View File

@ -39,7 +39,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# module = sys.modules[__name__] # module = sys.modules[__name__]
def sanitizeName(name): def sanitizeName(name: str) -> str:
""" """
machine names with [a-zA-Z0-9_-] machine names with [a-zA-Z0-9_-]
""" """