Merge remote-tracking branch 'origin/v3.5-mfa'

This commit is contained in:
Adolfo Gómez García 2022-07-04 22:11:06 +02:00
commit 8ec815a75b
9 changed files with 150 additions and 30 deletions

View File

@ -80,21 +80,21 @@ class Users(DetailHandler):
custom_methods = ['servicesPools', 'userServices'] custom_methods = ['servicesPools', 'userServices']
@staticmethod
def uuid_to_id(iterator):
for v in iterator:
v['id'] = v['uuid']
del v['uuid']
yield v
def getItems(self, parent: Authenticator, item: typing.Optional[str]): def getItems(self, parent: Authenticator, item: typing.Optional[str]):
# processes item to change uuid key for id
def uuid_to_id(iterable: typing.Iterable[typing.MutableMapping[str, typing.Any]]):
for v in iterable:
v['id'] = v['uuid']
del v['uuid']
yield v
logger.debug(item) logger.debug(item)
# Extract authenticator # Extract authenticator
try: try:
if item is None: if item is None:
values = list( values = list(
Users.uuid_to_id( uuid_to_id(
parent.users.all().values( (i for i in parent.users.all().values(
'uuid', 'uuid',
'name', 'name',
'real_name', 'real_name',
@ -104,7 +104,8 @@ class Users(DetailHandler):
'is_admin', 'is_admin',
'last_access', 'last_access',
'parent', 'parent',
) 'mfaData',
))
) )
) )
for res in values: for res in values:
@ -127,6 +128,7 @@ class Users(DetailHandler):
'is_admin', 'is_admin',
'last_access', 'last_access',
'parent', 'parent',
'mfaData',
), ),
) )
res['id'] = u.uuid res['id'] = u.uuid
@ -153,7 +155,7 @@ class Users(DetailHandler):
except Exception: except Exception:
return _('Current users') return _('Current users')
def getFields(self, parent): def getFields(self, parent: Authenticator):
return [ return [
{ {
'name': { 'name': {
@ -198,12 +200,16 @@ class Users(DetailHandler):
'staff_member', 'staff_member',
'is_admin', 'is_admin',
] ]
if self._params.get('name', '') == '': if self._params.get('name', '').strip() == '':
raise RequestError(_('Username cannot be empty')) raise RequestError(_('Username cannot be empty'))
if 'password' in self._params: if 'password' in self._params:
valid_fields.append('password') valid_fields.append('password')
self._params['password'] = cryptoManager().hash(self._params['password']) self._params['password'] = cryptoManager().hash(self._params['password'])
if 'mfaData' in self._params:
valid_fields.append('mfaData')
self._params['mfaData'] = self._params['mfaData'].strip()
fields = self.readFieldsFromParams(valid_fields) fields = self.readFieldsFromParams(valid_fields)
if not self._user.is_admin: if not self._user.is_admin:
@ -224,9 +230,8 @@ class Users(DetailHandler):
user.__dict__.update(fields) user.__dict__.update(fields)
logger.debug('User parent: %s', user.parent) logger.debug('User parent: %s', user.parent)
if auth.isExternalSource is False and ( # If internal auth, threat it "special"
user.parent is None or user.parent == '' if auth.isExternalSource is False and not user.parent:
):
groups = self.readFieldsFromParams(['groups'])['groups'] groups = self.readFieldsFromParams(['groups'])['groups']
logger.debug('Groups: %s', groups) logger.debug('Groups: %s', groups)
logger.debug('Got Groups %s', parent.groups.filter(uuid__in=groups)) logger.debug('Got Groups %s', parent.groups.filter(uuid__in=groups))
@ -431,7 +436,7 @@ class Groups(DetailHandler):
fields = self.readFieldsFromParams(valid_fields) fields = self.readFieldsFromParams(valid_fields)
is_pattern = fields.get('name', '').find('pat:') == 0 is_pattern = fields.get('name', '').find('pat:') == 0
auth = parent.getInstance() auth = parent.getInstance()
if item is None: # Create new if not item: # Create new
if not is_meta and not is_pattern: if not is_meta and not is_pattern:
auth.createGroup( auth.createGroup(
fields fields
@ -484,7 +489,9 @@ class Groups(DetailHandler):
except Exception: except Exception:
raise self.invalidItemException() raise self.invalidItemException()
def servicesPools(self, parent: Authenticator, item: str) -> typing.List[typing.Mapping[str, typing.Any]]: def servicesPools(
self, parent: Authenticator, item: str
) -> typing.List[typing.Mapping[str, typing.Any]]:
uuid = processUuid(item) uuid = processUuid(item)
group = parent.groups.get(uuid=processUuid(uuid)) group = parent.groups.get(uuid=processUuid(uuid))
res: typing.List[typing.Mapping[str, typing.Any]] = [] res: typing.List[typing.Mapping[str, typing.Any]] = []
@ -505,7 +512,9 @@ class Groups(DetailHandler):
return res return res
def users(self, parent: Authenticator, item: str) -> typing.List[typing.Mapping[str, typing.Any]]: def users(
self, parent: Authenticator, item: str
) -> typing.List[typing.Mapping[str, typing.Any]]:
uuid = processUuid(item) uuid = processUuid(item)
group = parent.groups.get(uuid=processUuid(uuid)) group = parent.groups.get(uuid=processUuid(uuid))

View File

@ -111,6 +111,15 @@ class RadiusAuth(auths.Authenticator):
tooltip=_('If set, this value will be added as group for all radius users'), tooltip=_('If set, this value will be added as group for all radius users'),
tab=gui.ADVANCED_TAB, tab=gui.ADVANCED_TAB,
) )
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None: def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
pass pass
@ -125,6 +134,12 @@ class RadiusAuth(auths.Authenticator):
appClassPrefix=self.appClassPrefix.value, appClassPrefix=self.appClassPrefix.value,
) )
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def authenticate( def authenticate(
self, self,
username: str, username: str,
@ -134,7 +149,14 @@ class RadiusAuth(auths.Authenticator):
) -> auths.AuthenticationResult: ) -> auths.AuthenticationResult:
try: try:
connection = self.radiusClient() connection = self.radiusClient()
groups = connection.authenticate(username=username, password=credentials) groups, mfaCode = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
# store the user mfa attribute if it is set
if mfaCode:
self.storage.putPickle(
self.mfaStorageKey(username),
mfaCode,
)
except Exception: except Exception:
authLogLogin( authLogLogin(
request, request,

View File

@ -73,7 +73,8 @@ class RadiusClient:
self.nasIdentifier = nasIdentifier self.nasIdentifier = nasIdentifier
self.appClassPrefix = appClassPrefix self.appClassPrefix = appClassPrefix
def authenticate(self, username: str, password: str) -> typing.List[str]: # Second element of return value is the mfa code from field
def authenticate(self, username: str, password: str, mfaField: str) -> typing.Tuple[typing.List[str], str]:
req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket( req: pyrad.packet.AuthPacket = self.radiusServer.CreateAuthPacket(
code=pyrad.packet.AccessRequest, code=pyrad.packet.AccessRequest,
User_Name=username, User_Name=username,
@ -95,7 +96,11 @@ class RadiusClient:
groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)] groups = [i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix)]
else: else:
logger.info('No "Class (25)" attribute found') logger.info('No "Class (25)" attribute found')
return [] return ([], '')
return groups # ...and mfa code
mfaCode = ''
if mfaField and mfaField in reply:
mfaCode = ''.join(i[groupClassPrefixLen:].decode() for i in typing.cast(typing.Iterable[bytes], reply['Class']) if i.startswith(groupClassPrefix))
return (groups, mfaCode)

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (c) 2012-2019 Virtual Cable S.L. # Copyright (c) 2012-2022 Virtual Cable S.L.U.
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without modification, # Redistribution and use in source and binary forms, with or without modification,
@ -12,7 +12,7 @@
# * Redistributions in binary form must reproduce the above copyright notice, # * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation # this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution. # and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L. nor the names of its contributors # * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software # may be used to endorse or promote products derived from this software
# without specific prior written permission. # without specific prior written permission.
# #
@ -175,6 +175,16 @@ class RegexLdap(auths.Authenticator):
tab=_('Advanced'), tab=_('Advanced'),
) )
mfaAttr = gui.TextField(
length=2048,
multiline=2,
label=_('MFA attribute'),
order=13,
tooltip=_('Attribute from where to extract the MFA code'),
required=False,
tab=gui.MFA_TAB,
)
typeName = _('Regex LDAP Authenticator') typeName = _('Regex LDAP Authenticator')
typeType = 'RegexLdapAuthenticator' typeType = 'RegexLdapAuthenticator'
typeDescription = _('Regular Expressions LDAP authenticator') typeDescription = _('Regular Expressions LDAP authenticator')
@ -204,6 +214,7 @@ class RegexLdap(auths.Authenticator):
_groupNameAttr: str = '' _groupNameAttr: str = ''
_userNameAttr: str = '' _userNameAttr: str = ''
_altClass: str = '' _altClass: str = ''
_mfaAttr: str = ''
def __init__( def __init__(
self, self,
@ -230,6 +241,7 @@ class RegexLdap(auths.Authenticator):
# self._regex = values['regex'] # self._regex = values['regex']
self._userNameAttr = values['userNameAttr'] self._userNameAttr = values['userNameAttr']
self._altClass = values['altClass'] self._altClass = values['altClass']
self._mfaAttr = values['mfaAttr']
def __validateField(self, field: str, fieldLabel: str) -> None: def __validateField(self, field: str, fieldLabel: str) -> None:
""" """
@ -295,6 +307,12 @@ class RegexLdap(auths.Authenticator):
logger.debug('Res: %s', res) logger.debug('Res: %s', res)
return res return res
def mfaStorageKey(self, username: str) -> str:
return 'mfa_' + self.dbAuthenticator().uuid + username
def mfaIdentifier(self, username: str) -> str:
return self.storage.getPickle(self.mfaStorageKey(username)) or ''
def valuesDict(self) -> gui.ValuesDictType: def valuesDict(self) -> gui.ValuesDictType:
return { return {
'host': self._host, 'host': self._host,
@ -309,12 +327,13 @@ class RegexLdap(auths.Authenticator):
'groupNameAttr': self._groupNameAttr, 'groupNameAttr': self._groupNameAttr,
'userNameAttr': self._userNameAttr, 'userNameAttr': self._userNameAttr,
'altClass': self._altClass, 'altClass': self._altClass,
'mfaAttr': self._mfaAttr,
} }
def marshal(self) -> bytes: def marshal(self) -> bytes:
return '\t'.join( return '\t'.join(
[ [
'v3', 'v4',
self._host, self._host,
self._port, self._port,
gui.boolToStr(self._ssl), gui.boolToStr(self._ssl),
@ -327,6 +346,7 @@ class RegexLdap(auths.Authenticator):
self._groupNameAttr, self._groupNameAttr,
self._userNameAttr, self._userNameAttr,
self._altClass, self._altClass,
self._mfaAttr,
] ]
).encode('utf8') ).encode('utf8')
@ -384,6 +404,24 @@ class RegexLdap(auths.Authenticator):
self._altClass, self._altClass,
) = vals[1:] ) = vals[1:]
self._ssl = gui.strToBool(ssl) self._ssl = gui.strToBool(ssl)
elif vals[0] == 'v4':
logger.debug("Data v4: %s", vals[1:])
(
self._host,
self._port,
ssl,
self._username,
self._password,
self._timeout,
self._ldapBase,
self._userClass,
self._userIdAttr,
self._groupNameAttr,
self._userNameAttr,
self._altClass,
self._mfaAttr,
) = vals[1:]
self._ssl = gui.strToBool(ssl)
def __connection(self) -> typing.Any: def __connection(self) -> typing.Any:
""" """
@ -427,6 +465,9 @@ class RegexLdap(auths.Authenticator):
+ self.__getAttrsFromField(self._userNameAttr) + self.__getAttrsFromField(self._userNameAttr)
+ self.__getAttrsFromField(self._groupNameAttr) + self.__getAttrsFromField(self._groupNameAttr)
) )
if self._mfaAttr:
attributes = attributes + self.__getAttrsFromField(self._mfaAttr)
user = ldaputil.getFirst( user = ldaputil.getFirst(
con=self.__connection(), con=self.__connection(),
base=self._ldapBase, base=self._ldapBase,
@ -520,6 +561,13 @@ class RegexLdap(auths.Authenticator):
) )
return auths.FAILED_AUTH return auths.FAILED_AUTH
# store the user mfa attribute if it is set
if self._mfaAttr:
self.storage.putPickle(
self.mfaStorageKey(username),
usr[self._mfaAttr][0],
)
groupsManager.validate(self.__getGroups(usr)) groupsManager.validate(self.__getGroups(usr))
return auths.SUCCESS_AUTH return auths.SUCCESS_AUTH

