mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-11 05:17:55 +03:00
Several minor fixes
- Error exception show fix - Oauth2 tokeninfo fix - Extracted common fields from ldap, etc.. to fields - Removed "rereading" tag from main.py, nonsense...
This commit is contained in:
parent
735a447334
commit
27f950e2b6
1
server/src/tests/.gitignore
vendored
1
server/src/tests/.gitignore
vendored
@ -1 +0,0 @@
|
||||
/enterprise
|
1
server/src/tests/enterprise
Symbolic link
1
server/src/tests/enterprise
Symbolic link
@ -0,0 +1 @@
|
||||
../../../../enterprise/server/src/tests/enterprise
|
54
server/src/tests/web/util/test_error.py
Normal file
54
server/src/tests/web/util/test_error.py
Normal file
@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (c) 2024 Virtual Cable S.L.U.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
|
||||
# may be used to endorse or promote products derived from this software
|
||||
# without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from uds.core import types
|
||||
|
||||
from ...utils.web import test
|
||||
|
||||
|
||||
class UtilityWebTest(test.WEBTestCase):
|
||||
"""
|
||||
Test WEB login and logout
|
||||
"""
|
||||
|
||||
def test_error_page(self) -> None:
|
||||
"""
|
||||
Test login and logout
|
||||
"""
|
||||
for error_type in types.errors.Error:
|
||||
response = self.client.get(reverse('webapi.error', kwargs={'err': error_type}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# Response should be {"error": "error message", "code": error code as string}
|
||||
data = response.json()
|
||||
self.assertEqual(data['error'], error_type.message)
|
||||
self.assertEqual(int(data['code']), error_type.value)
|
@ -72,19 +72,18 @@ class TokenInfo:
|
||||
id_token: typing.Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def from_dict(dct: dict[str, typing.Any]) -> 'TokenInfo':
|
||||
def from_dict(dct: collections.abc.Mapping[str, typing.Any]) -> 'TokenInfo':
|
||||
# expires is -10 to avoid problems with clock sync
|
||||
return TokenInfo(
|
||||
access_token=dct['access_token'],
|
||||
token_type=dct['token_type'],
|
||||
expires=model.sql_now() + datetime.timedelta(seconds=dct['expires_in'] - 10),
|
||||
refresh_token=dct['refresh_token'],
|
||||
refresh_token=dct.get('refresh_token', ''),
|
||||
scope=dct['scope'],
|
||||
info=dct.get('info', {}),
|
||||
id_token=dct.get('id_token', None),
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Authenticator(auths.Authenticator):
|
||||
"""
|
||||
This class represents an OAuth2 Authenticator.
|
||||
@ -96,7 +95,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
icon_file = 'oauth2.png'
|
||||
|
||||
authorization_endpoint = gui.TextField(
|
||||
length=64,
|
||||
length=256,
|
||||
label=_('Authorization endpoint'),
|
||||
order=10,
|
||||
tooltip=_('Authorization endpoint for OAuth2.'),
|
||||
@ -104,7 +103,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=_('Server'),
|
||||
)
|
||||
client_id = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Client ID'),
|
||||
order=2,
|
||||
tooltip=_('Client ID for OAuth2.'),
|
||||
@ -112,7 +111,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=_('Server'),
|
||||
)
|
||||
client_secret = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Client Secret'),
|
||||
order=3,
|
||||
tooltip=_('Client secret for OAuth2.'),
|
||||
@ -120,7 +119,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=_('Server'),
|
||||
)
|
||||
scope = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Scope'),
|
||||
order=4,
|
||||
tooltip=_('Scope for OAuth2.'),
|
||||
@ -128,7 +127,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=_('Server'),
|
||||
)
|
||||
common_groups = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Common Groups'),
|
||||
order=5,
|
||||
tooltip=_('User will be assigned to this groups once authenticated. Comma separated list of groups'),
|
||||
@ -138,7 +137,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
|
||||
# Advanced options
|
||||
redirection_endpoint = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Redirection endpoint'),
|
||||
order=90,
|
||||
tooltip=_('Redirection endpoint for OAuth2. (Filled by UDS)'),
|
||||
@ -168,7 +167,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
)
|
||||
# In case of code, we need to get the token from the token endpoint
|
||||
token_endpoint = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('Token endpoint'),
|
||||
order=92,
|
||||
tooltip=_('Token endpoint for OAuth2. Only required for "code" response type.'),
|
||||
@ -176,7 +175,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
)
|
||||
info_endpoint = gui.TextField(
|
||||
length=64,
|
||||
length=128,
|
||||
label=_('User information endpoint'),
|
||||
order=93,
|
||||
tooltip=_('User information endpoint for OAuth2. Only required for "code" response type.'),
|
||||
@ -192,36 +191,10 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
required=False,
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
)
|
||||
|
||||
username_attr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('User name attrs'),
|
||||
order=100,
|
||||
tooltip=_('Fields from where to extract user name'),
|
||||
required=True,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
groupname_attr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Group name attrs'),
|
||||
order=101,
|
||||
tooltip=_('Fields from where to extract the groups'),
|
||||
required=False,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
realname_attr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Real name attrs'),
|
||||
order=102,
|
||||
tooltip=_('Fields from where to extract the real name'),
|
||||
required=False,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
username_attr = fields.username_attr_field(order=100)
|
||||
groupname_attr = fields.groupname_attr_field(order=101)
|
||||
realname_attr = fields.realname_attr_field(order=102)
|
||||
|
||||
def _get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
|
||||
# Get certificates in self.publicKey.value, encoded as PEM
|
||||
@ -237,14 +210,14 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
Returns:
|
||||
tuple[str, str]: Code verifier and code challenge
|
||||
"""
|
||||
codeVerifier = ''.join(secrets.choice(PKCE_ALPHABET) for _ in range(128))
|
||||
codeChallenge = (
|
||||
b64decode(hashlib.sha256(codeVerifier.encode('ascii')).digest(), altchars=b'-_')
|
||||
code_verifier = ''.join(secrets.choice(PKCE_ALPHABET) for _ in range(128))
|
||||
code_challenge = (
|
||||
b64decode(hashlib.sha256(code_verifier.encode('ascii')).digest(), altchars=b'-_')
|
||||
.decode()
|
||||
.rstrip('=') # remove padding
|
||||
)
|
||||
|
||||
return codeVerifier, codeChallenge
|
||||
return code_verifier, code_challenge
|
||||
|
||||
def _get_response_type_string(self) -> str:
|
||||
match self.response_type.value:
|
||||
@ -331,13 +304,15 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
if code_verifier:
|
||||
param_dict['code_verifier'] = code_verifier
|
||||
|
||||
req = requests.post(self.token_endpoint.value, data=param_dict, timeout=consts.system.COMMS_TIMEOUT)
|
||||
logger.debug('Token request: %s %s', req.status_code, req.text)
|
||||
response = requests.post(
|
||||
self.token_endpoint.value, data=param_dict, timeout=consts.system.COMMS_TIMEOUT
|
||||
)
|
||||
logger.debug('Token request: %s %s', response.status_code, response.text)
|
||||
|
||||
if not req.ok:
|
||||
raise Exception('Error requesting token: {}'.format(req.text))
|
||||
if not response.ok:
|
||||
raise Exception('Error requesting token: {}'.format(response.text))
|
||||
|
||||
return TokenInfo.from_dict(req.json())
|
||||
return TokenInfo.from_dict(response.json())
|
||||
|
||||
def _request_info(self, token: 'TokenInfo') -> dict[str, typing.Any]:
|
||||
"""Request user info from the info endpoint using the token received from the token endpoint
|
||||
@ -557,16 +532,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
# Get the token, token_type, expires
|
||||
token = TokenInfo(
|
||||
access_token=parameters.get_params.get('access_token', ''),
|
||||
token_type=parameters.get_params.get('token_type', ''),
|
||||
expires=model.sql_now()
|
||||
+ datetime.timedelta(seconds=int(parameters.get_params.get('expires_in', 0))),
|
||||
refresh_token=parameters.get_params.get('refresh_token', ''),
|
||||
scope=parameters.get_params.get('scope', ''),
|
||||
info={},
|
||||
id_token=None,
|
||||
)
|
||||
token = TokenInfo.from_dict(parameters.get_params)
|
||||
return self._process_token(self._request_info(token), gm)
|
||||
|
||||
def auth_callback_openid_code(
|
||||
|
@ -133,31 +133,8 @@ class RegexLdap(auths.Authenticator):
|
||||
required=True,
|
||||
tab=_('Ldap info'),
|
||||
)
|
||||
username_attr = gui.TextField(
|
||||
length=640,
|
||||
label=_('User Name Attr'),
|
||||
lines=2,
|
||||
default='uid',
|
||||
order=23,
|
||||
tooltip=_(
|
||||
'Attributes that contains the user name attributes or attribute patterns (one for each line)'
|
||||
),
|
||||
required=True,
|
||||
tab=_('Ldap info'),
|
||||
)
|
||||
groupname_attr = gui.TextField(
|
||||
length=640,
|
||||
label=_('Group Name Attr'),
|
||||
lines=2,
|
||||
default='cn',
|
||||
order=24,
|
||||
tooltip=_(
|
||||
'Attribute that contains the group name attributes or attribute patterns (one for each line)'
|
||||
),
|
||||
required=True,
|
||||
tab=_('Ldap info'),
|
||||
)
|
||||
# regex = gui.TextField(length=64, label = _('Regular Exp. for groups'), defvalue = '^(.*)', order = 12, tooltip = _('Regular Expression to extract the group name'), required = True)
|
||||
username_attr = fields.realname_attr_field(tab=_('Ldap info'), order=23, default='uid')
|
||||
groupname_attr = fields.groupname_attr_field(tab=_('Ldap info'), order=24, default='cn')
|
||||
|
||||
alternate_class = gui.TextField(
|
||||
length=64,
|
||||
|
@ -40,7 +40,7 @@ from django.utils.translation import gettext_noop as _
|
||||
from uds.core import auths, environment, types, exceptions
|
||||
from uds.core.auths.auth import log_login
|
||||
from uds.core.ui import gui
|
||||
from uds.core.util import ensure, fields, ldaputil, validators
|
||||
from uds.core.util import ensure, fields, ldaputil, validators, auth as auth_utils
|
||||
|
||||
# Not imported at runtime, just for type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
@ -129,15 +129,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
required=True,
|
||||
tab=_('Ldap info'),
|
||||
)
|
||||
username_attr = gui.TextField(
|
||||
length=64,
|
||||
label=_('User Name Attr'),
|
||||
default='uid',
|
||||
order=33,
|
||||
tooltip=_('Attributes that contains the user name (list of comma separated values)'),
|
||||
required=True,
|
||||
tab=_('Ldap info'),
|
||||
)
|
||||
username_attr = fields.realname_attr_field(tab=_('Ldap info'), order=33, default='uid')
|
||||
group_class = gui.TextField(
|
||||
length=64,
|
||||
label=_('Group class'),
|
||||
@ -195,7 +187,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
|
||||
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
|
||||
if values:
|
||||
self.username_attr.value = self.username_attr.value.replace(' ', '') # Removes white spaces
|
||||
auth_utils.validate_regex_field(self.username_attr)
|
||||
validators.validate_certificate(self.certificate.value)
|
||||
|
||||
def unmarshal(self, data: bytes) -> None:
|
||||
@ -222,6 +214,9 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
self.member_attr.value = vals[12]
|
||||
self.username_attr.value = vals[13]
|
||||
|
||||
# Upgrade to new format
|
||||
self.username_attr.value = '\n'.join(self.username_attr.value.split(','))
|
||||
|
||||
logger.debug("Data: %s", vals[1:])
|
||||
|
||||
if vals[0] == 'v2':
|
||||
@ -331,21 +326,12 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
|
||||
logger.exception('Exception at __getGroups')
|
||||
return []
|
||||
|
||||
def _get_user_realname(self, usr: ldaputil.LDAPResultType) -> str:
|
||||
def _get_user_realname(self, user: ldaputil.LDAPResultType) -> str:
|
||||
'''
|
||||
Tries to extract the real name for this user. Will return all atttributes (joint)
|
||||
specified in _userNameAttr (comma separated).
|
||||
'''
|
||||
return ' '.join(
|
||||
[
|
||||
(
|
||||
' '.join((str(k) for k in usr.get(id_, '')))
|
||||
if isinstance(usr.get(id_), list)
|
||||
else str(usr.get(id_, ''))
|
||||
)
|
||||
for id_ in self.username_attr.as_str().split(',')
|
||||
]
|
||||
).strip()
|
||||
return ' '.join(auth_utils.process_regex_field(self.username_attr.value, user))
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
|
@ -65,9 +65,9 @@ class Error(enum.IntEnum):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
try:
|
||||
return ERROR_STRINGS[self.value]
|
||||
return str(ERROR_STRINGS[self.value])
|
||||
except IndexError:
|
||||
return ERROR_STRINGS[0]
|
||||
return str(ERROR_STRINGS[0])
|
||||
|
||||
@staticmethod
|
||||
def from_int(value: int) -> 'Error':
|
||||
|
@ -48,6 +48,7 @@ class Tab(enum.StrEnum):
|
||||
MFA = gettext_noop('MFA')
|
||||
MACHINE = gettext_noop('Machine')
|
||||
CONFIG = gettext_noop('Config')
|
||||
ATTRIBUTES = gettext_noop('Attributes')
|
||||
|
||||
@staticmethod
|
||||
def from_str(value: typing.Optional[str]) -> typing.Union['Tab', str, None]:
|
||||
|
@ -70,7 +70,7 @@ def get_attributes_regex_field(field: 'ui.gui.TextField|str') -> set[str]:
|
||||
|
||||
res: set[str] = set()
|
||||
for line in content.splitlines():
|
||||
attr, _pattern = (line.split('=')[0:2] + [''])[0:2]
|
||||
attr, _pattern = (line.split('=')[0:2] + [''])[0:2] # Endure 2 values
|
||||
|
||||
# If attributes concateated with +, add all
|
||||
if '+' in attr:
|
||||
@ -124,9 +124,10 @@ def process_regex_field(
|
||||
return []
|
||||
|
||||
for line in field.splitlines():
|
||||
equalPos = line.find('=')
|
||||
if equalPos != -1:
|
||||
attr, pattern = (line[:equalPos], line[equalPos + 1 :])
|
||||
equal_pos = line.find('=')
|
||||
if equal_pos != -1:
|
||||
# attr before first =, pattern after
|
||||
attr, pattern = (line[:equal_pos], line[equal_pos + 1 :])
|
||||
# if pattern do not have groups, define one with full re
|
||||
if pattern.find('(') == -1:
|
||||
pattern = '(' + pattern + ')'
|
||||
|
@ -488,7 +488,9 @@ def login_without_mfa_policy_field(
|
||||
)
|
||||
|
||||
|
||||
def put_back_to_cache_field(order: int = 120, tab: 'types.ui.Tab|str|None' = types.ui.Tab.ADVANCED) -> ui.gui.ChoiceField:
|
||||
def put_back_to_cache_field(
|
||||
order: int = 120, tab: 'types.ui.Tab|str|None' = types.ui.Tab.ADVANCED
|
||||
) -> ui.gui.ChoiceField:
|
||||
return ui.gui.ChoiceField(
|
||||
order=order,
|
||||
label=_('Put back to cache'),
|
||||
@ -501,6 +503,61 @@ def put_back_to_cache_field(order: int = 120, tab: 'types.ui.Tab|str|None' = typ
|
||||
)
|
||||
|
||||
|
||||
def username_attr_field(
|
||||
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
|
||||
default: typing.Optional[str] = None,
|
||||
order: int = 100,
|
||||
) -> ui.gui.TextField:
|
||||
return ui.gui.TextField(
|
||||
length=640,
|
||||
label=_('User Name Attr'),
|
||||
lines=3,
|
||||
default=default or '',
|
||||
order=order,
|
||||
tooltip=_(
|
||||
'Attributes that contains the user name attributes or attribute patterns (one for each line)'
|
||||
),
|
||||
required=True,
|
||||
tab=tab,
|
||||
)
|
||||
|
||||
|
||||
def groupname_attr_field(
|
||||
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
|
||||
default: typing.Optional[str] = None,
|
||||
order: int = 101,
|
||||
) -> ui.gui.TextField:
|
||||
return ui.gui.TextField(
|
||||
length=640,
|
||||
label=_('Group Name Attr'),
|
||||
lines=3,
|
||||
default=default or '',
|
||||
order=order,
|
||||
tooltip=_(
|
||||
'Attribute that contains the group name attributes or attribute patterns (one for each line)'
|
||||
),
|
||||
required=True,
|
||||
tab=tab,
|
||||
)
|
||||
|
||||
|
||||
def realname_attr_field(
|
||||
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
|
||||
default: typing.Optional[str] = None,
|
||||
order: int = 102,
|
||||
) -> ui.gui.TextField:
|
||||
return ui.gui.TextField(
|
||||
length=640,
|
||||
label=_('Real Name Attr'),
|
||||
lines=3,
|
||||
default=default or '',
|
||||
order=order,
|
||||
tooltip=_('Attribute that contains the real name attributes or attribute patterns (one for each line)'),
|
||||
required=True,
|
||||
tab=tab,
|
||||
)
|
||||
|
||||
|
||||
def onlogout_field_is_persistent(fld: ui.gui.ChoiceField) -> bool:
|
||||
return fld.value == 'keep-always'
|
||||
|
||||
|
@ -29,7 +29,6 @@
|
||||
"""
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
"""
|
||||
import traceback
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
@ -53,8 +52,8 @@ if typing.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def error_view(request: 'HttpRequest', errorCode: int) -> HttpResponseRedirect:
|
||||
return HttpResponseRedirect(reverse('page.error', kwargs={'err': errorCode}))
|
||||
def error_view(request: 'HttpRequest', error_code: int) -> HttpResponseRedirect:
|
||||
return HttpResponseRedirect(reverse('page.error', kwargs={'err': error_code}))
|
||||
|
||||
|
||||
def error(request: 'HttpRequest', err: str) -> 'HttpResponse':
|
||||
@ -68,15 +67,23 @@ def exception_view(request: 'HttpRequest', exception: Exception) -> HttpResponse
|
||||
"""
|
||||
Tries to render an error page with error information
|
||||
"""
|
||||
logger.debug(traceback.format_exc())
|
||||
# import traceback
|
||||
# logger.debug(traceback.format_exc())
|
||||
return error_view(request, types.errors.Error.from_exception(exception))
|
||||
|
||||
|
||||
def error_message(request: 'HttpRequest', err: int) -> 'HttpResponse':
|
||||
def error_message(request: 'HttpRequest', err: str) -> 'HttpResponse':
|
||||
"""
|
||||
Error view, responsible of error display
|
||||
"""
|
||||
# get error as integer or replace it by 0
|
||||
|
||||
try:
|
||||
err_int = int(err)
|
||||
except Exception:
|
||||
err_int = 0
|
||||
|
||||
return HttpResponse(
|
||||
json.dumps({'error': types.errors.Error.from_int(err).message, 'code': str(err)}),
|
||||
json.dumps({'error': types.errors.Error.from_int(err_int).message, 'code': str(err)}),
|
||||
content_type='application/json',
|
||||
)
|
||||
|
@ -133,7 +133,7 @@ def auth_callback_stage2(request: 'ExtendedHttpRequestWithUser', ticket_id: str)
|
||||
request.build_absolute_uri(str(e)) if e.args and e.args[0] else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('authCallback')
|
||||
logger.error('Error authenticating user: %s', e)
|
||||
return errors.exception_view(request, e)
|
||||
|
||||
|
||||
|
@ -99,40 +99,41 @@ def login(request: types.requests.ExtendedHttpRequest, tag: typing.Optional[str]
|
||||
request.authorized = False # Ensure that on login page, user is unauthorized first
|
||||
|
||||
form = LoginForm(request.POST, tag=tag)
|
||||
loginResult = check_login(request, form, tag)
|
||||
if loginResult.user:
|
||||
login_result = check_login(request, form, tag)
|
||||
if login_result.user:
|
||||
response = HttpResponseRedirect(reverse('page.index'))
|
||||
# save tag, weblogin will clear session
|
||||
tag = request.session.get('tag')
|
||||
# Tag is not removed from session, so next login will have it even if not provided
|
||||
# This means than once an url is used, unless manually goes to "/uds/page/login/xxx"
|
||||
# The tag will be used again
|
||||
auth.web_login(
|
||||
request, response, loginResult.user, loginResult.password
|
||||
request, response, login_result.user, login_result.password
|
||||
) # data is user password here
|
||||
|
||||
# If MFA is provided, we need to redirect to MFA page
|
||||
request.authorized = True
|
||||
if (
|
||||
loginResult.user.manager.get_type().provides_mfa()
|
||||
and loginResult.user.manager.mfa
|
||||
and loginResult.user.groups.filter(skip_mfa=types.states.State.ACTIVE).count() == 0
|
||||
login_result.user.manager.get_type().provides_mfa()
|
||||
and login_result.user.manager.mfa
|
||||
and login_result.user.groups.filter(skip_mfa=types.states.State.ACTIVE).count() == 0
|
||||
):
|
||||
request.authorized = False
|
||||
response = HttpResponseRedirect(reverse('page.mfa'))
|
||||
|
||||
else:
|
||||
# If redirection on login failure is found, honor it
|
||||
if loginResult.url: # Redirection
|
||||
return HttpResponseRedirect(loginResult.url)
|
||||
if login_result.url: # Redirection
|
||||
return HttpResponseRedirect(login_result.url)
|
||||
|
||||
if request.ip not in ('127.0.0.1', '::1'): # If not localhost, wait a bit
|
||||
time.sleep(
|
||||
random.SystemRandom().randint(1600, 2400) / 1000
|
||||
) # On failure, wait a bit if not localhost (random wait)
|
||||
# If error is numeric, redirect...
|
||||
if loginResult.errid:
|
||||
return errors.error_view(request, loginResult.errid)
|
||||
if login_result.errid:
|
||||
return errors.error_view(request, login_result.errid)
|
||||
|
||||
# Error, set error on session for process for js
|
||||
request.session['errors'] = [loginResult.errstr]
|
||||
request.session['errors'] = [login_result.errstr]
|
||||
else:
|
||||
request.session['tag'] = tag
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user