1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-24 21:34:41 +03:00

Adding MFA support to existing auths

This commit is contained in:
Adolfo Gómez García 2022-07-04 22:10:06 +02:00
parent 1c65722d24
commit f4da75cea9
4 changed files with 90 additions and 8 deletions

View File

@ -105,6 +105,13 @@ class InternalDBAuth(auths.Authenticator):
pass
return ip
def mfaIdentifier(self, username: str) -> str:
try:
self.dbAuthenticator().users.get(name=username, state=State.ACTIVE).mfaData
except Exception:
pass
return ''
def transformUsername(self, username: str) -> str:
if self.differentForEachHost.isTrue():
newUsername = self.getIp() + '-' + username

View File

@ -112,6 +112,15 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('If set, this value will be added as group for all radius users'),
tab=gui.ADVANCED_TAB,
)
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
pass
@ -126,12 +135,25 @@ class RadiusAuth(auths.Authenticator):
appClassPrefix=self.appClassPrefix.value,
)
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def authenticate(
self, username: str, credentials: str, groupsManager: 'auths.GroupsManager'
) -> bool:
try:
connection = self.radiusClient()
groups = connection.authenticate(username=username, password=credentials)
groups, mfaCode = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
# store the user mfa attribute if it is set
if mfaCode:
self.storage.putPickle(
self.mfaStorageKey(username),
mfaCode,
)
except Exception:
authLogLogin(getRequest(), self.dbAuthenticator(), username, 'Access denied by Raiuds')
return False
@ -178,7 +200,7 @@ class RadiusAuth(auths.Authenticator):
try:
connection = self.radiusClient()
# Reply is not important...
connection.authenticate(cryptoManager().randomString(10), cryptoManager().randomString(10))
connection.authenticate(cryptoManager().randomString(10), cryptoManager().randomString(10), mfaField=self.mfaAttr.value.strip())
except client.RadiusAuthenticationError as e:
pass
except Exception:

View File

@ -73,7 +73,8 @@ class RadiusClient:
self.nasIdentifier = nasIdentifier
self.appClassPrefix = appClassPrefix
def authenticate(self, username: str, password: str) -> typing.List[str]:
# Second element of return value is the mfa code from field
def authenticate(self, username: str, password: str, mfaField: str) -> typing.Tuple[typing.List[str], str]:
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
code=pyrad.packet.AccessRequest,
User_Name=username,
@ -95,7 +96,11 @@ class RadiusClient:
groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)]
else:
logger.info('No "Class (25)" attribute found')
return []
return ([], '')
return groups
# ...and mfa code
mfaCode = ''
if mfaField and mfaField in reply:
mfaCode = ''.join(i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix))
return (groups, mfaCode)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2019 Virtual Cable S.L.
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -12,7 +12,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -176,6 +176,16 @@ class RegexLdap(auths.Authenticator):
tab=_('Advanced'),
)
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
typeName = _('Regex LDAP Authenticator')
typeType = 'RegexLdapAuthenticator'
typeDescription = _('Regular Expressions LDAP authenticator')
@ -205,6 +215,7 @@ class RegexLdap(auths.Authenticator):
_groupNameAttr: str = ''
_userNameAttr: str = ''
_altClass: str = ''
_mfaAttr: str = ''
def __init__(
self,
@ -231,6 +242,7 @@ class RegexLdap(auths.Authenticator):
# self._regex = values['regex']
self._userNameAttr = values['userNameAttr']
self._altClass = values['altClass']
self._mfaAttr = values['mfaAttr']
def __validateField(self, field: str, fieldLabel: str) -> None:
"""
@ -296,6 +308,12 @@ class RegexLdap(auths.Authenticator):
logger.debug('Res: %s', res)
return res
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def valuesDict(self) -> gui.ValuesDictType:
return {
'host': self._host,
@ -310,12 +328,13 @@ class RegexLdap(auths.Authenticator):
'groupNameAttr': self._groupNameAttr,
'userNameAttr': self._userNameAttr,
'altClass': self._altClass,
'mfaAttr': self._mfaAttr,
}
def marshal(self) -> bytes:
return '\t'.join(
[
'v3',
'v4',
self._host,
self._port,
gui.boolToStr(self._ssl),
@ -328,6 +347,7 @@ class RegexLdap(auths.Authenticator):
self._groupNameAttr,
self._userNameAttr,
self._altClass,
self._mfaAttr,
]
).encode('utf8')
@ -385,6 +405,24 @@ class RegexLdap(auths.Authenticator):
self._altClass,
) = vals[1:]
self._ssl = gui.strToBool(ssl)
elif vals[0] == 'v4':
logger.debug("Data v4: %s", vals[1:])
(
self._host,
self._port,
ssl,
self._username,
self._password,
self._timeout,
self._ldapBase,
self._userClass,
self._userIdAttr,
self._groupNameAttr,
self._userNameAttr,
self._altClass,
self._mfaAttr,
) = vals[1:]
self._ssl = gui.strToBool(ssl)
def __connection(self) -> typing.Any:
"""
@ -428,6 +466,9 @@ class RegexLdap(auths.Authenticator):
+ self.__getAttrsFromField(self._userNameAttr)
+ self.__getAttrsFromField(self._groupNameAttr)
)
if self._mfaAttr:
attributes = attributes + self.__getAttrsFromField(self._mfaAttr)
user = ldaputil.getFirst(
con=self.__connection(),
base=self._ldapBase,
@ -517,6 +558,13 @@ class RegexLdap(auths.Authenticator):
)
return False
# store the user mfa attribute if it is set
if self._mfaAttr:
self.storage.putPickle(
self.mfaStorageKey(username),
usr[self._mfaAttr][0],
)
groupsManager.validate(self.__getGroups(usr))
return True