mirror of
https://github.com/dkmstr/openuds.git
synced 2025-03-12 04:58:34 +03:00
Refactor OAuth2Test to include code challenge verification and state value retrieval
Refactor the OAuth2Authenticator class in the authenticator.py file to use the cache.pop() method instead of cache.get() and cache.remove()
This commit is contained in:
parent
55f3a697ca
commit
27d8432d6d
@ -30,6 +30,8 @@
|
||||
Author: Adolfo Gómez, dkmaster at dkmon dot com
|
||||
'''
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from base64 import b64encode
|
||||
from hashlib import sha256
|
||||
from unittest import mock
|
||||
|
||||
from tests.utils.test import UDSTestCase
|
||||
@ -104,6 +106,7 @@ class OAuth2Test(UDSTestCase):
|
||||
self.assertEqual(query['redirect_uri'], [oauth2.redirection_endpoint.value], kind.as_text)
|
||||
scopes = set(query['scope'][0].split())
|
||||
|
||||
code_challenge = ''
|
||||
if kind == oauth2_types.ResponseType.PKCE:
|
||||
self.assertEqual(query['code_challenge_method'], ['S256'], kind.as_text)
|
||||
code_challenge = query['code_challenge'][0]
|
||||
@ -124,4 +127,16 @@ class OAuth2Test(UDSTestCase):
|
||||
expected_length = (oauth2_consts.STATE_LENGTH * 8 + 5) // 6
|
||||
self.assertEqual(len(state), expected_length, kind.as_text)
|
||||
# oauth2 cache should contain the state
|
||||
self.assertIsNotNone(oauth2.cache.get(state), kind.as_text)
|
||||
state_value = oauth2.cache.get(state)
|
||||
self.assertIsNotNone(state_value, kind.as_text)
|
||||
|
||||
# If pkce, we need to check the code_challenge. code_verifier is stored in the state
|
||||
if kind == oauth2_types.ResponseType.PKCE:
|
||||
calc_code_challenge = (
|
||||
b64encode(sha256(state_value.encode()).digest(), altchars=b'-_')
|
||||
.decode()
|
||||
.rstrip('=') # remove padding
|
||||
)
|
||||
|
||||
self.assertEqual(calc_code_challenge, code_challenge, kind.as_text)
|
||||
|
||||
|
@ -370,7 +370,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
If the token endpoint returns the user info, this method will not be used
|
||||
|
||||
Args:
|
||||
token (TokenInfo): Token received from the token endpoint
|
||||
token (TokenInfo): Token info received from the token endpoint
|
||||
|
||||
Returns:
|
||||
dict[str, typing.Any]: User info received from the info endpoint
|
||||
@ -502,10 +502,9 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
) -> types.auth.AuthenticationResult:
|
||||
"""Process the callback for PKCE authorization flow"""
|
||||
state = parameters.get_params.get('state', '')
|
||||
stateValue = self.cache.get(state)
|
||||
self.cache.remove(state)
|
||||
state_value = self.cache.pop(state)
|
||||
|
||||
if not state or not stateValue:
|
||||
if not state or not state_value:
|
||||
logger.error('Invalid state received on OAuth2 callback')
|
||||
return types.auth.FAILED_AUTH
|
||||
|
||||
@ -523,8 +522,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
) -> types.auth.AuthenticationResult:
|
||||
"""Process the callback for OpenID authorization flow"""
|
||||
state = parameters.post_params.get('state', '')
|
||||
nonce = self.cache.get(state)
|
||||
self.cache.remove(state)
|
||||
nonce = self.cache.pop(state)
|
||||
|
||||
if not state or not nonce:
|
||||
logger.error('Invalid state received on OAuth2 callback')
|
||||
@ -555,8 +553,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
) -> types.auth.AuthenticationResult:
|
||||
"""Process the callback for OpenID authorization flow"""
|
||||
state = parameters.post_params.get('state', '')
|
||||
nonce = self.cache.get(state)
|
||||
self.cache.remove(state)
|
||||
nonce = self.cache.pop(state)
|
||||
|
||||
if not state or not nonce:
|
||||
logger.error('Invalid state received on OAuth2 callback')
|
||||
|
Loading…
x
Reference in New Issue
Block a user