1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-11 05:17:55 +03:00

Several minor fixes

- Error exception show fix
- Oauth2 tokeninfo fix
- Extracted common fields from ldap, etc.. to fields
- Removed "rereading" tag from main.py, nonsense...
This commit is contained in:
Adolfo Gómez García 2024-09-17 02:47:34 +02:00
parent 735a447334
commit 27f950e2b6
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
13 changed files with 185 additions and 135 deletions

View File

@ -1 +0,0 @@
/enterprise

1
server/src/tests/enterprise Symbolic link
View File

@ -0,0 +1 @@
../../../../enterprise/server/src/tests/enterprise

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
from django.urls import reverse
from uds.core import types
from ...utils.web import test
class UtilityWebTest(test.WEBTestCase):
"""
Test WEB login and logout
"""
def test_error_page(self) -> None:
"""
Test login and logout
"""
for error_type in types.errors.Error:
response = self.client.get(reverse('webapi.error', kwargs={'err': error_type}))
self.assertEqual(response.status_code, 200)
# Response should be {"error": "error message", "code": error code as string}
data = response.json()
self.assertEqual(data['error'], error_type.message)
self.assertEqual(int(data['code']), error_type.value)

View File

@ -72,19 +72,18 @@ class TokenInfo:
id_token: typing.Optional[str]
@staticmethod
def from_dict(dct: dict[str, typing.Any]) -> 'TokenInfo':
def from_dict(dct: collections.abc.Mapping[str, typing.Any]) -> 'TokenInfo':
# expires is -10 to avoid problems with clock sync
return TokenInfo(
access_token=dct['access_token'],
token_type=dct['token_type'],
expires=model.sql_now() + datetime.timedelta(seconds=dct['expires_in'] - 10),
refresh_token=dct['refresh_token'],
refresh_token=dct.get('refresh_token', ''),
scope=dct['scope'],
info=dct.get('info', {}),
id_token=dct.get('id_token', None),
)
class OAuth2Authenticator(auths.Authenticator):
"""
This class represents an OAuth2 Authenticator.
@ -96,7 +95,7 @@ class OAuth2Authenticator(auths.Authenticator):
icon_file = 'oauth2.png'
authorization_endpoint = gui.TextField(
length=64,
length=256,
label=_('Authorization endpoint'),
order=10,
tooltip=_('Authorization endpoint for OAuth2.'),
@ -104,7 +103,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=_('Server'),
)
client_id = gui.TextField(
length=64,
length=128,
label=_('Client ID'),
order=2,
tooltip=_('Client ID for OAuth2.'),
@ -112,7 +111,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=_('Server'),
)
client_secret = gui.TextField(
length=64,
length=128,
label=_('Client Secret'),
order=3,
tooltip=_('Client secret for OAuth2.'),
@ -120,7 +119,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=_('Server'),
)
scope = gui.TextField(
length=64,
length=128,
label=_('Scope'),
order=4,
tooltip=_('Scope for OAuth2.'),
@ -128,7 +127,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=_('Server'),
)
common_groups = gui.TextField(
length=64,
length=128,
label=_('Common Groups'),
order=5,
tooltip=_('User will be assigned to this groups once authenticated. Comma separated list of groups'),
@ -138,7 +137,7 @@ class OAuth2Authenticator(auths.Authenticator):
# Advanced options
redirection_endpoint = gui.TextField(
length=64,
length=128,
label=_('Redirection endpoint'),
order=90,
tooltip=_('Redirection endpoint for OAuth2. (Filled by UDS)'),
@ -168,7 +167,7 @@ class OAuth2Authenticator(auths.Authenticator):
)
# In case of code, we need to get the token from the token endpoint
token_endpoint = gui.TextField(
length=64,
length=128,
label=_('Token endpoint'),
order=92,
tooltip=_('Token endpoint for OAuth2. Only required for "code" response type.'),
@ -176,7 +175,7 @@ class OAuth2Authenticator(auths.Authenticator):
tab=types.ui.Tab.ADVANCED,
)
info_endpoint = gui.TextField(
length=64,
length=128,
label=_('User information endpoint'),
order=93,
tooltip=_('User information endpoint for OAuth2. Only required for "code" response type.'),
@ -192,36 +191,10 @@ class OAuth2Authenticator(auths.Authenticator):
required=False,
tab=types.ui.Tab.ADVANCED,
)
username_attr = gui.TextField(
length=2048,
lines=2,
label=_('User name attrs'),
order=100,
tooltip=_('Fields from where to extract user name'),
required=True,
tab=_('Attributes'),
)
groupname_attr = gui.TextField(
length=2048,
lines=2,
label=_('Group name attrs'),
order=101,
tooltip=_('Fields from where to extract the groups'),
required=False,
tab=_('Attributes'),
)
realname_attr = gui.TextField(
length=2048,
lines=2,
label=_('Real name attrs'),
order=102,
tooltip=_('Fields from where to extract the real name'),
required=False,
tab=_('Attributes'),
)
username_attr = fields.username_attr_field(order=100)
groupname_attr = fields.groupname_attr_field(order=101)
realname_attr = fields.realname_attr_field(order=102)
def _get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
# Get certificates in self.publicKey.value, encoded as PEM
@ -237,14 +210,14 @@ class OAuth2Authenticator(auths.Authenticator):
Returns:
tuple[str, str]: Code verifier and code challenge
"""
codeVerifier = ''.join(secrets.choice(PKCE_ALPHABET) for _ in range(128))
codeChallenge = (
b64decode(hashlib.sha256(codeVerifier.encode('ascii')).digest(), altchars=b'-_')
code_verifier = ''.join(secrets.choice(PKCE_ALPHABET) for _ in range(128))
code_challenge = (
b64decode(hashlib.sha256(code_verifier.encode('ascii')).digest(), altchars=b'-_')
.decode()
.rstrip('=') # remove padding
)
return codeVerifier, codeChallenge
return code_verifier, code_challenge
def _get_response_type_string(self) -> str:
match self.response_type.value:
@ -331,13 +304,15 @@ class OAuth2Authenticator(auths.Authenticator):
if code_verifier:
param_dict['code_verifier'] = code_verifier
req = requests.post(self.token_endpoint.value, data=param_dict, timeout=consts.system.COMMS_TIMEOUT)
logger.debug('Token request: %s %s', req.status_code, req.text)
response = requests.post(
self.token_endpoint.value, data=param_dict, timeout=consts.system.COMMS_TIMEOUT
)
logger.debug('Token request: %s %s', response.status_code, response.text)
if not req.ok:
raise Exception('Error requesting token: {}'.format(req.text))
if not response.ok:
raise Exception('Error requesting token: {}'.format(response.text))
return TokenInfo.from_dict(req.json())
return TokenInfo.from_dict(response.json())
def _request_info(self, token: 'TokenInfo') -> dict[str, typing.Any]:
"""Request user info from the info endpoint using the token received from the token endpoint
@ -557,16 +532,7 @@ class OAuth2Authenticator(auths.Authenticator):
return types.auth.FAILED_AUTH
# Get the token, token_type, expires
token = TokenInfo(
access_token=parameters.get_params.get('access_token', ''),
token_type=parameters.get_params.get('token_type', ''),
expires=model.sql_now()
+ datetime.timedelta(seconds=int(parameters.get_params.get('expires_in', 0))),
refresh_token=parameters.get_params.get('refresh_token', ''),
scope=parameters.get_params.get('scope', ''),
info={},
id_token=None,
)
token = TokenInfo.from_dict(parameters.get_params)
return self._process_token(self._request_info(token), gm)
def auth_callback_openid_code(

View File

@ -133,31 +133,8 @@ class RegexLdap(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
username_attr = gui.TextField(
length=640,
label=_('User Name Attr'),
lines=2,
default='uid',
order=23,
tooltip=_(
'Attributes that contains the user name attributes or attribute patterns (one for each line)'
),
required=True,
tab=_('Ldap info'),
)
groupname_attr = gui.TextField(
length=640,
label=_('Group Name Attr'),
lines=2,
default='cn',
order=24,
tooltip=_(
'Attribute that contains the group name attributes or attribute patterns (one for each line)'
),
required=True,
tab=_('Ldap info'),
)
# regex = gui.TextField(length=64, label = _('Regular Exp. for groups'), defvalue = '^(.*)', order = 12, tooltip = _('Regular Expression to extract the group name'), required = True)
username_attr = fields.realname_attr_field(tab=_('Ldap info'), order=23, default='uid')
groupname_attr = fields.groupname_attr_field(tab=_('Ldap info'), order=24, default='cn')
alternate_class = gui.TextField(
length=64,

View File

@ -40,7 +40,7 @@ from django.utils.translation import gettext_noop as _
from uds.core import auths, environment, types, exceptions
from uds.core.auths.auth import log_login
from uds.core.ui import gui
from uds.core.util import ensure, fields, ldaputil, validators
from uds.core.util import ensure, fields, ldaputil, validators, auth as auth_utils
# Not imported at runtime, just for type checking
if typing.TYPE_CHECKING:
@ -129,15 +129,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
required=True,
tab=_('Ldap info'),
)
username_attr = gui.TextField(
length=64,
label=_('User Name Attr'),
default='uid',
order=33,
tooltip=_('Attributes that contains the user name (list of comma separated values)'),
required=True,
tab=_('Ldap info'),
)
username_attr = fields.realname_attr_field(tab=_('Ldap info'), order=33, default='uid')
group_class = gui.TextField(
length=64,
label=_('Group class'),
@ -195,7 +187,7 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
if values:
self.username_attr.value = self.username_attr.value.replace(' ', '') # Removes white spaces
auth_utils.validate_regex_field(self.username_attr)
validators.validate_certificate(self.certificate.value)
def unmarshal(self, data: bytes) -> None:
@ -222,6 +214,9 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
self.member_attr.value = vals[12]
self.username_attr.value = vals[13]
# Upgrade to new format
self.username_attr.value = '\n'.join(self.username_attr.value.split(','))
logger.debug("Data: %s", vals[1:])
if vals[0] == 'v2':
@ -331,21 +326,12 @@ class SimpleLDAPAuthenticator(auths.Authenticator):
logger.exception('Exception at __getGroups')
return []
def _get_user_realname(self, usr: ldaputil.LDAPResultType) -> str:
def _get_user_realname(self, user: ldaputil.LDAPResultType) -> str:
'''
Tries to extract the real name for this user. Will return all atttributes (joint)
specified in _userNameAttr (comma separated).
'''
return ' '.join(
[
(
' '.join((str(k) for k in usr.get(id_, '')))
if isinstance(usr.get(id_), list)
else str(usr.get(id_, ''))
)
for id_ in self.username_attr.as_str().split(',')
]
).strip()
return ' '.join(auth_utils.process_regex_field(self.username_attr.value, user))
def authenticate(
self,

View File

@ -65,9 +65,9 @@ class Error(enum.IntEnum):
@property
def message(self) -> str:
try:
return ERROR_STRINGS[self.value]
return str(ERROR_STRINGS[self.value])
except IndexError:
return ERROR_STRINGS[0]
return str(ERROR_STRINGS[0])
@staticmethod
def from_int(value: int) -> 'Error':

View File

@ -48,6 +48,7 @@ class Tab(enum.StrEnum):
MFA = gettext_noop('MFA')
MACHINE = gettext_noop('Machine')
CONFIG = gettext_noop('Config')
ATTRIBUTES = gettext_noop('Attributes')
@staticmethod
def from_str(value: typing.Optional[str]) -> typing.Union['Tab', str, None]:

View File

@ -70,7 +70,7 @@ def get_attributes_regex_field(field: 'ui.gui.TextField|str') -> set[str]:
res: set[str] = set()
for line in content.splitlines():
attr, _pattern = (line.split('=')[0:2] + [''])[0:2]
attr, _pattern = (line.split('=')[0:2] + [''])[0:2] # Endure 2 values
# If attributes concateated with +, add all
if '+' in attr:
@ -124,9 +124,10 @@ def process_regex_field(
return []
for line in field.splitlines():
equalPos = line.find('=')
if equalPos != -1:
attr, pattern = (line[:equalPos], line[equalPos + 1 :])
equal_pos = line.find('=')
if equal_pos != -1:
# attr before first =, pattern after
attr, pattern = (line[:equal_pos], line[equal_pos + 1 :])
# if pattern do not have groups, define one with full re
if pattern.find('(') == -1:
pattern = '(' + pattern + ')'

View File

@ -488,7 +488,9 @@ def login_without_mfa_policy_field(
)
def put_back_to_cache_field(order: int = 120, tab: 'types.ui.Tab|str|None' = types.ui.Tab.ADVANCED) -> ui.gui.ChoiceField:
def put_back_to_cache_field(
order: int = 120, tab: 'types.ui.Tab|str|None' = types.ui.Tab.ADVANCED
) -> ui.gui.ChoiceField:
return ui.gui.ChoiceField(
order=order,
label=_('Put back to cache'),
@ -501,6 +503,61 @@ def put_back_to_cache_field(order: int = 120, tab: 'types.ui.Tab|str|None' = typ
)
def username_attr_field(
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
default: typing.Optional[str] = None,
order: int = 100,
) -> ui.gui.TextField:
return ui.gui.TextField(
length=640,
label=_('User Name Attr'),
lines=3,
default=default or '',
order=order,
tooltip=_(
'Attributes that contains the user name attributes or attribute patterns (one for each line)'
),
required=True,
tab=tab,
)
def groupname_attr_field(
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
default: typing.Optional[str] = None,
order: int = 101,
) -> ui.gui.TextField:
return ui.gui.TextField(
length=640,
label=_('Group Name Attr'),
lines=3,
default=default or '',
order=order,
tooltip=_(
'Attribute that contains the group name attributes or attribute patterns (one for each line)'
),
required=True,
tab=tab,
)
def realname_attr_field(
tab: 'types.ui.Tab|str|None' = types.ui.Tab.ATTRIBUTES,
default: typing.Optional[str] = None,
order: int = 102,
) -> ui.gui.TextField:
return ui.gui.TextField(
length=640,
label=_('Real Name Attr'),
lines=3,
default=default or '',
order=order,
tooltip=_('Attribute that contains the real name attributes or attribute patterns (one for each line)'),
required=True,
tab=tab,
)
def onlogout_field_is_persistent(fld: ui.gui.ChoiceField) -> bool:
return fld.value == 'keep-always'

View File

@ -29,7 +29,6 @@
"""
Author: Adolfo Gómez, dkmaster at dkmon dot com
"""
import traceback
import json
import logging
import typing
@ -53,8 +52,8 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
def error_view(request: 'HttpRequest', errorCode: int) -> HttpResponseRedirect:
return HttpResponseRedirect(reverse('page.error', kwargs={'err': errorCode}))
def error_view(request: 'HttpRequest', error_code: int) -> HttpResponseRedirect:
return HttpResponseRedirect(reverse('page.error', kwargs={'err': error_code}))
def error(request: 'HttpRequest', err: str) -> 'HttpResponse':
@ -68,15 +67,23 @@ def exception_view(request: 'HttpRequest', exception: Exception) -> HttpResponse
"""
Tries to render an error page with error information
"""
logger.debug(traceback.format_exc())
# import traceback
# logger.debug(traceback.format_exc())
return error_view(request, types.errors.Error.from_exception(exception))
def error_message(request: 'HttpRequest', err: int) -> 'HttpResponse':
def error_message(request: 'HttpRequest', err: str) -> 'HttpResponse':
"""
Error view, responsible of error display
"""
# get error as integer or replace it by 0
try:
err_int = int(err)
except Exception:
err_int = 0
return HttpResponse(
json.dumps({'error': types.errors.Error.from_int(err).message, 'code': str(err)}),
json.dumps({'error': types.errors.Error.from_int(err_int).message, 'code': str(err)}),
content_type='application/json',
)

View File

@ -133,7 +133,7 @@ def auth_callback_stage2(request: 'ExtendedHttpRequestWithUser', ticket_id: str)
request.build_absolute_uri(str(e)) if e.args and e.args[0] else None,
)
except Exception as e:
logger.exception('authCallback')
logger.error('Error authenticating user: %s', e)
return errors.exception_view(request, e)

View File

@ -99,40 +99,41 @@ def login(request: types.requests.ExtendedHttpRequest, tag: typing.Optional[str]
request.authorized = False # Ensure that on login page, user is unauthorized first
form = LoginForm(request.POST, tag=tag)
loginResult = check_login(request, form, tag)
if loginResult.user:
login_result = check_login(request, form, tag)
if login_result.user:
response = HttpResponseRedirect(reverse('page.index'))
# save tag, weblogin will clear session
tag = request.session.get('tag')
# Tag is not removed from session, so next login will have it even if not provided
# This means than once an url is used, unless manually goes to "/uds/page/login/xxx"
# The tag will be used again
auth.web_login(
request, response, loginResult.user, loginResult.password
request, response, login_result.user, login_result.password
) # data is user password here
# If MFA is provided, we need to redirect to MFA page
request.authorized = True
if (
loginResult.user.manager.get_type().provides_mfa()
and loginResult.user.manager.mfa
and loginResult.user.groups.filter(skip_mfa=types.states.State.ACTIVE).count() == 0
login_result.user.manager.get_type().provides_mfa()
and login_result.user.manager.mfa
and login_result.user.groups.filter(skip_mfa=types.states.State.ACTIVE).count() == 0
):
request.authorized = False
response = HttpResponseRedirect(reverse('page.mfa'))
else:
# If redirection on login failure is found, honor it
if loginResult.url: # Redirection
return HttpResponseRedirect(loginResult.url)
if login_result.url: # Redirection
return HttpResponseRedirect(login_result.url)
if request.ip not in ('127.0.0.1', '::1'): # If not localhost, wait a bit
time.sleep(
random.SystemRandom().randint(1600, 2400) / 1000
) # On failure, wait a bit if not localhost (random wait)
# If error is numeric, redirect...
if loginResult.errid:
return errors.error_view(request, loginResult.errid)
if login_result.errid:
return errors.error_view(request, login_result.errid)
# Error, set error on session for process for js
request.session['errors'] = [loginResult.errstr]
request.session['errors'] = [login_result.errstr]
else:
request.session['tag'] = tag