View File

@ -178,6 +178,9 @@ class StorageAsDict(MutableMapping):
def get(self, key: str, default: typing.Any = None) -> typing.Any: def get(self, key: str, default: typing.Any = None) -> typing.Any:
return self[key] or default return self[key] or default
def delete(self, key: str) -> None:
self.__delitem__(key)
# Custom utility methods # Custom utility methods
@property @property
def group(self) -> str: def group(self) -> str:

View File

@ -1,4 +1,4 @@
# Generated by Django 3.2.10 on 2022-06-23 19:34 # Generated by Django 3.2.10 on 2022-07-04 21:20
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion import django.db.models.deletion
@ -11,6 +11,11 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.AddField(
model_name='user',
name='mfaData',
field=models.CharField(default='', max_length=128),
),
migrations.CreateModel( migrations.CreateModel(
name='MFA', name='MFA',
fields=[ fields=[

View File

@ -66,6 +66,9 @@ class User(UUIDModel):
state = models.CharField(max_length=1, db_index=True) state = models.CharField(max_length=1, db_index=True)
password = models.CharField( password = models.CharField(
max_length=128, default='' max_length=128, default=''
) # Only used on "internal" sources or sources that "needs password"
mfaData = models.CharField(
max_length=128, default=''
) # Only used on "internal" sources ) # Only used on "internal" sources
staff_member = models.BooleanField( staff_member = models.BooleanField(
default=False default=False
@ -202,6 +205,26 @@ class User(UUIDModel):
# This group matches # This group matches
yield g yield g
# Get custom data
def getCustomData(self, key: str) -> typing.Optional[str]:
"""
Returns the custom data for this user for the provided key.
Usually custom data will be associated with transports, but can be custom data registered by ANY module.
Args:
key: key of the custom data to get
Returns:
The custom data for the key specified as a string (can be empty if key is not found).
If the key exists, the custom data will always contain something, but may be the values are the default ones.
"""
with storage.StorageAccess('manager' + self.manager.uuid) as store:
return store[self.uuid + '_' + key]
def __str__(self): def __str__(self):
return 'User {} (id:{}) from auth {}'.format( return 'User {} (id:{}) from auth {}'.format(
self.name, self.id, self.manager.name self.name, self.id, self.manager.name
@ -217,11 +240,15 @@ class User(UUIDModel):
:note: If destroy raises an exception, the deletion is not taken. :note: If destroy raises an exception, the deletion is not taken.
""" """
toDelete = kwargs['instance'] toDelete: User = kwargs['instance']
# first, we invoke removeUser. If this raises an exception, user will not # first, we invoke removeUser. If this raises an exception, user will not
# be removed # be removed
toDelete.getManager().removeUser(toDelete.name) toDelete.getManager().removeUser(toDelete.name)
# Remove related stored values
with storage.StorageAccess('manager' + toDelete.manager.uuid) as store:
for key in store.keys():
store.delete(key)
# now removes all "child" of this user, if it has children # now removes all "child" of this user, if it has children
User.objects.filter(parent=toDelete.id).delete() User.objects.filter(parent=toDelete.id).delete()

View File

@ -98,7 +98,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
state_date = models.DateTimeField(db_index=True) state_date = models.DateTimeField(db_index=True)
creation_date = models.DateTimeField(db_index=True) creation_date = models.DateTimeField(db_index=True)
data = models.TextField(default='') data = models.TextField(default='')
user: 'models.ForeignKey[UserService, User]' = models.ForeignKey( user = models.ForeignKey(
User, User,
on_delete=models.CASCADE, on_delete=models.CASCADE,
related_name='userServices', related_name='userServices',
@ -394,7 +394,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
self.os_state = state self.os_state = state
self.save(update_fields=['os_state', 'state_date']) self.save(update_fields=['os_state', 'state_date'])
def assignToUser(self, user: User) -> None: def assignToUser(self, user: typing.Optional[User]) -> None:
""" """
Assigns this user deployed service to an user. Assigns this user deployed service to an user.
@ -403,7 +403,7 @@ class UserService(UUIDModel): # pylint: disable=too-many-public-methods
""" """
self.cache_level = 0 self.cache_level = 0
self.state_date = getSqlDatetime() self.state_date = getSqlDatetime()
self.user = user # type: ignore self.user = user
self.save(update_fields=['cache_level', 'state_date', 'user']) self.save(update_fields=['cache_level', 'state_date', 'user'])
def setInUse(self, inUse: bool) -> None: def setInUse(self, inUse: bool) -> None:

View File

@ -487,6 +487,7 @@ gettext("Role");
gettext("Admin"); gettext("Admin");
gettext("Staff member"); gettext("Staff member");
gettext("User"); gettext("User");
gettext("MFA");
gettext("Groups"); gettext("Groups");
gettext("Cancel"); gettext("Cancel");
gettext("Ok"); gettext("Ok");