Merge remote-tracking branch 'origin/v3.5-mfa'

This commit is contained in:
Adolfo Gómez García 2022-07-04 22:11:06 +02:00
commit 8ec815a75b
9 changed files with 150 additions and 30 deletions

View File

@ -80,21 +80,21 @@ class Users(DetailHandler):
custom_methods = ['servicesPools', 'userServices']
@staticmethod
def uuid_to_id(iterator):
for v in iterator:
v['id'] = v['uuid']
del v['uuid']
yield v
def getItems(self, parent: Authenticator, item: typing.Optional[str]):
# processes item to change uuid key for id
def uuid_to_id(iterable: typing.Iterable[typing.MutableMapping[str, typing.Any]]):
for v in iterable:
v['id'] = v['uuid']
del v['uuid']
yield v
logger.debug(item)
# Extract authenticator
try:
if item is None:
values = list(
Users.uuid_to_id(
parent.users.all().values(
uuid_to_id(
(i for i in parent.users.all().values(
'uuid',
'name',
'real_name',
@ -104,7 +104,8 @@ class Users(DetailHandler):
'is_admin',
'last_access',
'parent',
)
'mfaData',
))
)
)
for res in values:
@ -127,6 +128,7 @@ class Users(DetailHandler):
'is_admin',
'last_access',
'parent',
'mfaData',
),
)
res['id'] = u.uuid
@ -153,7 +155,7 @@ class Users(DetailHandler):
except Exception:
return _('Current users')
def getFields(self, parent):
def getFields(self, parent: Authenticator):
return [
{
'name': {
@ -198,12 +200,16 @@ class Users(DetailHandler):
'staff_member',
'is_admin',
]
if self._params.get('name', '') == '':
if self._params.get('name', '').strip() == '':
raise RequestError(_('Username cannot be empty'))
if 'password' in self._params:
valid_fields.append('password')
self._params['password'] = cryptoManager().hash(self._params['password'])
if 'mfaData' in self._params:
valid_fields.append('mfaData')
self._params['mfaData'] = self._params['mfaData'].strip()
fields = self.readFieldsFromParams(valid_fields)
if not self._user.is_admin:
@ -224,9 +230,8 @@ class Users(DetailHandler):
user.__dict__.update(fields)
logger.debug('User parent: %s', user.parent)
if auth.isExternalSource is False and (
user.parent is None or user.parent == ''
):
# If internal auth, threat it "special"
if auth.isExternalSource is False and not user.parent:
groups = self.readFieldsFromParams(['groups'])['groups']
logger.debug('Groups: %s', groups)
logger.debug('Got Groups %s', parent.groups.filter(uuid__in=groups))
@ -431,7 +436,7 @@ class Groups(DetailHandler):
fields = self.readFieldsFromParams(valid_fields)
is_pattern = fields.get('name', '').find('pat:') == 0
auth = parent.getInstance()
if item is None: # Create new
if not item: # Create new
if not is_meta and not is_pattern:
auth.createGroup(
fields
@ -484,7 +489,9 @@ class Groups(DetailHandler):
except Exception:
raise self.invalidItemException()
def servicesPools(self, parent: Authenticator, item: str) -> typing.List[typing.Mapping[str, typing.Any]]:
def servicesPools(
self, parent: Authenticator, item: str
) -> typing.List[typing.Mapping[str, typing.Any]]:
uuid = processUuid(item)
group = parent.groups.get(uuid=processUuid(uuid))
res: typing.List[typing.Mapping[str, typing.Any]] = []
@ -505,7 +512,9 @@ class Groups(DetailHandler):
return res
def users(self, parent: Authenticator, item: str) -> typing.List[typing.Mapping[str, typing.Any]]:
def users(
self, parent: Authenticator, item: str
) -> typing.List[typing.Mapping[str, typing.Any]]:
uuid = processUuid(item)
group = parent.groups.get(uuid=processUuid(uuid))

View File

@ -111,6 +111,15 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('If set, this value will be added as group for all radius users'),
tab=gui.ADVANCED_TAB,
)
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
pass
@ -125,6 +134,12 @@ class RadiusAuth(auths.Authenticator):
appClassPrefix=self.appClassPrefix.value,
)
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def authenticate(
self,
username: str,
@ -134,7 +149,14 @@ class RadiusAuth(auths.Authenticator):
) -> auths.AuthenticationResult:
try:
connection = self.radiusClient()
groups = connection.authenticate(username=username, password=credentials)
groups, mfaCode = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
# store the user mfa attribute if it is set
if mfaCode:
self.storage.putPickle(
self.mfaStorageKey(username),
mfaCode,
)
except Exception:
authLogLogin(
request,

View File

@ -73,7 +73,8 @@ class RadiusClient:
self.nasIdentifier = nasIdentifier
self.appClassPrefix = appClassPrefix
def authenticate(self, username: str, password: str) -> typing.List[str]:
# Second element of return value is the mfa code from field
def authenticate(self, username: str, password: str, mfaField: str) -> typing.Tuple[typing.List[str], str]:
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
code=pyrad.packet.AccessRequest,
User_Name=username,
@ -95,7 +96,11 @@ class RadiusClient:
groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)]
else:
logger.info('No "Class (25)" attribute found')
return []
return ([], '')
return groups
# ...and mfa code
mfaCode = ''
if mfaField and mfaField in reply:
mfaCode = ''.join(i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix))
return (groups, mfaCode)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2012-2019 Virtual Cable S.L.
# Copyright (c) 2012-2022 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -12,7 +12,7 @@
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
@ -175,6 +175,16 @@ class RegexLdap(auths.Authenticator):
tab=_('Advanced'),
)
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
typeName = _('Regex LDAP Authenticator')
typeType = 'RegexLdapAuthenticator'
typeDescription = _('Regular Expressions LDAP authenticator')
@ -204,6 +214,7 @@ class RegexLdap(auths.Authenticator):
_groupNameAttr: str = ''
_userNameAttr: str = ''
_altClass: str = ''
_mfaAttr: str = ''
def __init__(
self,
@ -230,6 +241,7 @@ class RegexLdap(auths.Authenticator):
# self._regex = values['regex']
self._userNameAttr = values['userNameAttr']
self._altClass = values['altClass']
self._mfaAttr = values['mfaAttr']
def __validateField(self, field: str, fieldLabel: str) -> None:
"""
@ -295,6 +307,12 @@ class RegexLdap(auths.Authenticator):
logger.debug('Res: %s', res)
return res
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def valuesDict(self) -> gui.ValuesDictType:
return {
'host': self._host,
@ -309,12 +327,13 @@ class RegexLdap(auths.Authenticator):
'groupNameAttr': self._groupNameAttr,
'userNameAttr': self._userNameAttr,
'altClass': self._altClass,
'mfaAttr': self._mfaAttr,
}
def marshal(self) -> bytes:
return '\t'.join(
[
'v3',
'v4',
self._host,
self._port,
gui.boolToStr(self._ssl),
@ -327,6 +346,7 @@ class RegexLdap(auths.Authenticator):
self._groupNameAttr,
self._userNameAttr,
self._altClass,
self._mfaAttr,
]
).encode('utf8')
@ -384,6 +404,24 @@ class RegexLdap(auths.Authenticator):
self._altClass,
) = vals[1:]
self._ssl = gui.strToBool(ssl)
elif vals[0] == 'v4':
logger.debug("Data v4: %s", vals[1:])
(
self._host,
self._port,
ssl,
self._username,
self._password,
self._timeout,
self._ldapBase,
self._userClass,
self._userIdAttr,
self._groupNameAttr,
self._userNameAttr,
self._altClass,
self._mfaAttr,
) = vals[1:]
self._ssl = gui.strToBool(ssl)
def __connection(self) -> typing.Any:
"""
@ -427,6 +465,9 @@ class RegexLdap(auths.Authenticator):
+ self.__getAttrsFromField(self._userNameAttr)
+ self.__getAttrsFromField(self._groupNameAttr)
)
if self._mfaAttr:
attributes = attributes + self.__getAttrsFromField(self._mfaAttr)
user = ldaputil.getFirst(
con=self.__connection(),
base=self._ldapBase,
@ -520,6 +561,13 @@ class RegexLdap(auths.Authenticator):
)
return auths.FAILED_AUTH
# store the user mfa attribute if it is set
if self._mfaAttr:
self.storage.putPickle(
self.mfaStorageKey(username),
usr[self._mfaAttr][0],
)
groupsManager.validate(self.__getGroups(usr))
return auths.SUCCESS_AUTH

