mirror of
https://github.com/dkmstr/openuds.git
synced 2025-01-11 05:17:55 +03:00
Refactor OAuth2Authenticator to handle logout URL and store token on session
This commit is contained in:
parent
fcac2d4d23
commit
8e51026336
@ -84,6 +84,7 @@ class TokenInfo:
|
|||||||
id_token=dct.get('id_token', None),
|
id_token=dct.get('id_token', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuth2Authenticator(auths.Authenticator):
|
class OAuth2Authenticator(auths.Authenticator):
|
||||||
"""
|
"""
|
||||||
This class represents an OAuth2 Authenticator.
|
This class represents an OAuth2 Authenticator.
|
||||||
@ -191,7 +192,15 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
required=False,
|
required=False,
|
||||||
tab=types.ui.Tab.ADVANCED,
|
tab=types.ui.Tab.ADVANCED,
|
||||||
)
|
)
|
||||||
|
logout_url = gui.TextField(
|
||||||
|
length=256,
|
||||||
|
label=_('Logout URL'),
|
||||||
|
order=95,
|
||||||
|
tooltip=_('URL to logout from OAuth2 provider. Allows {token} placeholder.'),
|
||||||
|
required=False,
|
||||||
|
tab=types.ui.Tab.ADVANCED,
|
||||||
|
)
|
||||||
|
|
||||||
username_attr = fields.username_attr_field(order=100)
|
username_attr = fields.username_attr_field(order=100)
|
||||||
groupname_attr = fields.groupname_attr_field(order=101)
|
groupname_attr = fields.groupname_attr_field(order=101)
|
||||||
realname_attr = fields.realname_attr_field(order=102)
|
realname_attr = fields.realname_attr_field(order=102)
|
||||||
@ -346,25 +355,31 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
userInfo = req.json()
|
userInfo = req.json()
|
||||||
return userInfo
|
return userInfo
|
||||||
|
|
||||||
def _process_token(
|
def _store_token_on_session(self, request: 'HttpRequest', token: str) -> None:
|
||||||
self, userInfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
|
request.session['oauth2_token'] = token
|
||||||
|
|
||||||
|
def _retrieve_token_from_session(self, request: 'HttpRequest') -> str:
|
||||||
|
return request.session.get('oauth2_token', '')
|
||||||
|
|
||||||
|
def _process_userinfo(
|
||||||
|
self, userinfo: collections.abc.Mapping[str, typing.Any], gm: 'auths.GroupsManager'
|
||||||
) -> types.auth.AuthenticationResult:
|
) -> types.auth.AuthenticationResult:
|
||||||
# After this point, we don't mind about the token, we only need to authenticate user
|
# After this point, we don't mind about the token, we only need to authenticate user
|
||||||
# and get some basic info from it
|
# and get some basic info from it
|
||||||
|
|
||||||
username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userInfo)).replace(' ', '_')
|
username = ''.join(auth_utils.process_regex_field(self.username_attr.value, userinfo)).replace(' ', '_')
|
||||||
if len(username) == 0:
|
if len(username) == 0:
|
||||||
raise Exception('No username received')
|
raise Exception('No username received')
|
||||||
|
|
||||||
realName = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userInfo))
|
realname = ''.join(auth_utils.process_regex_field(self.realname_attr.value, userinfo))
|
||||||
|
|
||||||
# Get groups
|
# Get groups
|
||||||
groups = auth_utils.process_regex_field(self.groupname_attr.value, userInfo)
|
groups = auth_utils.process_regex_field(self.groupname_attr.value, userinfo)
|
||||||
# Append common groups
|
# Append common groups
|
||||||
groups.extend(self.common_groups.value.split(','))
|
groups.extend(self.common_groups.value.split(','))
|
||||||
|
|
||||||
# store groups for this username at storage, so we can check it at a later stage
|
# store groups for this username at storage, so we can check it at a later stage
|
||||||
self.storage.save_pickled(username, [realName, groups])
|
self.storage.save_pickled(username, [realname, groups])
|
||||||
|
|
||||||
# Validate common groups
|
# Validate common groups
|
||||||
gm.validate(groups)
|
gm.validate(groups)
|
||||||
@ -395,7 +410,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
# All is fine, get user & look for groups
|
# All is fine, get user & look for groups
|
||||||
|
|
||||||
# Process attributes from payload
|
# Process attributes from payload
|
||||||
return self._process_token(payload, gm)
|
return self._process_userinfo(payload, gm)
|
||||||
except (jwt.InvalidTokenError, IndexError):
|
except (jwt.InvalidTokenError, IndexError):
|
||||||
# logger.debug('Data was invalid: %s', e)
|
# logger.debug('Data was invalid: %s', e)
|
||||||
pass
|
pass
|
||||||
@ -456,17 +471,22 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
case 'openid+code':
|
case 'openid+code':
|
||||||
return self.auth_callback_openid_code(parameters, groups_manager, request)
|
return self.auth_callback_openid_code(parameters, groups_manager, request)
|
||||||
case 'openid+token_id':
|
case 'openid+token_id':
|
||||||
return self.authcallback_openid_id_token(parameters, groups_manager, request)
|
return self.auth_callback_openid_id_token(parameters, groups_manager, request)
|
||||||
case _:
|
case _:
|
||||||
raise Exception('Invalid response type')
|
raise Exception('Invalid response type')
|
||||||
return auths.SUCCESS_AUTH
|
|
||||||
|
|
||||||
def logout(
|
def logout(
|
||||||
self,
|
self,
|
||||||
request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument
|
request: 'types.requests.ExtendedHttpRequest', # pylint: disable=unused-argument
|
||||||
username: str, # pylint: disable=unused-argument
|
username: str, # pylint: disable=unused-argument
|
||||||
) -> types.auth.AuthenticationResult:
|
) -> types.auth.AuthenticationResult:
|
||||||
return types.auth.SUCCESS_AUTH
|
if self.logout_url.value.strip() == '' or (token := self._retrieve_token_from_session(request)) == '':
|
||||||
|
return types.auth.SUCCESS_AUTH
|
||||||
|
|
||||||
|
return types.auth.AuthenticationResult(
|
||||||
|
types.auth.AuthenticationState.SUCCESS,
|
||||||
|
url=self.logout_url.value.replace('{token}', urllib.parse.quote(token)),
|
||||||
|
)
|
||||||
|
|
||||||
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
|
def get_javascript(self, request: 'HttpRequest') -> typing.Optional[str]:
|
||||||
"""
|
"""
|
||||||
@ -495,8 +515,7 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
"""Process the callback for code authorization flow"""
|
"""Process the callback for code authorization flow"""
|
||||||
state = parameters.get_params.get('state', '')
|
state = parameters.get_params.get('state', '')
|
||||||
# Remove state from cache
|
# Remove state from cache
|
||||||
code_verifier = self.cache.get(state)
|
code_verifier = self.cache.pop(state)
|
||||||
self.cache.remove(state)
|
|
||||||
|
|
||||||
if not state or not code_verifier:
|
if not state or not code_verifier:
|
||||||
logger.error('Invalid state received on OAuth2 callback')
|
logger.error('Invalid state received on OAuth2 callback')
|
||||||
@ -512,8 +531,10 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
if code_verifier == 'none':
|
if code_verifier == 'none':
|
||||||
code_verifier = None
|
code_verifier = None
|
||||||
|
|
||||||
token = self._request_token(code, code_verifier)
|
token_info = self._request_token(code, code_verifier)
|
||||||
return self._process_token(self._request_info(token), gm)
|
# Store for later use
|
||||||
|
self._store_token_on_session(request, token_info.access_token)
|
||||||
|
return self._process_userinfo(self._request_info(token_info), gm)
|
||||||
|
|
||||||
def auth_callback_token(
|
def auth_callback_token(
|
||||||
self,
|
self,
|
||||||
@ -532,7 +553,9 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
|
|
||||||
# Get the token, token_type, expires
|
# Get the token, token_type, expires
|
||||||
token = TokenInfo.from_dict(parameters.get_params)
|
token = TokenInfo.from_dict(parameters.get_params)
|
||||||
return self._process_token(self._request_info(token), gm)
|
# Store for later use
|
||||||
|
self._store_token_on_session(request, token.access_token)
|
||||||
|
return self._process_userinfo(self._request_info(token), gm)
|
||||||
|
|
||||||
def auth_callback_openid_code(
|
def auth_callback_openid_code(
|
||||||
self,
|
self,
|
||||||
@ -562,9 +585,11 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
logger.error('No id_token received on OAuth2 callback')
|
logger.error('No id_token received on OAuth2 callback')
|
||||||
return types.auth.FAILED_AUTH
|
return types.auth.FAILED_AUTH
|
||||||
|
|
||||||
|
# Store for later use
|
||||||
|
self._store_token_on_session(request, token.access_token)
|
||||||
return self._process_token_open_id(token.id_token, nonce, gm)
|
return self._process_token_open_id(token.id_token, nonce, gm)
|
||||||
|
|
||||||
def authcallback_openid_id_token(
|
def auth_callback_openid_id_token(
|
||||||
self,
|
self,
|
||||||
parameters: 'types.auth.AuthCallbackParams',
|
parameters: 'types.auth.AuthCallbackParams',
|
||||||
gm: 'auths.GroupsManager',
|
gm: 'auths.GroupsManager',
|
||||||
@ -585,4 +610,6 @@ class OAuth2Authenticator(auths.Authenticator):
|
|||||||
logger.error('Invalid id_token received on OAuth2 callback')
|
logger.error('Invalid id_token received on OAuth2 callback')
|
||||||
return types.auth.FAILED_AUTH
|
return types.auth.FAILED_AUTH
|
||||||
|
|
||||||
|
# Store for later use
|
||||||
|
self._store_token_on_session(request, id_token)
|
||||||
return self._process_token_open_id(id_token, nonce, gm)
|
return self._process_token_open_id(id_token, nonce, gm)
|
||||||
|
Loading…
Reference in New Issue
Block a user