1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

snake_case fixin before stabilizing... :)

This commit is contained in:
Adolfo Gómez García 2024-10-12 17:55:14 +02:00
parent 25aa09309b
commit bd53926c81
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
19 changed files with 194 additions and 170 deletions

View File

@ -156,9 +156,9 @@ class InternalDBAuth(auths.Authenticator):
request: 'ExtendedHttpRequest', request: 'ExtendedHttpRequest',
) -> types.auth.AuthenticationResult: ) -> types.auth.AuthenticationResult:
username = username.lower() username = username.lower()
dbAuth = self.db_obj() auth_db = self.db_obj()
try: try:
user: 'models.User' = dbAuth.users.get(name=username, state=State.ACTIVE) user: 'models.User' = auth_db.users.get(name=username, state=State.ACTIVE)
except Exception: except Exception:
log_login(request, self.db_obj(), username, 'Invalid user', as_error=True) log_login(request, self.db_obj(), username, 'Invalid user', as_error=True)
return types.auth.FAILED_AUTH return types.auth.FAILED_AUTH
@ -175,15 +175,15 @@ class InternalDBAuth(auths.Authenticator):
return types.auth.FAILED_AUTH return types.auth.FAILED_AUTH
def get_groups(self, username: str, groups_manager: 'auths.GroupsManager') -> None: def get_groups(self, username: str, groups_manager: 'auths.GroupsManager') -> None:
dbAuth = self.db_obj() auth_db = self.db_obj()
try: try:
user: 'models.User' = dbAuth.users.get(name=username.lower(), state=State.ACTIVE) user: 'models.User' = auth_db.users.get(name=username.lower(), state=State.ACTIVE)
except Exception: except Exception:
return return
grps = [g.name for g in user.groups.all()] grps = [g.name for g in user.groups.all()]
if user.parent: if user.parent:
try: try:
parent = dbAuth.users.get(uuid=user.parent, state=State.ACTIVE) parent = auth_db.users.get(uuid=user.parent, state=State.ACTIVE)
grps.extend([g.name for g in parent.groups.all()]) grps.extend([g.name for g in parent.groups.all()])
except Exception: except Exception:
pass pass

View File

@ -84,7 +84,7 @@ class RadiusAuth(auths.Authenticator):
required=True, required=True,
) )
nasIdentifier = gui.TextField( nas_identifier = gui.TextField(
length=64, length=64,
label=_('NAS Identifier'), label=_('NAS Identifier'),
default='uds-server', default='uds-server',
@ -92,26 +92,29 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('NAS Identifier for Radius Server'), tooltip=_('NAS Identifier for Radius Server'),
required=True, required=True,
tab=types.ui.Tab.ADVANCED, tab=types.ui.Tab.ADVANCED,
old_field_name='nasIdentifier',
) )
appClassPrefix = gui.TextField( app_class_prefix = gui.TextField(
length=64, length=64,
label=_('App Prefix for Class Attributes'), label=_('App Prefix for Class Attributes'),
default='', default='',
order=11, order=11,
tooltip=_('Application prefix for filtering groups from "Class" attribute'), tooltip=_('Application prefix for filtering groups from "Class" attribute'),
tab=types.ui.Tab.ADVANCED, tab=types.ui.Tab.ADVANCED,
old_field_name='appClassPrefix',
) )
globalGroup = gui.TextField( global_group = gui.TextField(
length=64, length=64,
label=_('Global group'), label=_('Global group'),
default='', default='',
order=12, order=12,
tooltip=_('If set, this value will be added as group for all radius users'), tooltip=_('If set, this value will be added as group for all radius users'),
tab=types.ui.Tab.ADVANCED, tab=types.ui.Tab.ADVANCED,
old_field_name='globalGroup',
) )
mfaAttr = gui.TextField( mfa_attr = gui.TextField(
length=2048, length=2048,
lines=2, lines=2,
label=_('MFA attribute'), label=_('MFA attribute'),
@ -119,6 +122,7 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('Attribute from where to extract the MFA code'), tooltip=_('Attribute from where to extract the MFA code'),
required=False, required=False,
tab=types.ui.Tab.MFA, tab=types.ui.Tab.MFA,
old_field_name='mfaAttr',
) )
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None: def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
@ -129,9 +133,9 @@ class RadiusAuth(auths.Authenticator):
return client.RadiusClient( return client.RadiusClient(
self.server.value, self.server.value,
self.secret.value.encode(), self.secret.value.encode(),
authPort=self.port.as_int(), auth_port=self.port.as_int(),
nasIdentifier=self.nasIdentifier.value, nas_identifier=self.nas_identifier.value,
appClassPrefix=self.appClassPrefix.value, appclass_prefix=self.app_class_prefix.value,
) )
def mfa_storage_key(self, username: str) -> str: def mfa_storage_key(self, username: str) -> str:
@ -149,17 +153,17 @@ class RadiusAuth(auths.Authenticator):
) -> types.auth.AuthenticationResult: ) -> types.auth.AuthenticationResult:
try: try:
connection = self.radius_client() connection = self.radius_client()
groups, mfaCode, state = connection.authenticate( groups, mfa_code, state = connection.authenticate(
username=username, password=credentials, mfaField=self.mfaAttr.value.strip() username=username, password=credentials, mfa_field=self.mfa_attr.value.strip()
) )
# If state, store in session # If state, store in session
if state: if state:
request.session[client.STATE_VAR_NAME] = state.decode() request.session[client.STATE_VAR_NAME] = state.decode()
# store the user mfa attribute if it is set # store the user mfa attribute if it is set
if mfaCode: if mfa_code:
self.storage.save_pickled( self.storage.save_pickled(
self.mfa_storage_key(username), self.mfa_storage_key(username),
mfaCode, mfa_code,
) )
except Exception: except Exception:
@ -172,8 +176,8 @@ class RadiusAuth(auths.Authenticator):
) )
return types.auth.FAILED_AUTH return types.auth.FAILED_AUTH
if self.globalGroup.value.strip(): if self.global_group.value.strip():
groups.append(self.globalGroup.value.strip()) groups.append(self.global_group.value.strip())
# Cache groups for "getGroups", because radius will not send us those # Cache groups for "getGroups", because radius will not send us those
with self.storage.as_dict() as storage: with self.storage.as_dict() as storage:

