mirror of
https://github.com/dkmstr/openuds.git
synced 2024-12-23 17:34:17 +03:00
Basic OAuth2 with code interchange done
This commit is contained in:
parent
9419e0f69e
commit
ad9755177a
@ -178,25 +178,36 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
required=False,
|
||||
tab=types.ui.Tab.ADVANCED,
|
||||
)
|
||||
|
||||
# Attributes info fields
|
||||
userAttribute = gui.TextField(
|
||||
length=64,
|
||||
label=_('Username attribute'),
|
||||
|
||||
userNameAttr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('User name attrs'),
|
||||
order=100,
|
||||
tooltip=_('Attribute that contains the username'),
|
||||
tooltip=_('Fields from where to extract user name'),
|
||||
required=True,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
groupsAttributes = gui.TextField(
|
||||
length=64,
|
||||
label=_('Groups attribute'),
|
||||
|
||||
groupNameAttr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Group name attrs'),
|
||||
order=101,
|
||||
tooltip=_('Attribute that contains the groups'),
|
||||
required=True,
|
||||
tooltip=_('Fields from where to extract the groups'),
|
||||
required=False,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
realNameAttr = gui.TextField(
|
||||
length=2048,
|
||||
lines=2,
|
||||
label=_('Real name attrs'),
|
||||
order=102,
|
||||
tooltip=_('Fields from where to extract the real name'),
|
||||
required=False,
|
||||
tab=_('Attributes'),
|
||||
)
|
||||
|
||||
|
||||
def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
|
||||
if not values:
|
||||
@ -206,10 +217,9 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
raise exceptions.ValidationError(
|
||||
gettext('This kind of Authenticator does not support white spaces on field NAME')
|
||||
)
|
||||
|
||||
auth_utils.validateRegexField(self.userAttribute)
|
||||
auth_utils.validateRegexField(self.userAttribute)
|
||||
|
||||
|
||||
auth_utils.validateRegexField(self.userNameAttr)
|
||||
auth_utils.validateRegexField(self.userNameAttr)
|
||||
|
||||
if self.responseType.value == 'code':
|
||||
if self.commonGroups.value.strip() == '':
|
||||
@ -266,7 +276,6 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
raise Exception('Error requesting token: {}'.format(req.text))
|
||||
|
||||
return TokenInfo.fromJson(req.json())
|
||||
|
||||
|
||||
def authCallback(
|
||||
self,
|
||||
@ -291,6 +300,18 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
"""
|
||||
return f'window.location="{self._getLoginURL(request)}";'
|
||||
|
||||
def getGroups(self, username: str, groupsManager: 'auths.GroupsManager'):
|
||||
data = self.storage.getPickle(username)
|
||||
if not data:
|
||||
return
|
||||
groupsManager.validate(data[1])
|
||||
|
||||
def getRealName(self, username: str) -> str:
|
||||
data = self.storage.getPickle(username)
|
||||
if not data:
|
||||
return username
|
||||
return data[0]
|
||||
|
||||
def authCallbackCode(
|
||||
self,
|
||||
parameters: 'types.auth.AuthCallbackParams',
|
||||
@ -313,26 +334,41 @@ class OAuth2Authenticator(auths.Authenticator):
|
||||
return auths.FAILED_AUTH
|
||||
|
||||
token = self._requestToken(request, code)
|
||||
|
||||
|
||||
userInfo: typing.Dict[str, typing.Any]
|
||||
|
||||
|
||||
if self.infoEndpoint.value.strip() == '':
|
||||
if not token.info:
|
||||
raise Exception('No user info received')
|
||||
userInfo = token.info
|
||||
else:
|
||||
# Get user info
|
||||
req = requests.get(self.infoEndpoint.value, headers={'Authorization': 'Bearer ' + token.access_token}, timeout=consts.COMMS_TIMEOUT)
|
||||
req = requests.get(
|
||||
self.infoEndpoint.value,
|
||||
headers={'Authorization': 'Bearer ' + token.access_token},
|
||||
timeout=consts.COMMS_TIMEOUT,
|
||||
)
|
||||
if not req.ok:
|
||||
raise Exception('Error requesting user info: {}'.format(req.text))
|
||||
userInfo = req.json()
|
||||
|
||||
username = ''.join(auth_utils.processRegexField(self.userNameAttr.value, userInfo)).replace(' ', '_')
|
||||
if len(username) == 0:
|
||||
raise Exception('No username received')
|
||||
|
||||
realName = ''.join(auth_utils.processRegexField(self.realNameAttr.value, userInfo))
|
||||
|
||||
# Get groups
|
||||
groups = auth_utils.processRegexField(self.groupNameAttr.value, userInfo)
|
||||
# Append common groups
|
||||
groups.extend(self.commonGroups.value.split(','))
|
||||
|
||||
# store groups for this username at storage, so we can check it at a later stage
|
||||
self.storage.putPickle(username, [realName, groups])
|
||||
|
||||
# Validate common groups
|
||||
groups = self.commonGroups.value.split(',')
|
||||
gm.validate(groups)
|
||||
|
||||
# We don't mind about the token, we only need to authenticate user
|
||||
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
|
||||
return auths.AuthenticationResult(
|
||||
auths.AuthenticationSuccess.OK, username=parameters.get_params.get('username', '')
|
||||
)
|
||||
return auths.AuthenticationResult(auths.AuthenticationSuccess.OK, username=username)
|
||||
|
@ -38,11 +38,15 @@ from uds.core.util import ensure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str] = None):
|
||||
"""
|
||||
Validates the multi line fields refering to attributes
|
||||
"""
|
||||
value: str = fieldValue or field.value
|
||||
if value.strip() == '':
|
||||
return # Ok, empty
|
||||
|
||||
for line in value.splitlines():
|
||||
if line.find('=') != -1:
|
||||
pattern = line.split('=')[0:2][1]
|
||||
@ -53,15 +57,21 @@ def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str]
|
||||
except Exception as e:
|
||||
raise exceptions.ValidationError(f'Invalid pattern at {field.label}: {line}') from e
|
||||
|
||||
def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]) -> typing.List[str]:
|
||||
|
||||
def processRegexField(
|
||||
field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]
|
||||
) -> typing.List[str]:
|
||||
"""Proccesses a field, that can be a multiline field, and returns a list of values
|
||||
|
||||
|
||||
Args:
|
||||
field (str): Field to process
|
||||
attributes (typing.Dict[str, typing.List[str]]): Attributes to use on processing
|
||||
"""
|
||||
try:
|
||||
res: typing.List[str] = []
|
||||
field = field.strip()
|
||||
if field == '':
|
||||
return res
|
||||
|
||||
def getAttr(attrName: str) -> typing.List[str]:
|
||||
try:
|
||||
@ -71,9 +81,7 @@ def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[s
|
||||
# Check all attributes are present, and has only one value
|
||||
attrValues = [ensure.is_list(attributes.get(a, [''])) for a in attrList]
|
||||
if not all([len(v) <= 1 for v in attrValues]):
|
||||
logger.warning(
|
||||
'Attribute %s do not has exactly one value, skipping %s', attrName, line
|
||||
)
|
||||
logger.warning('Attribute %s do not has exactly one value, skipping %s', attrName, line)
|
||||
return val
|
||||
|
||||
val = [''.join(v) for v in attrValues] # flatten
|
||||
|
Loading…
Reference in New Issue
Block a user