diff --git a/awx/main/tests/inventory.py b/awx/main/tests/inventory.py index 036246e0bd..93ef6817f2 100644 --- a/awx/main/tests/inventory.py +++ b/awx/main/tests/inventory.py @@ -1480,6 +1480,7 @@ class InventoryUpdatesTest(BaseTransactionTest): # updated because the instance ID matches. enabled_host_pks = set(self.inventory.hosts.filter(enabled=True).values_list('pk', flat=True)) instance_types = {} + key_names = {} for host in self.inventory.hosts.all(): host.enabled = False host.name = 'changed-%s' % host.name @@ -1488,6 +1489,10 @@ class InventoryUpdatesTest(BaseTransactionTest): instance_type = host.variables_dict.get('ec2_instance_type', '') if instance_type: instance_types.setdefault(instance_type, []).append(host.pk) + # Get key names for later use with instance_filters. + key_name = host.variables_dict.get('ec2_key_name', '') + if key_name: + key_names.setdefault(key_name, []).append(host.pk) old_host_pks = set(self.inventory.hosts.values_list('pk', flat=True)) self.check_inventory_source(inventory_source, initial=False, enabled_host_pks=enabled_host_pks) new_host_pks = set(self.inventory.hosts.values_list('pk', flat=True)) @@ -1495,18 +1500,28 @@ class InventoryUpdatesTest(BaseTransactionTest): # Verify that main group is in top level groups (hasn't been added as # its own child). self.assertTrue(self.group in self.inventory.root_groups) - # Now add instance filters and verify that only the matching hosts are + + # Now add instance filters and verify that only matching hosts are # synced, specify new cache path to force refresh. - cache_path2 = tempfile.mkdtemp(prefix='awx_ec2_') - self._temp_paths.append(cache_path2) + cache_path = tempfile.mkdtemp(prefix='awx_ec2_') + self._temp_paths.append(cache_path) instance_type = max(instance_types.items(), key=lambda x: len(x[1]))[0] inventory_source.instance_filters = 'instance-type=%s' % instance_type - inventory_source.source_vars = '---\n\nnested_groups: false\ncache_path: %s\n' % cache_path2 + inventory_source.source_vars = '---\n\nnested_groups: false\ncache_path: %s\n' % cache_path inventory_source.overwrite = True inventory_source.save() self.check_inventory_source(inventory_source, initial=False) - new_host_pks = set(self.inventory.hosts.filter(active=True).values_list('pk', flat=True)) - self.assertEqual(new_host_pks, set(instance_types[instance_type])) + for host in self.inventory.hosts.filter(active=True): + self.assertEqual(host.variables_dict['ec2_instance_type'], instance_type) + + # Try invalid instance filters: empty, only "=", more than one "=", whitespace + cache_path = tempfile.mkdtemp(prefix='awx_ec2_') + self._temp_paths.append(cache_path) + key_name = max(key_names.items(), key=lambda x: len(x[1]))[0] + inventory_source.instance_filters = ',=,image-id=ami=12345678,instance-type=%s, key-name=%s' % (instance_type, key_name) + inventory_source.source_vars = '---\n\nnested_groups: false\ncache_path: %s\n' % cache_path + inventory_source.save() + self.check_inventory_source(inventory_source, initial=False) def test_update_from_ec2_with_nested_groups(self): source_username = getattr(settings, 'TEST_AWS_ACCESS_KEY_ID', '') diff --git a/awx/plugins/inventory/ec2.py b/awx/plugins/inventory/ec2.py index 3d88cbf450..b6e9e39177 100755 --- a/awx/plugins/inventory/ec2.py +++ b/awx/plugins/inventory/ec2.py @@ -294,11 +294,16 @@ class Ec2Inventory(object): except ConfigParser.NoOptionError, e: self.pattern_exclude = None - # Instance filters (see boto and EC2 API docs) + # Instance filters (see boto and EC2 API docs). Ignore invalid filters. self.ec2_instance_filters = defaultdict(list) if config.has_option('ec2', 'instance_filters'): - for x in config.get('ec2', 'instance_filters', '').split(','): - filter_key, filter_value = x.split('=') + for instance_filter in config.get('ec2', 'instance_filters', '').split(','): + instance_filter = instance_filter.strip() + if not instance_filter or '=' not in instance_filter: + continue + filter_key, filter_value = [x.strip() for x in instance_filter.split('=', 1)] + if not filter_key: + continue self.ec2_instance_filters[filter_key].append(filter_value) def parse_cli_args(self):