diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 873f822171..98090bb611 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -45,7 +45,7 @@ from awx.main.fields import ImplicitRoleField from awx.main.utils import ( get_type_for_model, get_model_for_type, timestamp_apiformat, camelcase_to_underscore, getattrd, parse_yaml_or_json, - has_model_field_prefetched) + has_model_field_prefetched, extract_ansible_vars) from awx.main.utils.filters import SmartFilter from awx.main.validators import vars_validate_or_raise @@ -2759,6 +2759,14 @@ class AdHocCommandSerializer(UnifiedJobSerializer): ret['name'] = obj.module_name return ret + def validate_extra_vars(self, value): + redacted_extra_vars, removed_vars = extract_ansible_vars(value) + if removed_vars: + raise serializers.ValidationError(_( + "Variables {} are prohibited from use in ad hoc commands." + ).format(",".join(removed_vars))) + return vars_validate_or_raise(value) + class AdHocCommandCancelSerializer(AdHocCommandSerializer): diff --git a/awx/api/views.py b/awx/api/views.py index e07b7f8b96..ba7fc709e1 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -69,7 +69,7 @@ from awx.conf.license import get_license, feature_enabled, feature_exists, Licen from awx.main.models import * # noqa from awx.main.utils import * # noqa from awx.main.utils import ( - callback_filter_out_ansible_extra_vars, + extract_ansible_vars, decrypt_field, ) from awx.main.utils.filters import SmartFilter @@ -3160,7 +3160,8 @@ class JobTemplateCallback(GenericAPIView): # Everything is fine; actually create the job. kv = {"limit": limit, "launch_type": 'callback'} if extra_vars is not None and job_template.ask_variables_on_launch: - kv['extra_vars'] = callback_filter_out_ansible_extra_vars(extra_vars) + extra_vars_redacted, removed = extract_ansible_vars(extra_vars) + kv['extra_vars'] = extra_vars_redacted with transaction.atomic(): job = job_template.create_job(**kv) diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index e148d45862..91285ecc16 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -34,7 +34,7 @@ from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin from awx.main.utils import ( decrypt_field, _inventory_updates, copy_model_by_class, copy_m2m_relationships, - get_type_for_model + get_type_for_model, parse_yaml_or_json ) from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.consumers import emit_channel_notification @@ -878,21 +878,13 @@ class UnifiedJob(PolymorphicModel, PasswordFieldsModel, CommonModelNameNotUnique return [] def handle_extra_data(self, extra_data): - if hasattr(self, 'extra_vars'): - extra_vars = {} - if isinstance(extra_data, dict): - extra_vars = extra_data - elif extra_data is None: - return - else: - if extra_data == "": - return - try: - extra_vars = json.loads(extra_data) - except Exception as e: - logger.warn("Exception deserializing extra vars: " + str(e)) + if hasattr(self, 'extra_vars') and extra_data: + try: + extra_data_dict = parse_yaml_or_json(extra_data, silent_failure=False) + except Exception as e: + logger.warn("Exception deserializing extra vars: " + str(e)) evars = self.extra_vars_dict - evars.update(extra_vars) + evars.update(extra_data_dict) self.update_fields(extra_vars=json.dumps(evars)) @property diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 02bea178f5..3bc4bc9c45 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -56,7 +56,7 @@ from awx.main.utils import (get_ansible_version, get_ssh_version, decrypt_field, check_proot_installed, build_proot_temp_dir, get_licenser, wrap_args_with_proot, get_system_task_capacity, OutputEventFilter, parse_yaml_or_json, ignore_inventory_computed_fields, ignore_inventory_group_removal, - get_type_for_model) + get_type_for_model, extract_ansible_vars) from awx.main.utils.reload import restart_local_services, stop_local_services from awx.main.utils.handlers import configure_external_logger from awx.main.consumers import emit_channel_notification @@ -2139,6 +2139,12 @@ class RunAdHocCommand(BaseTask): args.append('-%s' % ('v' * min(5, ad_hoc_command.verbosity))) if ad_hoc_command.extra_vars_dict: + redacted_extra_vars, removed_vars = extract_ansible_vars(ad_hoc_command.extra_vars_dict) + if removed_vars: + raise ValueError(_( + "unable to use {} variables with ad hoc commands" + ).format(",".format(removed_vars))) + args.extend(['-e', json.dumps(ad_hoc_command.extra_vars_dict)]) args.extend(['-m', ad_hoc_command.module_name]) diff --git a/awx/main/tests/unit/utils/test_common.py b/awx/main/tests/unit/utils/test_common.py index 41f9012040..44f960a673 100644 --- a/awx/main/tests/unit/utils/test_common.py +++ b/awx/main/tests/unit/utils/test_common.py @@ -5,6 +5,7 @@ import os import pytest from uuid import uuid4 +import json from django.core.cache import cache @@ -115,3 +116,12 @@ def test_memoize_parameter_error(): with pytest.raises(common.IllegalArgumentError): fn() + +def test_extract_ansible_vars(): + my_dict = { + "foobar": "baz", + "ansible_connetion_setting": "1928" + } + redacted, var_list = common.extract_ansible_vars(json.dumps(my_dict)) + assert var_list == set(['ansible_connetion_setting']) + assert redacted == {"foobar": "baz"} diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index dbbc589392..22183c33dd 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -41,7 +41,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'ignore_inventory_computed_fields', 'ignore_inventory_group_removal', '_inventory_updates', 'get_pk_from_dict', 'getattrd', 'NoDefaultProvided', 'get_current_apps', 'set_current_apps', 'OutputEventFilter', - 'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity', + 'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity', 'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict', 'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest', 'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',] @@ -581,7 +581,7 @@ def cache_list_capabilities(page, prefetch_list, model, user): obj.capabilities_cache[display_method] = True -def parse_yaml_or_json(vars_str): +def parse_yaml_or_json(vars_str, silent_failure=True): ''' Attempt to parse a string with variables, and if attempt fails, return an empty dictionary. @@ -595,7 +595,9 @@ def parse_yaml_or_json(vars_str): vars_dict = yaml.safe_load(vars_str) assert isinstance(vars_dict, dict) except (yaml.YAMLError, TypeError, AttributeError, AssertionError): - vars_dict = {} + if silent_failure: + return {} + raise return vars_dict @@ -877,13 +879,18 @@ class OutputEventFilter(object): self._current_event_data = None -def callback_filter_out_ansible_extra_vars(extra_vars): - extra_vars_redacted = {} +def is_ansible_variable(key): + return key.startswith('ansible_') + + +def extract_ansible_vars(extra_vars): extra_vars = parse_yaml_or_json(extra_vars) - for key, value in extra_vars.iteritems(): - if not key.startswith('ansible_'): - extra_vars_redacted[key] = value - return extra_vars_redacted + ansible_vars = set([]) + for key in extra_vars.keys(): + if is_ansible_variable(key): + extra_vars.pop(key) + ansible_vars.add(key) + return (extra_vars, ansible_vars) def get_search_fields(model):