forked from shaba/openuds
Adding MFA support to existing auths
This commit is contained in:
parent
1c65722d24
commit
f4da75cea9
@ -105,6 +105,13 @@ class InternalDBAuth(auths.Authenticator):
|
|||||||
pass
|
pass
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
|
def mfaIdentifier(self, username: str) -> str:
|
||||||
|
try:
|
||||||
|
self.dbAuthenticator().users.get(name=username, state=State.ACTIVE).mfaData
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return ''
|
||||||
|
|
||||||
def transformUsername(self, username: str) -> str:
|
def transformUsername(self, username: str) -> str:
|
||||||
if self.differentForEachHost.isTrue():
|
if self.differentForEachHost.isTrue():
|
||||||
newUsername = self.getIp() + '-' + username
|
newUsername = self.getIp() + '-' + username
|
||||||
|
@ -112,6 +112,15 @@ class RadiusAuth(auths.Authenticator):
|
|||||||
tooltip=_('If set, this value will be added as group for all radius users'),
|
tooltip=_('If set, this value will be added as group for all radius users'),
|
||||||
tab=gui.ADVANCED_TAB,
|
tab=gui.ADVANCED_TAB,
|
||||||
)
|
)
|
||||||
|
mfaAttr = gui.TextField(
|
||||||
|
length=2048,
|
||||||
|
multiline=2,
|
||||||
|
label=_('MFA attribute'),
|
||||||
|
order=13,
|
||||||
|
tooltip=_('Attribute from where to extract the MFA code'),
|
||||||
|
required=False,
|
||||||
|
tab=gui.MFA_TAB,
|
||||||
|
)
|
||||||
|
|
||||||
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
|
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
|
||||||
pass
|
pass
|
||||||
@ -126,12 +135,25 @@ class RadiusAuth(auths.Authenticator):
|
|||||||
appClassPrefix=self.appClassPrefix.value,
|
appClassPrefix=self.appClassPrefix.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mfaStorageKey(self, username: str) -> str:
|
||||||
|
return 'mfa_' + self.dbAuthenticator().uuid + username
|
||||||
|
|
||||||
|
def mfaIdentifier(self, username: str) -> str:
|
||||||
|
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||||
|
|
||||||
def authenticate(
|
def authenticate(
|
||||||
self, username: str, credentials: str, groupsManager: 'auths.GroupsManager'
|
self, username: str, credentials: str, groupsManager: 'auths.GroupsManager'
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
connection = self.radiusClient()
|
connection = self.radiusClient()
|
||||||
groups = connection.authenticate(username=username, password=credentials)
|
groups, mfaCode = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
|
||||||
|
# store the user mfa attribute if it is set
|
||||||
|
if mfaCode:
|
||||||
|
self.storage.putPickle(
|
||||||
|
self.mfaStorageKey(username),
|
||||||
|
mfaCode,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
authLogLogin(getRequest(), self.dbAuthenticator(), username, 'Access denied by Raiuds')
|
authLogLogin(getRequest(), self.dbAuthenticator(), username, 'Access denied by Raiuds')
|
||||||
return False
|
return False
|
||||||
@ -178,7 +200,7 @@ class RadiusAuth(auths.Authenticator):
|
|||||||
try:
|
try:
|
||||||
connection = self.radiusClient()
|
connection = self.radiusClient()
|
||||||
# Reply is not important...
|
# Reply is not important...
|
||||||
connection.authenticate(cryptoManager().randomString(10), cryptoManager().randomString(10))
|
connection.authenticate(cryptoManager().randomString(10), cryptoManager().randomString(10), mfaField=self.mfaAttr.value.strip())
|
||||||
except client.RadiusAuthenticationError as e:
|
except client.RadiusAuthenticationError as e:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -73,7 +73,8 @@ class RadiusClient:
|
|||||||
self.nasIdentifier = nasIdentifier
|
self.nasIdentifier = nasIdentifier
|
||||||
self.appClassPrefix = appClassPrefix
|
self.appClassPrefix = appClassPrefix
|
||||||
|
|
||||||
def authenticate(self, username: str, password: str) -> typing.List[str]:
|
# Second element of return value is the mfa code from field
|
||||||
|
def authenticate(self, username: str, password: str, mfaField: str) -> typing.Tuple[typing.List[str], str]:
|
||||||
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
|
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
|
||||||
code=pyrad.packet.AccessRequest,
|
code=pyrad.packet.AccessRequest,
|
||||||
User_Name=username,
|
User_Name=username,
|
||||||
@ -95,7 +96,11 @@ class RadiusClient:
|
|||||||
groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)]
|
groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)]
|
||||||
else:
|
else:
|
||||||
logger.info('No "Class (25)" attribute found')
|
logger.info('No "Class (25)" attribute found')
|
||||||
return []
|
return ([], '')
|
||||||
|
|
||||||
return groups
|
# ...and mfa code
|
||||||
|
mfaCode = ''
|
||||||
|
if mfaField and mfaField in reply:
|
||||||
|
mfaCode = ''.join(i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix))
|
||||||
|
return (groups, mfaCode)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
#
|
#
|
||||||
# Copyright (c) 2012-2019 Virtual Cable S.L.
|
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# Redistribution and use in source and binary forms, with or without modification,
|
# Redistribution and use in source and binary forms, with or without modification,
|
||||||
@ -12,7 +12,7 @@
|
|||||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
# this list of conditions and the following disclaimer in the documentation
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
# and/or other materials provided with the distribution.
|
# and/or other materials provided with the distribution.
|
||||||
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
|
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||||
# may be used to endorse or promote products derived from this software
|
# may be used to endorse or promote products derived from this software
|
||||||
# without specific prior written permission.
|
# without specific prior written permission.
|
||||||
#
|
#
|
||||||
@ -176,6 +176,16 @@ class RegexLdap(auths.Authenticator):
|
|||||||
tab=_('Advanced'),
|
tab=_('Advanced'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mfaAttr = gui.TextField(
|
||||||
|
length=2048,
|
||||||
|
multiline=2,
|
||||||
|
label=_('MFA attribute'),
|
||||||
|
order=13,
|
||||||
|
tooltip=_('Attribute from where to extract the MFA code'),
|
||||||
|
required=False,
|
||||||
|
tab=gui.MFA_TAB,
|
||||||
|
)
|
||||||
|
|
||||||
typeName = _('Regex LDAP Authenticator')
|
typeName = _('Regex LDAP Authenticator')
|
||||||
typeType = 'RegexLdapAuthenticator'
|
typeType = 'RegexLdapAuthenticator'
|
||||||
typeDescription = _('Regular Expressions LDAP authenticator')
|
typeDescription = _('Regular Expressions LDAP authenticator')
|
||||||
@ -205,6 +215,7 @@ class RegexLdap(auths.Authenticator):
|
|||||||
_groupNameAttr: str = ''
|
_groupNameAttr: str = ''
|
||||||
_userNameAttr: str = ''
|
_userNameAttr: str = ''
|
||||||
_altClass: str = ''
|
_altClass: str = ''
|
||||||
|
_mfaAttr: str = ''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -231,6 +242,7 @@ class RegexLdap(auths.Authenticator):
|
|||||||
# self._regex = values['regex']
|
# self._regex = values['regex']
|
||||||
self._userNameAttr = values['userNameAttr']
|
self._userNameAttr = values['userNameAttr']
|
||||||
self._altClass = values['altClass']
|
self._altClass = values['altClass']
|
||||||
|
self._mfaAttr = values['mfaAttr']
|
||||||
|
|
||||||
def __validateField(self, field: str, fieldLabel: str) -> None:
|
def __validateField(self, field: str, fieldLabel: str) -> None:
|
||||||
"""
|
"""
|
||||||
@ -296,6 +308,12 @@ class RegexLdap(auths.Authenticator):
|
|||||||
logger.debug('Res: %s', res)
|
logger.debug('Res: %s', res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def mfaStorageKey(self, username: str) -> str:
|
||||||
|
return 'mfa_' + self.dbAuthenticator().uuid + username
|
||||||
|
|
||||||
|
def mfaIdentifier(self, username: str) -> str:
|
||||||
|
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
|
||||||
|
|
||||||
def valuesDict(self) -> gui.ValuesDictType:
|
def valuesDict(self) -> gui.ValuesDictType:
|
||||||
return {
|
return {
|
||||||
'host': self._host,
|
'host': self._host,
|
||||||
@ -310,12 +328,13 @@ class RegexLdap(auths.Authenticator):
|
|||||||
'groupNameAttr': self._groupNameAttr,
|
'groupNameAttr': self._groupNameAttr,
|
||||||
'userNameAttr': self._userNameAttr,
|
'userNameAttr': self._userNameAttr,
|
||||||
'altClass': self._altClass,
|
'altClass': self._altClass,
|
||||||
|
'mfaAttr': self._mfaAttr,
|
||||||
}
|
}
|
||||||
|
|
||||||
def marshal(self) -> bytes:
|
def marshal(self) -> bytes:
|
||||||
return '\t'.join(
|
return '\t'.join(
|
||||||
[
|
[
|
||||||
'v3',
|
'v4',
|
||||||
self._host,
|
self._host,
|
||||||
self._port,
|
self._port,
|
||||||
gui.boolToStr(self._ssl),
|
gui.boolToStr(self._ssl),
|
||||||
@ -328,6 +347,7 @@ class RegexLdap(auths.Authenticator):
|
|||||||
self._groupNameAttr,
|
self._groupNameAttr,
|
||||||
self._userNameAttr,
|
self._userNameAttr,
|
||||||
self._altClass,
|
self._altClass,
|
||||||
|
self._mfaAttr,
|
||||||
]
|
]
|
||||||
).encode('utf8')
|
).encode('utf8')
|
||||||
|
|
||||||
@ -385,6 +405,24 @@ class RegexLdap(auths.Authenticator):
|
|||||||
self._altClass,
|
self._altClass,
|
||||||
) = vals[1:]
|
) = vals[1:]
|
||||||
self._ssl = gui.strToBool(ssl)
|
self._ssl = gui.strToBool(ssl)
|
||||||
|
elif vals[0] == 'v4':
|
||||||
|
logger.debug("Data v4: %s", vals[1:])
|
||||||
|
(
|
||||||
|
self._host,
|
||||||
|
self._port,
|
||||||
|
ssl,
|
||||||
|
self._username,
|
||||||
|
self._password,
|
||||||
|
self._timeout,
|
||||||
|
self._ldapBase,
|
||||||
|
self._userClass,
|
||||||
|
self._userIdAttr,
|
||||||
|
self._groupNameAttr,
|
||||||
|
self._userNameAttr,
|
||||||
|
self._altClass,
|
||||||
|
self._mfaAttr,
|
||||||
|
) = vals[1:]
|
||||||
|
self._ssl = gui.strToBool(ssl)
|
||||||
|
|
||||||
def __connection(self) -> typing.Any:
|
def __connection(self) -> typing.Any:
|
||||||
"""
|
"""
|
||||||
@ -428,6 +466,9 @@ class RegexLdap(auths.Authenticator):
|
|||||||
+ self.__getAttrsFromField(self._userNameAttr)
|
+ self.__getAttrsFromField(self._userNameAttr)
|
||||||
+ self.__getAttrsFromField(self._groupNameAttr)
|
+ self.__getAttrsFromField(self._groupNameAttr)
|
||||||
)
|
)
|
||||||
|
if self._mfaAttr:
|
||||||
|
attributes = attributes + self.__getAttrsFromField(self._mfaAttr)
|
||||||
|
|
||||||
user = ldaputil.getFirst(
|
user = ldaputil.getFirst(
|
||||||
con=self.__connection(),
|
con=self.__connection(),
|
||||||
base=self._ldapBase,
|
base=self._ldapBase,
|
||||||
@ -517,6 +558,13 @@ class RegexLdap(auths.Authenticator):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# store the user mfa attribute if it is set
|
||||||
|
if self._mfaAttr:
|
||||||
|
self.storage.putPickle(
|
||||||
|
self.mfaStorageKey(username),
|
||||||
|
usr[self._mfaAttr][0],
|
||||||
|
)
|
||||||
|
|
||||||
groupsManager.validate(self.__getGroups(usr))
|
groupsManager.validate(self.__getGroups(usr))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user