1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Refactor pyrightconfig.json to include src/**/*.py and add stubPath for enterprise/stubs

Working on oauth2 tests
This commit is contained in:
Adolfo Gómez García 2024-09-19 22:21:57 +02:00
parent 6ce5cbe10e
commit 55f3a697ca
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
6 changed files with 208 additions and 132 deletions

View File

@ -1,6 +1,6 @@
{
"include": [
"src"
"src/**/*.py",
],
"exclude": [
"**/scripts",
@ -11,4 +11,5 @@
"reportUnusedImport": true,
"reportMissingTypeStubs": false,
"disableBytesTypePromotions": true,
"stubPath": "../../enterprise/stubs",
}

View File

@ -45,7 +45,7 @@ DATA_TEMPLATE: dict[str, str] = {
'authorization_endpoint': 'https://auth_endpoint.com',
'client_id': 'client_id',
'client_secret': 'client_secret',
'scope': 'openid email profile', # Default scopes
'scope': 'email profile', # Default scopes
'common_groups': 'common_group',
'redirection_endpoint': 'https://redirect_endpoint.com',
'response_type': 'code',

View File

@ -29,6 +29,7 @@
'''
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
from urllib.parse import urlparse, parse_qs
from unittest import mock
from tests.utils.test import UDSTestCase
@ -36,7 +37,7 @@ from uds.core import types
from . import fixtures
from uds.auths.OAuth2 import types as oauth2_types
from uds.auths.OAuth2 import types as oauth2_types, consts as oauth2_consts
class OAuth2Test(UDSTestCase):
@ -83,3 +84,44 @@ class OAuth2Test(UDSTestCase):
self.assertIsInstance(logout, types.auth.AuthenticationResult)
self.assertTrue(logout.success)
self.assertEqual(logout.url, 'https://logout.com?token=token_value')
def test_get_login_url_code(self) -> None:
for kind in oauth2_types.ResponseType:
with fixtures.create_authenticator(kind) as oauth2:
url = oauth2.get_login_url()
self.assertIsInstance(url, str)
# Parse URL and ensure it's correct
auth_url_info = urlparse(url)
configured_url_info = urlparse(oauth2.authorization_endpoint.value, kind.as_text)
query = parse_qs(auth_url_info.query)
configures_scopes = set(oauth2.scope.value.split())
self.assertEqual(auth_url_info.scheme, configured_url_info.scheme, kind.as_text)
self.assertEqual(auth_url_info.netloc, configured_url_info.netloc, kind.as_text)
self.assertEqual(auth_url_info.path, configured_url_info.path, kind.as_text)
self.assertEqual(query['response_type'], [kind.for_query], kind.as_text)
self.assertEqual(query['client_id'], [oauth2.client_id.value], kind.as_text)
self.assertEqual(query['redirect_uri'], [oauth2.redirection_endpoint.value], kind.as_text)
scopes = set(query['scope'][0].split())
if kind == oauth2_types.ResponseType.PKCE:
self.assertEqual(query['code_challenge_method'], ['S256'], kind.as_text)
code_challenge = query['code_challenge'][0]
self.assertIsInstance(code_challenge, str, kind.as_text)
# All configured scopes should be present
self.assertTrue(configures_scopes.issubset(scopes), kind.as_text)
# And if openid variant, scope should contain openid
if kind in (oauth2_types.ResponseType.OPENID_CODE, oauth2_types.ResponseType.OPENID_ID_TOKEN):
self.assertIn('openid', scopes, kind.as_text)
state = query['state'][0]
self.assertIsInstance(state, str, kind.as_text)
# state is in base64, so it will take a bit more than 16 characters
# Exactly every 6 bits will take 8 bits, so we need to divide by 6 and multiply by 8
# Adjusting to the upper integer
expected_length = (oauth2_consts.STATE_LENGTH * 8 + 5) // 6
self.assertEqual(len(state), expected_length, kind.as_text)
# oauth2 cache should contain the state
self.assertIsNotNone(oauth2.cache.get(state), kind.as_text)

View File

@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Virtual Cable S.L.U.
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
@ -33,17 +31,16 @@ Author: Adolfo Gómez, dkmaster at dkmon dot com
import logging
import hashlib
import secrets
import string
import typing
import collections.abc
import urllib.parse
from base64 import b64decode
from base64 import b64encode
import jwt
from django.utils.translation import gettext
from django.utils.translation import gettext_noop as _
from . import types as oauth2_types
from . import types as oauth2_types, consts as oauth2_consts
from uds.core import auths, consts, exceptions, types
from uds.core.ui import gui
from uds.core.util import fields, auth as auth_utils, security
@ -54,11 +51,6 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Alphabet used for PKCE
PKCE_ALPHABET: typing.Final[str] = string.ascii_letters + string.digits + '-._~'
# Length of the State parameter
STATE_LENGTH: typing.Final[int] = 16
class OAuth2Authenticator(auths.Authenticator):
"""
@ -185,34 +177,123 @@ class OAuth2Authenticator(auths.Authenticator):
# Non serializable variables
session: typing.ClassVar['requests.Session'] = security.secure_requests_session()
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
if not values:
return
if ' ' in values['name']:
raise exceptions.ui.ValidationError(
gettext('This kind of Authenticator does not support white spaces on field NAME')
)
auth_utils.validate_regex_field(self.username_attr)
auth_utils.validate_regex_field(self.username_attr)
if self.response_type.value in (
oauth2_types.ResponseType.CODE,
oauth2_types.ResponseType.PKCE,
oauth2_types.ResponseType.OPENID_CODE,
):
if self.common_groups.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Common groups is required for "code" response types')
)
if self.token_endpoint.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Token endpoint is required for "code" response types')
)
# infoEndpoint will not be necesary if the response of tokenEndpoint contains the user info
if self.response_type.value == 'openid+token_id':
# Ensure we have a public key
if self.public_key.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Public key is required for "openid+token_id" response type')
)
if self.redirection_endpoint.value.strip() == '' and self.db_obj() and '_request' in values:
request: 'HttpRequest' = values['_request']
self.redirection_endpoint.value = request.build_absolute_uri(self.callback_url())
def auth_callback(
self,
parameters: 'types.auth.AuthCallbackParams',
groups_manager: 'auths.GroupsManager',
request: 'types.requests.ExtendedHttpRequest',
) -> types.auth.AuthenticationResult:
match oauth2_types.ResponseType(self.response_type.value):
case oauth2_types.ResponseType.CODE | oauth2_types.ResponseType.PKCE:
return self.auth_callback_code(parameters, groups_manager, request)
# case 'token':
case oauth2_types.ResponseType.TOKEN:
return self.auth_callback_token(parameters, groups_manager, request)
# case 'openid+code':
case oauth2_types.ResponseType.OPENID_CODE:
return self.auth_callback_openid_code(parameters, groups_manager, request)
# case 'openid+token_id':
case oauth2_types.ResponseType.OPENID_ID_TOKEN:
return self.auth_callback_openid_id_token(parameters, groups_manager, request)
def logout(
self,
request: 'types.requests.ExtendedHttpRequest',
username: str,
) -> types.auth.AuthenticationResult:
if self.logout_url.value.strip() == '' or (token := self.retrieve_token(request)) == '':
return types.auth.SUCCESS_AUTH
return types.auth.AuthenticationResult(
types.auth.AuthenticationState.SUCCESS,
url=self.logout_url.value.replace('{token}', urllib.parse.quote(token)),
)
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
"""
We will here compose the azure request and send it via http-redirect
"""
return f'window.location="{self.get_login_url()}";'
def get_groups(self, username: str, groups_manager: 'auths.GroupsManager') -> None:
data = self.storage.read_pickled(username)
if not data:
return
groups_manager.validate(data[1])
def get_real_name(self, username: str) -> str:
data = self.storage.read_pickled(username)
if not data:
return username
return data[0]
# own methods
def get_public_keys(self) -> list[typing.Any]: # In fact, any of the PublicKey types
# Get certificates in self.publicKey.value, encoded as PEM
# Return a list of certificates in DER format
return [cert.public_key() for cert in fields.get_certificates_from_field(self.public_key)]
def _code_verifier_and_challenge(self) -> tuple[str, str]:
def code_verifier_and_challenge(self) -> tuple[str, str]:
"""Generate a code verifier and a code challenge for PKCE
Returns:
tuple[str, str]: Code verifier and code challenge
"""
code_verifier = ''.join(secrets.choice(PKCE_ALPHABET) for _ in range(128))
code_verifier = ''.join(secrets.choice(oauth2_consts.PKCE_ALPHABET) for _ in range(128))
code_challenge = (
b64decode(hashlib.sha256(code_verifier.encode('ascii')).digest(), altchars=b'-_')
b64encode(hashlib.sha256(code_verifier.encode()).digest(), altchars=b'-_')
.decode()
.rstrip('=') # remove padding
)
return code_verifier, code_challenge
def _get_login_url(self, request: 'HttpRequest') -> str:
def get_login_url(self) -> str:
"""
:type request: django.http.request.HttpRequest
"""
state: str = secrets.token_urlsafe(STATE_LENGTH)
state: str = secrets.token_urlsafe(oauth2_consts.STATE_LENGTH)
response_type = oauth2_types.ResponseType(self.response_type.value)
param_dict = {
param_dict: dict[str, str] = {
'response_type': response_type.for_query,
'client_id': self.client_id.value,
'redirect_uri': self.redirection_endpoint.value,
@ -225,22 +306,23 @@ class OAuth2Authenticator(auths.Authenticator):
# Code or token flow
# Simply store state, no code_verifier, store "none" as code_verifier to later restore it
self.cache.put(state, 'none', 3600)
case oauth2_types.ResponseType.OPENID_CODE | oauth2_types.ResponseType.OPENID_TOKEN_ID:
case oauth2_types.ResponseType.OPENID_CODE | oauth2_types.ResponseType.OPENID_ID_TOKEN:
# OpenID flow
nonce = secrets.token_urlsafe(STATE_LENGTH)
nonce = secrets.token_urlsafe(oauth2_consts.STATE_LENGTH)
self.cache.put(state, nonce, 3600) # Store nonce
# Fix scope
param_dict['scope'] = 'openid ' + param_dict['scope']
# Fix scope to ensure openid is present
if 'openid' not in param_dict['scope']:
param_dict['scope'] = 'openid ' + param_dict['scope']
# Append nonce
param_dict['nonce'] = nonce
# Add response_mode
param_dict['response_mode'] = 'form_post' # ['query', 'fragment', 'form_post']
case oauth2_types.ResponseType.PKCE:
# PKCE flow
codeVerifier, codeChallenge = self._code_verifier_and_challenge()
param_dict['code_challenge'] = codeChallenge
code_verifier, code_challenge = self.code_verifier_and_challenge()
param_dict['code_challenge'] = code_challenge
param_dict['code_challenge_method'] = 'S256'
self.cache.put(state, codeVerifier, 3600)
self.cache.put(state, code_verifier, 3600)
# Nonce only is used
if False:
@ -253,7 +335,7 @@ class OAuth2Authenticator(auths.Authenticator):
return self.authorization_endpoint.value + '?' + params
def _request_token(self, code: str, code_verifier: typing.Optional[str] = None) -> 'oauth2_types.TokenInfo':
def request_token(self, code: str, code_verifier: typing.Optional[str] = None) -> 'oauth2_types.TokenInfo':
"""Request a token from the token endpoint using the code received from the authorization endpoint
Args:
@ -282,7 +364,7 @@ class OAuth2Authenticator(auths.Authenticator):
return oauth2_types.TokenInfo.from_dict(response.json())
def _request_info(self, token: 'oauth2_types.TokenInfo') -> dict[str, typing.Any]:
def request_userinfo(self, token: 'oauth2_types.TokenInfo') -> dict[str, typing.Any]:
"""Request user info from the info endpoint using the token received from the token endpoint
If the token endpoint returns the user info, this method will not be used
@ -297,7 +379,7 @@ class OAuth2Authenticator(auths.Authenticator):
if self.info_endpoint.value.strip() == '':
if not token.info:
raise Exception('No user info received')
raise Exception('No user info endpoint and token does not contain user info')
userinfo = token.info
else:
# Get user info
@ -314,13 +396,13 @@ class OAuth2Authenticator(auths.Authenticator):
userinfo = req.json()
return userinfo
def _store_token_on_session(self, request: 'HttpRequest', token: str) -> None:
def save_token(self, request: 'HttpRequest', token: str) -> None:
request.session['oauth2_token'] = token
def _retrieve_token_from_session(self, request: 'HttpRequest') -> str:
def retrieve_token(self, request: 'HttpRequest') -> str:
return request.session.get('oauth2_token', '')
def _process_userinfo(
def process_userinfo(
self, userinfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult:
# After this point, we don't mind about the token, we only need to authenticate user
@ -347,7 +429,7 @@ class OAuth2Authenticator(auths.Authenticator):
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
return types.auth.AuthenticationResult(types.auth.AuthenticationState.SUCCESS, username=username)
def _process_token_open_id(
def process_token_open_id(
self, token_id: str, nonce: str, gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult:
# Get token headers, to extract algorithm
@ -369,7 +451,7 @@ class OAuth2Authenticator(auths.Authenticator):
# All is fine, get user & look for groups
# Process attributes from payload
return self._process_userinfo(payload, gm)
return self.process_userinfo(payload, gm)
except (jwt.InvalidTokenError, IndexError):
# logger.debug('Data was invalid: %s', e)
pass
@ -382,89 +464,6 @@ class OAuth2Authenticator(auths.Authenticator):
return types.auth.FAILED_AUTH
def initialize(self, values: typing.Optional[dict[str, typing.Any]]) -> None:
if not values:
return
if ' ' in values['name']:
raise exceptions.ui.ValidationError(
gettext('This kind of Authenticator does not support white spaces on field NAME')
)
auth_utils.validate_regex_field(self.username_attr)
auth_utils.validate_regex_field(self.username_attr)
if self.response_type.value in (oauth2_types.ResponseType.CODE, oauth2_types.ResponseType.PKCE, oauth2_types.ResponseType.OPENID_CODE):
if self.common_groups.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Common groups is required for "code" response types')
)
if self.token_endpoint.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Token endpoint is required for "code" response types')
)
# infoEndpoint will not be necesary if the response of tokenEndpoint contains the user info
if self.response_type.value == 'openid+token_id':
# Ensure we have a public key
if self.public_key.value.strip() == '':
raise exceptions.ui.ValidationError(
gettext('Public key is required for "openid+token_id" response type')
)
if self.redirection_endpoint.value.strip() == '' and self.db_obj() and '_request' in values:
request: 'HttpRequest' = values['_request']
self.redirection_endpoint.value = request.build_absolute_uri(self.callback_url())
def auth_callback(
self,
parameters: 'types.auth.AuthCallbackParams',
groups_manager: 'auths.GroupsManager',
request: 'types.requests.ExtendedHttpRequest',
) -> types.auth.AuthenticationResult:
match self.response_type.value:
case 'code' | 'pkce':
return self.auth_callback_code(parameters, groups_manager, request)
case 'token':
return self.auth_callback_token(parameters, groups_manager, request)
case 'openid+code':
return self.auth_callback_openid_code(parameters, groups_manager, request)
case 'openid+token_id':
return self.auth_callback_openid_id_token(parameters, groups_manager, request)
case _:
raise Exception('Invalid response type')
def logout(
self,
request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument
username: str, # pylint: disable=unused-argument
) -> types.auth.AuthenticationResult:
if self.logout_url.value.strip() == '' or (token := self._retrieve_token_from_session(request)) == '':
return types.auth.SUCCESS_AUTH
return types.auth.AuthenticationResult(
types.auth.AuthenticationState.SUCCESS,
url=self.logout_url.value.replace('{token}', urllib.parse.quote(token)),
)
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
"""
We will here compose the azure request and send it via http-redirect
"""
return f'window.location="{self._get_login_url(request)}";'
def get_groups(self, username: str, groups_manager: 'auths.GroupsManager') -> None:
data = self.storage.read_pickled(username)
if not data:
return
groups_manager.validate(data[1])
def get_real_name(self, username: str) -> str:
data = self.storage.read_pickled(username)
if not data:
return username
return data[0]
def auth_callback_code(
self,
parameters: 'types.auth.AuthCallbackParams',
@ -473,7 +472,7 @@ class OAuth2Authenticator(auths.Authenticator):
) -> types.auth.AuthenticationResult:
"""Process the callback for code authorization flow"""
state = parameters.get_params.get('state', '')
# Remove state from cache
# Get and remove state from cache
code_verifier = self.cache.pop(state)
if not state or not code_verifier:
@ -490,10 +489,10 @@ class OAuth2Authenticator(auths.Authenticator):
if code_verifier == 'none':
code_verifier = None
token_info = self._request_token(code, code_verifier)
token_info = self.request_token(code, code_verifier)
# Store for later use
self._store_token_on_session(request, token_info.access_token)
return self._process_userinfo(self._request_info(token_info), gm)
self.save_token(request, token_info.access_token)
return self.process_userinfo(self.request_userinfo(token_info), gm)
def auth_callback_token(
self,
@ -513,8 +512,8 @@ class OAuth2Authenticator(auths.Authenticator):
# Get the token, token_type, expires
token = oauth2_types.TokenInfo.from_dict(parameters.get_params)
# Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_userinfo(self._request_info(token), gm)
self.save_token(request, token.access_token)
return self.process_userinfo(self.request_userinfo(token), gm)
def auth_callback_openid_code(
self,
@ -538,15 +537,15 @@ class OAuth2Authenticator(auths.Authenticator):
return types.auth.FAILED_AUTH
# Get the token, token_type, expires
token = self._request_token(code)
token = self.request_token(code)
if not token.id_token:
logger.error('No id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_token_open_id(token.id_token, nonce, gm)
self.save_token(request, token.access_token)
return self.process_token_open_id(token.id_token, nonce, gm)
def auth_callback_openid_id_token(
self,
@ -570,5 +569,5 @@ class OAuth2Authenticator(auths.Authenticator):
return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, id_token)
return self._process_token_open_id(id_token, nonce, gm)
self.save_token(request, id_token)
return self.process_token_open_id(id_token, nonce, gm)

View File

@ -0,0 +1,34 @@
#
# Copyright (c) 2024 Virtual Cable S.L.U.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of Virtual Cable S.L.U. nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import typing
import string
# Alphabet used for PKCE
PKCE_ALPHABET: typing.Final[str] = string.ascii_letters + string.digits + '-._~'
# Length of the State parameter
STATE_LENGTH: typing.Final[int] = 24

View File

@ -66,7 +66,7 @@ class ResponseType(enum.StrEnum):
CODE = 'code'
PKCE = 'pkce'
TOKEN = 'token'
OPENID_TOKEN_ID = 'openid+token_id'
OPENID_ID_TOKEN = 'openid+token_id'
OPENID_CODE = 'openid+code'
@property
@ -78,7 +78,7 @@ class ResponseType(enum.StrEnum):
return 'code'
case ResponseType.TOKEN:
return 'token'
case ResponseType.OPENID_TOKEN_ID:
case ResponseType.OPENID_ID_TOKEN:
return 'id_token'
case ResponseType.OPENID_CODE:
return 'code'
@ -92,7 +92,7 @@ class ResponseType(enum.StrEnum):
return _('PKCE (authorization code flow with PKCE)')
case ResponseType.TOKEN:
return _('Token (implicit flow)')
case ResponseType.OPENID_TOKEN_ID:
case ResponseType.OPENID_ID_TOKEN:
return _('OpenID Connect Token (implicit flow with OpenID Connect)')
case ResponseType.OPENID_CODE:
return _('OpenID Connect Code (authorization code flow with OpenID Connect)')