From 0dae058bef7a658d1bc84089c4cf2c8401154088 Mon Sep 17 00:00:00 2001 From: Aaron Tan Date: Wed, 14 Jun 2017 15:45:15 -0400 Subject: [PATCH] Set priority rules for find_matching_hosts. --- awx/api/views.py | 49 +++++++------------ awx/main/models/inventory.py | 12 +++++ .../functional/api/test_job_runtime_params.py | 30 +++++++++++- 3 files changed, 60 insertions(+), 31 deletions(-) diff --git a/awx/api/views.py b/awx/api/views.py index 1fb87ff7d7..3eb49b858d 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -3046,40 +3046,29 @@ class JobTemplateCallback(GenericAPIView): # Find the host objects to search for a match. obj = self.get_object() hosts = obj.inventory.hosts.all() - # First try for an exact match on the name. - try: - return set([hosts.get(name__in=remote_hosts)]) - except (Host.DoesNotExist, Host.MultipleObjectsReturned): - pass - # Next, try matching based on name or ansible_host variables. - matches = set() + # Populate host_mappings + host_mappings = {} for host in hosts: - for host_var in ['ansible_ssh_host', 'ansible_host']: - ansible_host = host.variables_dict.get(host_var, '') - if ansible_host in remote_hosts: - matches.add(host) - if host.name != ansible_host and host.name in remote_hosts: - matches.add(host) + host_name = host.get_effective_host_name() + host_mappings.setdefault(host_name, []) + host_mappings[host_name].append(host) + # Try finding direct match + matches = set() + for host_name in remote_hosts: + if host_name in host_mappings: + matches.update(host_mappings[host_name]) if len(matches) == 1: return matches - # Try to resolve forward addresses for each host to find matches. - for host in hosts: - hostnames = set([host.name]) - for host_var in ['ansible_ssh_host', 'ansible_host']: - ansible_host = host.variables_dict.get(host_var, '') - if ansible_host: - hostnames.add(ansible_host) - for hostname in hostnames: - try: - result = socket.getaddrinfo(hostname, None) - possible_ips = set(x[4][0] for x in result) - possible_ips.discard(hostname) - if possible_ips and possible_ips & remote_hosts: - matches.add(host) - except socket.gaierror: - pass - # Return all matches found. + for host_name in host_mappings: + try: + result = socket.getaddrinfo(host_name, None) + possible_ips = set(x[4][0] for x in result) + possible_ips.discard(host_name) + if possible_ips and possible_ips & remote_hosts: + matches.update(host_mappings[host_name]) + except socket.gaierror: + pass return matches def get(self, request, *args, **kwargs): diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index 05a5cd2364..c8751c4164 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -527,6 +527,18 @@ class Host(CommonModelNameNotUnique): self.ansible_facts[module] = facts self.save() + def get_effective_host_name(self): + ''' + Return the name of the host that will be used in actual ansible + command run. + ''' + host_name = self.name + if 'ansible_ssh_host' in self.variables_dict: + host_name = self.variables_dict['ansible_ssh_host'] + if 'ansible_host' in self.variables_dict: + host_name = self.variables_dict['ansible_host'] + return host_name + class Group(CommonModelNameNotUnique): ''' diff --git a/awx/main/tests/functional/api/test_job_runtime_params.py b/awx/main/tests/functional/api/test_job_runtime_params.py index 2941d977fd..69c82e5d28 100644 --- a/awx/main/tests/functional/api/test_job_runtime_params.py +++ b/awx/main/tests/functional/api/test_job_runtime_params.py @@ -3,7 +3,7 @@ import yaml from awx.api.serializers import JobLaunchSerializer from awx.main.models.credential import Credential -from awx.main.models.inventory import Inventory +from awx.main.models.inventory import Inventory, Host from awx.main.models.jobs import Job, JobTemplate from awx.api.versioning import reverse @@ -431,3 +431,31 @@ def test_callback_ignore_unprompted_extra_var(mocker, survey_spec_factory, job_t 'limit': 'single-host'},) mock_job.signal_start.assert_called_once() + + +@pytest.mark.django_db +@pytest.mark.job_runtime_vars +def test_callback_find_matching_hosts(mocker, get, job_template_prompts, admin_user): + job_template = job_template_prompts(False) + job_template.host_config_key = "foo" + job_template.save() + host_with_alias = Host(name='localhost', inventory=job_template.inventory) + host_with_alias.save() + with mocker.patch('awx.main.access.BaseAccess.check_license'): + r = get(reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), + user=admin_user, expect=200) + assert tuple(r.data['matching_hosts']) == ('localhost',) + + +@pytest.mark.django_db +@pytest.mark.job_runtime_vars +def test_callback_extra_var_takes_priority_over_host_name(mocker, get, job_template_prompts, admin_user): + job_template = job_template_prompts(False) + job_template.host_config_key = "foo" + job_template.save() + host_with_alias = Host(name='localhost', variables={'ansible_host': 'foobar'}, inventory=job_template.inventory) + host_with_alias.save() + with mocker.patch('awx.main.access.BaseAccess.check_license'): + r = get(reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), + user=admin_user, expect=200) + assert not r.data['matching_hosts']