1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-03-12 04:58:34 +03:00

Refactor OAuth2Test to include code challenge verification and state value retrieval

Refactor the OAuth2Authenticator class in the authenticator.py file to use the cache.pop() method instead of cache.get() and cache.remove()
This commit is contained in:
Adolfo Gómez García 2024-09-20 00:14:38 +02:00
parent 55f3a697ca
commit 27d8432d6d
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
2 changed files with 21 additions and 9 deletions

View File

@ -30,6 +30,8 @@
Author: Adolfo Gómez, dkmaster at dkmon dot com
'''
from urllib.parse import urlparse, parse_qs
from base64 import b64encode
from hashlib import sha256
from unittest import mock
from tests.utils.test import UDSTestCase
@ -104,6 +106,7 @@ class OAuth2Test(UDSTestCase):
self.assertEqual(query['redirect_uri'], [oauth2.redirection_endpoint.value], kind.as_text)
scopes = set(query['scope'][0].split())
code_challenge = ''
if kind == oauth2_types.ResponseType.PKCE:
self.assertEqual(query['code_challenge_method'], ['S256'], kind.as_text)
code_challenge = query['code_challenge'][0]
@ -124,4 +127,16 @@ class OAuth2Test(UDSTestCase):
expected_length = (oauth2_consts.STATE_LENGTH * 8 + 5) // 6
self.assertEqual(len(state), expected_length, kind.as_text)
# oauth2 cache should contain the state
self.assertIsNotNone(oauth2.cache.get(state), kind.as_text)
state_value = oauth2.cache.get(state)
self.assertIsNotNone(state_value, kind.as_text)
# If pkce, we need to check the code_challenge. code_verifier is stored in the state
if kind == oauth2_types.ResponseType.PKCE:
calc_code_challenge = (
b64encode(sha256(state_value.encode()).digest(), altchars=b'-_')
.decode()
.rstrip('=') # remove padding
)
self.assertEqual(calc_code_challenge, code_challenge, kind.as_text)

View File

@ -370,7 +370,7 @@ class OAuth2Authenticator(auths.Authenticator):
If the token endpoint returns the user info, this method will not be used
Args:
token (TokenInfo): Token received from the token endpoint
token (TokenInfo): Token info received from the token endpoint
Returns:
dict[str, typing.Any]: User info received from the info endpoint
@ -502,10 +502,9 @@ class OAuth2Authenticator(auths.Authenticator):
) -> types.auth.AuthenticationResult:
"""Process the callback for PKCE authorization flow"""
state = parameters.get_params.get('state', '')
stateValue = self.cache.get(state)
self.cache.remove(state)
state_value = self.cache.pop(state)
if not state or not stateValue:
if not state or not state_value:
logger.error('Invalid state received on OAuth2 callback')
return types.auth.FAILED_AUTH
@ -523,8 +522,7 @@ class OAuth2Authenticator(auths.Authenticator):
) -> types.auth.AuthenticationResult:
"""Process the callback for OpenID authorization flow"""
state = parameters.post_params.get('state', '')
nonce = self.cache.get(state)
self.cache.remove(state)
nonce = self.cache.pop(state)
if not state or not nonce:
logger.error('Invalid state received on OAuth2 callback')
@ -555,8 +553,7 @@ class OAuth2Authenticator(auths.Authenticator):
) -> types.auth.AuthenticationResult:
"""Process the callback for OpenID authorization flow"""
state = parameters.post_params.get('state', '')
nonce = self.cache.get(state)
self.cache.remove(state)
nonce = self.cache.pop(state)
if not state or not nonce:
logger.error('Invalid state received on OAuth2 callback')