1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-01-11 05:17:55 +03:00

Refactor OAuth2Authenticator to handle logout URL and store token on session

This commit is contained in:
Adolfo Gómez García 2024-09-18 22:09:51 +02:00
parent fcac2d4d23
commit 8e51026336
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23

View File

@ -84,6 +84,7 @@ class TokenInfo:
id_token=dct.get('id_token', None), id_token=dct.get('id_token', None),
) )
class OAuth2Authenticator(auths.Authenticator): class OAuth2Authenticator(auths.Authenticator):
""" """
This class represents an OAuth2 Authenticator. This class represents an OAuth2 Authenticator.
@ -191,6 +192,14 @@ class OAuth2Authenticator(auths.Authenticator):
required=False, required=False,
tab=types.ui.Tab.ADVANCED, tab=types.ui.Tab.ADVANCED,
) )
logout_url = gui.TextField(
length=256,
label=_('Logout URL'),
order=95,
tooltip=_('URL to logout from OAuth2 provider. Allows {token} placeholder.'),
required=False,
tab=types.ui.Tab.ADVANCED,
)
username_attr = fields.username_attr_field(order=100) username_attr = fields.username_attr_field(order=100)
groupname_attr = fields.groupname_attr_field(order=101) groupname_attr = fields.groupname_attr_field(order=101)
@ -346,25 +355,31 @@ class OAuth2Authenticator(auths.Authenticator):
userInfo = req.json() userInfo = req.json()
return userInfo return userInfo
def _process_token( def _store_token_on_session(self, request: 'HttpRequest', token: str) -> None:
self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager' request.session['oauth2_token'] = token
def _retrieve_token_from_session(self, request: 'HttpRequest') -> str:
return request.session.get('oauth2_token', '')
def _process_userinfo(
self, userinfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult: ) -> types.auth.AuthenticationResult:
# After this point, we don't mind about the token, we only need to authenticate user # After this point, we don't mind about the token, we only need to authenticate user
# and get some basic info from it # and get some basic info from it
username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userInfo)).replace(' ', '_') username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userinfo)).replace(' ', '_')
if len(username) == 0: if len(username) == 0:
raise Exception('No username received') raise Exception('No username received')
realName = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userInfo)) realname = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userinfo))
# Get groups # Get groups
groups = auth_utils.process_regex_field(self.groupname_attr.value, userInfo) groups = auth_utils.process_regex_field(self.groupname_attr.value, userinfo)
# Append common groups # Append common groups
groups.extend(self.common_groups.value.split(',')) groups.extend(self.common_groups.value.split(','))
# store groups for this username at storage, so we can check it at a later stage # store groups for this username at storage, so we can check it at a later stage
self.storage.save_pickled(username, [realName, groups]) self.storage.save_pickled(username, [realname, groups])
# Validate common groups # Validate common groups
gm.validate(groups) gm.validate(groups)
@ -395,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator):
# All is fine, get user & look for groups # All is fine, get user & look for groups
# Process attributes from payload # Process attributes from payload
return self._process_token(payload, gm) return self._process_userinfo(payload, gm)
except (jwt.InvalidTokenError, IndexError): except (jwt.InvalidTokenError, IndexError):
# logger.debug('Data was invalid: %s', e) # logger.debug('Data was invalid: %s', e)
pass pass
@ -456,18 +471,23 @@ class OAuth2Authenticator(auths.Authenticator):
case 'openid+code': case 'openid+code':
return self.auth_callback_openid_code(parameters, groups_manager, request) return self.auth_callback_openid_code(parameters, groups_manager, request)
case 'openid+token_id': case 'openid+token_id':
return self.authcallback_openid_id_token(parameters, groups_manager, request) return self.auth_callback_openid_id_token(parameters, groups_manager, request)
case _: case _:
raise Exception('Invalid response type') raise Exception('Invalid response type')
return auths.SUCCESS_AUTH
def logout( def logout(
self, self,
request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument
username: str, # pylint: disable=unused-argument username: str, # pylint: disable=unused-argument
) -> types.auth.AuthenticationResult: ) -> types.auth.AuthenticationResult:
if self.logout_url.value.strip() == '' or (token := self._retrieve_token_from_session(request)) == '':
return types.auth.SUCCESS_AUTH return types.auth.SUCCESS_AUTH
return types.auth.AuthenticationResult(
types.auth.AuthenticationState.SUCCESS,
url=self.logout_url.value.replace('{token}', urllib.parse.quote(token)),
)
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]: def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
""" """
We will here compose the azure request and send it via http-redirect We will here compose the azure request and send it via http-redirect
@ -495,8 +515,7 @@ class OAuth2Authenticator(auths.Authenticator):
"""Process the callback for code authorization flow""" """Process the callback for code authorization flow"""
state = parameters.get_params.get('state', '') state = parameters.get_params.get('state', '')
# Remove state from cache # Remove state from cache
code_verifier = self.cache.get(state) code_verifier = self.cache.pop(state)
self.cache.remove(state)
if not state or not code_verifier: if not state or not code_verifier:
logger.error('Invalid state received on OAuth2 callback') logger.error('Invalid state received on OAuth2 callback')
@ -512,8 +531,10 @@ class OAuth2Authenticator(auths.Authenticator):
if code_verifier == 'none': if code_verifier == 'none':
code_verifier = None code_verifier = None
token = self._request_token(code, code_verifier) token_info = self._request_token(code, code_verifier)
return self._process_token(self._request_info(token), gm) # Store for later use
self._store_token_on_session(request, token_info.access_token)
return self._process_userinfo(self._request_info(token_info), gm)
def auth_callback_token( def auth_callback_token(
self, self,
@ -532,7 +553,9 @@ class OAuth2Authenticator(auths.Authenticator):
# Get the token, token_type, expires # Get the token, token_type, expires
token = TokenInfo.from_dict(parameters.get_params) token = TokenInfo.from_dict(parameters.get_params)
return self._process_token(self._request_info(token), gm) # Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_userinfo(self._request_info(token), gm)
def auth_callback_openid_code( def auth_callback_openid_code(
self, self,
@ -562,9 +585,11 @@ class OAuth2Authenticator(auths.Authenticator):
logger.error('No id_token received on OAuth2 callback') logger.error('No id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_token_open_id(token.id_token, nonce, gm) return self._process_token_open_id(token.id_token, nonce, gm)
def authcallback_openid_id_token( def auth_callback_openid_id_token(
self, self,
parameters: 'types.auth.AuthCallbackParams', parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager', gm: 'auths.GroupsManager',
@ -585,4 +610,6 @@ class OAuth2Authenticator(auths.Authenticator):
logger.error('Invalid id_token received on OAuth2 callback') logger.error('Invalid id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, id_token)
return self._process_token_open_id(id_token, nonce, gm) return self._process_token_open_id(id_token, nonce, gm)