View File

@ -110,49 +110,49 @@ class RadiusResult:
""" """
pwd: RadiusStates = RadiusStates.INCORRECT pwd: RadiusStates = RadiusStates.INCORRECT
replyMessage: typing.Optional[bytes] = None reply_message: typing.Optional[bytes] = None
state: typing.Optional[bytes] = None state: typing.Optional[bytes] = None
otp: RadiusStates = RadiusStates.NOT_CHECKED otp: RadiusStates = RadiusStates.NOT_CHECKED
otp_needed: RadiusStates = RadiusStates.NOT_CHECKED otp_needed: RadiusStates = RadiusStates.NOT_CHECKED
class RadiusClient: class RadiusClient:
radiusServer: Client server: Client
nasIdentifier: str nas_identifier: str
appClassPrefix: str appclass_prefix: str
def __init__( def __init__(
self, self,
server: str, server: str,
secret: bytes, secret: bytes,
*, *,
authPort: int = 1812, auth_port: int = 1812,
nasIdentifier: str = 'uds-server', nas_identifier: str = 'uds-server',
appClassPrefix: str = '', appclass_prefix: str = '',
dictionary: str = RADDICT, dictionary: str = RADDICT,
) -> None: ) -> None:
self.radiusServer = Client( self.server = Client(
server=server, server=server,
authport=authPort, authport=auth_port,
secret=secret, secret=secret,
dict=Dictionary(io.StringIO(dictionary)), dict=Dictionary(io.StringIO(dictionary)),
) )
self.nasIdentifier = nasIdentifier self.nas_identifier = nas_identifier
self.appClassPrefix = appClassPrefix self.appclass_prefix = appclass_prefix
def extractAccessChallenge(self, reply: pyrad.packet.AuthPacket) -> RadiusResult: def extract_access_challenge(self, reply: pyrad.packet.AuthPacket) -> RadiusResult:
return RadiusResult( return RadiusResult(
pwd=RadiusStates.CORRECT, pwd=RadiusStates.CORRECT,
replyMessage=typing.cast(list[bytes], reply.get('Reply-Message') or [''])[0], reply_message=typing.cast(list[bytes], reply.get('Reply-Message') or [''])[0],
state=typing.cast(list[bytes], reply.get('State') or [b''])[0], state=typing.cast(list[bytes], reply.get('State') or [b''])[0],
otp_needed=RadiusStates.NEEDED, otp_needed=RadiusStates.NEEDED,
) )
def sendAccessRequest(self, username: str, password: str, **kwargs: typing.Any) -> pyrad.packet.AuthPacket: def send_access_request(self, username: str, password: str, **kwargs: typing.Any) -> pyrad.packet.AuthPacket:
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket( req: pyrad.packet.AuthPacket = self.server.CreateAuthPacket(
code=pyrad.packet.AccessRequest, code=pyrad.packet.AccessRequest,
User_Name=username, User_Name=username,
NAS_Identifier=self.nasIdentifier, NAS_Identifier=self.nas_identifier,
) )
req["User-Password"] = req.PwCrypt(password) req["User-Password"] = req.PwCrypt(password)
@ -161,46 +161,46 @@ class RadiusClient:
for k, v in kwargs.items(): for k, v in kwargs.items():
req[k] = v req[k] = v
return typing.cast(pyrad.packet.AuthPacket, self.radiusServer.SendPacket(req)) return typing.cast(pyrad.packet.AuthPacket, self.server.SendPacket(req))
# Second element of return value is the mfa code from field # Second element of return value is the mfa code from field
def authenticate( def authenticate(
self, username: str, password: str, mfaField: str = '' self, username: str, password: str, mfa_field: str = ''
) -> tuple[list[str], str, bytes]: ) -> tuple[list[str], str, bytes]:
reply = self.sendAccessRequest(username, password) reply = self.send_access_request(username, password)
if reply.code not in (pyrad.packet.AccessAccept, pyrad.packet.AccessChallenge): if reply.code not in (pyrad.packet.AccessAccept, pyrad.packet.AccessChallenge):
raise RadiusAuthenticationError('Access denied') raise RadiusAuthenticationError('Access denied')
# User accepted, extract groups... # User accepted, extract groups...
# All radius users belongs to, at least, 'uds-users' group # All radius users belongs to, at least, 'uds-users' group
groupClassPrefix = (self.appClassPrefix + 'group=').encode() groupclass_prefix = (self.appclass_prefix + 'group=').encode()
groupClassPrefixLen = len(groupClassPrefix) groupclass_prefix_len = len(groupclass_prefix)
if 'Class' in reply: if 'Class' in reply:
groups = [ groups = [
i[groupClassPrefixLen:].decode() i[groupclass_prefix_len:].decode()
for i in typing.cast(collections.abc.Iterable[bytes], reply['Class']) for i in typing.cast(collections.abc.Iterable[bytes], reply['Class'])
if i.startswith(groupClassPrefix) if i.startswith(groupclass_prefix)
] ]
else: else:
logger.info('No "Class (25)" attribute found') logger.info('No "Class (25)" attribute found')
return ([], '', b'') return ([], '', b'')
# ...and mfa code # ...and mfa code
mfaCode = '' mfa_code = ''
if mfaField and mfaField in reply: if mfa_field and mfa_field in reply:
mfaCode = ''.join( mfa_code = ''.join(
i[groupClassPrefixLen:].decode() i[groupclass_prefix_len:].decode()
for i in typing.cast(collections.abc.Iterable[bytes], reply['Class']) for i in typing.cast(collections.abc.Iterable[bytes], reply['Class'])
if i.startswith(groupClassPrefix) if i.startswith(groupclass_prefix)
) )
return (groups, mfaCode, typing.cast(list[bytes], reply.get('State') or [b''])[0]) return (groups, mfa_code, typing.cast(list[bytes], reply.get('State') or [b''])[0])
def authenticate_only(self, username: str, password: str) -> RadiusResult: def authenticate_only(self, username: str, password: str) -> RadiusResult:
reply = self.sendAccessRequest(username, password) reply = self.send_access_request(username, password)
if reply.code == pyrad.packet.AccessChallenge: if reply.code == pyrad.packet.AccessChallenge:
return self.extractAccessChallenge(reply) return self.extract_access_challenge(reply)
# user/pwd accepted: this user does not have challenge data # user/pwd accepted: this user does not have challenge data
if reply.code == pyrad.packet.AccessAccept: if reply.code == pyrad.packet.AccessAccept:
@ -221,7 +221,7 @@ class RadiusClient:
logger.debug('Sending AccessChallenge request wit otp [%s]', otp) logger.debug('Sending AccessChallenge request wit otp [%s]', otp)
reply = self.sendAccessRequest(username, otp, State=state) reply = self.send_access_request(username, otp, State=state)
logger.debug('Received AccessChallenge reply: %s', reply) logger.debug('Received AccessChallenge reply: %s', reply)
@ -238,7 +238,7 @@ class RadiusClient:
) )
def authenticate_and_challenge(self, username: str, password: str, otp: str) -> RadiusResult: def authenticate_and_challenge(self, username: str, password: str, otp: str) -> RadiusResult:
reply = self.sendAccessRequest(username, password) reply = self.send_access_request(username, password)
if reply.code == pyrad.packet.AccessChallenge: if reply.code == pyrad.packet.AccessChallenge:
state = typing.cast(list[bytes], reply.get('State') or [b''])[0] state = typing.cast(list[bytes], reply.get('State') or [b''])[0]

