1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-02-02 09:47:13 +03:00

Updated SAML

This commit is contained in:
Adolfo Gómez García 2024-01-24 04:15:53 +01:00
parent 3afeb4869c
commit 81abe1d99f
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
5 changed files with 189 additions and 160 deletions

View File

@ -380,14 +380,14 @@ class OAuth2Authenticator(auths.Authenticator):
# After this point, we don't mind about the token, we only need to authenticate user
# and get some basic info from it
username = ''.join(auth_utils.processRegexField(self.userNameAttr.value, userInfo)).replace(' ', '_')
username = ''.join(auth_utils.process_regex_field(self.userNameAttr.value, userInfo)).replace(' ', '_')
if len(username) == 0:
raise Exception('No username received')
realName = ''.join(auth_utils.processRegexField(self.realNameAttr.value, userInfo))
realName = ''.join(auth_utils.process_regex_field(self.realNameAttr.value, userInfo))
# Get groups
groups = auth_utils.processRegexField(self.groupNameAttr.value, userInfo)
groups = auth_utils.process_regex_field(self.groupNameAttr.value, userInfo)
# Append common groups
groups.extend(self.commonGroups.value.split(','))
@ -445,8 +445,8 @@ class OAuth2Authenticator(auths.Authenticator):
gettext('This kind of Authenticator does not support white spaces on field NAME')
)
auth_utils.validateRegexField(self.userNameAttr)
auth_utils.validateRegexField(self.userNameAttr)
auth_utils.validate_regex_field(self.userNameAttr)
auth_utils.validate_regex_field(self.userNameAttr)
if self.responseType.value in ('code', 'pkce', 'openid+code'):
if self.commonGroups.value.strip() == '':

View File

@ -229,9 +229,9 @@ class RegexLdap(auths.Authenticator):
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
if values:
auth_utils.validateRegexField(self.userNameAttr, values['userNameAttr'])
auth_utils.validateRegexField(self.userIdAttr, values['userIdAttr'])
auth_utils.validateRegexField(self.groupNameAttr, values['groupNameAttr'])
auth_utils.validate_regex_field(self.userNameAttr, values['userNameAttr'])
auth_utils.validate_regex_field(self.userIdAttr, values['userIdAttr'])
auth_utils.validate_regex_field(self.groupNameAttr, values['groupNameAttr'])
self._host = values['host']
self._port = values['port']
@ -464,13 +464,13 @@ class RegexLdap(auths.Authenticator):
return user
def __getGroups(self, user: ldaputil.LDAPResultType):
grps = auth_utils.processRegexField(self._groupNameAttr, user)
grps = auth_utils.process_regex_field(self._groupNameAttr, user)
if extra:
grps += extra.getGroups(self, user)
return grps
def __getUserRealName(self, user: ldaputil.LDAPResultType):
return ' '.join(auth_utils.processRegexField(self._userNameAttr, user))
return ' '.join(auth_utils.process_regex_field(self._userNameAttr, user))
def authenticate(
self,

View File

@ -36,7 +36,7 @@ import re
import typing
import collections.abc
import xml.sax # nosec: used to parse trusted xml provided only by administrators
from urllib.parse import urlparse
from urllib import parse
import requests
from django.utils.translation import gettext
@ -63,7 +63,7 @@ logger = logging.getLogger(__name__)
def CACHING_KEY_FNC(auth: 'SAMLAuthenticator') -> str:
return str(hash(auth.idpMetadata.value))
return auth.entity_id.as_str()
class SAMLAuthenticator(auths.Authenticator):
@ -114,7 +114,7 @@ class SAMLAuthenticator(auths.Authenticator):
# : We will define a simple form where we will use a simple
# : list editor to allow entering a few group names
privateKey = gui.TextField(
private_key = gui.TextField(
length=4096,
lines=10,
label=_('Private key'),
@ -122,19 +122,19 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Private key used for sign and encription, as generated in base 64 from openssl'),
required=True,
tab=_('Certificates'),
old_field_name='privateKey',
)
serverCertificate = gui.TextField(
server_certificate = gui.TextField(
length=4096,
lines=10,
label=_('Certificate'),
order=2,
tooltip=_(
'Public key used for sign and encription (public part of previous private key), as generated in base 64 from openssl'
),
tooltip=_('Server certificate used in SAML, as generated in base 64 from openssl'),
required=True,
tab=_('Certificates'),
old_field_name='serverCertificate',
)
idpMetadata = gui.TextField(
idp_metadata = gui.TextField(
length=8192,
lines=4,
label=_('IDP Metadata'),
@ -142,16 +142,18 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('You can enter here the URL or the IDP metadata or the metadata itself (xml)'),
required=True,
tab=_('Metadata'),
old_field_name='idpMetadata',
)
entityID = gui.TextField(
entity_id = gui.TextField(
length=256,
label=_('Entity ID'),
order=4,
tooltip=_('ID of the SP. If left blank, this will be autogenerated from server URL'),
tab=_('Metadata'),
old_field_name='entityID',
)
userNameAttr = gui.TextField(
attrs_username = gui.TextField(
length=2048,
lines=2,
label=_('User name attrs'),
@ -159,9 +161,10 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Fields from where to extract user name'),
required=True,
tab=_('Attributes'),
old_field_name='userNameAttr',
)
groupNameAttr = gui.TextField(
attrs_groupname = gui.TextField(
length=2048,
lines=2,
label=_('Group name attrs'),
@ -169,9 +172,10 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Fields from where to extract the groups'),
required=True,
tab=_('Attributes'),
old_field_name='groupNameAttr',
)
realNameAttr = gui.TextField(
attrs_realname = gui.TextField(
length=2048,
lines=2,
label=_('Real name attrs'),
@ -179,24 +183,27 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Fields from where to extract the real name'),
required=True,
tab=_('Attributes'),
old_field_name='realNameAttr',
)
globalLogout = gui.CheckBoxField(
use_global_logout = gui.CheckBoxField(
label=_('Global logout'),
default=False,
order=10,
tooltip=_('If set, logout from UDS will trigger SAML logout'),
tab=types.ui.Tab.ADVANCED,
old_field_name='globalLogout',
)
adFS = gui.CheckBoxField(
adfs = gui.CheckBoxField(
label=_('ADFS compatibility'),
default=False,
order=11,
tooltip=_('If set, enable lowercase url encoding so ADFS can work correctly'),
tab=types.ui.Tab.ADVANCED,
old_field_name='adFS',
)
mfaAttr = gui.TextField(
mfa_attr = gui.TextField(
length=2048,
lines=2,
label=_('MFA attribute'),
@ -204,120 +211,135 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=types.ui.Tab.ADVANCED,
old_field_name='mfaAttr',
)
checkSSLCertificate = gui.CheckBoxField(
check_https_certificate = gui.CheckBoxField(
label=_('Check SSL certificate'),
default=False, # For compatibility with previous versions
order=23,
tooltip=_('If set, check SSL certificate on requests for IDP Metadata'),
tab=_('Security'),
old_field_name='checkSSLCertificate',
)
nameIdEncrypted = gui.CheckBoxField(
use_name_id_encrypted = gui.CheckBoxField(
label=_('Encripted nameID'),
default=False,
order=12,
tooltip=_('If set, nameID will be encripted'),
tab=_('Security'),
old_field_name='nameIdEncrypted',
)
authnRequestsSigned = gui.CheckBoxField(
use_authn_requests_signed = gui.CheckBoxField(
label=_('Authn requests signed'),
default=False,
order=13,
tooltip=_('If set, authn requests will be signed'),
tab=_('Security'),
old_field_name='authnRequestsSigned',
)
logoutRequestSigned = gui.CheckBoxField(
logout_request_signed = gui.CheckBoxField(
label=_('Logout requests signed'),
default=False,
order=14,
tooltip=_('If set, logout requests will be signed'),
tab=_('Security'),
old_field_name='logoutRequestSigned',
)
logoutResponseSigned = gui.CheckBoxField(
use_signed_logout_response = gui.CheckBoxField(
label=_('Logout responses signed'),
default=False,
order=15,
tooltip=_('If set, logout responses will be signed'),
tab=_('Security'),
old_field_name='logoutResponseSigned',
)
signMetadata = gui.CheckBoxField(
use_signed_metadata = gui.CheckBoxField(
label=_('Sign metadata'),
default=False,
order=16,
tooltip=_('If set, metadata will be signed'),
tab=_('Security'),
old_field_name='signMetadata',
)
wantMessagesSigned = gui.CheckBoxField(
want_messages_signed = gui.CheckBoxField(
label=_('Want messages signed'),
default=False,
order=17,
tooltip=_('If set, messages will be signed'),
tab=_('Security'),
old_field_name='wantMessagesSigned',
)
wantAssertionsSigned = gui.CheckBoxField(
want_assertions_signed = gui.CheckBoxField(
label=_('Want assertions signed'),
default=False,
order=18,
tooltip=_('If set, assertions will be signed'),
tab=_('Security'),
old_field_name='wantAssertionsSigned',
)
wantAssertionsEncrypted = gui.CheckBoxField(
want_assertions_encrypted = gui.CheckBoxField(
label=_('Want assertions encrypted'),
default=False,
order=19,
tooltip=_('If set, assertions will be encrypted'),
tab=_('Security'),
old_field_name='wantAssertionsEncrypted',
)
wantNameIdEncrypted = gui.CheckBoxField(
want_name_id_encrypted = gui.CheckBoxField(
label=_('Want nameID encrypted'),
default=False,
order=20,
tooltip=_('If set, nameID will be encrypted'),
tab=_('Security'),
old_field_name='wantNameIdEncrypted',
)
requestedAuthnContext = gui.CheckBoxField(
use_requested_authn_context = gui.CheckBoxField(
label=_('Requested authn context'),
default=False,
order=21,
tooltip=_('If set, requested authn context will be sent'),
tab=_('Security'),
old_field_name='requestedAuthnContext',
)
allowDeprecatedSignatureAlgorithms = gui.CheckBoxField(
allow_deprecated_signature_algorithms = gui.CheckBoxField(
label=_('Allow deprecated signature algorithms'),
default=True,
order=23,
tooltip=_('If set, deprecated signature algorithms will be allowed (as SHA1, MD5, etc...)'),
tab=_('Security'),
old_field_name='allowDeprecatedSignatureAlgorithms',
)
metadataCacheDuration = gui.NumericField(
metadata_cache_duration = gui.NumericField(
label=_('Metadata cache duration'),
default=0,
order=22,
tooltip=_('Duration of metadata cache in days. 0 means default (ten years)'),
tab=_('Metadata'),
old_field_name='metadataCacheDuration',
)
metadataValidityDuration = gui.NumericField(
metadata_validity_duration = gui.NumericField(
label=_('Metadata validity duration'),
default=0,
order=22,
tooltip=_('Duration of metadata validity in days. 0 means default (ten years)'),
tab=_('Metadata'),
old_field_name='metadataValidityDuration',
)
organization_name = gui.TextField(
length=256,
label=_('Organization Name'),
@ -326,7 +348,7 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Organization name to use on SAML SP Metadata'),
tab=_('Organization'),
)
organization_display_name = gui.TextField(
length=256,
label=_('Organization Display Name'),
@ -335,7 +357,7 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Organization Display name to use on SAML SP Metadata'),
tab=_('Organization'),
)
organization_url = gui.TextField(
length=256,
label=_('Organization URL'),
@ -344,8 +366,8 @@ class SAMLAuthenticator(auths.Authenticator):
tooltip=_('Organization url to use on SAML SP Metadata'),
tab=_('Organization'),
)
manageUrl = gui.HiddenField(serializable=True)
manage_url = gui.HiddenField(serializable=True, old_field_name='manageUrl')
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
"""
@ -364,11 +386,13 @@ class SAMLAuthenticator(auths.Authenticator):
gettext('This kind of Authenticator does not support white spaces on field NAME')
)
# First, validate certificates
# Ensure ipdMetadata cache is empty, to regenerate it
self.cache.remove('idpMetadata')
# First, validate certificates
# This is in fact not needed, but we may say something useful to user if we check this
if self.serverCertificate.value.startswith('-----BEGIN CERTIFICATE-----\n') is False:
if self.server_certificate.value.startswith('-----BEGIN CERTIFICATE-----\n') is False:
raise exceptions.validation.ValidationError(
gettext(
'Server certificate should be a valid PEM (PEM certificates starts with -----BEGIN CERTIFICATE-----)'
@ -376,13 +400,13 @@ class SAMLAuthenticator(auths.Authenticator):
)
try:
CryptoManager().load_certificate(self.serverCertificate.value)
CryptoManager().load_certificate(self.server_certificate.value)
except Exception as e:
raise exceptions.validation.ValidationError(gettext('Invalid server certificate. ') + str(e))
if (
self.privateKey.value.startswith('-----BEGIN RSA PRIVATE KEY-----\n') is False
and self.privateKey.value.startswith('-----BEGIN PRIVATE KEY-----\n') is False
self.private_key.value.startswith('-----BEGIN RSA PRIVATE KEY-----\n') is False
and self.private_key.value.startswith('-----BEGIN PRIVATE KEY-----\n') is False
):
raise exceptions.validation.ValidationError(
gettext(
@ -391,64 +415,66 @@ class SAMLAuthenticator(auths.Authenticator):
)
try:
CryptoManager().load_private_key(self.privateKey.value)
CryptoManager().load_private_key(self.private_key.value)
except Exception as e:
raise exceptions.validation.ValidationError(gettext('Invalid private key. ') + str(e))
if not security.check_certificate_matches_private_key(
cert=self.serverCertificate.value, key=self.privateKey.value
cert=self.server_certificate.value, key=self.private_key.value
):
raise exceptions.validation.ValidationError(gettext('Certificate and private key do not match'))
request: 'ExtendedHttpRequest' = values['_request']
if self.entityID.value == '':
self.entityID.value = request.build_absolute_uri(self.info_url())
if self.entity_id.value == '':
self.entity_id.value = request.build_absolute_uri(self.info_url())
self.manageUrl.value = request.build_absolute_uri(self.callback_url())
self.manage_url.value = request.build_absolute_uri(self.callback_url())
idpMetadata: str = self.idpMetadata.value
fromUrl: bool = False
if idpMetadata.startswith('http://') or idpMetadata.startswith('https://'):
logger.debug('idp Metadata is an URL: %s', idpMetadata)
idp_metadata: str = self.idp_metadata.value
from_url: bool = False
if idp_metadata.startswith('http://') or idp_metadata.startswith('https://'):
logger.debug('idp Metadata is an URL: %s', idp_metadata)
try:
resp = requests.get(
idpMetadata.split('\n')[0],
verify=self.checkSSLCertificate.as_bool(),
idp_metadata.split('\n')[0],
verify=self.check_https_certificate.as_bool(),
timeout=10,
)
idpMetadata = resp.content.decode()
idp_metadata = resp.content.decode()
except Exception as e:
raise exceptions.validation.ValidationError(
gettext('Can\'t fetch url {0}: {1}').format(self.idpMetadata.value, str(e))
gettext('Can\'t fetch url {0}: {1}').format(self.idp_metadata.value, str(e))
)
fromUrl = True
from_url = True
# Try to parse it so we can check it is valid. Right now, it checks just that this is XML, will
# correct it to check that is is valid idp metadata
try:
xml.sax.parseString(idpMetadata, xml.sax.ContentHandler()) # type: ignore # nosec: url provided by admin
xml.sax.parseString(idp_metadata, xml.sax.ContentHandler()) # type: ignore # nosec: url provided by admin
except Exception as e:
msg = (gettext(' (obtained from URL)') if fromUrl else '') + str(e)
raise exceptions.validation.ValidationError(gettext('XML does not seem valid for IDP Metadata ') + msg)
msg = (gettext(' (obtained from URL)') if from_url else '') + str(e)
raise exceptions.validation.ValidationError(
gettext('XML does not seem valid for IDP Metadata ') + msg
)
# Now validate regular expressions, if they exists
auth_utils.validateRegexField(self.userNameAttr)
auth_utils.validateRegexField(self.groupNameAttr)
auth_utils.validateRegexField(self.realNameAttr)
auth_utils.validate_regex_field(self.attrs_username)
auth_utils.validate_regex_field(self.attrs_groupname)
auth_utils.validate_regex_field(self.attrs_realname)
def getReqFromRequest(
def build_req_from_request(
self,
request: 'ExtendedHttpRequest',
params: typing.Optional['types.auth.AuthCallbackParams'] = None,
) -> dict[str, typing.Any]:
manageUrlObj = urlparse(self.manageUrl.value)
script_path = manageUrlObj.path
host = manageUrlObj.netloc
manage_url_obj: 'parse.ParseResult' = parse.urlparse(self.manage_url.value)
script_path: str = manage_url_obj.path
host: str = manage_url_obj.netloc
if ':' in host:
host, port = host.split(':')
else:
if manageUrlObj.scheme == 'http':
if manage_url_obj.scheme == 'http':
port = '80'
else:
port = '443'
@ -463,7 +489,7 @@ class SAMLAuthenticator(auths.Authenticator):
'server_port': port, # params['server_port'],
'get_data': params.get_params.copy(),
'post_data': params.post_params.copy(),
'lowercase_urlencoding': self.adFS.as_bool(),
'lowercase_urlencoding': self.adfs.as_bool(),
'query_string': params.query_string,
}
# No callback parameters, we use the request
@ -474,72 +500,74 @@ class SAMLAuthenticator(auths.Authenticator):
'server_port': port, # request.META['SERVER_PORT'],
'get_data': request.GET.copy(),
'post_data': request.POST.copy(),
'lowercase_urlencoding': self.adFS.as_bool(),
'lowercase_urlencoding': self.adfs.as_bool(),
'query_string': request.META['QUERY_STRING'],
}
@decorators.cached(
prefix='idpm',
key_fnc=CACHING_KEY_FNC,
timeout=3600 * 24 * 365, # 1 year
)
def getIdpMetadataDict(self) -> dict[str, typing.Any]:
if self.idpMetadata.value.startswith('http'):
def get_idp_metadata_dict(self) -> dict[str, typing.Any]:
if self.idp_metadata.value.startswith('http'):
resp = self.cache.get('idpMetadata')
if resp:
return resp
try:
resp = requests.get(
self.idpMetadata.value.split('\n')[0],
verify=self.checkSSLCertificate.as_bool(),
self.idp_metadata.value.split('\n')[0],
verify=self.check_https_certificate.as_bool(),
timeout=10,
)
val = resp.content.decode()
# 10 years, unless edited the metadata will be kept
self.cache.put('idpMetadata', val, 86400 * 365 * 10)
except Exception as e:
logger.error('Error fetching idp metadata: %s', e)
raise exceptions.auth.AuthenticatorException(gettext('Can\'t access idp metadata'))
else:
val = self.idpMetadata.value
val = self.idp_metadata.value
return OneLogin_Saml2_IdPMetadataParser.parse(val)
def oneLoginSettings(self) -> dict[str, typing.Any]:
def build_onelogin_settings(self) -> dict[str, typing.Any]:
return {
'strict': True,
'debug': True,
'sp': {
'entityId': self.entityID.value,
'entityId': self.entity_id.value,
'assertionConsumerService': {
'url': self.manageUrl.value,
'url': self.manage_url.value,
'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST',
},
'singleLogoutService': {
'url': self.manageUrl.value + '?logout=true',
'url': self.manage_url.value + '?logout=true',
'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect',
},
'x509cert': self.serverCertificate.value,
'privateKey': self.privateKey.value,
'x509cert': self.server_certificate.value,
'privateKey': self.private_key.value,
'NameIDFormat': 'urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified',
},
'idp': self.getIdpMetadataDict()['idp'],
'idp': self.get_idp_metadata_dict()['idp'],
'security': {
'metadataCacheDuration': self.metadataCacheDuration.int_value
if self.metadataCacheDuration.int_value > 0
# in days, converted to seconds, this is a duration
'metadataCacheDuration': self.metadata_cache_duration.as_int() * 86400
if self.metadata_cache_duration.int_value > 0
else 86400 * 365 * 10,
# This is a date of end of validity
'metadataValidUntil': sql_datetime()
+ datetime.timedelta(seconds=self.metadataValidityDuration.int_value)
if self.metadataCacheDuration.int_value > 0
+ datetime.timedelta(days=self.metadata_validity_duration.as_int())
if self.metadata_cache_duration.int_value > 0
else sql_datetime() + datetime.timedelta(days=365 * 10),
'nameIdEncrypted': self.nameIdEncrypted.as_bool(),
'authnRequestsSigned': self.authnRequestsSigned.as_bool(),
'logoutRequestSigned': self.logoutRequestSigned.as_bool(),
'logoutResponseSigned': self.logoutResponseSigned.as_bool(),
'signMetadata': self.signMetadata.as_bool(),
'wantMessagesSigned': self.wantMessagesSigned.as_bool(),
'wantAssertionsSigned': self.wantAssertionsSigned.as_bool(),
'wantAssertionsEncrypted': self.wantAssertionsEncrypted.as_bool(),
'wantNameIdEncrypted': self.wantNameIdEncrypted.as_bool(),
'requestedAuthnContext': self.requestedAuthnContext.as_bool(),
'nameIdEncrypted': self.use_name_id_encrypted.as_bool(),
'authnRequestsSigned': self.use_authn_requests_signed.as_bool(),
'logoutRequestSigned': self.logout_request_signed.as_bool(),
'logoutResponseSigned': self.use_signed_logout_response.as_bool(),
'signMetadata': self.use_signed_metadata.as_bool(),
'wantMessagesSigned': self.want_messages_signed.as_bool(),
'wantAssertionsSigned': self.want_assertions_signed.as_bool(),
'wantAssertionsEncrypted': self.want_assertions_encrypted.as_bool(),
'wantNameIdEncrypted': self.want_name_id_encrypted.as_bool(),
'requestedAuthnContext': self.use_requested_authn_context.as_bool(),
"signatureAlgorithm": "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
"digestAlgorithm": "http://www.w3.org/2001/04/xmlenc#sha256",
"rejectDeprecatedAlgorithm": not self.allowDeprecatedSignatureAlgorithms.as_bool(),
"rejectDeprecatedAlgorithm": not self.allow_deprecated_signature_algorithms.as_bool(),
},
'organization': {
'en-US': {
@ -555,8 +583,8 @@ class SAMLAuthenticator(auths.Authenticator):
key_fnc=CACHING_KEY_FNC,
timeout=3600, # 1 hour
)
def getSpMetadata(self) -> str:
saml_settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
def get_sp_metadata(self) -> str:
saml_settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
metadata = saml_settings.get_sp_metadata()
errors = saml_settings.validate_metadata(metadata)
if len(errors) > 0:
@ -567,7 +595,6 @@ class SAMLAuthenticator(auths.Authenticator):
return metadata
return typing.cast(bytes, metadata).decode()
def get_info(
self, parameters: collections.abc.Mapping[str, str]
) -> typing.Optional[tuple[str, typing.Optional[str]]]:
@ -575,7 +602,7 @@ class SAMLAuthenticator(auths.Authenticator):
Althought this is mainly a get info callback, this can be used for any other purpuse we like.
In this case, we use it to provide logout callback also
"""
info = self.getSpMetadata()
info = self.get_sp_metadata()
wantsHtml = parameters.get('format') == 'html'
content_type = 'text/html' if wantsHtml else 'application/samlmetadata+xml'
@ -584,16 +611,16 @@ class SAMLAuthenticator(auths.Authenticator):
)
return info, content_type # 'application/samlmetadata+xml')
def mfaStorageKey(self, username: str) -> str:
def mfa_storage_key(self, username: str) -> str:
return 'mfa_' + self.db_obj().uuid + username # type: ignore
def mfaClean(self, username: str):
self.storage.remove(self.mfaStorageKey(username))
def mfa_clean(self, username: str):
self.storage.remove(self.mfa_storage_key(username))
def mfa_identifier(self, username: str) -> str:
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
return self.storage.get_unpickle(self.mfa_storage_key(username)) or ''
def logoutFromCallback(
def logout_callback(
self,
req: dict[str, typing.Any],
request: 'ExtendedHttpRequestWithUser',
@ -605,15 +632,15 @@ class SAMLAuthenticator(auths.Authenticator):
else:
req['get_data']['SAMLResponse'] = req['post_data']['SAMLResponse']
logoutRequestId = request.session.get('samlLogoutRequestId', None)
logout_req_id = request.session.get('samlLogoutRequestId', None)
# Cleanup session & session cookie
request.session.flush()
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
auth = OneLogin_Saml2_Auth(req, settings)
url = auth.process_slo(request_id=logoutRequestId)
url = auth.process_slo(request_id=logout_req_id)
errors = auth.get_errors()
@ -625,7 +652,7 @@ class SAMLAuthenticator(auths.Authenticator):
# Remove MFA related data
if request.user:
self.mfaClean(request.user.name)
self.mfa_clean(request.user.name)
return types.auth.AuthenticationResult(
success=types.auth.AuthenticationState.REDIRECT,
@ -639,13 +666,13 @@ class SAMLAuthenticator(auths.Authenticator):
gm: 'auths.GroupsManager',
request: 'ExtendedHttpRequestWithUser',
) -> types.auth.AuthenticationResult:
req = self.getReqFromRequest(request, params=parameters)
req = self.build_req_from_request(request, params=parameters)
if 'logout' in parameters.get_params:
return self.logoutFromCallback(req, request)
return self.logout_callback(req, request)
try:
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
auth = OneLogin_Saml2_Auth(req, settings)
auth.process_response()
except Exception as e:
@ -688,41 +715,43 @@ class SAMLAuthenticator(auths.Authenticator):
# Now that we have attributes, we can extract values from this, map groups, etc...
username = ''.join(
auth_utils.processRegexField(self.userNameAttr.value, attributes)
auth_utils.process_regex_field(self.attrs_username.value, attributes)
) # in case of multiple values is returned, join them
logger.debug('Username: %s', username)
groups = auth_utils.processRegexField(self.groupNameAttr.value, attributes)
groups = auth_utils.process_regex_field(self.attrs_groupname.value, attributes)
logger.debug('Groups: %s', groups)
realName = ' '.join(auth_utils.processRegexField(self.realNameAttr.value, attributes))
realName = ' '.join(auth_utils.process_regex_field(self.attrs_realname.value, attributes))
logger.debug('Real name: %s', realName)
# store groups for this username at storage, so we can check it at a later stage
self.storage.put_pickle(username, [realName, groups])
# store also the mfa identifier field value, in case we have provided it
if self.mfaAttr.value.strip():
if self.mfa_attr.value.strip():
self.storage.put_pickle(
self.mfaStorageKey(username),
''.join(auth_utils.processRegexField(self.mfaAttr.value, attributes)),
self.mfa_storage_key(username),
''.join(auth_utils.process_regex_field(self.mfa_attr.value, attributes)),
) # in case multipel values is returned, join them
else:
self.storage.remove(self.mfaStorageKey(username))
self.storage.remove(self.mfa_storage_key(username))
# Now we check validity of user
gm.validate(groups)
return types.auth.AuthenticationResult(success=types.auth.AuthenticationState.SUCCESS, username=username)
return types.auth.AuthenticationResult(
success=types.auth.AuthenticationState.SUCCESS, username=username
)
def logout(self, request: 'ExtendedHttpRequest', username: str) -> types.auth.AuthenticationResult:
if not self.globalLogout.as_bool():
if not self.use_global_logout.as_bool():
return types.auth.SUCCESS_AUTH
req = self.getReqFromRequest(request)
req = self.build_req_from_request(request)
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
auth = OneLogin_Saml2_Auth(req, settings)
@ -732,7 +761,7 @@ class SAMLAuthenticator(auths.Authenticator):
request.session.clear()
# Remove MFA related data
self.mfaClean(username)
self.mfa_clean(username)
if not saml:
return types.auth.SUCCESS_AUTH
@ -764,8 +793,8 @@ class SAMLAuthenticator(auths.Authenticator):
"""
We will here compose the saml request and send it via http-redirect
"""
req = self.getReqFromRequest(request)
auth = OneLogin_Saml2_Auth(req, self.oneLoginSettings())
req = self.build_req_from_request(request)
auth = OneLogin_Saml2_Auth(req, self.build_onelogin_settings())
return f'window.location="{auth.login()}";'

View File

@ -1488,13 +1488,13 @@ class UserInterface(metaclass=UserInterfaceAbstract):
# Any unexpected type will raise an exception
# Note that currently, we will store old field name on db
# to allow "backwards" migration if needed, but will be removed on a future version
arr = [
fields = [
(field.old_field_name() or field_name, field.type.name, FIELDS_ENCODERS[field.type](field))
for field_name, field in self._gui.items()
if FIELDS_ENCODERS[field.type](field) is not None
]
return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(arr)
return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(fields)
def deserialize_fields(
self,
@ -1528,36 +1528,36 @@ class UserInterface(metaclass=UserInterfaceAbstract):
logger.info('Empty values on unserialize_fields')
return
arr = serializer.deserialize(values) or []
fields = serializer.deserialize(values) or []
# Dict of translations from old_field_name to field_name
field_names_translations: dict[str, str] = self._get_fieldname_translations()
# Set all values to defaults ones
for fld_name in self._gui:
fld = self._gui[fld_name]
for field_name in self._gui:
field = self._gui[field_name]
if (
self._gui[fld_name].is_type(types.ui.FieldType.HIDDEN)
and self._gui[fld_name].is_serializable() is False
field.is_type(types.ui.FieldType.HIDDEN)
and field.is_serializable() is False
):
# logger.debug('Field {0} is not unserializable'.format(k))
continue
self._gui[fld_name].value = self._gui[fld_name].default
field.value = field.default
for fld_name, fld_type, fld_value in arr:
if fld_name in field_names_translations:
fld_name = field_names_translations[fld_name] # Convert old field name to new one if needed
if fld_name not in self._gui:
logger.warning('Field %s not found in form', fld_name)
for field_name, field_type, field_value in fields:
if field_name in field_names_translations:
field_name = field_names_translations[field_name] # Convert old field name to new one if needed
if field_name not in self._gui:
logger.warning('Field %s not found in form', field_name)
continue
field_type = self._gui[fld_name].type
if field_type not in FIELD_DECODERS:
logger.warning('Field %s has no converter', fld_name)
internal_field_type = self._gui[field_name].type
if internal_field_type not in FIELD_DECODERS:
logger.warning('Field %s has no converter', field_name)
continue
if fld_type != field_type.name:
logger.warning('Field %s has different type than expected', fld_name)
if field_type != internal_field_type.name:
logger.warning('Field %s has different type than expected', field_name)
continue
self._gui[fld_name].value = FIELD_DECODERS[field_type](fld_value)
self._gui[field_name].value = FIELD_DECODERS[internal_field_type](field_value)
def deserialize_from_old_format(self, values: bytes) -> None:
"""
@ -1689,7 +1689,7 @@ FIELD_DECODERS: typing.Final[collections.abc.Mapping[types.ui.FieldType, collect
types.ui.FieldType.TEXT_AUTOCOMPLETE: lambda x: x,
types.ui.FieldType.NUMERIC: int,
types.ui.FieldType.PASSWORD: lambda x: (CryptoManager().aes_decrypt(x.encode(), UDSK, True).decode()),
types.ui.FieldType.HIDDEN: lambda x: None,
types.ui.FieldType.HIDDEN: lambda x: x,
types.ui.FieldType.CHOICE: lambda x: x,
types.ui.FieldType.MULTICHOICE: lambda x: serializer.deserialize(base64.b64decode(x.encode())),
types.ui.FieldType.EDITABLELIST: lambda x: serializer.deserialize(base64.b64decode(x.encode())),

View File

@ -40,7 +40,7 @@ from uds.core.util import ensure
logger = logging.getLogger(__name__)
def validateRegexField(field: ui.gui.TextField, field_value: typing.Optional[str] = None):
def validate_regex_field(field: ui.gui.TextField, field_value: typing.Optional[str] = None):
"""
Validates the multi line fields refering to attributes
"""
@ -59,7 +59,7 @@ def validateRegexField(field: ui.gui.TextField, field_value: typing.Optional[str
raise exceptions.validation.ValidationError(f'Invalid pattern at {field.label}: {line}') from e
def processRegexField(
def process_regex_field(
field: str, attributes: collections.abc.Mapping[str, typing.Union[str, list[str]]]
) -> list[str]:
"""Proccesses a field, that can be a multiline field, and returns a list of values