1
0
mirror of https://github.com/dkmstr/openuds.git synced 2024-12-23 17:34:17 +03:00

Basic OAuth2 with code interchange done

This commit is contained in:
Adolfo Gómez García 2023-10-19 19:38:06 +02:00
parent 9419e0f69e
commit ad9755177a
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
2 changed files with 73 additions and 29 deletions

View File

@ -179,24 +179,35 @@ class OAuth2Authenticator(auths.Authenticator):
tab=types.ui.Tab.ADVANCED,
)
# Attributes info fields
userAttribute = gui.TextField(
length=64,
label=_('Username attribute'),
userNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('User name attrs'),
order=100,
tooltip=_('Attribute that contains the username'),
required=True,
tab=_('Attributes'),
)
groupsAttributes = gui.TextField(
length=64,
label=_('Groups attribute'),
order=101,
tooltip=_('Attribute that contains the groups'),
tooltip=_('Fields from where to extract user name'),
required=True,
tab=_('Attributes'),
)
groupNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('Group name attrs'),
order=101,
tooltip=_('Fields from where to extract the groups'),
required=False,
tab=_('Attributes'),
)
realNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('Real name attrs'),
order=102,
tooltip=_('Fields from where to extract the real name'),
required=False,
tab=_('Attributes'),
)
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
if not values:
@ -207,9 +218,8 @@ class OAuth2Authenticator(auths.Authenticator):
gettext('This kind of Authenticator does not support white spaces on field NAME')
)
auth_utils.validateRegexField(self.userAttribute)
auth_utils.validateRegexField(self.userAttribute)
auth_utils.validateRegexField(self.userNameAttr)
auth_utils.validateRegexField(self.userNameAttr)
if self.responseType.value == 'code':
if self.commonGroups.value.strip() == '':
@ -267,7 +277,6 @@ class OAuth2Authenticator(auths.Authenticator):
return TokenInfo.fromJson(req.json())
def authCallback(
self,
parameters: 'types.auth.AuthCallbackParams',
@ -291,6 +300,18 @@ class OAuth2Authenticator(auths.Authenticator):
"""
return f'window.location="{self._getLoginURL(request)}";'
def getGroups(self, username: str, groupsManager: 'auths.GroupsManager'):
data = self.storage.getPickle(username)
if not data:
return
groupsManager.validate(data[1])
def getRealName(self, username: str) -> str:
data = self.storage.getPickle(username)
if not data:
return username
return data[0]
def authCallbackCode(
self,
parameters: 'types.auth.AuthCallbackParams',
@ -322,17 +343,32 @@ class OAuth2Authenticator(auths.Authenticator):
userInfo = token.info
else:
# Get user info
req = requests.get(self.infoEndpoint.value, headers={'Authorization': 'Bearer ' + token.access_token}, timeout=consts.COMMS_TIMEOUT)
req = requests.get(
self.infoEndpoint.value,
headers={'Authorization': 'Bearer ' + token.access_token},
timeout=consts.COMMS_TIMEOUT,
)
if not req.ok:
raise Exception('Error requesting user info: {}'.format(req.text))
userInfo = req.json()
username = ''.join(auth_utils.processRegexField(self.userNameAttr.value, userInfo)).replace(' ', '_')
if len(username) == 0:
raise Exception('No username received')
realName = ''.join(auth_utils.processRegexField(self.realNameAttr.value, userInfo))
# Get groups
groups = auth_utils.processRegexField(self.groupNameAttr.value, userInfo)
# Append common groups
groups.extend(self.commonGroups.value.split(','))
# store groups for this username at storage, so we can check it at a later stage
self.storage.putPickle(username, [realName, groups])
# Validate common groups
groups = self.commonGroups.value.split(',')
gm.validate(groups)
# We don't mind about the token, we only need to authenticate user
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
return auths.AuthenticationResult(
auths.AuthenticationSuccess.OK, username=parameters.get_params.get('username', '')
)
return auths.AuthenticationResult(auths.AuthenticationSuccess.OK, username=username)

View File

@ -38,11 +38,15 @@ from uds.core.util import ensure
logger = logging.getLogger(__name__)
def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str] = None):
"""
Validates the multi line fields refering to attributes
"""
value: str = fieldValue or field.value
if value.strip() == '':
return # Ok, empty
for line in value.splitlines():
if line.find('=') != -1:
pattern = line.split('=')[0:2][1]
@ -53,7 +57,10 @@ def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str]
except Exception as e:
raise exceptions.ValidationError(f'Invalid pattern at {field.label}: {line}') from e
def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]) -> typing.List[str]:
def processRegexField(
field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]
) -> typing.List[str]:
"""Proccesses a field, that can be a multiline field, and returns a list of values
Args:
@ -62,6 +69,9 @@ def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[s
"""
try:
res: typing.List[str] = []
field = field.strip()
if field == '':
return res
def getAttr(attrName: str) -> typing.List[str]:
try:
@ -71,9 +81,7 @@ def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[s
# Check all attributes are present, and has only one value
attrValues = [ensure.is_list(attributes.get(a, [''])) for a in attrList]
if not all([len(v) <= 1 for v in attrValues]):
logger.warning(
'Attribute %s do not has exactly one value, skipping %s', attrName, line
)
logger.warning('Attribute %s do not has exactly one value, skipping %s', attrName, line)
return val
val = [''.join(v) for v in attrValues] # flatten