1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-22 13:34:04 +03:00

Refactor OAuth2Authenticator to handle logout URL and store token on session

This commit is contained in:
Adolfo Gómez García 2024-09-18 22:09:51 +02:00
parent fcac2d4d23
commit 8e51026336
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23

View File

@ -84,6 +84,7 @@ class TokenInfo:
id_token=dct.get('id_token', None),
)
class OAuth2Authenticator(auths.Authenticator):
"""
This class represents an OAuth2 Authenticator.
@ -191,7 +192,15 @@ class OAuth2Authenticator(auths.Authenticator):
required=False,
tab=types.ui.Tab.ADVANCED,
)
logout_url = gui.TextField(
length=256,
label=_('Logout URL'),
order=95,
tooltip=_('URL to logout from OAuth2 provider. Allows {token} placeholder.'),
required=False,
tab=types.ui.Tab.ADVANCED,
)
username_attr = fields.username_attr_field(order=100)
groupname_attr = fields.groupname_attr_field(order=101)
realname_attr = fields.realname_attr_field(order=102)
@ -346,25 +355,31 @@ class OAuth2Authenticator(auths.Authenticator):
userInfo = req.json()
return userInfo
def _process_token(
self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
def _store_token_on_session(self, request: 'HttpRequest', token: str) -> None:
request.session['oauth2_token'] = token
def _retrieve_token_from_session(self, request: 'HttpRequest') -> str:
return request.session.get('oauth2_token', '')
def _process_userinfo(
self, userinfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
) -> types.auth.AuthenticationResult:
# After this point, we don't mind about the token, we only need to authenticate user
# and get some basic info from it
username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userInfo)).replace(' ', '_')
username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userinfo)).replace(' ', '_')
if len(username) == 0:
raise Exception('No username received')
realName = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userInfo))
realname = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userinfo))
# Get groups
groups = auth_utils.process_regex_field(self.groupname_attr.value, userInfo)
groups = auth_utils.process_regex_field(self.groupname_attr.value, userinfo)
# Append common groups
groups.extend(self.common_groups.value.split(','))
# store groups for this username at storage, so we can check it at a later stage
self.storage.save_pickled(username, [realName, groups])
self.storage.save_pickled(username, [realname, groups])
# Validate common groups
gm.validate(groups)
@ -395,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator):
# All is fine, get user & look for groups
# Process attributes from payload
return self._process_token(payload, gm)
return self._process_userinfo(payload, gm)
except (jwt.InvalidTokenError, IndexError):
# logger.debug('Data was invalid: %s', e)
pass
@ -456,17 +471,22 @@ class OAuth2Authenticator(auths.Authenticator):
case 'openid+code':
return self.auth_callback_openid_code(parameters, groups_manager, request)
case 'openid+token_id':
return self.authcallback_openid_id_token(parameters, groups_manager, request)
return self.auth_callback_openid_id_token(parameters, groups_manager, request)
case _:
raise Exception('Invalid response type')
return auths.SUCCESS_AUTH
def logout(
self,
request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument
username: str, # pylint: disable=unused-argument
) -> types.auth.AuthenticationResult:
return types.auth.SUCCESS_AUTH
if self.logout_url.value.strip() == '' or (token := self._retrieve_token_from_session(request)) == '':
return types.auth.SUCCESS_AUTH
return types.auth.AuthenticationResult(
types.auth.AuthenticationState.SUCCESS,
url=self.logout_url.value.replace('{token}', urllib.parse.quote(token)),
)
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
"""
@ -495,8 +515,7 @@ class OAuth2Authenticator(auths.Authenticator):
"""Process the callback for code authorization flow"""
state = parameters.get_params.get('state', '')
# Remove state from cache
code_verifier = self.cache.get(state)
self.cache.remove(state)
code_verifier = self.cache.pop(state)
if not state or not code_verifier:
logger.error('Invalid state received on OAuth2 callback')
@ -512,8 +531,10 @@ class OAuth2Authenticator(auths.Authenticator):
if code_verifier == 'none':
code_verifier = None
token = self._request_token(code, code_verifier)
return self._process_token(self._request_info(token), gm)
token_info = self._request_token(code, code_verifier)
# Store for later use
self._store_token_on_session(request, token_info.access_token)
return self._process_userinfo(self._request_info(token_info), gm)
def auth_callback_token(
self,
@ -532,7 +553,9 @@ class OAuth2Authenticator(auths.Authenticator):
# Get the token, token_type, expires
token = TokenInfo.from_dict(parameters.get_params)
return self._process_token(self._request_info(token), gm)
# Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_userinfo(self._request_info(token), gm)
def auth_callback_openid_code(
self,
@ -562,9 +585,11 @@ class OAuth2Authenticator(auths.Authenticator):
logger.error('No id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, token.access_token)
return self._process_token_open_id(token.id_token, nonce, gm)
def authcallback_openid_id_token(
def auth_callback_openid_id_token(
self,
parameters: 'types.auth.AuthCallbackParams',
gm: 'auths.GroupsManager',
@ -585,4 +610,6 @@ class OAuth2Authenticator(auths.Authenticator):
logger.error('Invalid id_token received on OAuth2 callback')
return types.auth.FAILED_AUTH
# Store for later use
self._store_token_on_session(request, id_token)
return self._process_token_open_id(id_token, nonce, gm)