mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-22 13:34:04 +03:00
Several mfa fixes and improvements
This commit is contained in:
parent
63d1693fea
commit
f2c5ca2e92
@ -30,19 +30,22 @@
|
||||
"""
|
||||
@author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import collections.abc
|
||||
import datetime
|
||||
import random
|
||||
import enum
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import typing
|
||||
import collections.abc
|
||||
|
||||
from django.utils.translation import gettext_noop as _, gettext
|
||||
from django.utils.translation import gettext
|
||||
from django.utils.translation import gettext_noop as _
|
||||
|
||||
from uds.core import exceptions, types
|
||||
from uds.core.ui import gui
|
||||
from uds.core.module import Module
|
||||
from uds.core.util.model import sql_datetime
|
||||
from uds.models.network import Network
|
||||
from uds.core import exceptions
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from uds.core.environment import Environment
|
||||
@ -71,10 +74,7 @@ class LoginAllowed(enum.StrEnum):
|
||||
def checkIp() -> bool:
|
||||
if networks is None:
|
||||
return True # No network restrictions, so we allow
|
||||
return any(
|
||||
i.contains(request.ip)
|
||||
for i in Network.objects.filter(uuid__in=list(networks))
|
||||
)
|
||||
return any(i.contains(request.ip) for i in Network.objects.filter(uuid__in=list(networks)))
|
||||
|
||||
if isinstance(action, str):
|
||||
action = LoginAllowed(action)
|
||||
@ -87,17 +87,18 @@ class LoginAllowed(enum.StrEnum):
|
||||
}.get(action, False)
|
||||
|
||||
@staticmethod
|
||||
def choices() -> collections.abc.Mapping[str, str]:
|
||||
return {
|
||||
LoginAllowed.ALLOWED.value: gettext('Allow user login'),
|
||||
LoginAllowed.DENIED.value: gettext('Deny user login'),
|
||||
LoginAllowed.ALLOWED_IF_IN_NETWORKS.value: gettext(
|
||||
'Allow user to login if it IP is in the networks list'
|
||||
),
|
||||
LoginAllowed.DENIED_IF_IN_NETWORKS.value: gettext(
|
||||
'Deny user to login if it IP is in the networks list'
|
||||
),
|
||||
}
|
||||
def choices(include_global_allowance: bool = True) -> list[types.ui.ChoiceItem]:
|
||||
result = [
|
||||
gui.choice_item(LoginAllowed.ALLOWED.value, gettext('Allow user login')),
|
||||
gui.choice_item(LoginAllowed.DENIED.value, gettext('Deny user login'))
|
||||
] if include_global_allowance else []
|
||||
result.extend(
|
||||
[
|
||||
gui.choice_item(LoginAllowed.ALLOWED_IF_IN_NETWORKS.value, gettext('Allow user to login if it IP is in the networks list')),
|
||||
gui.choice_item(LoginAllowed.DENIED_IF_IN_NETWORKS.value, gettext('Deny user to login if it IP is in the networks list')),
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MFA(Module):
|
||||
@ -179,9 +180,7 @@ class MFA(Module):
|
||||
"""
|
||||
return ''
|
||||
|
||||
def allow_login_without_identifier(
|
||||
self, request: 'ExtendedHttpRequest'
|
||||
) -> typing.Optional[bool]:
|
||||
def allow_login_without_identifier(self, request: 'ExtendedHttpRequest') -> typing.Optional[bool]:
|
||||
"""
|
||||
If this method returns True, an user that has no "identifier" is allowed to login without MFA
|
||||
Returns:
|
||||
@ -319,11 +318,7 @@ class MFA(Module):
|
||||
data = self._get_data(request, userId)
|
||||
if data and len(data) == 2:
|
||||
validity = validity if validity is not None else 0
|
||||
if (
|
||||
validity > 0
|
||||
and data[0] + datetime.timedelta(seconds=validity)
|
||||
< sql_datetime()
|
||||
):
|
||||
if validity > 0 and data[0] + datetime.timedelta(seconds=validity) < sql_datetime():
|
||||
# if it is no more valid, raise an error
|
||||
# Remove stored code and raise error
|
||||
self._remove_data(request, userId)
|
||||
@ -358,6 +353,4 @@ class MFA(Module):
|
||||
if not mfa:
|
||||
raise exceptions.auth.MFAError('MFA is not enabled')
|
||||
|
||||
return hashlib.sha3_256(
|
||||
(user.name + (user.uuid or '') + mfa.uuid).encode()
|
||||
).hexdigest()
|
||||
return hashlib.sha3_256((user.name + (user.uuid or '') + mfa.uuid).encode()).hexdigest()
|
||||
|
@ -101,7 +101,7 @@ class EmailMFA(mfas.MFA):
|
||||
tab=_('SMTP Server'),
|
||||
)
|
||||
|
||||
emailSubject = gui.TextField(
|
||||
email_subject = gui.TextField(
|
||||
length=128,
|
||||
default='Verification Code',
|
||||
label=_('Subject'),
|
||||
@ -109,26 +109,29 @@ class EmailMFA(mfas.MFA):
|
||||
tooltip=_('Subject of the email'),
|
||||
required=True,
|
||||
tab=_('Config'),
|
||||
old_field_name='emailSubject',
|
||||
)
|
||||
|
||||
fromEmail = gui.TextField(
|
||||
from_email = gui.TextField(
|
||||
length=128,
|
||||
label=_('From Email'),
|
||||
order=11,
|
||||
tooltip=_('Email address that will be used as sender'),
|
||||
required=True,
|
||||
tab=_('Config'),
|
||||
old_field_name='fromEmail',
|
||||
)
|
||||
|
||||
enableHTML = gui.CheckBoxField(
|
||||
enable_html = gui.CheckBoxField(
|
||||
label=_('Enable HTML'),
|
||||
order=13,
|
||||
tooltip=_('Enable HTML in emails'),
|
||||
default=True,
|
||||
tab=_('Config'),
|
||||
old_field_name='enableHTML',
|
||||
)
|
||||
|
||||
allowLoginWithoutMFA = gui.ChoiceField(
|
||||
allow_login_without_mfa = gui.ChoiceField(
|
||||
label=_('Policy for users without MFA support'),
|
||||
order=31,
|
||||
default='0',
|
||||
@ -136,6 +139,7 @@ class EmailMFA(mfas.MFA):
|
||||
required=True,
|
||||
choices=mfas.LoginAllowed.choices(),
|
||||
tab=_('Config'),
|
||||
old_field_name='allowLoginWithoutMFA',
|
||||
)
|
||||
|
||||
networks = gui.MultiChoiceField(
|
||||
@ -145,10 +149,14 @@ class EmailMFA(mfas.MFA):
|
||||
order=32,
|
||||
tooltip=_('Networks for Email OTP authentication'),
|
||||
required=False,
|
||||
choices=lambda: [
|
||||
gui.choice_item(v.uuid, v.name) # type: ignore
|
||||
for v in models.Network.objects.all().order_by('name')
|
||||
],
|
||||
tab=_('Config'),
|
||||
)
|
||||
|
||||
mailTxt = gui.TextField(
|
||||
mail_txt = gui.TextField(
|
||||
length=1024,
|
||||
label=_('Mail text'),
|
||||
order=33,
|
||||
@ -160,9 +168,10 @@ class EmailMFA(mfas.MFA):
|
||||
required=True,
|
||||
default='',
|
||||
tab=_('Config'),
|
||||
old_field_name='mailTxt',
|
||||
)
|
||||
|
||||
mailHtml = gui.TextField(
|
||||
mail_html = gui.TextField(
|
||||
length=1024,
|
||||
label=_('Mail HTML'),
|
||||
order=34,
|
||||
@ -175,6 +184,7 @@ class EmailMFA(mfas.MFA):
|
||||
required=False,
|
||||
default='',
|
||||
tab=_('Config'),
|
||||
old_field_name='mailHtml',
|
||||
)
|
||||
|
||||
def initialize(self, values: 'Module.ValuesType' = None):
|
||||
@ -200,32 +210,26 @@ class EmailMFA(mfas.MFA):
|
||||
self.hostname.value = validators.validate_fqdn(host)
|
||||
|
||||
# now check from email and to email
|
||||
self.fromEmail.value = validators.validate_email(self.fromEmail.value)
|
||||
self.from_email.value = validators.validate_email(self.from_email.value)
|
||||
|
||||
def html(self, request: 'ExtendedHttpRequest', userId: str, username: str) -> str:
|
||||
return gettext('Check your mail. You will receive an email with the verification code')
|
||||
|
||||
def init_gui(self) -> None:
|
||||
# Populate the networks list
|
||||
self.networks.set_choices(
|
||||
[gui.choice_item(v.uuid, v.name) for v in models.Network.objects.all().order_by('name') if v.uuid]
|
||||
)
|
||||
|
||||
def allow_login_without_identifier(self, request: 'ExtendedHttpRequest') -> typing.Optional[bool]:
|
||||
return mfas.LoginAllowed.check_action(self.allowLoginWithoutMFA.value, request, self.networks.value)
|
||||
return mfas.LoginAllowed.check_action(self.allow_login_without_mfa.value, request, self.networks.value)
|
||||
|
||||
def label(self) -> str:
|
||||
return 'OTP received via email'
|
||||
|
||||
@decorators.threaded
|
||||
def doSendCode(self, request: 'ExtendedHttpRequest', identifier: str, code: str) -> None:
|
||||
def send_verification_code_thread(self, request: 'ExtendedHttpRequest', identifier: str, code: str) -> None:
|
||||
# Send and email with the notification
|
||||
with self.login() as smtp:
|
||||
try:
|
||||
# Create message container
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = self.emailSubject.as_clean_str()
|
||||
msg['From'] = self.fromEmail.as_clean_str()
|
||||
msg['Subject'] = self.email_subject.as_clean_str()
|
||||
msg['From'] = self.from_email.as_clean_str()
|
||||
msg['To'] = identifier
|
||||
|
||||
msg.attach(
|
||||
@ -235,7 +239,7 @@ class EmailMFA(mfas.MFA):
|
||||
)
|
||||
)
|
||||
|
||||
if self.enableHTML.value:
|
||||
if self.enable_html.value:
|
||||
msg.attach(
|
||||
MIMEText(
|
||||
f'<p>A login attemt has been made from <b>{request.ip}</b>.</p><p>To continue, provide the verification code <b>{code}</b></p>',
|
||||
@ -243,7 +247,7 @@ class EmailMFA(mfas.MFA):
|
||||
)
|
||||
)
|
||||
|
||||
smtp.sendmail(self.fromEmail.value, identifier, msg.as_string())
|
||||
smtp.sendmail(self.from_email.value, identifier, msg.as_string())
|
||||
except smtplib.SMTPException as e:
|
||||
logger.error('Error sending email: %s', e)
|
||||
raise
|
||||
@ -256,7 +260,7 @@ class EmailMFA(mfas.MFA):
|
||||
identifier: str,
|
||||
code: str,
|
||||
) -> mfas.MFA.RESULT:
|
||||
self.doSendCode(
|
||||
self.send_verification_code_thread(
|
||||
request,
|
||||
identifier,
|
||||
code,
|
||||
|
@ -99,16 +99,17 @@ class RadiusOTP(mfas.MFA):
|
||||
'If checked, all users must enter OTP, so authentication step is skipped.'
|
||||
),
|
||||
)
|
||||
nasIdentifier = gui.TextField(
|
||||
nas_identifier = gui.TextField(
|
||||
length=64,
|
||||
label=_('NAS Identifier'),
|
||||
default='uds-server',
|
||||
order=5,
|
||||
tooltip=_('NAS Identifier for Radius Server'),
|
||||
required=True,
|
||||
old_field_name='nasIdentifier',
|
||||
)
|
||||
|
||||
responseErrorAction = gui.ChoiceField(
|
||||
response_error_action = gui.ChoiceField(
|
||||
label=_('Radius OTP communication error action'),
|
||||
order=31,
|
||||
default='0',
|
||||
@ -116,6 +117,7 @@ class RadiusOTP(mfas.MFA):
|
||||
required=True,
|
||||
choices=mfas.LoginAllowed.choices(),
|
||||
tab=_('Config'),
|
||||
old_field_name='responseErrorAction',
|
||||
)
|
||||
|
||||
networks = gui.MultiChoiceField(
|
||||
@ -132,7 +134,7 @@ class RadiusOTP(mfas.MFA):
|
||||
tab=_('Config'),
|
||||
)
|
||||
|
||||
allowLoginWithoutMFA = gui.ChoiceField(
|
||||
allow_login_without_mfa = gui.ChoiceField(
|
||||
label=_('User without defined OTP in server'),
|
||||
order=33,
|
||||
default='0',
|
||||
@ -140,21 +142,22 @@ class RadiusOTP(mfas.MFA):
|
||||
required=True,
|
||||
choices=mfas.LoginAllowed.choices(),
|
||||
tab=_('Config'),
|
||||
old_field_name='allowLoginWithoutMFA',
|
||||
)
|
||||
|
||||
def initialize(self, values: 'Module.ValuesType') -> None:
|
||||
return super().initialize(values)
|
||||
|
||||
def radiusClient(self) -> client.RadiusClient:
|
||||
def radius_client(self) -> client.RadiusClient:
|
||||
"""Return a new radius client ."""
|
||||
return client.RadiusClient(
|
||||
self.server.value,
|
||||
self.secret.value.encode(),
|
||||
authPort=self.port.as_int(),
|
||||
nasIdentifier=self.nasIdentifier.value,
|
||||
nasIdentifier=self.nas_identifier.value,
|
||||
)
|
||||
|
||||
def checkResult(self, action: str, request: 'ExtendedHttpRequest') -> mfas.MFA.RESULT:
|
||||
def check_result(self, action: str, request: 'ExtendedHttpRequest') -> mfas.MFA.RESULT:
|
||||
if mfas.LoginAllowed.check_action(action, request, self.networks.value):
|
||||
return mfas.MFA.RESULT.OK
|
||||
raise Exception('User not allowed to login')
|
||||
@ -192,11 +195,11 @@ class RadiusOTP(mfas.MFA):
|
||||
|
||||
web_pwd = web_password(request)
|
||||
try:
|
||||
connection = self.radiusClient()
|
||||
connection = self.radius_client()
|
||||
auth_reply = connection.authenticate_challenge(username, password=web_pwd)
|
||||
except Exception as e:
|
||||
logger.error("Exception found connecting to Radius OTP %s: %s", e.__class__, e)
|
||||
if not mfas.LoginAllowed.check_action(self.responseErrorAction.value, request, self.networks.value):
|
||||
if not mfas.LoginAllowed.check_action(self.response_error_action.value, request, self.networks.value):
|
||||
raise Exception(_('Radius OTP connection error')) from e
|
||||
logger.warning(
|
||||
"Radius OTP connection error: Allowing access to user [%s] from IP [%s] without OTP",
|
||||
@ -213,7 +216,7 @@ class RadiusOTP(mfas.MFA):
|
||||
)
|
||||
# we should not be here: not synchronized user password between auth server and radius server
|
||||
# What do we want to do here ??
|
||||
return self.checkResult(self.responseErrorAction.value, request)
|
||||
return self.check_result(self.response_error_action.value, request)
|
||||
|
||||
if auth_reply.otp_needed == NOT_NEEDED:
|
||||
logger.warning(
|
||||
@ -221,7 +224,7 @@ class RadiusOTP(mfas.MFA):
|
||||
username,
|
||||
request.ip,
|
||||
)
|
||||
return self.checkResult(self.allowLoginWithoutMFA.value, request)
|
||||
return self.check_result(self.allow_login_without_mfa.value, request)
|
||||
|
||||
# Store state for later use, related to this user
|
||||
request.session[client.STATE_VAR_NAME] = auth_reply.state or b''
|
||||
@ -253,7 +256,7 @@ class RadiusOTP(mfas.MFA):
|
||||
|
||||
web_pwd = web_password(request)
|
||||
try:
|
||||
connection = self.radiusClient()
|
||||
connection = self.radius_client()
|
||||
state = request.session.get(client.STATE_VAR_NAME, b'')
|
||||
if state:
|
||||
# Remove state from session
|
||||
@ -264,7 +267,7 @@ class RadiusOTP(mfas.MFA):
|
||||
auth_reply = connection.authenticate_challenge(username, password=web_pwd, otp=code)
|
||||
except Exception as e:
|
||||
logger.error("Exception found connecting to Radius OTP %s: %s", e.__class__, e)
|
||||
if mfas.LoginAllowed.check_action(self.responseErrorAction.value, request, self.networks.value):
|
||||
if mfas.LoginAllowed.check_action(self.response_error_action.value, request, self.networks.value):
|
||||
raise Exception(_('Radius OTP connection error')) from e
|
||||
logger.warning(
|
||||
"Radius OTP connection error: Allowing access to user [%s] from IP [%s] without OTP",
|
||||
|
@ -254,19 +254,17 @@ class SMSMFA(mfas.MFA):
|
||||
order=32,
|
||||
tooltip=_('Networks for SMS authentication'),
|
||||
required=False,
|
||||
choices=lambda: [
|
||||
gui.choice_item(v.uuid, v.name) # type: ignore
|
||||
for v in models.Network.objects.all().order_by('name')
|
||||
],
|
||||
tab=_('Config'),
|
||||
)
|
||||
|
||||
def initialize(self, values: 'Module.ValuesType') -> None:
|
||||
return super().initialize(values)
|
||||
|
||||
def init_gui(self) -> None:
|
||||
# Populate the networks list
|
||||
self.networks.set_choices(
|
||||
[gui.choice_item(v.uuid, v.name) for v in models.Network.objects.all().order_by('name') if v.uuid]
|
||||
)
|
||||
|
||||
def composeSmsUrl(
|
||||
def build_sms_url(
|
||||
self,
|
||||
userId: str, # pylint: disable=unused-argument
|
||||
userName: str,
|
||||
@ -281,7 +279,7 @@ class SMSMFA(mfas.MFA):
|
||||
url = url.replace('{justUsername}', userName.split('@')[0])
|
||||
return url
|
||||
|
||||
def getSession(self) -> requests.Session:
|
||||
def get_session(self) -> requests.Session:
|
||||
session = security.secure_requests_session(verify=self.ignoreCertificateErrors.as_bool())
|
||||
# 0 means no authentication
|
||||
if self.authenticationMethod.value == '1':
|
||||
@ -307,7 +305,7 @@ class SMSMFA(mfas.MFA):
|
||||
def allow_login_without_identifier(self, request: 'ExtendedHttpRequest') -> typing.Optional[bool]:
|
||||
return mfas.LoginAllowed.check_action(self.allowLoginWithoutMFA.value, request, self.networks.value)
|
||||
|
||||
def processResponse(self, request: 'ExtendedHttpRequest', response: requests.Response) -> mfas.MFA.RESULT:
|
||||
def process_response(self, request: 'ExtendedHttpRequest', response: requests.Response) -> mfas.MFA.RESULT:
|
||||
logger.debug('Response: %s', response)
|
||||
if not response.ok:
|
||||
logger.warning(
|
||||
@ -336,7 +334,7 @@ class SMSMFA(mfas.MFA):
|
||||
return mfas.MFA.RESULT.ALLOWED
|
||||
return mfas.MFA.RESULT.OK
|
||||
|
||||
def getData(
|
||||
def _build_data(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest', # pylint: disable=unused-argument
|
||||
userId: str, # pylint: disable=unused-argument
|
||||
@ -356,16 +354,16 @@ class SMSMFA(mfas.MFA):
|
||||
)
|
||||
return data.encode(self.encoding.value)
|
||||
|
||||
def sendSMS_GET(
|
||||
def _send_sms_using_get(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest',
|
||||
userId: str, # pylint: disable=unused-argument
|
||||
username: str, # pylint: disable=unused-argument
|
||||
url: str,
|
||||
) -> mfas.MFA.RESULT:
|
||||
return self.processResponse(request, self.getSession().get(url))
|
||||
return self.process_response(request, self.get_session().get(url))
|
||||
|
||||
def sendSMS_POST(
|
||||
def _send_sms_using_post(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest',
|
||||
userId: str,
|
||||
@ -375,14 +373,14 @@ class SMSMFA(mfas.MFA):
|
||||
phone: str,
|
||||
) -> mfas.MFA.RESULT:
|
||||
# Compose POST data
|
||||
session = self.getSession()
|
||||
bdata = self.getData(request, userId, username, url, code, phone)
|
||||
session = self.get_session()
|
||||
bdata = self._build_data(request, userId, username, url, code, phone)
|
||||
# Add content-length header
|
||||
session.headers['Content-Length'] = str(len(bdata))
|
||||
|
||||
return self.processResponse(request, session.post(url, data=bdata))
|
||||
return self.process_response(request, session.post(url, data=bdata))
|
||||
|
||||
def sendSMS_PUT(
|
||||
def _send_sms_using_put(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest',
|
||||
userId: str,
|
||||
@ -392,10 +390,10 @@ class SMSMFA(mfas.MFA):
|
||||
phone: str,
|
||||
) -> mfas.MFA.RESULT:
|
||||
# Compose POST data
|
||||
bdata = self.getData(request, userId, username, url, code, phone)
|
||||
return self.processResponse(request, self.getSession().put(url, data=bdata))
|
||||
bdata = self._build_data(request, userId, username, url, code, phone)
|
||||
return self.process_response(request, self.get_session().put(url, data=bdata))
|
||||
|
||||
def sendSMS(
|
||||
def _send_sms(
|
||||
self,
|
||||
request: 'ExtendedHttpRequest',
|
||||
userId: str,
|
||||
@ -403,13 +401,13 @@ class SMSMFA(mfas.MFA):
|
||||
code: str,
|
||||
phone: str,
|
||||
) -> mfas.MFA.RESULT:
|
||||
url = self.composeSmsUrl(userId, username, code, phone)
|
||||
url = self.build_sms_url(userId, username, code, phone)
|
||||
if self.sendingMethod.value == 'GET':
|
||||
return self.sendSMS_GET(request, userId, username, url)
|
||||
return self._send_sms_using_get(request, userId, username, url)
|
||||
if self.sendingMethod.value == 'POST':
|
||||
return self.sendSMS_POST(request, userId, username, url, code, phone)
|
||||
return self._send_sms_using_post(request, userId, username, url, code, phone)
|
||||
if self.sendingMethod.value == 'PUT':
|
||||
return self.sendSMS_PUT(request, userId, username, url, code, phone)
|
||||
return self._send_sms_using_put(request, userId, username, url, code, phone)
|
||||
raise Exception('Unknown SMS sending method')
|
||||
|
||||
def label(self) -> str:
|
||||
@ -433,4 +431,4 @@ class SMSMFA(mfas.MFA):
|
||||
userId,
|
||||
identifier,
|
||||
)
|
||||
return self.sendSMS(request, userId, username, code, identifier)
|
||||
return self._send_sms(request, userId, username, code, identifier)
|
||||
|
@ -74,7 +74,7 @@ class TOTP_MFA(mfas.MFA):
|
||||
readonly=True, # This is not editable, as it is used to generate the QR code. Once generated, it can't be changed
|
||||
)
|
||||
|
||||
validWindow = gui.NumericField(
|
||||
valid_window = gui.NumericField(
|
||||
length=2,
|
||||
label=_('Valid Window'),
|
||||
default=1,
|
||||
@ -84,6 +84,7 @@ class TOTP_MFA(mfas.MFA):
|
||||
tooltip=_('Number of valid codes before and after the current one'),
|
||||
required=True,
|
||||
tab=_('Config'),
|
||||
old_field_name='validWindow',
|
||||
)
|
||||
networks = gui.MultiChoiceField(
|
||||
label=_('TOTP networks'),
|
||||
@ -92,26 +93,20 @@ class TOTP_MFA(mfas.MFA):
|
||||
order=32,
|
||||
tooltip=_('Users within these networks will not be asked for OTP'),
|
||||
required=False,
|
||||
choices=lambda: [
|
||||
gui.choice_item(v.uuid, v.name) # type: ignore
|
||||
for v in models.Network.objects.all().order_by('name')
|
||||
],
|
||||
tab=_('Config'),
|
||||
)
|
||||
|
||||
def initialize(self, values: 'Module.ValuesType') -> None:
|
||||
return super().initialize(values)
|
||||
|
||||
@classmethod
|
||||
def initClassGui(cls) -> None:
|
||||
# Populate the networks list
|
||||
cls.networks.set_choices(
|
||||
[
|
||||
gui.choice_item(v.uuid, v.name) # type: ignore
|
||||
for v in models.Network.objects.all().order_by('name')
|
||||
]
|
||||
)
|
||||
|
||||
def allow_login_without_identifier(self, request: 'ExtendedHttpRequest') -> typing.Optional[bool]:
|
||||
return None
|
||||
|
||||
def askForOTP(self, request: 'ExtendedHttpRequest') -> bool:
|
||||
def ask_for_otp(self, request: 'ExtendedHttpRequest') -> bool:
|
||||
"""
|
||||
Check if we need to ask for OTP for a given user
|
||||
|
||||
@ -124,25 +119,25 @@ class TOTP_MFA(mfas.MFA):
|
||||
def label(self) -> str:
|
||||
return gettext('Authentication Code')
|
||||
|
||||
def _userData(self, userId: str) -> tuple[str, bool]:
|
||||
def _user_data(self, userId: str) -> tuple[str, bool]:
|
||||
# Get data from storage related to this user
|
||||
# Data contains the secret and if the user has already logged in already some time
|
||||
# so we show the QR code only once
|
||||
data: typing.Optional[tuple[str, bool]] = self.storage.get_unpickle(userId)
|
||||
if data is None:
|
||||
data = (pyotp.random_base32(), False)
|
||||
self._saveUserData(userId, data)
|
||||
self._save_user_data(userId, data)
|
||||
return data
|
||||
|
||||
def _saveUserData(self, userId: str, data: tuple[str, bool]) -> None:
|
||||
def _save_user_data(self, userId: str, data: tuple[str, bool]) -> None:
|
||||
self.storage.put_pickle(userId, data)
|
||||
|
||||
def _removeUserData(self, userId: str) -> None:
|
||||
def _remove_user_data(self, userId: str) -> None:
|
||||
self.storage.remove(userId)
|
||||
|
||||
def getTOTP(self, userId: str, username: str) -> pyotp.TOTP:
|
||||
def get_totp(self, userId: str, username: str) -> pyotp.TOTP:
|
||||
return pyotp.TOTP(
|
||||
self._userData(userId)[0],
|
||||
self._user_data(userId)[0],
|
||||
issuer=self.issuer.value,
|
||||
name=username,
|
||||
interval=TOTP_INTERVAL,
|
||||
@ -150,11 +145,11 @@ class TOTP_MFA(mfas.MFA):
|
||||
|
||||
def html(self, request: 'ExtendedHttpRequest', userId: str, username: str) -> str:
|
||||
# Get data from storage related to this user
|
||||
qrShown = self._userData(userId)[1]
|
||||
qrShown = self._user_data(userId)[1]
|
||||
if qrShown:
|
||||
return _('Enter your authentication code')
|
||||
# Compose the QR code from provisioning URI
|
||||
totp = self.getTOTP(userId, username)
|
||||
totp = self.get_totp(userId, username)
|
||||
uri = totp.provisioning_uri()
|
||||
img = qrcode.make(uri)
|
||||
imgByteStream = io.BytesIO()
|
||||
@ -181,7 +176,7 @@ class TOTP_MFA(mfas.MFA):
|
||||
identifier: str,
|
||||
validity: typing.Optional[int] = None,
|
||||
) -> 'mfas.MFA.RESULT':
|
||||
if self.askForOTP(request) is False:
|
||||
if self.ask_for_otp(request) is False:
|
||||
return mfas.MFA.RESULT.ALLOWED
|
||||
|
||||
# The data is provided by an external source, so we need to process anything on the request
|
||||
@ -196,25 +191,25 @@ class TOTP_MFA(mfas.MFA):
|
||||
code: str,
|
||||
validity: typing.Optional[int] = None,
|
||||
) -> None:
|
||||
if self.askForOTP(request) is False:
|
||||
if self.ask_for_otp(request) is False:
|
||||
return
|
||||
|
||||
if self.cache.get(userId + code) is not None:
|
||||
raise exceptions.auth.MFAError(gettext('Code is already used. Wait a minute and try again.'))
|
||||
|
||||
# Get data from storage related to this user
|
||||
secret, qrShown = self._userData(userId)
|
||||
secret, qrShown = self._user_data(userId)
|
||||
|
||||
# Validate code
|
||||
if not self.getTOTP(userId, username).verify(
|
||||
code, valid_window=self.validWindow.as_int(), for_time=sql_datetime()
|
||||
if not self.get_totp(userId, username).verify(
|
||||
code, valid_window=self.valid_window.as_int(), for_time=sql_datetime()
|
||||
):
|
||||
raise exceptions.auth.MFAError(gettext('Invalid code'))
|
||||
|
||||
self.cache.put(userId + code, True, self.validWindow.as_int() * (TOTP_INTERVAL + 1))
|
||||
self.cache.put(userId + code, True, self.valid_window.as_int() * (TOTP_INTERVAL + 1))
|
||||
|
||||
if qrShown is False:
|
||||
self._saveUserData(userId, (secret, True)) # Update user data to show QR code only once
|
||||
self._save_user_data(userId, (secret, True)) # Update user data to show QR code only once
|
||||
|
||||
def reset_data(self, userId: str) -> None:
|
||||
self._removeUserData(userId)
|
||||
self._remove_user_data(userId)
|
||||
|
Loading…
Reference in New Issue
Block a user