View File

@ -178,6 +178,9 @@ class StorageAsDict(MutableMapping):
def get(self, key: str, default: typing.Any = None) -> typing.Any:
return self[key] or default
def delete(self, key: str) -> None:
self.__delitem__(key)
# Custom utility methods
@property
def group(self) -> str:

View File

@ -1,4 +1,4 @@
# Generated by Django 3.2.10 on 2022-06-23 19:34
# Generated by Django 3.2.10 on 2022-07-04 21:20
from django.db import migrations, models
import django.db.models.deletion
@ -11,6 +11,11 @@ class Migration(migrations.Migration):
]
operations = [
migrations.AddField(
model_name='user',
name='mfaData',
field=models.CharField(default='', max_length=128),
),
migrations.CreateModel(
name='MFA',
fields=[

View File

@ -66,6 +66,9 @@ class User(UUIDModel):
state = models.CharField(max_length=1, db_index=True)
password = models.CharField(
max_length=128, default=''
) # Only used on "internal" sources or sources that "needs password"
mfaData = models.CharField(
max_length=128, default=''
) # Only used on "internal" sources
staff_member = models.BooleanField(
default=False
@ -202,6 +205,26 @@ class User(UUIDModel):
# This group matches
yield g
# Get custom data
def getCustomData(self, key: str) -> typing.Optional[str]:
"""
Returns the custom data for this user for the provided key.
Usually custom data will be associated with transports, but can be custom data registered by ANY module.
Args:
key: key of the custom data to get
Returns:
The custom data for the key specified as a string (can be empty if key is not found).
If the key exists, the custom data will always contain something, but may be the values are the default ones.
"""
with storage.StorageAccess('manager' + self.manager.uuid) as store:
return store[self.uuid + '_' + key]
def __str__(self):
return 'User {} (id:{}) from auth {}'.format(
self.name, self.id, self.manager.name
@ -217,11 +240,15 @@ class User(UUIDModel):
:note: If destroy raises an exception, the deletion is not taken.
"""
toDelete = kwargs['instance']
toDelete: User = kwargs['instance']
# first, we invoke removeUser. If this raises an exception, user will not
# be removed
toDelete.getManager().removeUser(toDelete.name)
# Remove related stored values
with storage.StorageAccess('manager' + toDelete.manager.uuid) as store:
for key in store.keys():
store.delete(key)
# now removes all "child" of this user, if it has children
User.objects.filter(parent=toDelete.id).delete()

View File

@ -98,7 +98,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
state_date = models.DateTimeField(db_index=True)
creation_date = models.DateTimeField(db_index=True)
data = models.TextField(default='')
user: 'models.ForeignKey[UserService, User]' = models.ForeignKey(
user = models.ForeignKey(
User,
on_delete=models.CASCADE,
related_name='userServices',
@ -394,7 +394,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
self.os_state = state
self.save(update_fields=['os_state', 'state_date'])
def assignToUser(self, user: User) -> None:
def assignToUser(self, user: typing.Optional[User]) -> None:
"""
Assigns this user deployed service to an user.
@ -403,7 +403,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
"""
self.cache_level = 0
self.state_date = getSqlDatetime()
self.user = user # type: ignore
self.user = user
self.save(update_fields=['cache_level', 'state_date', 'user'])
def setInUse(self, inUse: bool) -> None:

View File

@ -487,6 +487,7 @@ gettext("Role");
gettext("Admin");
gettext("Staff member");
gettext("User");
gettext("MFA");
gettext("Groups");
gettext("Cancel");
gettext("Ok");