View File

@ -283,11 +283,11 @@ class RegexLdap(auths.Authenticator):
user = ldaputil.first( user = ldaputil.first(
con=self._stablish_connection(), con=self._stablish_connection(),
base=self.ldap_base.as_str(), base=self.ldap_base.as_str(),
objectClass=self.user_class.as_str(), object_class=self.user_class.as_str(),
field=self.userid_attr.as_str(), field=self.userid_attr.as_str(),
value=username, value=username,
attributes=attributes, attributes=attributes,
sizeLimit=LDAP_RESULT_LIMIT, max_entries=LDAP_RESULT_LIMIT,
) )
# If user attributes is split, that is, it has more than one "ldap entry", get a second entry filtering by a new attribute # If user attributes is split, that is, it has more than one "ldap entry", get a second entry filtering by a new attribute

View File

@ -707,15 +707,15 @@ class SAMLAuthenticator(auths.Authenticator):
) )
logger.debug('Groups: %s', groups) logger.debug('Groups: %s', groups)
realName = ' '.join( realname = ' '.join(
auth_utils.process_regex_field( auth_utils.process_regex_field(
self.attrs_realname.value, attributes # pyright: ignore reportUnknownVariableType self.attrs_realname.value, attributes # pyright: ignore reportUnknownVariableType
) )
) )
logger.debug('Real name: %s', realName) logger.debug('Real name: %s', realname)
# store groups for this username at storage, so we can check it at a later stage # store groups for this username at storage, so we can check it at a later stage
self.storage.save_pickled(username, [realName, groups]) self.storage.save_pickled(username, [realname, groups])
# store also the mfa identifier field value, in case we have provided it # store also the mfa identifier field value, in case we have provided it
if self.mfa_attr.value.strip(): if self.mfa_attr.value.strip():

View File

@ -280,11 +280,11 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
return ldaputil.first( return ldaputil.first(
con=self._get_connection(), con=self._get_connection(),
base=self.ldap_base.as_str(), base=self.ldap_base.as_str(),
objectClass=self.user_class.as_str(), object_class=self.user_class.as_str(),
field=self.user_id_attr.as_str(), field=self.user_id_attr.as_str(),
value=username, value=username,
attributes=attributes, attributes=attributes,
sizeLimit=LDAP_RESULT_LIMIT, max_entries=LDAP_RESULT_LIMIT,
) )
def _get_group(self, groupName: str) -> typing.Optional[ldaputil.LDAPResultType]: def _get_group(self, groupName: str) -> typing.Optional[ldaputil.LDAPResultType]:
@ -296,11 +296,11 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
return ldaputil.first( return ldaputil.first(
con=self._get_connection(), con=self._get_connection(),
base=self.ldap_base.as_str(), base=self.ldap_base.as_str(),
objectClass=self.group_class.as_str(), object_class=self.group_class.as_str(),
field=self.group_id_attr.as_str(), field=self.group_id_attr.as_str(),
value=groupName, value=groupName,
attributes=[self.member_attr.as_str()], attributes=[self.member_attr.as_str()],
sizeLimit=LDAP_RESULT_LIMIT, max_entries=LDAP_RESULT_LIMIT,
) )
def _get_groups(self, user: ldaputil.LDAPResultType) -> list[str]: def _get_groups(self, user: ldaputil.LDAPResultType) -> list[str]:

View File

@ -69,7 +69,7 @@ class Authenticator(Module):
As always, if you override __init__, do not forget to invoke base __init__ as this:: As always, if you override __init__, do not forget to invoke base __init__ as this::
super(self.__class__, self).__init__(self, environment, values, dbAuth) super(self.__class__, self).__init__(self, environment, values, uuid)
This is a MUST, so internal structured gets filled correctly, so don't forget it!. This is a MUST, so internal structured gets filled correctly, so don't forget it!.
@ -85,7 +85,7 @@ class Authenticator(Module):
so if an user do not exist at UDS database, it will not be valid. so if an user do not exist at UDS database, it will not be valid.
In other words, if you have an authenticator where you must create users, In other words, if you have an authenticator where you must create users,
you can modify them, you must assign passwords manually, and group membership you can modify them, you must assign passwords manually, and group membership
also must be assigned manually, the authenticator is not an externalSource. also must be assigned manually, the authenticator is not an external_source.
As you can notice, almost avery authenticator except internal db will be As you can notice, almost avery authenticator except internal db will be
external source, so, by default, attribute that indicates that is an external external source, so, by default, attribute that indicates that is an external

