1
0
mirror of https://github.com/ansible/awx.git synced 2024-11-02 01:21:21 +03:00

validate group type params

This commit is contained in:
chris meyers 2018-03-19 11:32:40 -04:00
parent 17795f82e8
commit 1c578cdd74
3 changed files with 79 additions and 10 deletions

View File

@ -305,7 +305,7 @@ class SettingsWrapper(UserSettingsHolder):
settings_to_cache['_awx_conf_preload_expires'] = self._awx_conf_preload_expires
self.cache.set_many(settings_to_cache, timeout=SETTING_CACHE_TIMEOUT)
def _get_local(self, name):
def _get_local(self, name, validate=True):
self._preload_cache()
cache_key = Setting.get_cache_key(name)
try:
@ -368,7 +368,10 @@ class SettingsWrapper(UserSettingsHolder):
field.run_validators(internal_value)
return internal_value
else:
return field.run_validation(value)
if validate:
return field.run_validation(value)
else:
return value
except Exception:
logger.warning(
'The current value "%r" for setting "%s" is invalid.',

View File

@ -8,7 +8,11 @@ from django.core.exceptions import ValidationError
# Django Auth LDAP
import django_auth_ldap.config
from django_auth_ldap.config import LDAPSearch, LDAPSearchUnion
from django_auth_ldap.config import (
LDAPSearch,
LDAPSearchUnion,
LDAPGroupType,
)
# This must be imported so get_subclasses picks it up
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
@ -28,6 +32,25 @@ def get_subclasses(cls):
yield subclass
class DependsOnMixin():
def get_depends_on(self):
"""
Get the value of the dependent field.
First try to find the value in the request.
Then fall back to the raw value from the setting in the DB.
"""
from django.conf import settings
dependent_key = iter(self.depends_on).next()
if self.context:
request = self.context.get('request', None)
if request and request.data and \
request.data.get(dependent_key, None):
return request.data.get(dependent_key)
res = settings._get_local(dependent_key, validate=False)
return res
class AuthenticationBackendsField(fields.StringListField):
# Mapping of settings that must be set in order to enable each
@ -326,7 +349,15 @@ class LDAPUserAttrMapField(fields.DictField):
return data
class LDAPGroupTypeField(fields.ChoiceField):
VALID_GROUP_TYPE_PARAMS_MAP = {
'LDAPGroupType': ['name_attr'],
'MemberDNGroupType': ['name_attr', 'member_attr'],
'PosixUIDGroupType': ['name_attr', 'ldap_group_user_attr'],
}
class LDAPGroupTypeField(fields.ChoiceField, DependsOnMixin):
default_error_messages = {
'type_error': _('Expected an instance of LDAPGroupType but got {input_type} instead.'),
@ -357,8 +388,7 @@ class LDAPGroupTypeField(fields.ChoiceField):
if not data:
return None
from django.conf import settings
params = getattr(settings, iter(self.depends_on).next(), None) or {}
params = self.get_depends_on() or {}
cls = find_class_in_modules(data)
if not cls:
return None
@ -370,8 +400,9 @@ class LDAPGroupTypeField(fields.ChoiceField):
# took a parameter.
params_sanitized = dict()
if isinstance(cls, LDAPGroupType):
if 'name_attr' in params:
params_sanitized['name_attr'] = params['name_attr']
for k in VALID_GROUP_TYPE_PARAMS_MAP['LDAPGroupType']:
if k in params:
params_sanitized['name_attr'] = params['name_attr']
if data.endswith('MemberDNGroupType'):
params.setdefault('member_attr', 'member')
@ -383,8 +414,22 @@ class LDAPGroupTypeField(fields.ChoiceField):
return cls(**params_sanitized)
class LDAPGroupTypeParamsField(fields.DictField):
pass
class LDAPGroupTypeParamsField(fields.DictField, DependsOnMixin):
default_error_messages = {
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
}
def to_internal_value(self, value):
value = super(LDAPGroupTypeParamsField, self).to_internal_value(value)
if not value:
return value
group_type_str = self.get_depends_on()
group_type_str = group_type_str or ''
invalid_keys = (set(value.keys()) - set(VALID_GROUP_TYPE_PARAMS_MAP.get(group_type_str, 'LDAPGroupType')))
if invalid_keys:
keys_display = json.dumps(list(invalid_keys)).lstrip('[').rstrip(']')
self.fail('invalid_keys', invalid_keys=keys_display)
return value
class LDAPUserFlagsField(fields.DictField):

View File

@ -1,11 +1,13 @@
import pytest
import mock
from rest_framework.exceptions import ValidationError
from awx.sso.fields import (
SAMLOrgAttrField,
SAMLTeamAttrField,
LDAPGroupTypeParamsField,
)
@ -80,3 +82,22 @@ class TestSAMLTeamAttrField():
field.to_internal_value(data)
assert str(e.value) == str(expected)
class TestLDAPGroupTypeParamsField():
@pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')),
])
def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type)
with pytest.raises(type(expected)) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)