mirror of
https://github.com/dkmstr/openuds.git
synced 2025-02-02 09:47:13 +03:00
Updated SAML
This commit is contained in:
parent
3afeb4869c
commit
81abe1d99f
@ -380,14 +380,14 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
# After this point, we don't mind about the token, we only need to authenticate user
|
||||
# and get some basic info from it
|
||||
|
||||
username = ''.join(auth_utils.processRegexField(self.userNameAttr.value, userInfo)).replace(' ', '_')
|
||||
username = ''.join(auth_utils.process_regex_field(self.userNameAttr.value, userInfo)).replace(' ', '_')
|
||||
if len(username) == 0:
|
||||
raise Exception('No username received')
|
||||
|
||||
realName = ''.join(auth_utils.processRegexField(self.realNameAttr.value, userInfo))
|
||||
realName = ''.join(auth_utils.process_regex_field(self.realNameAttr.value, userInfo))
|
||||
|
||||
# Get groups
|
||||
groups = auth_utils.processRegexField(self.groupNameAttr.value, userInfo)
|
||||
groups = auth_utils.process_regex_field(self.groupNameAttr.value, userInfo)
|
||||
# Append common groups
|
||||
groups.extend(self.commonGroups.value.split(','))
|
||||
|
||||
@ -445,8 +445,8 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
gettext('This kind of Authenticator does not support white spaces on field NAME')
|
||||
)
|
||||
|
||||
auth_utils.validateRegexField(self.userNameAttr)
|
||||
auth_utils.validateRegexField(self.userNameAttr)
|
||||
auth_utils.validate_regex_field(self.userNameAttr)
|
||||
auth_utils.validate_regex_field(self.userNameAttr)
|
||||
|
||||
if self.responseType.value in ('code', 'pkce', 'openid+code'):
|
||||
if self.commonGroups.value.strip() == '':
|
||||
|
@ -229,9 +229,9 @@ class RegexLdap(auths.Authenticator):
|
||||
|
||||
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
|
||||
if values:
|
||||
auth_utils.validateRegexField(self.userNameAttr, values['userNameAttr'])
|
||||
auth_utils.validateRegexField(self.userIdAttr, values['userIdAttr'])
|
||||
auth_utils.validateRegexField(self.groupNameAttr, values['groupNameAttr'])
|
||||
auth_utils.validate_regex_field(self.userNameAttr, values['userNameAttr'])
|
||||
auth_utils.validate_regex_field(self.userIdAttr, values['userIdAttr'])
|
||||
auth_utils.validate_regex_field(self.groupNameAttr, values['groupNameAttr'])
|
||||
|
||||
self._host = values['host']
|
||||
self._port = values['port']
|
||||
@ -464,13 +464,13 @@ class RegexLdap(auths.Authenticator):
|
||||
return user
|
||||
|
||||
def __getGroups(self, user: ldaputil.LDAPResultType):
|
||||
grps = auth_utils.processRegexField(self._groupNameAttr, user)
|
||||
grps = auth_utils.process_regex_field(self._groupNameAttr, user)
|
||||
if extra:
|
||||
grps += extra.getGroups(self, user)
|
||||
return grps
|
||||
|
||||
def __getUserRealName(self, user: ldaputil.LDAPResultType):
|
||||
return ' '.join(auth_utils.processRegexField(self._userNameAttr, user))
|
||||
return ' '.join(auth_utils.process_regex_field(self._userNameAttr, user))
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
|
@ -36,7 +36,7 @@ import re
|
||||
import typing
|
||||
import collections.abc
|
||||
import xml.sax # nosec: used to parse trusted xml provided only by administrators
|
||||
from urllib.parse import urlparse
|
||||
from urllib import parse
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext
|
||||
@ -63,7 +63,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CACHING_KEY_FNC(auth: 'SAMLAuthenticator') -> str:
|
||||
return str(hash(auth.idpMetadata.value))
|
||||
return auth.entity_id.as_str()
|
||||
|
||||
|
||||
class SAMLAuthenticator(auths.Authenticator):
|
||||
@ -114,7 +114,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
# : We will define a simple form where we will use a simple
|
||||
# : list editor to allow entering a few group names
|
||||
|
||||
privateKey = gui.TextField(
|
||||
private_key = gui.TextField(
|
||||
length=4096,
|
||||
lines=10,
|
||||
label=_('Private key'),
|
||||
@ -122,19 +122,19 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Private key used for sign and encription, as generated in base 64 from openssl'),
|
||||
required=True,
|
||||
tab=_('Certificates'),
|
||||
old_field_name='privateKey',
|
||||
)
|
||||
serverCertificate = gui.TextField(
|
||||
server_certificate = gui.TextField(
|
||||
length=4096,
|
||||
lines=10,
|
||||
label=_('Certificate'),
|
||||
order=2,
|
||||
tooltip=_(
|
||||
'Public key used for sign and encription (public part of previous private key), as generated in base 64 from openssl'
|
||||
),
|
||||
tooltip=_('Server certificate used in SAML, as generated in base 64 from openssl'),
|
||||
required=True,
|
||||
tab=_('Certificates'),
|
||||
old_field_name='serverCertificate',
|
||||
)
|
||||
idpMetadata = gui.TextField(
|
||||
idp_metadata = gui.TextField(
|
||||
length=8192,
|
||||
lines=4,
|
||||
label=_('IDP Metadata'),
|
||||
@ -142,16 +142,18 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('You can enter here the URL or the IDP metadata or the metadata itself (xml)'),
|
||||
required=True,
|
||||
tab=_('Metadata'),
|
||||
old_field_name='idpMetadata',
|
||||
)
|
||||
entityID = gui.TextField(
|
||||
entity_id = gui.TextField(
|
||||
length=256,
|
||||
label=_('Entity ID'),
|
||||
order=4,
|
||||
tooltip=_('ID of the SP. If left blank, this will be autogenerated from server URL'),
|
||||
tab=_('Metadata'),
|
||||
old_field_name='entityID',
|
||||
)
|
||||
|
||||
userNameAttr = gui.TextField(
|
||||
attrs_username = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('User name attrs'),
|
||||
@ -159,9 +161,10 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Fields from where to extract user name'),
|
||||
required=True,
|
||||
tab=_('Attributes'),
|
||||
old_field_name='userNameAttr',
|
||||
)
|
||||
|
||||
groupNameAttr = gui.TextField(
|
||||
attrs_groupname = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Group name attrs'),
|
||||
@ -169,9 +172,10 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Fields from where to extract the groups'),
|
||||
required=True,
|
||||
tab=_('Attributes'),
|
||||
old_field_name='groupNameAttr',
|
||||
)
|
||||
|
||||
realNameAttr = gui.TextField(
|
||||
attrs_realname = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Real name attrs'),
|
||||
@ -179,24 +183,27 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Fields from where to extract the real name'),
|
||||
required=True,
|
||||
tab=_('Attributes'),
|
||||
old_field_name='realNameAttr',
|
||||
)
|
||||
|
||||
globalLogout = gui.CheckBoxField(
|
||||
use_global_logout = gui.CheckBoxField(
|
||||
label=_('Global logout'),
|
||||
default=False,
|
||||
order=10,
|
||||
tooltip=_('If set, logout from UDS will trigger SAML logout'),
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
old_field_name='globalLogout',
|
||||
)
|
||||
|
||||
adFS = gui.CheckBoxField(
|
||||
adfs = gui.CheckBoxField(
|
||||
label=_('ADFS compatibility'),
|
||||
default=False,
|
||||
order=11,
|
||||
tooltip=_('If set, enable lowercase url encoding so ADFS can work correctly'),
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
old_field_name='adFS',
|
||||
)
|
||||
mfaAttr = gui.TextField(
|
||||
mfa_attr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('MFA attribute'),
|
||||
@ -204,120 +211,135 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Attribute from where to extract the MFA code'),
|
||||
required=False,
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
old_field_name='mfaAttr',
|
||||
)
|
||||
|
||||
checkSSLCertificate = gui.CheckBoxField(
|
||||
check_https_certificate = gui.CheckBoxField(
|
||||
label=_('Check SSL certificate'),
|
||||
default=False, # For compatibility with previous versions
|
||||
order=23,
|
||||
tooltip=_('If set, check SSL certificate on requests for IDP Metadata'),
|
||||
tab=_('Security'),
|
||||
old_field_name='checkSSLCertificate',
|
||||
)
|
||||
|
||||
nameIdEncrypted = gui.CheckBoxField(
|
||||
use_name_id_encrypted = gui.CheckBoxField(
|
||||
label=_('Encripted nameID'),
|
||||
default=False,
|
||||
order=12,
|
||||
tooltip=_('If set, nameID will be encripted'),
|
||||
tab=_('Security'),
|
||||
old_field_name='nameIdEncrypted',
|
||||
)
|
||||
|
||||
authnRequestsSigned = gui.CheckBoxField(
|
||||
use_authn_requests_signed = gui.CheckBoxField(
|
||||
label=_('Authn requests signed'),
|
||||
default=False,
|
||||
order=13,
|
||||
tooltip=_('If set, authn requests will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='authnRequestsSigned',
|
||||
)
|
||||
|
||||
logoutRequestSigned = gui.CheckBoxField(
|
||||
logout_request_signed = gui.CheckBoxField(
|
||||
label=_('Logout requests signed'),
|
||||
default=False,
|
||||
order=14,
|
||||
tooltip=_('If set, logout requests will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='logoutRequestSigned',
|
||||
)
|
||||
|
||||
logoutResponseSigned = gui.CheckBoxField(
|
||||
use_signed_logout_response = gui.CheckBoxField(
|
||||
label=_('Logout responses signed'),
|
||||
default=False,
|
||||
order=15,
|
||||
tooltip=_('If set, logout responses will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='logoutResponseSigned',
|
||||
)
|
||||
|
||||
signMetadata = gui.CheckBoxField(
|
||||
use_signed_metadata = gui.CheckBoxField(
|
||||
label=_('Sign metadata'),
|
||||
default=False,
|
||||
order=16,
|
||||
tooltip=_('If set, metadata will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='signMetadata',
|
||||
)
|
||||
|
||||
wantMessagesSigned = gui.CheckBoxField(
|
||||
want_messages_signed = gui.CheckBoxField(
|
||||
label=_('Want messages signed'),
|
||||
default=False,
|
||||
order=17,
|
||||
tooltip=_('If set, messages will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='wantMessagesSigned',
|
||||
)
|
||||
|
||||
wantAssertionsSigned = gui.CheckBoxField(
|
||||
want_assertions_signed = gui.CheckBoxField(
|
||||
label=_('Want assertions signed'),
|
||||
default=False,
|
||||
order=18,
|
||||
tooltip=_('If set, assertions will be signed'),
|
||||
tab=_('Security'),
|
||||
old_field_name='wantAssertionsSigned',
|
||||
)
|
||||
|
||||
wantAssertionsEncrypted = gui.CheckBoxField(
|
||||
want_assertions_encrypted = gui.CheckBoxField(
|
||||
label=_('Want assertions encrypted'),
|
||||
default=False,
|
||||
order=19,
|
||||
tooltip=_('If set, assertions will be encrypted'),
|
||||
tab=_('Security'),
|
||||
old_field_name='wantAssertionsEncrypted',
|
||||
)
|
||||
|
||||
wantNameIdEncrypted = gui.CheckBoxField(
|
||||
want_name_id_encrypted = gui.CheckBoxField(
|
||||
label=_('Want nameID encrypted'),
|
||||
default=False,
|
||||
order=20,
|
||||
tooltip=_('If set, nameID will be encrypted'),
|
||||
tab=_('Security'),
|
||||
old_field_name='wantNameIdEncrypted',
|
||||
)
|
||||
|
||||
requestedAuthnContext = gui.CheckBoxField(
|
||||
use_requested_authn_context = gui.CheckBoxField(
|
||||
label=_('Requested authn context'),
|
||||
default=False,
|
||||
order=21,
|
||||
tooltip=_('If set, requested authn context will be sent'),
|
||||
tab=_('Security'),
|
||||
old_field_name='requestedAuthnContext',
|
||||
)
|
||||
|
||||
allowDeprecatedSignatureAlgorithms = gui.CheckBoxField(
|
||||
allow_deprecated_signature_algorithms = gui.CheckBoxField(
|
||||
label=_('Allow deprecated signature algorithms'),
|
||||
default=True,
|
||||
order=23,
|
||||
tooltip=_('If set, deprecated signature algorithms will be allowed (as SHA1, MD5, etc...)'),
|
||||
tab=_('Security'),
|
||||
old_field_name='allowDeprecatedSignatureAlgorithms',
|
||||
)
|
||||
|
||||
metadataCacheDuration = gui.NumericField(
|
||||
metadata_cache_duration = gui.NumericField(
|
||||
label=_('Metadata cache duration'),
|
||||
default=0,
|
||||
order=22,
|
||||
tooltip=_('Duration of metadata cache in days. 0 means default (ten years)'),
|
||||
tab=_('Metadata'),
|
||||
old_field_name='metadataCacheDuration',
|
||||
)
|
||||
|
||||
metadataValidityDuration = gui.NumericField(
|
||||
metadata_validity_duration = gui.NumericField(
|
||||
label=_('Metadata validity duration'),
|
||||
default=0,
|
||||
order=22,
|
||||
tooltip=_('Duration of metadata validity in days. 0 means default (ten years)'),
|
||||
tab=_('Metadata'),
|
||||
old_field_name='metadataValidityDuration',
|
||||
)
|
||||
|
||||
|
||||
organization_name = gui.TextField(
|
||||
length=256,
|
||||
label=_('Organization Name'),
|
||||
@ -326,7 +348,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Organization name to use on SAML SP Metadata'),
|
||||
tab=_('Organization'),
|
||||
)
|
||||
|
||||
|
||||
organization_display_name = gui.TextField(
|
||||
length=256,
|
||||
label=_('Organization Display Name'),
|
||||
@ -335,7 +357,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Organization Display name to use on SAML SP Metadata'),
|
||||
tab=_('Organization'),
|
||||
)
|
||||
|
||||
|
||||
organization_url = gui.TextField(
|
||||
length=256,
|
||||
label=_('Organization URL'),
|
||||
@ -344,8 +366,8 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
tooltip=_('Organization url to use on SAML SP Metadata'),
|
||||
tab=_('Organization'),
|
||||
)
|
||||
|
||||
manageUrl = gui.HiddenField(serializable=True)
|
||||
|
||||
manage_url = gui.HiddenField(serializable=True, old_field_name='manageUrl')
|
||||
|
||||
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
|
||||
"""
|
||||
@ -364,11 +386,13 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
gettext('This kind of Authenticator does not support white spaces on field NAME')
|
||||
)
|
||||
|
||||
# First, validate certificates
|
||||
# Ensure ipdMetadata cache is empty, to regenerate it
|
||||
self.cache.remove('idpMetadata')
|
||||
|
||||
# First, validate certificates
|
||||
|
||||
# This is in fact not needed, but we may say something useful to user if we check this
|
||||
if self.serverCertificate.value.startswith('-----BEGIN CERTIFICATE-----\n') is False:
|
||||
if self.server_certificate.value.startswith('-----BEGIN CERTIFICATE-----\n') is False:
|
||||
raise exceptions.validation.ValidationError(
|
||||
gettext(
|
||||
'Server certificate should be a valid PEM (PEM certificates starts with -----BEGIN CERTIFICATE-----)'
|
||||
@ -376,13 +400,13 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
)
|
||||
|
||||
try:
|
||||
CryptoManager().load_certificate(self.serverCertificate.value)
|
||||
CryptoManager().load_certificate(self.server_certificate.value)
|
||||
except Exception as e:
|
||||
raise exceptions.validation.ValidationError(gettext('Invalid server certificate. ') + str(e))
|
||||
|
||||
if (
|
||||
self.privateKey.value.startswith('-----BEGIN RSA PRIVATE KEY-----\n') is False
|
||||
and self.privateKey.value.startswith('-----BEGIN PRIVATE KEY-----\n') is False
|
||||
self.private_key.value.startswith('-----BEGIN RSA PRIVATE KEY-----\n') is False
|
||||
and self.private_key.value.startswith('-----BEGIN PRIVATE KEY-----\n') is False
|
||||
):
|
||||
raise exceptions.validation.ValidationError(
|
||||
gettext(
|
||||
@ -391,64 +415,66 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
)
|
||||
|
||||
try:
|
||||
CryptoManager().load_private_key(self.privateKey.value)
|
||||
CryptoManager().load_private_key(self.private_key.value)
|
||||
except Exception as e:
|
||||
raise exceptions.validation.ValidationError(gettext('Invalid private key. ') + str(e))
|
||||
|
||||
if not security.check_certificate_matches_private_key(
|
||||
cert=self.serverCertificate.value, key=self.privateKey.value
|
||||
cert=self.server_certificate.value, key=self.private_key.value
|
||||
):
|
||||
raise exceptions.validation.ValidationError(gettext('Certificate and private key do not match'))
|
||||
|
||||
request: 'ExtendedHttpRequest' = values['_request']
|
||||
|
||||
if self.entityID.value == '':
|
||||
self.entityID.value = request.build_absolute_uri(self.info_url())
|
||||
if self.entity_id.value == '':
|
||||
self.entity_id.value = request.build_absolute_uri(self.info_url())
|
||||
|
||||
self.manageUrl.value = request.build_absolute_uri(self.callback_url())
|
||||
self.manage_url.value = request.build_absolute_uri(self.callback_url())
|
||||
|
||||
idpMetadata: str = self.idpMetadata.value
|
||||
fromUrl: bool = False
|
||||
if idpMetadata.startswith('http://') or idpMetadata.startswith('https://'):
|
||||
logger.debug('idp Metadata is an URL: %s', idpMetadata)
|
||||
idp_metadata: str = self.idp_metadata.value
|
||||
from_url: bool = False
|
||||
if idp_metadata.startswith('http://') or idp_metadata.startswith('https://'):
|
||||
logger.debug('idp Metadata is an URL: %s', idp_metadata)
|
||||
try:
|
||||
resp = requests.get(
|
||||
idpMetadata.split('\n')[0],
|
||||
verify=self.checkSSLCertificate.as_bool(),
|
||||
idp_metadata.split('\n')[0],
|
||||
verify=self.check_https_certificate.as_bool(),
|
||||
timeout=10,
|
||||
)
|
||||
idpMetadata = resp.content.decode()
|
||||
idp_metadata = resp.content.decode()
|
||||
except Exception as e:
|
||||
raise exceptions.validation.ValidationError(
|
||||
gettext('Can\'t fetch url {0}: {1}').format(self.idpMetadata.value, str(e))
|
||||
gettext('Can\'t fetch url {0}: {1}').format(self.idp_metadata.value, str(e))
|
||||
)
|
||||
fromUrl = True
|
||||
from_url = True
|
||||
|
||||
# Try to parse it so we can check it is valid. Right now, it checks just that this is XML, will
|
||||
# correct it to check that is is valid idp metadata
|
||||
try:
|
||||
xml.sax.parseString(idpMetadata, xml.sax.ContentHandler()) # type: ignore # nosec: url provided by admin
|
||||
xml.sax.parseString(idp_metadata, xml.sax.ContentHandler()) # type: ignore # nosec: url provided by admin
|
||||
except Exception as e:
|
||||
msg = (gettext(' (obtained from URL)') if fromUrl else '') + str(e)
|
||||
raise exceptions.validation.ValidationError(gettext('XML does not seem valid for IDP Metadata ') + msg)
|
||||
msg = (gettext(' (obtained from URL)') if from_url else '') + str(e)
|
||||
raise exceptions.validation.ValidationError(
|
||||
gettext('XML does not seem valid for IDP Metadata ') + msg
|
||||
)
|
||||
|
||||
# Now validate regular expressions, if they exists
|
||||
auth_utils.validateRegexField(self.userNameAttr)
|
||||
auth_utils.validateRegexField(self.groupNameAttr)
|
||||
auth_utils.validateRegexField(self.realNameAttr)
|
||||
auth_utils.validate_regex_field(self.attrs_username)
|
||||
auth_utils.validate_regex_field(self.attrs_groupname)
|
||||
auth_utils.validate_regex_field(self.attrs_realname)
|
||||
|
||||
def getReqFromRequest(
|
||||
def build_req_from_request(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest',
|
||||
params: typing.Optional['types.auth.AuthCallbackParams'] = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
manageUrlObj = urlparse(self.manageUrl.value)
|
||||
script_path = manageUrlObj.path
|
||||
host = manageUrlObj.netloc
|
||||
manage_url_obj: 'parse.ParseResult' = parse.urlparse(self.manage_url.value)
|
||||
script_path: str = manage_url_obj.path
|
||||
host: str = manage_url_obj.netloc
|
||||
if ':' in host:
|
||||
host, port = host.split(':')
|
||||
else:
|
||||
if manageUrlObj.scheme == 'http':
|
||||
if manage_url_obj.scheme == 'http':
|
||||
port = '80'
|
||||
else:
|
||||
port = '443'
|
||||
@ -463,7 +489,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
'server_port': port, # params['server_port'],
|
||||
'get_data': params.get_params.copy(),
|
||||
'post_data': params.post_params.copy(),
|
||||
'lowercase_urlencoding': self.adFS.as_bool(),
|
||||
'lowercase_urlencoding': self.adfs.as_bool(),
|
||||
'query_string': params.query_string,
|
||||
}
|
||||
# No callback parameters, we use the request
|
||||
@ -474,72 +500,74 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
'server_port': port, # request.META['SERVER_PORT'],
|
||||
'get_data': request.GET.copy(),
|
||||
'post_data': request.POST.copy(),
|
||||
'lowercase_urlencoding': self.adFS.as_bool(),
|
||||
'lowercase_urlencoding': self.adfs.as_bool(),
|
||||
'query_string': request.META['QUERY_STRING'],
|
||||
}
|
||||
|
||||
@decorators.cached(
|
||||
prefix='idpm',
|
||||
key_fnc=CACHING_KEY_FNC,
|
||||
timeout=3600 * 24 * 365, # 1 year
|
||||
)
|
||||
def getIdpMetadataDict(self) -> dict[str, typing.Any]:
|
||||
if self.idpMetadata.value.startswith('http'):
|
||||
def get_idp_metadata_dict(self) -> dict[str, typing.Any]:
|
||||
if self.idp_metadata.value.startswith('http'):
|
||||
resp = self.cache.get('idpMetadata')
|
||||
if resp:
|
||||
return resp
|
||||
try:
|
||||
resp = requests.get(
|
||||
self.idpMetadata.value.split('\n')[0],
|
||||
verify=self.checkSSLCertificate.as_bool(),
|
||||
self.idp_metadata.value.split('\n')[0],
|
||||
verify=self.check_https_certificate.as_bool(),
|
||||
timeout=10,
|
||||
)
|
||||
val = resp.content.decode()
|
||||
# 10 years, unless edited the metadata will be kept
|
||||
self.cache.put('idpMetadata', val, 86400 * 365 * 10)
|
||||
except Exception as e:
|
||||
logger.error('Error fetching idp metadata: %s', e)
|
||||
raise exceptions.auth.AuthenticatorException(gettext('Can\'t access idp metadata'))
|
||||
else:
|
||||
val = self.idpMetadata.value
|
||||
val = self.idp_metadata.value
|
||||
|
||||
return OneLogin_Saml2_IdPMetadataParser.parse(val)
|
||||
|
||||
def oneLoginSettings(self) -> dict[str, typing.Any]:
|
||||
def build_onelogin_settings(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
'strict': True,
|
||||
'debug': True,
|
||||
'sp': {
|
||||
'entityId': self.entityID.value,
|
||||
'entityId': self.entity_id.value,
|
||||
'assertionConsumerService': {
|
||||
'url': self.manageUrl.value,
|
||||
'url': self.manage_url.value,
|
||||
'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST',
|
||||
},
|
||||
'singleLogoutService': {
|
||||
'url': self.manageUrl.value + '?logout=true',
|
||||
'url': self.manage_url.value + '?logout=true',
|
||||
'binding': 'urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect',
|
||||
},
|
||||
'x509cert': self.serverCertificate.value,
|
||||
'privateKey': self.privateKey.value,
|
||||
'x509cert': self.server_certificate.value,
|
||||
'privateKey': self.private_key.value,
|
||||
'NameIDFormat': 'urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified',
|
||||
},
|
||||
'idp': self.getIdpMetadataDict()['idp'],
|
||||
'idp': self.get_idp_metadata_dict()['idp'],
|
||||
'security': {
|
||||
'metadataCacheDuration': self.metadataCacheDuration.int_value
|
||||
if self.metadataCacheDuration.int_value > 0
|
||||
# in days, converted to seconds, this is a duration
|
||||
'metadataCacheDuration': self.metadata_cache_duration.as_int() * 86400
|
||||
if self.metadata_cache_duration.int_value > 0
|
||||
else 86400 * 365 * 10,
|
||||
# This is a date of end of validity
|
||||
'metadataValidUntil': sql_datetime()
|
||||
+ datetime.timedelta(seconds=self.metadataValidityDuration.int_value)
|
||||
if self.metadataCacheDuration.int_value > 0
|
||||
+ datetime.timedelta(days=self.metadata_validity_duration.as_int())
|
||||
if self.metadata_cache_duration.int_value > 0
|
||||
else sql_datetime() + datetime.timedelta(days=365 * 10),
|
||||
'nameIdEncrypted': self.nameIdEncrypted.as_bool(),
|
||||
'authnRequestsSigned': self.authnRequestsSigned.as_bool(),
|
||||
'logoutRequestSigned': self.logoutRequestSigned.as_bool(),
|
||||
'logoutResponseSigned': self.logoutResponseSigned.as_bool(),
|
||||
'signMetadata': self.signMetadata.as_bool(),
|
||||
'wantMessagesSigned': self.wantMessagesSigned.as_bool(),
|
||||
'wantAssertionsSigned': self.wantAssertionsSigned.as_bool(),
|
||||
'wantAssertionsEncrypted': self.wantAssertionsEncrypted.as_bool(),
|
||||
'wantNameIdEncrypted': self.wantNameIdEncrypted.as_bool(),
|
||||
'requestedAuthnContext': self.requestedAuthnContext.as_bool(),
|
||||
'nameIdEncrypted': self.use_name_id_encrypted.as_bool(),
|
||||
'authnRequestsSigned': self.use_authn_requests_signed.as_bool(),
|
||||
'logoutRequestSigned': self.logout_request_signed.as_bool(),
|
||||
'logoutResponseSigned': self.use_signed_logout_response.as_bool(),
|
||||
'signMetadata': self.use_signed_metadata.as_bool(),
|
||||
'wantMessagesSigned': self.want_messages_signed.as_bool(),
|
||||
'wantAssertionsSigned': self.want_assertions_signed.as_bool(),
|
||||
'wantAssertionsEncrypted': self.want_assertions_encrypted.as_bool(),
|
||||
'wantNameIdEncrypted': self.want_name_id_encrypted.as_bool(),
|
||||
'requestedAuthnContext': self.use_requested_authn_context.as_bool(),
|
||||
"signatureAlgorithm": "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
|
||||
"digestAlgorithm": "http://www.w3.org/2001/04/xmlenc#sha256",
|
||||
"rejectDeprecatedAlgorithm": not self.allowDeprecatedSignatureAlgorithms.as_bool(),
|
||||
"rejectDeprecatedAlgorithm": not self.allow_deprecated_signature_algorithms.as_bool(),
|
||||
},
|
||||
'organization': {
|
||||
'en-US': {
|
||||
@ -555,8 +583,8 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
key_fnc=CACHING_KEY_FNC,
|
||||
timeout=3600, # 1 hour
|
||||
)
|
||||
def getSpMetadata(self) -> str:
|
||||
saml_settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
|
||||
def get_sp_metadata(self) -> str:
|
||||
saml_settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
|
||||
metadata = saml_settings.get_sp_metadata()
|
||||
errors = saml_settings.validate_metadata(metadata)
|
||||
if len(errors) > 0:
|
||||
@ -567,7 +595,6 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
return metadata
|
||||
return typing.cast(bytes, metadata).decode()
|
||||
|
||||
|
||||
def get_info(
|
||||
self, parameters: collections.abc.Mapping[str, str]
|
||||
) -> typing.Optional[tuple[str, typing.Optional[str]]]:
|
||||
@ -575,7 +602,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
Althought this is mainly a get info callback, this can be used for any other purpuse we like.
|
||||
In this case, we use it to provide logout callback also
|
||||
"""
|
||||
info = self.getSpMetadata()
|
||||
info = self.get_sp_metadata()
|
||||
wantsHtml = parameters.get('format') == 'html'
|
||||
|
||||
content_type = 'text/html' if wantsHtml else 'application/samlmetadata+xml'
|
||||
@ -584,16 +611,16 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
)
|
||||
return info, content_type # 'application/samlmetadata+xml')
|
||||
|
||||
def mfaStorageKey(self, username: str) -> str:
|
||||
def mfa_storage_key(self, username: str) -> str:
|
||||
return 'mfa_' + self.db_obj().uuid + username # type: ignore
|
||||
|
||||
def mfaClean(self, username: str):
|
||||
self.storage.remove(self.mfaStorageKey(username))
|
||||
def mfa_clean(self, username: str):
|
||||
self.storage.remove(self.mfa_storage_key(username))
|
||||
|
||||
def mfa_identifier(self, username: str) -> str:
|
||||
return self.storage.get_unpickle(self.mfaStorageKey(username)) or ''
|
||||
return self.storage.get_unpickle(self.mfa_storage_key(username)) or ''
|
||||
|
||||
def logoutFromCallback(
|
||||
def logout_callback(
|
||||
self,
|
||||
req: dict[str, typing.Any],
|
||||
request: 'ExtendedHttpRequestWithUser',
|
||||
@ -605,15 +632,15 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
else:
|
||||
req['get_data']['SAMLResponse'] = req['post_data']['SAMLResponse']
|
||||
|
||||
logoutRequestId = request.session.get('samlLogoutRequestId', None)
|
||||
logout_req_id = request.session.get('samlLogoutRequestId', None)
|
||||
|
||||
# Cleanup session & session cookie
|
||||
request.session.flush()
|
||||
|
||||
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
|
||||
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
|
||||
auth = OneLogin_Saml2_Auth(req, settings)
|
||||
|
||||
url = auth.process_slo(request_id=logoutRequestId)
|
||||
url = auth.process_slo(request_id=logout_req_id)
|
||||
|
||||
errors = auth.get_errors()
|
||||
|
||||
@ -625,7 +652,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
|
||||
# Remove MFA related data
|
||||
if request.user:
|
||||
self.mfaClean(request.user.name)
|
||||
self.mfa_clean(request.user.name)
|
||||
|
||||
return types.auth.AuthenticationResult(
|
||||
success=types.auth.AuthenticationState.REDIRECT,
|
||||
@ -639,13 +666,13 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
gm: 'auths.GroupsManager',
|
||||
request: 'ExtendedHttpRequestWithUser',
|
||||
) -> types.auth.AuthenticationResult:
|
||||
req = self.getReqFromRequest(request, params=parameters)
|
||||
req = self.build_req_from_request(request, params=parameters)
|
||||
|
||||
if 'logout' in parameters.get_params:
|
||||
return self.logoutFromCallback(req, request)
|
||||
return self.logout_callback(req, request)
|
||||
|
||||
try:
|
||||
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
|
||||
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
|
||||
auth = OneLogin_Saml2_Auth(req, settings)
|
||||
auth.process_response()
|
||||
except Exception as e:
|
||||
@ -688,41 +715,43 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
|
||||
# Now that we have attributes, we can extract values from this, map groups, etc...
|
||||
username = ''.join(
|
||||
auth_utils.processRegexField(self.userNameAttr.value, attributes)
|
||||
auth_utils.process_regex_field(self.attrs_username.value, attributes)
|
||||
) # in case of multiple values is returned, join them
|
||||
logger.debug('Username: %s', username)
|
||||
|
||||
groups = auth_utils.processRegexField(self.groupNameAttr.value, attributes)
|
||||
groups = auth_utils.process_regex_field(self.attrs_groupname.value, attributes)
|
||||
logger.debug('Groups: %s', groups)
|
||||
|
||||
realName = ' '.join(auth_utils.processRegexField(self.realNameAttr.value, attributes))
|
||||
realName = ' '.join(auth_utils.process_regex_field(self.attrs_realname.value, attributes))
|
||||
logger.debug('Real name: %s', realName)
|
||||
|
||||
# store groups for this username at storage, so we can check it at a later stage
|
||||
self.storage.put_pickle(username, [realName, groups])
|
||||
|
||||
# store also the mfa identifier field value, in case we have provided it
|
||||
if self.mfaAttr.value.strip():
|
||||
if self.mfa_attr.value.strip():
|
||||
self.storage.put_pickle(
|
||||
self.mfaStorageKey(username),
|
||||
''.join(auth_utils.processRegexField(self.mfaAttr.value, attributes)),
|
||||
self.mfa_storage_key(username),
|
||||
''.join(auth_utils.process_regex_field(self.mfa_attr.value, attributes)),
|
||||
) # in case multipel values is returned, join them
|
||||
else:
|
||||
self.storage.remove(self.mfaStorageKey(username))
|
||||
self.storage.remove(self.mfa_storage_key(username))
|
||||
|
||||
# Now we check validity of user
|
||||
|
||||
gm.validate(groups)
|
||||
|
||||
return types.auth.AuthenticationResult(success=types.auth.AuthenticationState.SUCCESS, username=username)
|
||||
return types.auth.AuthenticationResult(
|
||||
success=types.auth.AuthenticationState.SUCCESS, username=username
|
||||
)
|
||||
|
||||
def logout(self, request: 'ExtendedHttpRequest', username: str) -> types.auth.AuthenticationResult:
|
||||
if not self.globalLogout.as_bool():
|
||||
if not self.use_global_logout.as_bool():
|
||||
return types.auth.SUCCESS_AUTH
|
||||
|
||||
req = self.getReqFromRequest(request)
|
||||
req = self.build_req_from_request(request)
|
||||
|
||||
settings = OneLogin_Saml2_Settings(settings=self.oneLoginSettings())
|
||||
settings = OneLogin_Saml2_Settings(settings=self.build_onelogin_settings())
|
||||
|
||||
auth = OneLogin_Saml2_Auth(req, settings)
|
||||
|
||||
@ -732,7 +761,7 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
request.session.clear()
|
||||
|
||||
# Remove MFA related data
|
||||
self.mfaClean(username)
|
||||
self.mfa_clean(username)
|
||||
|
||||
if not saml:
|
||||
return types.auth.SUCCESS_AUTH
|
||||
@ -764,8 +793,8 @@ class SAMLAuthenticator(auths.Authenticator):
|
||||
"""
|
||||
We will here compose the saml request and send it via http-redirect
|
||||
"""
|
||||
req = self.getReqFromRequest(request)
|
||||
auth = OneLogin_Saml2_Auth(req, self.oneLoginSettings())
|
||||
req = self.build_req_from_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, self.build_onelogin_settings())
|
||||
|
||||
return f'window.location="{auth.login()}";'
|
||||
|
||||
|
@ -1488,13 +1488,13 @@ class UserInterface(metaclass=UserInterfaceAbstract):
|
||||
# Any unexpected type will raise an exception
|
||||
# Note that currently, we will store old field name on db
|
||||
# to allow "backwards" migration if needed, but will be removed on a future version
|
||||
arr = [
|
||||
fields = [
|
||||
(field.old_field_name() or field_name, field.type.name, FIELDS_ENCODERS[field.type](field))
|
||||
for field_name, field in self._gui.items()
|
||||
if FIELDS_ENCODERS[field.type](field) is not None
|
||||
]
|
||||
|
||||
return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(arr)
|
||||
return SERIALIZATION_HEADER + SERIALIZATION_VERSION + serializer.serialize(fields)
|
||||
|
||||
def deserialize_fields(
|
||||
self,
|
||||
@ -1528,36 +1528,36 @@ class UserInterface(metaclass=UserInterfaceAbstract):
|
||||
logger.info('Empty values on unserialize_fields')
|
||||
return
|
||||
|
||||
arr = serializer.deserialize(values) or []
|
||||
fields = serializer.deserialize(values) or []
|
||||
|
||||
# Dict of translations from old_field_name to field_name
|
||||
field_names_translations: dict[str, str] = self._get_fieldname_translations()
|
||||
|
||||
# Set all values to defaults ones
|
||||
for fld_name in self._gui:
|
||||
fld = self._gui[fld_name]
|
||||
for field_name in self._gui:
|
||||
field = self._gui[field_name]
|
||||
if (
|
||||
self._gui[fld_name].is_type(types.ui.FieldType.HIDDEN)
|
||||
and self._gui[fld_name].is_serializable() is False
|
||||
field.is_type(types.ui.FieldType.HIDDEN)
|
||||
and field.is_serializable() is False
|
||||
):
|
||||
# logger.debug('Field {0} is not unserializable'.format(k))
|
||||
continue
|
||||
self._gui[fld_name].value = self._gui[fld_name].default
|
||||
field.value = field.default
|
||||
|
||||
for fld_name, fld_type, fld_value in arr:
|
||||
if fld_name in field_names_translations:
|
||||
fld_name = field_names_translations[fld_name] # Convert old field name to new one if needed
|
||||
if fld_name not in self._gui:
|
||||
logger.warning('Field %s not found in form', fld_name)
|
||||
for field_name, field_type, field_value in fields:
|
||||
if field_name in field_names_translations:
|
||||
field_name = field_names_translations[field_name] # Convert old field name to new one if needed
|
||||
if field_name not in self._gui:
|
||||
logger.warning('Field %s not found in form', field_name)
|
||||
continue
|
||||
field_type = self._gui[fld_name].type
|
||||
if field_type not in FIELD_DECODERS:
|
||||
logger.warning('Field %s has no converter', fld_name)
|
||||
internal_field_type = self._gui[field_name].type
|
||||
if internal_field_type not in FIELD_DECODERS:
|
||||
logger.warning('Field %s has no converter', field_name)
|
||||
continue
|
||||
if fld_type != field_type.name:
|
||||
logger.warning('Field %s has different type than expected', fld_name)
|
||||
if field_type != internal_field_type.name:
|
||||
logger.warning('Field %s has different type than expected', field_name)
|
||||
continue
|
||||
self._gui[fld_name].value = FIELD_DECODERS[field_type](fld_value)
|
||||
self._gui[field_name].value = FIELD_DECODERS[internal_field_type](field_value)
|
||||
|
||||
def deserialize_from_old_format(self, values: bytes) -> None:
|
||||
"""
|
||||
@ -1689,7 +1689,7 @@ FIELD_DECODERS: typing.Final[collections.abc.Mapping[types.ui.FieldType, collect
|
||||
types.ui.FieldType.TEXT_AUTOCOMPLETE: lambda x: x,
|
||||
types.ui.FieldType.NUMERIC: int,
|
||||
types.ui.FieldType.PASSWORD: lambda x: (CryptoManager().aes_decrypt(x.encode(), UDSK, True).decode()),
|
||||
types.ui.FieldType.HIDDEN: lambda x: None,
|
||||
types.ui.FieldType.HIDDEN: lambda x: x,
|
||||
types.ui.FieldType.CHOICE: lambda x: x,
|
||||
types.ui.FieldType.MULTICHOICE: lambda x: serializer.deserialize(base64.b64decode(x.encode())),
|
||||
types.ui.FieldType.EDITABLELIST: lambda x: serializer.deserialize(base64.b64decode(x.encode())),
|
||||
|
@ -40,7 +40,7 @@ from uds.core.util import ensure
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validateRegexField(field: ui.gui.TextField, field_value: typing.Optional[str] = None):
|
||||
def validate_regex_field(field: ui.gui.TextField, field_value: typing.Optional[str] = None):
|
||||
"""
|
||||
Validates the multi line fields refering to attributes
|
||||
"""
|
||||
@ -59,7 +59,7 @@ def validateRegexField(field: ui.gui.TextField, field_value: typing.Optional[str
|
||||
raise exceptions.validation.ValidationError(f'Invalid pattern at {field.label}: {line}') from e
|
||||
|
||||
|
||||
def processRegexField(
|
||||
def process_regex_field(
|
||||
field: str, attributes: collections.abc.Mapping[str, typing.Union[str, list[str]]]
|
||||
) -> list[str]:
|
||||
"""Proccesses a field, that can be a multiline field, and returns a list of values
|
||||
|
Loading…
x
Reference in New Issue
Block a user