View File

@ -83,27 +83,28 @@ class GroupsManager:
""" """
_groups: list[_LocalGrp] _groups: list[_LocalGrp]
_db_auth: 'DBAuthenticator'
def __init__(self, dbAuthenticator: 'DBAuthenticator'): def __init__(self, db_auth: 'DBAuthenticator'):
""" """
Initializes the groups manager. Initializes the groups manager.
The dbAuthenticator is the database record of the authenticator The dbAuthenticator is the database record of the authenticator
to which this groupsManager will be associated to which this groupsManager will be associated
""" """
self._dbAuthenticator = dbAuthenticator self._db_auth = db_auth
# We just get active groups, inactive aren't visible to this class # We just get active groups, inactive aren't visible to this class
self._groups = [] self._groups = []
if ( if (
dbAuthenticator.id db_auth.id
): # If "fake" authenticator (that is, root user with no authenticator in fact) ): # If "fake" authenticator (that is, root user with no authenticator in fact)
for g in dbAuthenticator.groups.filter(state=State.ACTIVE, is_meta=False): for g in db_auth.groups.filter(state=State.ACTIVE, is_meta=False):
name = g.name.lower() name = g.name.lower()
isPattern = name.find('pat:') == 0 # Is a pattern? is_pattern_group = name.startswith('pat:') # Is a pattern?
self._groups.append( self._groups.append(
_LocalGrp( _LocalGrp(
name=name[4:] if isPattern else name, name=name[4:] if is_pattern_group else name,
group=Group(g), group=Group(g),
is_pattern=isPattern, is_pattern=is_pattern_group,
) )
) )
@ -147,7 +148,7 @@ class GroupsManager:
# Now, get metagroups and also return them # Now, get metagroups and also return them
for db_group in DBGroup.objects.filter( for db_group in DBGroup.objects.filter(
manager__id=self._dbAuthenticator.id, is_meta=True manager__id=self._db_auth.id, is_meta=True
): # @UndefinedVariable ): # @UndefinedVariable
gn = db_group.groups.filter( gn = db_group.groups.filter(
id__in=valid_id_list, state=State.ACTIVE id__in=valid_id_list, state=State.ACTIVE

View File

@ -74,10 +74,9 @@ class Environment:
): ):
""" """
Initialized the Environment for the specified id Initialized the Environment for the specified id
@param uniqueKey: Key for this environment
@param idGenerators: Hash of generators of ids for this environment. This "generators of ids" feature Args:
is used basically at User Services to auto-create ids for macs or names, using unique_key: Unique key for the environment
{'mac' : UniqueMacGenerator, 'name' : UniqueNameGenerator } as argument.
""" """
# Avoid circular imports # Avoid circular imports
from uds.core.util.cache import Cache # pylint: disable=import-outside-toplevel from uds.core.util.cache import Cache # pylint: disable=import-outside-toplevel

View File

@ -135,7 +135,7 @@ class DelayedTaskRunner(metaclass=singleton.Singleton):
return return
if task_instance: if task_instance:
logger.debug('Executing delayedTask:>%s<', task) logger.debug('Executing delayedtask:>%s<', task)
# Re-create environment data # Re-create environment data
task_instance.env = Environment.type_environment(task_instance.__class__) task_instance.env = Environment.type_environment(task_instance.__class__)
DelayedTaskThread(task_instance).start() DelayedTaskThread(task_instance).start()

View File

@ -79,7 +79,7 @@ class JobsFactory(factory.Factory['Job']):
job.save() job.save()
except Exception as e: except Exception as e:
logger.debug( logger.debug(
'Exception at ensureJobsInDatabase in JobsFactory: %s, %s', 'Exception at ensure_jobs_registered in JobsFactory: %s, %s',
e.__class__, e.__class__,
e, e,
) )

View File

@ -64,11 +64,11 @@ class JobThread(threading.Thread):
_db_job_id: int _db_job_id: int
_freq: int _freq: int
def __init__(self, jobInstance: 'Job', dbJob: DBScheduler) -> None: def __init__(self, job_instance: 'Job', db_job: DBScheduler) -> None:
super().__init__() super().__init__()
self._job_instance = jobInstance self._job_instance = job_instance
self._db_job_id = dbJob.id self._db_job_id = db_job.id
self._freq = dbJob.frecuency self._freq = db_job.frecuency
def run(self) -> None: def run(self) -> None:
try: try:
@ -148,7 +148,7 @@ class Scheduler:
""" """
Looks for the best waiting job and executes it Looks for the best waiting job and executes it
""" """
jobInstance = None job_instance = None
try: try:
now = sql_now() # Datetimes are based on database server times now = sql_now() # Datetimes are based on database server times
fltr = Q(state=State.FOR_EXECUTE) & ( fltr = Q(state=State.FOR_EXECUTE) & (
@ -172,14 +172,14 @@ class Scheduler:
job.last_execution = now job.last_execution = now
job.save(update_fields=['state', 'owner_server', 'last_execution']) job.save(update_fields=['state', 'owner_server', 'last_execution'])
jobInstance = job.get_instance() job_instance = job.get_instance()
if jobInstance is None: if job_instance is None:
logger.error('Job instance can\'t be resolved for %s, removing it', job) logger.error('Job instance can\'t be resolved for %s, removing it', job)
job.delete() job.delete()
return return
logger.debug('Executing job:>%s<', job.name) logger.debug('Executing job:>%s<', job.name)
JobThread(jobInstance, job).start() # Do not instatiate thread, just run it JobThread(job_instance, job).start() # Do not instatiate thread, just run it
except IndexError: except IndexError:
# Do nothing, there is no jobs for execution # Do nothing, there is no jobs for execution
return return

View File

@ -156,11 +156,11 @@ class CryptoManager(metaclass=singleton.Singleton):
modes.CBC(b'udsinitvectoruds'), modes.CBC(b'udsinitvectoruds'),
backend=default_backend(), backend=default_backend(),
) )
rndStr = secrets.token_bytes(16) # Same as block size of CBC (that is 16 here) rnd_string = secrets.token_bytes(16) # Same as block size of CBC (that is 16 here)
paddedLength = ((len(text) + 4 + 15) // 16) * 16 padded_length = ((len(text) + 4 + 15) // 16) * 16 # calculate padding length, 4 is for length of text
toEncode = struct.pack('>i', len(text)) + text + rndStr[: paddedLength - len(text) - 4] to_encode = struct.pack('>i', len(text)) + text + rnd_string[: padded_length - len(text) - 4]
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
encoded = encryptor.update(toEncode) + encryptor.finalize() encoded = encryptor.update(to_encode) + encryptor.finalize()
if base64: if base64:
encoded = codecs.encode(encoded, 'base64').strip() # Return as bytes encoded = codecs.encode(encoded, 'base64').strip() # Return as bytes
@ -178,8 +178,8 @@ class CryptoManager(metaclass=singleton.Singleton):
) )
decryptor = cipher.decryptor() decryptor = cipher.decryptor()
toDecode = decryptor.update(text) + decryptor.finalize() to_decode = decryptor.update(text) + decryptor.finalize()
return toDecode[4 : 4 + struct.unpack('>i', toDecode[:4])[0]] return to_decode[4 : 4 + struct.unpack('>i', to_decode[:4])[0]]
# Fast encription using django SECRET_KEY as key # Fast encription using django SECRET_KEY as key
def fast_crypt(self, data: bytes) -> bytes: def fast_crypt(self, data: bytes) -> bytes:
@ -212,28 +212,28 @@ class CryptoManager(metaclass=singleton.Singleton):
return self.aes_crypt(text, key) return self.aes_crypt(text, key)
def symmetric_decrypt(self, cryptText: typing.Union[str, bytes], key: typing.Union[str, bytes]) -> str: def symmetric_decrypt(self, encrypted_text: typing.Union[str, bytes], key: typing.Union[str, bytes]) -> str:
if isinstance(cryptText, str): if isinstance(encrypted_text, str):
cryptText = cryptText.encode() encrypted_text = encrypted_text.encode()
if isinstance(key, str): if isinstance(key, str):
key = key.encode() key = key.encode()
if not cryptText or not key: if not encrypted_text or not key:
return '' return ''
try: try:
return self.aes_decrypt(cryptText, key).decode('utf-8') return self.aes_decrypt(encrypted_text, key).decode('utf-8')
except Exception: # Error decoding crypted element, return empty one except Exception: # Error decoding crypted element, return empty one
return '' return ''
def load_private_key( def load_private_key(
self, rsaKey: str self, rsa_key: str
) -> typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey']: ) -> typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey']:
try: try:
return typing.cast( return typing.cast(
typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey'], typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey'],
serialization.load_pem_private_key(rsaKey.encode(), password=None, backend=default_backend()), serialization.load_pem_private_key(rsa_key.encode(), password=None, backend=default_backend()),
) )
except Exception as e: except Exception as e:
raise e raise e
@ -271,32 +271,32 @@ class CryptoManager(metaclass=singleton.Singleton):
# Argon2 # Argon2
return '{ARGON2}' + PasswordHasher(type=ArgonType.ID).hash(value) return '{ARGON2}' + PasswordHasher(type=ArgonType.ID).hash(value)
def check_hash(self, value: typing.Union[str, bytes], hashValue: str) -> bool: def check_hash(self, value: typing.Union[str, bytes], hash_value: str) -> bool:
if isinstance(value, str): if isinstance(value, str):
value = value.encode() value = value.encode()
if not value: if not value:
return not hashValue return not hash_value
if hashValue[:8] == '{SHA256}': if hash_value[:8] == '{SHA256}':
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hashValue[8:]) return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hash_value[8:])
if hashValue[:12] == '{SHA256SALT}': if hash_value[:12] == '{SHA256SALT}':
# Extract 16 chars salt and hash # Extract 16 chars salt and hash
salt = hashValue[12:28].encode() salt = hash_value[12:28].encode()
value = salt + value value = salt + value
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hashValue[28:]) return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hash_value[28:])
# Argon2 # Argon2
if hashValue[:8] == '{ARGON2}': if hash_value[:8] == '{ARGON2}':
ph = PasswordHasher() # Type is implicit in hash ph = PasswordHasher() # Type is implicit in hash
try: try:
ph.verify(hashValue[8:], value) ph.verify(hash_value[8:], value)
return True return True
except Exception: except Exception:
return False # Verify will raise an exception if not valid return False # Verify will raise an exception if not valid
# Old sha1 # Old sha1
return secrets.compare_digest( return secrets.compare_digest(
hashValue, hash_value,
str( str(
hashlib.sha1( hashlib.sha1(
value value

View File

@ -51,7 +51,7 @@ class DownloadsManager(metaclass=singleton.Singleton):
For registering, use at __init__.py of the conecto something like this: For registering, use at __init__.py of the conecto something like this:
from uds.core.managers import DownloadsManager from uds.core.managers import DownloadsManager
import os.path, sys import os.path, sys
downloadsManager().registerDownloadable('test.exe', DownloadsManager.manager().registerDownloadable('test.exe',
_('comments for test'), _('comments for test'),
os.path.join(os.path.dirname(sys.modules[__package__].__file__), 'files/test.exe'), os.path.join(os.path.dirname(sys.modules[__package__].__file__), 'files/test.exe'),
'application/x-msdos-program') 'application/x-msdos-program')

View File

@ -63,26 +63,26 @@ class PublicationOldMachinesCleaner(DelayedTask):
This delayed task is for removing a pending "removable" publication This delayed task is for removing a pending "removable" publication
""" """
def __init__(self, publicationId: int): def __init__(self, publication_id: int):
super().__init__() super().__init__()
self._id = publicationId self._id = publication_id
def run(self) -> None: def run(self) -> None:
try: try:
servicePoolPub: ServicePoolPublication = ServicePoolPublication.objects.get(pk=self._id) servicepool_publication: ServicePoolPublication = ServicePoolPublication.objects.get(pk=self._id)
if servicePoolPub.state != State.REMOVABLE: if servicepool_publication.state != State.REMOVABLE:
logger.info('Already removed') logger.info('Already removed')
now = sql_now() now = sql_now()
current_publication: typing.Optional[ServicePoolPublication] = ( current_publication: typing.Optional[ServicePoolPublication] = (
servicePoolPub.deployed_service.active_publication() servicepool_publication.deployed_service.active_publication()
) )
if current_publication: if current_publication:
servicePoolPub.deployed_service.userServices.filter(in_use=True).exclude( servicepool_publication.deployed_service.userServices.filter(in_use=True).exclude(
publication=current_publication publication=current_publication
).update(in_use=False, state_date=now) ).update(in_use=False, state_date=now)
servicePoolPub.deployed_service.mark_old_userservices_as_removable(current_publication) servicepool_publication.deployed_service.mark_old_userservices_as_removable(current_publication)
except Exception: # nosec: Removed publication, no problem at all, just continue except Exception: # nosec: Removed publication, no problem at all, just continue
pass pass
@ -102,7 +102,9 @@ class PublicationLauncher(DelayedTask):
try: try:
now = sql_now() now = sql_now()
with transaction.atomic(): with transaction.atomic():
servicepool_publication = ServicePoolPublication.objects.select_for_update().get(pk=self._publication_id) servicepool_publication = ServicePoolPublication.objects.select_for_update().get(
pk=self._publication_id
)
if not servicepool_publication: if not servicepool_publication:
raise ServicePool.DoesNotExist() raise ServicePool.DoesNotExist()
if ( if (
@ -113,13 +115,13 @@ class PublicationLauncher(DelayedTask):
servicepool_publication.save() servicepool_publication.save()
pi = servicepool_publication.get_instance() pi = servicepool_publication.get_instance()
state = pi.publish() state = pi.publish()
servicePool: ServicePool = servicepool_publication.deployed_service servicepool: ServicePool = servicepool_publication.deployed_service
servicePool.current_pub_revision += 1 servicepool.current_pub_revision += 1
servicePool.set_value( servicepool.set_value(
'toBeReplacedIn', 'toBeReplacedIn',
serialize(now + datetime.timedelta(hours=GlobalConfig.SESSION_EXPIRE_TIME.as_int(True))), serialize(now + datetime.timedelta(hours=GlobalConfig.SESSION_EXPIRE_TIME.as_int(True))),
) )
servicePool.save() servicepool.save()
PublicationFinishChecker.state_updater(servicepool_publication, pi, state) PublicationFinishChecker.state_updater(servicepool_publication, pi, state)
except ( except (
ServicePoolPublication.DoesNotExist ServicePoolPublication.DoesNotExist
@ -196,21 +198,21 @@ class PublicationFinishChecker(DelayedTask):
publication.update_data(publication_instance) publication.update_data(publication_instance)
if check_later: if check_later:
PublicationFinishChecker.check_later(publication, publication_instance) PublicationFinishChecker.check_later(publication)
except Exception: except Exception:
logger.exception('At checkAndUpdate for publication') logger.exception('At check_and_update for publication')
PublicationFinishChecker.check_later(publication, publication_instance) PublicationFinishChecker.check_later(publication)
@staticmethod @staticmethod
def check_later(publication: ServicePoolPublication, publicationInstance: 'services.Publication') -> None: def check_later(publication: ServicePoolPublication) -> None:
""" """
Inserts a task in the delayedTaskRunner so we can check the state of this publication Inserts a task in the delayed_task_runner so we can check the state of this publication
@param dps: Database object for ServicePoolPublication @param dps: Database object for ServicePoolPublication
@param pi: Instance of Publication manager for the object @param pi: Instance of Publication manager for the object
""" """
DelayedTaskRunner.runner().insert( DelayedTaskRunner.runner().insert(
PublicationFinishChecker(publication), PublicationFinishChecker(publication),
publicationInstance.suggested_delay, publication.get_instance().suggested_delay,
PUBTAG + str(publication.id), PUBTAG + str(publication.id),
) )
@ -221,13 +223,13 @@ class PublicationFinishChecker(DelayedTask):
if publication.state != self._state: if publication.state != self._state:
logger.debug('Task overrided by another task (state of item changed)') logger.debug('Task overrided by another task (state of item changed)')
else: else:
publicationInstance = publication.get_instance() publication_instance = publication.get_instance()
logger.debug("publication instance class: %s", publicationInstance.__class__) logger.debug("publication instance class: %s", publication_instance.__class__)
try: try:
state = publicationInstance.check_state() state = publication_instance.check_state()
except Exception: except Exception:
state = types.states.TaskState.ERROR state = types.states.TaskState.ERROR
PublicationFinishChecker.state_updater(publication, publicationInstance, state) PublicationFinishChecker.state_updater(publication, publication_instance, state)
except Exception as e: except Exception as e:
logger.debug( logger.debug(
'Deployed service not found (erased from database) %s : %s', 'Deployed service not found (erased from database) %s : %s',
@ -251,11 +253,13 @@ class PublicationManager(metaclass=singleton.Singleton):
""" """
return PublicationManager() # Singleton pattern will return always the same instance return PublicationManager() # Singleton pattern will return always the same instance
def publish(self, servicepool: ServicePool, changeLog: typing.Optional[str] = None) -> None: def publish(self, servicepool: ServicePool, changelog: typing.Optional[str] = None) -> None:
""" """
Initiates the publication of a service pool, or raises an exception if this cannot be done Initiates the publication of a service pool, or raises an exception if this cannot be done
:param servicePool: Service pool object (db object)
:param changeLog: if not None, store change log string on "change log" table Args:
servicepool: Service pool to publish
changelog: Optional changelog to store
""" """
if servicepool.publications.filter(state__in=State.PUBLISH_STATES).count() > 0: if servicepool.publications.filter(state__in=State.PUBLISH_STATES).count() > 0:
raise PublishException( raise PublishException(
@ -274,9 +278,9 @@ class PublicationManager(metaclass=singleton.Singleton):
publish_date=now, publish_date=now,
revision=servicepool.current_pub_revision, revision=servicepool.current_pub_revision,
) )
if changeLog: if changelog:
servicepool.changelog.create( servicepool.changelog.create(
revision=servicepool.current_pub_revision, log=changeLog, stamp=now revision=servicepool.current_pub_revision, log=changelog, stamp=now
) )
if publication: if publication:
DelayedTaskRunner.runner().insert( DelayedTaskRunner.runner().insert(
@ -295,7 +299,10 @@ class PublicationManager(metaclass=singleton.Singleton):
""" """
Invoked to cancel a publication. Invoked to cancel a publication.
Double invokation (i.e. invokation over a "cancelling" item) will lead to a "forced" cancellation (unclean) Double invokation (i.e. invokation over a "cancelling" item) will lead to a "forced" cancellation (unclean)
:param servicePoolPub: Service pool publication (db object for a publication)
Args:
publication: Publication to cancel
""" """
publication = ServicePoolPublication.objects.get(pk=publication.id) # Reloads publication from db publication = ServicePoolPublication.objects.get(pk=publication.id) # Reloads publication from db
if publication.state not in State.PUBLISH_STATES: if publication.state not in State.PUBLISH_STATES:
@ -330,7 +337,10 @@ class PublicationManager(metaclass=singleton.Singleton):
def unpublish(self, servicepool_publication: ServicePoolPublication) -> None: def unpublish(self, servicepool_publication: ServicePoolPublication) -> None:
""" """
Unpublishes an active (usable) or removable publication Unpublishes an active (usable) or removable publication
:param servicePoolPub: Publication to unpublish
Args:
servicepool_publication: Publication to unpublish
""" """
if ( if (
State.from_str(servicepool_publication.state).is_usable() is False State.from_str(servicepool_publication.state).is_usable() is False
@ -340,9 +350,9 @@ class PublicationManager(metaclass=singleton.Singleton):
if servicepool_publication.userServices.exclude(state__in=State.INFO_STATES).count() > 0: if servicepool_publication.userServices.exclude(state__in=State.INFO_STATES).count() > 0:
raise PublishException(_('Can\'t unpublish publications with services in process')) raise PublishException(_('Can\'t unpublish publications with services in process'))
try: try:
pubInstance = servicepool_publication.get_instance() publication_instance = servicepool_publication.get_instance()
state = pubInstance.destroy() state = publication_instance.destroy()
servicepool_publication.set_state(State.REMOVING) servicepool_publication.set_state(State.REMOVING)
PublicationFinishChecker.state_updater(servicepool_publication, pubInstance, state) PublicationFinishChecker.state_updater(servicepool_publication, publication_instance, state)
except Exception as e: except Exception as e:
raise PublishException(str(e)) from e raise PublishException(str(e)) from e

View File

@ -337,10 +337,10 @@ class ServerManager(metaclass=singleton.Singleton):
Unassigns a server from an user Unassigns a server from an user
Args: Args:
userService: User service to unassign server from userservice: User service to unassign server from
serverGroups: Server group to unassign server from server_group: Server group to unassign server from
unlock: If True, unlock server, even if it has more users assigned to it unlock: If True, unlock server, even if it has more users assigned to it
userUuid: If not None, use this uuid instead of userService.user.uuid user_uuid: If not None, use this uuid instead of userservice.user.uuid
""" """
user_uuid = user_uuid if user_uuid else userservice.user.uuid if userservice.user else None user_uuid = user_uuid if user_uuid else userservice.user.uuid if userservice.user else None
@ -350,37 +350,37 @@ class ServerManager(metaclass=singleton.Singleton):
prop_name = self.property_name(userservice.user) prop_name = self.property_name(userservice.user)
with server_group.properties as props: with server_group.properties as props:
with transaction.atomic(): with transaction.atomic():
resetCounter = False reset_counter = False
# ServerCounterType # ServerCounterType
serverCounter: typing.Optional[types.servers.ServerCounter] = ( server_counter: typing.Optional[types.servers.ServerCounter] = (
types.servers.ServerCounter.from_iterable(props.get(prop_name)) types.servers.ServerCounter.from_iterable(props.get(prop_name))
) )
# If no cached value, get server assignation # If no cached value, get server assignation
if serverCounter is None: if server_counter is None:
return types.servers.ServerCounter.null() return types.servers.ServerCounter.null()
# Ensure counter is at least 1 # Ensure counter is at least 1
serverCounter = types.servers.ServerCounter( server_counter = types.servers.ServerCounter(
serverCounter.server_uuid, max(1, serverCounter.counter) server_counter.server_uuid, max(1, server_counter.counter)
) )
if serverCounter.counter == 1 or unlock: if server_counter.counter == 1 or unlock:
# Last one, remove it # Last one, remove it
del props[prop_name] del props[prop_name]
else: # Not last one, just decrement counter else: # Not last one, just decrement counter
props[prop_name] = (serverCounter.server_uuid, serverCounter.counter - 1) props[prop_name] = (server_counter.server_uuid, server_counter.counter - 1)
server = models.Server.objects.get(uuid=serverCounter.server_uuid) server = models.Server.objects.get(uuid=server_counter.server_uuid)
if unlock or serverCounter.counter == 1: if unlock or server_counter.counter == 1:
server.locked_until = None # Ensure server is unlocked if no more users are assigned to it server.locked_until = None # Ensure server is unlocked if no more users are assigned to it
server.save(update_fields=['locked_until']) server.save(update_fields=['locked_until'])
# Enure server counter is cleaned also, because server is considered "fully released" # Enure server counter is cleaned also, because server is considered "fully released"
resetCounter = True reset_counter = True
# If unmanaged, decrease usage # If unmanaged, decrease usage
if server.type == types.servers.ServerType.UNMANAGED: if server.type == types.servers.ServerType.UNMANAGED:
self.decrement_unmanaged_usage(server.uuid, force_reset=resetCounter) self.decrement_unmanaged_usage(server.uuid, force_reset=reset_counter)
# Ensure next assignation will have updated stats # Ensure next assignation will have updated stats
# This is a simple simulation on cached stats, will be updated on next stats retrieval # This is a simple simulation on cached stats, will be updated on next stats retrieval
@ -389,7 +389,7 @@ class ServerManager(metaclass=singleton.Singleton):
self.notify_release(server, userservice) self.notify_release(server, userservice)
return types.servers.ServerCounter(serverCounter.server_uuid, serverCounter.counter - 1) return types.servers.ServerCounter(server_counter.server_uuid, server_counter.counter - 1)
def notify_preconnect( def notify_preconnect(
self, self,
@ -424,12 +424,12 @@ class ServerManager(metaclass=singleton.Singleton):
def notify_release( def notify_release(
self, self,
server: 'models.Server', server: 'models.Server',
userService: 'models.UserService', userservice: 'models.UserService',
) -> None: ) -> None:
""" """
Notifies release to server Notifies release to server
""" """
requester.ServerApiRequester(server).notify_release(userService) requester.ServerApiRequester(server).notify_release(userservice)
def assignation_info(self, server_group: 'models.ServerGroup') -> dict[str, int]: def assignation_info(self, server_group: 'models.ServerGroup') -> dict[str, int]:
""" """

View File

@ -172,7 +172,8 @@ def ensure_connected(
# def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... # def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
# def cache_clear(self) -> None: ... # def cache_clear(self) -> None: ...
# def cache_info(self) -> CacheInfo: ... # def cache_info(self) -> CacheInfo: ...
# Now, we could use this by creating two decorators, one for the class methods and one for the functions
# But the inheritance problem will still be there, so we will keep the current implementation
# Decorator for caching # Decorator for caching
# This decorator will cache the result of the function for a given time, and given parameters # This decorator will cache the result of the function for a given time, and given parameters

View File

@ -169,13 +169,13 @@ def connection(
# Disable TLS1 and TLS1.1 # Disable TLS1 and TLS1.1
# 0x304 = TLS1.3, 0x303 = TLS1.2, 0x302 = TLS1.1, 0x301 = TLS1.0, but use ldap module constants # 0x304 = TLS1.3, 0x303 = TLS1.2, 0x302 = TLS1.1, 0x301 = TLS1.0, but use ldap module constants
# Ensure that libldap is compiled with TLS1.3 support # Ensure that libldap is compiled with TLS1.3 support
minVersion = getattr(settings, 'SECURE_MIN_TLS_VERSION', '1.2') min_tls_version = getattr(settings, 'SECURE_MIN_TLS_VERSION', '1.2')
if hasattr(ldap, 'OPT_X_TLS_PROTOCOL_TLS1_3'): if hasattr(ldap, 'OPT_X_TLS_PROTOCOL_TLS1_3'):
tls_version: typing.Any = { # for pyright to ignore tls_version: typing.Any = { # for pyright to ignore
'1.2': ldap.OPT_X_TLS_PROTOCOL_TLS1_2, # pyright: ignore '1.2': ldap.OPT_X_TLS_PROTOCOL_TLS1_2, # pyright: ignore
'1.3': ldap.OPT_X_TLS_PROTOCOL_TLS1_3, # pyright: ignore '1.3': ldap.OPT_X_TLS_PROTOCOL_TLS1_3, # pyright: ignore
}.get( }.get(
minVersion, ldap.OPT_X_TLS_PROTOCOL_TLS1_2 # pyright: ignore min_tls_version, ldap.OPT_X_TLS_PROTOCOL_TLS1_2 # pyright: ignore
) )
l.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, tls_version) # pyright: ignore l.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, tls_version) # pyright: ignore
@ -271,26 +271,35 @@ def as_dict(
def first( def first(
con: 'LDAPObject', con: 'LDAPObject',
base: str, base: str,
objectClass: str, object_class: str,
field: str, field: str,
value: str, value: str,
attributes: typing.Optional[collections.abc.Iterable[str]] = None, attributes: typing.Optional[collections.abc.Iterable[str]] = None,
sizeLimit: int = 50, max_entries: int = 50,
) -> typing.Optional[LDAPResultType]: ) -> typing.Optional[LDAPResultType]:
""" """
Searchs for the username and returns its LDAP entry Searchs for the username and returns its LDAP entry
@param username: username to search, using user provided parameters at configuration to map search entries.
@param objectClass: Objectclass of the user mane username to search. Args:
@return: None if username is not found, an dictionary of LDAP entry attributes if found (all in unicode on py2, str on py3). con (LDAPObject): Connection to LDAP
base (str): Base to search
object_class (str): Object class to search
field (str): Field to search
value (str): Value to search
attributes (typing.Optional[collections.abc.Iterable[str]], optional): Attributes to return. Defaults to None.
max_entries (int, optional): Max entries to return. Defaults to 50.
Returns:
typing.Optional[LDAPResultType]: Result of the search
""" """
value = ldap.filter.escape_filter_chars(value) # pyright: ignore reportGeneralTypeIssues value = ldap.filter.escape_filter_chars(value) # pyright: ignore reportGeneralTypeIssues
attrList = [field] + list(attributes) if attributes else [] attributes = [field] + list(attributes) if attributes else []
ldapFilter = f'(&(objectClass={objectClass})({field}={value}))' ldap_filter = f'(&(objectClass={object_class})({field}={value}))'
try: try:
obj = next(as_dict(con, base, ldapFilter, attrList, sizeLimit)) obj = next(as_dict(con, base, ldap_filter, attributes, max_entries))
except StopIteration: except StopIteration:
return None # None found return None # None found

View File

@ -116,8 +116,8 @@ class RadiusOTP(mfas.MFA):
return client.RadiusClient( return client.RadiusClient(
self.server.value, self.server.value,
self.secret.value.encode(), self.secret.value.encode(),
authPort=self.port.as_int(), auth_port=self.port.as_int(),
nasIdentifier=self.nas_identifier.value, nas_identifier=self.nas_identifier.value,
) )
def check_result(self, action: str, request: 'ExtendedHttpRequest') -> mfas.MFA.RESULT: def check_result(self, action: str, request: 'ExtendedHttpRequest') -> mfas.MFA.RESULT: