1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

snake_case fixin before stabilizing... :)

This commit is contained in:
Adolfo Gómez García 2024-10-12 17:55:14 +02:00
parent 25aa09309b
commit bd53926c81
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
19 changed files with 194 additions and 170 deletions

View File

@ -156,9 +156,9 @@ class InternalDBAuth(auths.Authenticator):
request: 'ExtendedHttpRequest',
) -> types.auth.AuthenticationResult:
username = username.lower()
dbAuth = self.db_obj()
auth_db = self.db_obj()
try:
user: 'models.User' = dbAuth.users.get(name=username, state=State.ACTIVE)
user: 'models.User' = auth_db.users.get(name=username, state=State.ACTIVE)
except Exception:
log_login(request, self.db_obj(), username, 'Invalid user', as_error=True)
return types.auth.FAILED_AUTH
@ -175,15 +175,15 @@ class InternalDBAuth(auths.Authenticator):
return types.auth.FAILED_AUTH
def get_groups(self, username: str, groups_manager: 'auths.GroupsManager') -> None:
dbAuth = self.db_obj()
auth_db = self.db_obj()
try:
user: 'models.User' = dbAuth.users.get(name=username.lower(), state=State.ACTIVE)
user: 'models.User' = auth_db.users.get(name=username.lower(), state=State.ACTIVE)
except Exception:
return
grps = [g.name for g in user.groups.all()]
if user.parent:
try:
parent = dbAuth.users.get(uuid=user.parent, state=State.ACTIVE)
parent = auth_db.users.get(uuid=user.parent, state=State.ACTIVE)
grps.extend([g.name for g in parent.groups.all()])
except Exception:
pass

View File

@ -84,7 +84,7 @@ class RadiusAuth(auths.Authenticator):
required=True,
)
nasIdentifier = gui.TextField(
nas_identifier = gui.TextField(
length=64,
label=_('NAS Identifier'),
default='uds-server',
@ -92,26 +92,29 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('NAS Identifier for Radius Server'),
required=True,
tab=types.ui.Tab.ADVANCED,
old_field_name='nasIdentifier',
)
appClassPrefix = gui.TextField(
app_class_prefix = gui.TextField(
length=64,
label=_('App Prefix for Class Attributes'),
default='',
order=11,
tooltip=_('Application prefix for filtering groups from "Class" attribute'),
tab=types.ui.Tab.ADVANCED,
old_field_name='appClassPrefix',
)
globalGroup = gui.TextField(
global_group = gui.TextField(
length=64,
label=_('Global group'),
default='',
order=12,
tooltip=_('If set, this value will be added as group for all radius users'),
tab=types.ui.Tab.ADVANCED,
old_field_name='globalGroup',
)
mfaAttr = gui.TextField(
mfa_attr = gui.TextField(
length=2048,
lines=2,
label=_('MFA attribute'),
@ -119,6 +122,7 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=types.ui.Tab.MFA,
old_field_name='mfaAttr',
)
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
@ -129,9 +133,9 @@ class RadiusAuth(auths.Authenticator):
return client.RadiusClient(
self.server.value,
self.secret.value.encode(),
authPort=self.port.as_int(),
nasIdentifier=self.nasIdentifier.value,
appClassPrefix=self.appClassPrefix.value,
auth_port=self.port.as_int(),
nas_identifier=self.nas_identifier.value,
appclass_prefix=self.app_class_prefix.value,
)
def mfa_storage_key(self, username: str) -> str:
@ -149,17 +153,17 @@ class RadiusAuth(auths.Authenticator):
) -> types.auth.AuthenticationResult:
try:
connection = self.radius_client()
groups, mfaCode, state = connection.authenticate(
username=username, password=credentials, mfaField=self.mfaAttr.value.strip()
groups, mfa_code, state = connection.authenticate(
username=username, password=credentials, mfa_field=self.mfa_attr.value.strip()
)
# If state, store in session
if state:
request.session[client.STATE_VAR_NAME] = state.decode()
# store the user mfa attribute if it is set
if mfaCode:
if mfa_code:
self.storage.save_pickled(
self.mfa_storage_key(username),
mfaCode,
mfa_code,
)
except Exception:
@ -172,8 +176,8 @@ class RadiusAuth(auths.Authenticator):
)
return types.auth.FAILED_AUTH
if self.globalGroup.value.strip():
groups.append(self.globalGroup.value.strip())
if self.global_group.value.strip():
groups.append(self.global_group.value.strip())
# Cache groups for "getGroups", because radius will not send us those
with self.storage.as_dict() as storage:

View File

@ -110,49 +110,49 @@ class RadiusResult:
"""
pwd: RadiusStates = RadiusStates.INCORRECT
replyMessage: typing.Optional[bytes] = None
reply_message: typing.Optional[bytes] = None
state: typing.Optional[bytes] = None
otp: RadiusStates = RadiusStates.NOT_CHECKED
otp_needed: RadiusStates = RadiusStates.NOT_CHECKED
class RadiusClient:
radiusServer: Client
nasIdentifier: str
appClassPrefix: str
server: Client
nas_identifier: str
appclass_prefix: str
def __init__(
self,
server: str,
secret: bytes,
*,
authPort: int = 1812,
nasIdentifier: str = 'uds-server',
appClassPrefix: str = '',
auth_port: int = 1812,
nas_identifier: str = 'uds-server',
appclass_prefix: str = '',
dictionary: str = RADDICT,
) -> None:
self.radiusServer = Client(
self.server = Client(
server=server,
authport=authPort,
authport=auth_port,
secret=secret,
dict=Dictionary(io.StringIO(dictionary)),
)
self.nasIdentifier = nasIdentifier
self.appClassPrefix = appClassPrefix
self.nas_identifier = nas_identifier
self.appclass_prefix = appclass_prefix
def extractAccessChallenge(self, reply: pyrad.packet.AuthPacket) -> RadiusResult:
def extract_access_challenge(self, reply: pyrad.packet.AuthPacket) -> RadiusResult:
return RadiusResult(
pwd=RadiusStates.CORRECT,
replyMessage=typing.cast(list[bytes], reply.get('Reply-Message') or [''])[0],
reply_message=typing.cast(list[bytes], reply.get('Reply-Message') or [''])[0],
state=typing.cast(list[bytes], reply.get('State') or [b''])[0],
otp_needed=RadiusStates.NEEDED,
)
def sendAccessRequest(self, username: str, password: str, **kwargs: typing.Any) -> pyrad.packet.AuthPacket:
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
def send_access_request(self, username: str, password: str, **kwargs: typing.Any) -> pyrad.packet.AuthPacket:
req: pyrad.packet.AuthPacket = self.server.CreateAuthPacket(
code=pyrad.packet.AccessRequest,
User_Name=username,
NAS_Identifier=self.nasIdentifier,
NAS_Identifier=self.nas_identifier,
)
req["User-Password"] = req.PwCrypt(password)
@ -161,46 +161,46 @@ class RadiusClient:
for k, v in kwargs.items():
req[k] = v
return typing.cast(pyrad.packet.AuthPacket, self.radiusServer.SendPacket(req))
return typing.cast(pyrad.packet.AuthPacket, self.server.SendPacket(req))
# Second element of return value is the mfa code from field
def authenticate(
self, username: str, password: str, mfaField: str = ''
self, username: str, password: str, mfa_field: str = ''
) -> tuple[list[str], str, bytes]:
reply = self.sendAccessRequest(username, password)
reply = self.send_access_request(username, password)
if reply.code not in (pyrad.packet.AccessAccept, pyrad.packet.AccessChallenge):
raise RadiusAuthenticationError('Access denied')
# User accepted, extract groups...
# All radius users belongs to, at least, 'uds-users' group
groupClassPrefix = (self.appClassPrefix + 'group=').encode()
groupClassPrefixLen = len(groupClassPrefix)
groupclass_prefix = (self.appclass_prefix + 'group=').encode()
groupclass_prefix_len = len(groupclass_prefix)
if 'Class' in reply:
groups = [
i[groupClassPrefixLen:].decode()
i[groupclass_prefix_len:].decode()
for i in typing.cast(collections.abc.Iterable[bytes], reply['Class'])
if i.startswith(groupClassPrefix)
if i.startswith(groupclass_prefix)
]
else:
logger.info('No "Class (25)" attribute found')
return ([], '', b'')
# ...and mfa code
mfaCode = ''
if mfaField and mfaField in reply:
mfaCode = ''.join(
i[groupClassPrefixLen:].decode()
mfa_code = ''
if mfa_field and mfa_field in reply:
mfa_code = ''.join(
i[groupclass_prefix_len:].decode()
for i in typing.cast(collections.abc.Iterable[bytes], reply['Class'])
if i.startswith(groupClassPrefix)
if i.startswith(groupclass_prefix)
)
return (groups, mfaCode, typing.cast(list[bytes], reply.get('State') or [b''])[0])
return (groups, mfa_code, typing.cast(list[bytes], reply.get('State') or [b''])[0])
def authenticate_only(self, username: str, password: str) -> RadiusResult:
reply = self.sendAccessRequest(username, password)
reply = self.send_access_request(username, password)
if reply.code == pyrad.packet.AccessChallenge:
return self.extractAccessChallenge(reply)
return self.extract_access_challenge(reply)
# user/pwd accepted: this user does not have challenge data
if reply.code == pyrad.packet.AccessAccept:
@ -221,7 +221,7 @@ class RadiusClient:
logger.debug('Sending AccessChallenge request wit otp [%s]', otp)
reply = self.sendAccessRequest(username, otp, State=state)
reply = self.send_access_request(username, otp, State=state)
logger.debug('Received AccessChallenge reply: %s', reply)
@ -238,7 +238,7 @@ class RadiusClient:
)
def authenticate_and_challenge(self, username: str, password: str, otp: str) -> RadiusResult:
reply = self.sendAccessRequest(username, password)
reply = self.send_access_request(username, password)
if reply.code == pyrad.packet.AccessChallenge:
state = typing.cast(list[bytes], reply.get('State') or [b''])[0]

View File

@ -283,11 +283,11 @@ class RegexLdap(auths.Authenticator):
user = ldaputil.first(
con=self._stablish_connection(),
base=self.ldap_base.as_str(),
objectClass=self.user_class.as_str(),
object_class=self.user_class.as_str(),
field=self.userid_attr.as_str(),
value=username,
attributes=attributes,
sizeLimit=LDAP_RESULT_LIMIT,
max_entries=LDAP_RESULT_LIMIT,
)
# If user attributes is split, that is, it has more than one "ldap entry", get a second entry filtering by a new attribute

View File

@ -707,15 +707,15 @@ class SAMLAuthenticator(auths.Authenticator):
)
logger.debug('Groups: %s', groups)
realName = ' '.join(
realname = ' '.join(
auth_utils.process_regex_field(
self.attrs_realname.value, attributes # pyright: ignore reportUnknownVariableType
)
)
logger.debug('Real name: %s', realName)
logger.debug('Real name: %s', realname)
# store groups for this username at storage, so we can check it at a later stage
self.storage.save_pickled(username, [realName, groups])
self.storage.save_pickled(username, [realname, groups])
# store also the mfa identifier field value, in case we have provided it
if self.mfa_attr.value.strip():

View File

@ -280,11 +280,11 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
return ldaputil.first(
con=self._get_connection(),
base=self.ldap_base.as_str(),
objectClass=self.user_class.as_str(),
object_class=self.user_class.as_str(),
field=self.user_id_attr.as_str(),
value=username,
attributes=attributes,
sizeLimit=LDAP_RESULT_LIMIT,
max_entries=LDAP_RESULT_LIMIT,
)
def _get_group(self, groupName: str) -> typing.Optional[ldaputil.LDAPResultType]:
@ -296,11 +296,11 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
return ldaputil.first(
con=self._get_connection(),
base=self.ldap_base.as_str(),
objectClass=self.group_class.as_str(),
object_class=self.group_class.as_str(),
field=self.group_id_attr.as_str(),
value=groupName,
attributes=[self.member_attr.as_str()],
sizeLimit=LDAP_RESULT_LIMIT,
max_entries=LDAP_RESULT_LIMIT,
)
def _get_groups(self, user: ldaputil.LDAPResultType) -> list[str]:

View File

@ -69,7 +69,7 @@ class Authenticator(Module):
As always, if you override __init__, do not forget to invoke base __init__ as this::
super(self.__class__, self).__init__(self, environment, values, dbAuth)
super(self.__class__, self).__init__(self, environment, values, uuid)
This is a MUST, so internal structured gets filled correctly, so don't forget it!.
@ -85,7 +85,7 @@ class Authenticator(Module):
so if an user do not exist at UDS database, it will not be valid.
In other words, if you have an authenticator where you must create users,
you can modify them, you must assign passwords manually, and group membership
also must be assigned manually, the authenticator is not an externalSource.
also must be assigned manually, the authenticator is not an external_source.
As you can notice, almost avery authenticator except internal db will be
external source, so, by default, attribute that indicates that is an external

View File

@ -83,27 +83,28 @@ class GroupsManager:
"""
_groups: list[_LocalGrp]
_db_auth: 'DBAuthenticator'
def __init__(self, dbAuthenticator: 'DBAuthenticator'):
def __init__(self, db_auth: 'DBAuthenticator'):
"""
Initializes the groups manager.
The dbAuthenticator is the database record of the authenticator
to which this groupsManager will be associated
"""
self._dbAuthenticator = dbAuthenticator
self._db_auth = db_auth
# We just get active groups, inactive aren't visible to this class
self._groups = []
if (
dbAuthenticator.id
db_auth.id
): # If "fake" authenticator (that is, root user with no authenticator in fact)
for g in dbAuthenticator.groups.filter(state=State.ACTIVE, is_meta=False):
for g in db_auth.groups.filter(state=State.ACTIVE, is_meta=False):
name = g.name.lower()
isPattern = name.find('pat:') == 0 # Is a pattern?
is_pattern_group = name.startswith('pat:') # Is a pattern?
self._groups.append(
_LocalGrp(
name=name[4:] if isPattern else name,
name=name[4:] if is_pattern_group else name,
group=Group(g),
is_pattern=isPattern,
is_pattern=is_pattern_group,
)
)
@ -147,7 +148,7 @@ class GroupsManager:
# Now, get metagroups and also return them
for db_group in DBGroup.objects.filter(
manager__id=self._dbAuthenticator.id, is_meta=True
manager__id=self._db_auth.id, is_meta=True
): # @UndefinedVariable
gn = db_group.groups.filter(
id__in=valid_id_list, state=State.ACTIVE

View File

@ -74,10 +74,9 @@ class Environment:
):
"""
Initialized the Environment for the specified id
@param uniqueKey: Key for this environment
@param idGenerators: Hash of generators of ids for this environment. This "generators of ids" feature
is used basically at User Services to auto-create ids for macs or names, using
{'mac' : UniqueMacGenerator, 'name' : UniqueNameGenerator } as argument.
Args:
unique_key: Unique key for the environment
"""
# Avoid circular imports
from uds.core.util.cache import Cache # pylint: disable=import-outside-toplevel

View File

@ -135,7 +135,7 @@ class DelayedTaskRunner(metaclass=singleton.Singleton):
return
if task_instance:
logger.debug('Executing delayedTask:>%s<', task)
logger.debug('Executing delayedtask:>%s<', task)
# Re-create environment data
task_instance.env = Environment.type_environment(task_instance.__class__)
DelayedTaskThread(task_instance).start()

View File

@ -79,7 +79,7 @@ class JobsFactory(factory.Factory['Job']):
job.save()
except Exception as e:
logger.debug(
'Exception at ensureJobsInDatabase in JobsFactory: %s, %s',
'Exception at ensure_jobs_registered in JobsFactory: %s, %s',
e.__class__,
e,
)

View File

@ -64,11 +64,11 @@ class JobThread(threading.Thread):
_db_job_id: int
_freq: int
def __init__(self, jobInstance: 'Job', dbJob: DBScheduler) -> None:
def __init__(self, job_instance: 'Job', db_job: DBScheduler) -> None:
super().__init__()
self._job_instance = jobInstance
self._db_job_id = dbJob.id
self._freq = dbJob.frecuency
self._job_instance = job_instance
self._db_job_id = db_job.id
self._freq = db_job.frecuency
def run(self) -> None:
try:
@ -148,7 +148,7 @@ class Scheduler:
"""
Looks for the best waiting job and executes it
"""
jobInstance = None
job_instance = None
try:
now = sql_now() # Datetimes are based on database server times
fltr = Q(state=State.FOR_EXECUTE) & (
@ -172,14 +172,14 @@ class Scheduler:
job.last_execution = now
job.save(update_fields=['state', 'owner_server', 'last_execution'])
jobInstance = job.get_instance()
job_instance = job.get_instance()
if jobInstance is None:
if job_instance is None:
logger.error('Job instance can\'t be resolved for %s, removing it', job)
job.delete()
return
logger.debug('Executing job:>%s<', job.name)
JobThread(jobInstance, job).start() # Do not instatiate thread, just run it
JobThread(job_instance, job).start() # Do not instatiate thread, just run it
except IndexError:
# Do nothing, there is no jobs for execution
return

View File

@ -156,11 +156,11 @@ class CryptoManager(metaclass=singleton.Singleton):
modes.CBC(b'udsinitvectoruds'),
backend=default_backend(),
)
rndStr = secrets.token_bytes(16) # Same as block size of CBC (that is 16 here)
paddedLength = ((len(text) + 4 + 15) // 16) * 16
toEncode = struct.pack('>i', len(text)) + text + rndStr[: paddedLength - len(text) - 4]
rnd_string = secrets.token_bytes(16) # Same as block size of CBC (that is 16 here)
padded_length = ((len(text) + 4 + 15) // 16) * 16 # calculate padding length, 4 is for length of text
to_encode = struct.pack('>i', len(text)) + text + rnd_string[: padded_length - len(text) - 4]
encryptor = cipher.encryptor()
encoded = encryptor.update(toEncode) + encryptor.finalize()
encoded = encryptor.update(to_encode) + encryptor.finalize()
if base64:
encoded = codecs.encode(encoded, 'base64').strip() # Return as bytes
@ -178,8 +178,8 @@ class CryptoManager(metaclass=singleton.Singleton):
)
decryptor = cipher.decryptor()
toDecode = decryptor.update(text) + decryptor.finalize()
return toDecode[4 : 4 + struct.unpack('>i', toDecode[:4])[0]]
to_decode = decryptor.update(text) + decryptor.finalize()
return to_decode[4 : 4 + struct.unpack('>i', to_decode[:4])[0]]
# Fast encription using django SECRET_KEY as key
def fast_crypt(self, data: bytes) -> bytes:
@ -212,28 +212,28 @@ class CryptoManager(metaclass=singleton.Singleton):
return self.aes_crypt(text, key)
def symmetric_decrypt(self, cryptText: typing.Union[str, bytes], key: typing.Union[str, bytes]) -> str:
if isinstance(cryptText, str):
cryptText = cryptText.encode()
def symmetric_decrypt(self, encrypted_text: typing.Union[str, bytes], key: typing.Union[str, bytes]) -> str:
if isinstance(encrypted_text, str):
encrypted_text = encrypted_text.encode()
if isinstance(key, str):
key = key.encode()
if not cryptText or not key:
if not encrypted_text or not key:
return ''
try:
return self.aes_decrypt(cryptText, key).decode('utf-8')
return self.aes_decrypt(encrypted_text, key).decode('utf-8')
except Exception: # Error decoding crypted element, return empty one
return ''
def load_private_key(
self, rsaKey: str
self, rsa_key: str
) -> typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey']:
try:
return typing.cast(
typing.Union['RSAPrivateKey', 'DSAPrivateKey', 'DHPrivateKey', 'EllipticCurvePrivateKey'],
serialization.load_pem_private_key(rsaKey.encode(), password=None, backend=default_backend()),
serialization.load_pem_private_key(rsa_key.encode(), password=None, backend=default_backend()),
)
except Exception as e:
raise e
@ -271,32 +271,32 @@ class CryptoManager(metaclass=singleton.Singleton):
# Argon2
return '{ARGON2}' + PasswordHasher(type=ArgonType.ID).hash(value)
def check_hash(self, value: typing.Union[str, bytes], hashValue: str) -> bool:
def check_hash(self, value: typing.Union[str, bytes], hash_value: str) -> bool:
if isinstance(value, str):
value = value.encode()
if not value:
return not hashValue
return not hash_value
if hashValue[:8] == '{SHA256}':
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hashValue[8:])
if hashValue[:12] == '{SHA256SALT}':
if hash_value[:8] == '{SHA256}':
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hash_value[8:])
if hash_value[:12] == '{SHA256SALT}':
# Extract 16 chars salt and hash
salt = hashValue[12:28].encode()
salt = hash_value[12:28].encode()
value = salt + value
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hashValue[28:])
return secrets.compare_digest(hashlib.sha3_256(value).hexdigest(), hash_value[28:])
# Argon2
if hashValue[:8] == '{ARGON2}':
if hash_value[:8] == '{ARGON2}':
ph = PasswordHasher() # Type is implicit in hash
try:
ph.verify(hashValue[8:], value)
ph.verify(hash_value[8:], value)
return True
except Exception:
return False # Verify will raise an exception if not valid
# Old sha1
return secrets.compare_digest(
hashValue,
hash_value,
str(
hashlib.sha1(
value

View File

@ -51,7 +51,7 @@ class DownloadsManager(metaclass=singleton.Singleton):
For registering, use at __init__.py of the conecto something like this:
from uds.core.managers import DownloadsManager
import os.path, sys
downloadsManager().registerDownloadable('test.exe',
DownloadsManager.manager().registerDownloadable('test.exe',
_('comments for test'),
os.path.join(os.path.dirname(sys.modules[__package__].__file__), 'files/test.exe'),
'application/x-msdos-program')

View File

@ -63,26 +63,26 @@ class PublicationOldMachinesCleaner(DelayedTask):
This delayed task is for removing a pending "removable" publication
"""
def __init__(self, publicationId: int):
def __init__(self, publication_id: int):
super().__init__()
self._id = publicationId
self._id = publication_id
def run(self) -> None:
try:
servicePoolPub: ServicePoolPublication = ServicePoolPublication.objects.get(pk=self._id)
if servicePoolPub.state != State.REMOVABLE:
servicepool_publication: ServicePoolPublication = ServicePoolPublication.objects.get(pk=self._id)
if servicepool_publication.state != State.REMOVABLE:
logger.info('Already removed')
now = sql_now()
current_publication: typing.Optional[ServicePoolPublication] = (
servicePoolPub.deployed_service.active_publication()
servicepool_publication.deployed_service.active_publication()
)
if current_publication:
servicePoolPub.deployed_service.userServices.filter(in_use=True).exclude(
servicepool_publication.deployed_service.userServices.filter(in_use=True).exclude(
publication=current_publication
).update(in_use=False, state_date=now)
servicePoolPub.deployed_service.mark_old_userservices_as_removable(current_publication)
servicepool_publication.deployed_service.mark_old_userservices_as_removable(current_publication)
except Exception: # nosec: Removed publication, no problem at all, just continue
pass
@ -102,7 +102,9 @@ class PublicationLauncher(DelayedTask):
try:
now = sql_now()
with transaction.atomic():
servicepool_publication = ServicePoolPublication.objects.select_for_update().get(pk=self._publication_id)
servicepool_publication = ServicePoolPublication.objects.select_for_update().get(
pk=self._publication_id
)
if not servicepool_publication:
raise ServicePool.DoesNotExist()
if (
@ -113,13 +115,13 @@ class PublicationLauncher(DelayedTask):
servicepool_publication.save()
pi = servicepool_publication.get_instance()
state = pi.publish()
servicePool: ServicePool = servicepool_publication.deployed_service
servicePool.current_pub_revision += 1
servicePool.set_value(
servicepool: ServicePool = servicepool_publication.deployed_service
servicepool.current_pub_revision += 1
servicepool.set_value(
'toBeReplacedIn',
serialize(now + datetime.timedelta(hours=GlobalConfig.SESSION_EXPIRE_TIME.as_int(True))),
)
servicePool.save()
servicepool.save()
PublicationFinishChecker.state_updater(servicepool_publication, pi, state)
except (
ServicePoolPublication.DoesNotExist
@ -196,21 +198,21 @@ class PublicationFinishChecker(DelayedTask):
publication.update_data(publication_instance)
if check_later:
PublicationFinishChecker.check_later(publication, publication_instance)
PublicationFinishChecker.check_later(publication)
except Exception:
logger.exception('At checkAndUpdate for publication')
PublicationFinishChecker.check_later(publication, publication_instance)
logger.exception('At check_and_update for publication')
PublicationFinishChecker.check_later(publication)
@staticmethod
def check_later(publication: ServicePoolPublication, publicationInstance: 'services.Publication') -> None:
def check_later(publication: ServicePoolPublication) -> None:
"""
Inserts a task in the delayedTaskRunner so we can check the state of this publication
Inserts a task in the delayed_task_runner so we can check the state of this publication
@param dps: Database object for ServicePoolPublication
@param pi: Instance of Publication manager for the object
"""
DelayedTaskRunner.runner().insert(
PublicationFinishChecker(publication),
publicationInstance.suggested_delay,
publication.get_instance().suggested_delay,
PUBTAG + str(publication.id),
)
@ -221,13 +223,13 @@ class PublicationFinishChecker(DelayedTask):
if publication.state != self._state:
logger.debug('Task overrided by another task (state of item changed)')
else:
publicationInstance = publication.get_instance()
logger.debug("publication instance class: %s", publicationInstance.__class__)
publication_instance = publication.get_instance()
logger.debug("publication instance class: %s", publication_instance.__class__)
try:
state = publicationInstance.check_state()
state = publication_instance.check_state()
except Exception:
state = types.states.TaskState.ERROR
PublicationFinishChecker.state_updater(publication, publicationInstance, state)
PublicationFinishChecker.state_updater(publication, publication_instance, state)
except Exception as e:
logger.debug(
'Deployed service not found (erased from database) %s : %s',
@ -251,11 +253,13 @@ class PublicationManager(metaclass=singleton.Singleton):
"""
return PublicationManager() # Singleton pattern will return always the same instance
def publish(self, servicepool: ServicePool, changeLog: typing.Optional[str] = None) -> None:
def publish(self, servicepool: ServicePool, changelog: typing.Optional[str] = None) -> None:
"""
Initiates the publication of a service pool, or raises an exception if this cannot be done
:param servicePool: Service pool object (db object)
:param changeLog: if not None, store change log string on "change log" table
Args:
servicepool: Service pool to publish
changelog: Optional changelog to store
"""
if servicepool.publications.filter(state__in=State.PUBLISH_STATES).count() > 0:
raise PublishException(
@ -274,9 +278,9 @@ class PublicationManager(metaclass=singleton.Singleton):
publish_date=now,
revision=servicepool.current_pub_revision,
)
if changeLog:
if changelog:
servicepool.changelog.create(
revision=servicepool.current_pub_revision, log=changeLog, stamp=now
revision=servicepool.current_pub_revision, log=changelog, stamp=now
)
if publication:
DelayedTaskRunner.runner().insert(
@ -295,7 +299,10 @@ class PublicationManager(metaclass=singleton.Singleton):
"""
Invoked to cancel a publication.
Double invokation (i.e. invokation over a "cancelling" item) will lead to a "forced" cancellation (unclean)
:param servicePoolPub: Service pool publication (db object for a publication)
Args:
publication: Publication to cancel
"""
publication = ServicePoolPublication.objects.get(pk=publication.id) # Reloads publication from db
if publication.state not in State.PUBLISH_STATES:
@ -330,7 +337,10 @@ class PublicationManager(metaclass=singleton.Singleton):
def unpublish(self, servicepool_publication: ServicePoolPublication) -> None:
"""
Unpublishes an active (usable) or removable publication
:param servicePoolPub: Publication to unpublish
Args:
servicepool_publication: Publication to unpublish
"""
if (
State.from_str(servicepool_publication.state).is_usable() is False
@ -340,9 +350,9 @@ class PublicationManager(metaclass=singleton.Singleton):
if servicepool_publication.userServices.exclude(state__in=State.INFO_STATES).count() > 0:
raise PublishException(_('Can\'t unpublish publications with services in process'))
try:
pubInstance = servicepool_publication.get_instance()
state = pubInstance.destroy()
publication_instance = servicepool_publication.get_instance()
state = publication_instance.destroy()
servicepool_publication.set_state(State.REMOVING)
PublicationFinishChecker.state_updater(servicepool_publication, pubInstance, state)
PublicationFinishChecker.state_updater(servicepool_publication, publication_instance, state)
except Exception as e:
raise PublishException(str(e)) from e

View File

@ -337,10 +337,10 @@ class ServerManager(metaclass=singleton.Singleton):
Unassigns a server from an user
Args:
userService: User service to unassign server from
serverGroups: Server group to unassign server from
userservice: User service to unassign server from
server_group: Server group to unassign server from
unlock: If True, unlock server, even if it has more users assigned to it
userUuid: If not None, use this uuid instead of userService.user.uuid
user_uuid: If not None, use this uuid instead of userservice.user.uuid
"""
user_uuid = user_uuid if user_uuid else userservice.user.uuid if userservice.user else None
@ -350,37 +350,37 @@ class ServerManager(metaclass=singleton.Singleton):
prop_name = self.property_name(userservice.user)
with server_group.properties as props:
with transaction.atomic():
resetCounter = False
reset_counter = False
# ServerCounterType
serverCounter: typing.Optional[types.servers.ServerCounter] = (
server_counter: typing.Optional[types.servers.ServerCounter] = (
types.servers.ServerCounter.from_iterable(props.get(prop_name))
)
# If no cached value, get server assignation
if serverCounter is None:
if server_counter is None:
return types.servers.ServerCounter.null()
# Ensure counter is at least 1
serverCounter = types.servers.ServerCounter(
serverCounter.server_uuid, max(1, serverCounter.counter)
server_counter = types.servers.ServerCounter(
server_counter.server_uuid, max(1, server_counter.counter)
)
if serverCounter.counter == 1 or unlock:
if server_counter.counter == 1 or unlock:
# Last one, remove it
del props[prop_name]
else: # Not last one, just decrement counter
props[prop_name] = (serverCounter.server_uuid, serverCounter.counter - 1)
props[prop_name] = (server_counter.server_uuid, server_counter.counter - 1)
server = models.Server.objects.get(uuid=serverCounter.server_uuid)
server = models.Server.objects.get(uuid=server_counter.server_uuid)
if unlock or serverCounter.counter == 1:
if unlock or server_counter.counter == 1:
server.locked_until = None # Ensure server is unlocked if no more users are assigned to it
server.save(update_fields=['locked_until'])
# Enure server counter is cleaned also, because server is considered "fully released"
resetCounter = True
reset_counter = True
# If unmanaged, decrease usage
if server.type == types.servers.ServerType.UNMANAGED:
self.decrement_unmanaged_usage(server.uuid, force_reset=resetCounter)
self.decrement_unmanaged_usage(server.uuid, force_reset=reset_counter)
# Ensure next assignation will have updated stats
# This is a simple simulation on cached stats, will be updated on next stats retrieval
@ -389,7 +389,7 @@ class ServerManager(metaclass=singleton.Singleton):
self.notify_release(server, userservice)
return types.servers.ServerCounter(serverCounter.server_uuid, serverCounter.counter - 1)
return types.servers.ServerCounter(server_counter.server_uuid, server_counter.counter - 1)
def notify_preconnect(
self,
@ -424,12 +424,12 @@ class ServerManager(metaclass=singleton.Singleton):
def notify_release(
self,
server: 'models.Server',
userService: 'models.UserService',
userservice: 'models.UserService',
) -> None:
"""
Notifies release to server
"""
requester.ServerApiRequester(server).notify_release(userService)
requester.ServerApiRequester(server).notify_release(userservice)
def assignation_info(self, server_group: 'models.ServerGroup') -> dict[str, int]:
"""

View File

@ -172,7 +172,8 @@ def ensure_connected(
# def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
# def cache_clear(self) -> None: ...
# def cache_info(self) -> CacheInfo: ...
# Now, we could use this by creating two decorators, one for the class methods and one for the functions
# But the inheritance problem will still be there, so we will keep the current implementation
# Decorator for caching
# This decorator will cache the result of the function for a given time, and given parameters

View File

@ -169,13 +169,13 @@ def connection(
# Disable TLS1 and TLS1.1
# 0x304 = TLS1.3, 0x303 = TLS1.2, 0x302 = TLS1.1, 0x301 = TLS1.0, but use ldap module constants
# Ensure that libldap is compiled with TLS1.3 support
minVersion = getattr(settings, 'SECURE_MIN_TLS_VERSION', '1.2')
min_tls_version = getattr(settings, 'SECURE_MIN_TLS_VERSION', '1.2')
if hasattr(ldap, 'OPT_X_TLS_PROTOCOL_TLS1_3'):
tls_version: typing.Any = { # for pyright to ignore
'1.2': ldap.OPT_X_TLS_PROTOCOL_TLS1_2, # pyright: ignore
'1.3': ldap.OPT_X_TLS_PROTOCOL_TLS1_3, # pyright: ignore
}.get(
minVersion, ldap.OPT_X_TLS_PROTOCOL_TLS1_2 # pyright: ignore
min_tls_version, ldap.OPT_X_TLS_PROTOCOL_TLS1_2 # pyright: ignore
)
l.set_option(ldap.OPT_X_TLS_PROTOCOL_MIN, tls_version) # pyright: ignore
@ -271,26 +271,35 @@ def as_dict(
def first(
con: 'LDAPObject',
base: str,
objectClass: str,
object_class: str,
field: str,
value: str,
attributes: typing.Optional[collections.abc.Iterable[str]] = None,
sizeLimit: int = 50,
max_entries: int = 50,
) -> typing.Optional[LDAPResultType]:
"""
Searchs for the username and returns its LDAP entry
@param username: username to search, using user provided parameters at configuration to map search entries.
@param objectClass: Objectclass of the user mane username to search.
@return: None if username is not found, an dictionary of LDAP entry attributes if found (all in unicode on py2, str on py3).
Args:
con (LDAPObject): Connection to LDAP
base (str): Base to search
object_class (str): Object class to search
field (str): Field to search
value (str): Value to search
attributes (typing.Optional[collections.abc.Iterable[str]], optional): Attributes to return. Defaults to None.
max_entries (int, optional): Max entries to return. Defaults to 50.
Returns:
typing.Optional[LDAPResultType]: Result of the search
"""
value = ldap.filter.escape_filter_chars(value) # pyright: ignore reportGeneralTypeIssues
attrList = [field] + list(attributes) if attributes else []
attributes = [field] + list(attributes) if attributes else []
ldapFilter = f'(&(objectClass={objectClass})({field}={value}))'
ldap_filter = f'(&(objectClass={object_class})({field}={value}))'
try:
obj = next(as_dict(con, base, ldapFilter, attrList, sizeLimit))
obj = next(as_dict(con, base, ldap_filter, attributes, max_entries))
except StopIteration:
return None # None found

View File

@ -116,8 +116,8 @@ class RadiusOTP(mfas.MFA):
return client.RadiusClient(
self.server.value,
self.secret.value.encode(),
authPort=self.port.as_int(),
nasIdentifier=self.nas_identifier.value,
auth_port=self.port.as_int(),
nas_identifier=self.nas_identifier.value,
)
def check_result(self, action: str, request: 'ExtendedHttpRequest') -> mfas.MFA.RESULT: