diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index 45cfcac087..4cef92391f 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -430,7 +430,8 @@ def load_inventory_source(source, all_group=None, group_filter_re=None, original_all_group = all_group if not os.path.exists(source): raise IOError('Source does not exist: %s' % source) - source = os.path.join(os.path.dirname(source) or os.getcwd(), source) + source = os.path.join(os.getcwd(), os.path.dirname(source), + os.path.basename(source)) source = os.path.normpath(os.path.abspath(source)) if os.path.isdir(source): all_group = all_group or MemGroup('all', source) @@ -511,7 +512,7 @@ class Command(NoArgsCommand): ) def init_logging(self): - log_levels = dict(enumerate([logging.ERROR, logging.INFO, + log_levels = dict(enumerate([logging.WARNING, logging.INFO, logging.DEBUG, 0])) self.logger = logging.getLogger('awx.main.commands.inventory_import') self.logger.setLevel(log_levels.get(self.verbosity, 0)) @@ -577,15 +578,24 @@ class Command(NoArgsCommand): # FIXME: Wait or raise error if inventory is being updated by another # source. - def load_into_database(self): - ''' - Load inventory from in-memory groups to the database, overwriting or - merging as appropriate. - ''' + def _batch_add_m2m(self, related_manager, *objs, **kwargs): + key = (related_manager.instance.pk, related_manager.through._meta.db_table) + flush = bool(kwargs.get('flush', False)) + if not hasattr(self, '_batch_add_m2m_cache'): + self._batch_add_m2m_cache = {} + cached_objs = self._batch_add_m2m_cache.setdefault(key, []) + cached_objs.extend(objs) + if len(cached_objs) > 100 or flush: + if len(cached_objs): + related_manager.add(*cached_objs) + self._batch_add_m2m_cache[key] = [] - # Find any hosts in the database without an instance_id set that may - # still have one available via host variables. - db_instance_id_map = {} + def _build_db_instance_id_map(self): + ''' + Find any hosts in the database without an instance_id set that may + still have one available via host variables. + ''' + self.db_instance_id_map = {} if self.instance_id_var: if self.inventory_source.group: host_qs = self.inventory_source.group.all_hosts @@ -597,11 +607,14 @@ class Command(NoArgsCommand): instance_id = host.variables_dict.get(self.instance_id_var, '') if not instance_id: continue - db_instance_id_map[instance_id] = host.pk + self.db_instance_id_map[instance_id] = host.pk - # Update instance ID for each imported host and define a mapping of - # instance IDs to MemHost instances. - mem_instance_id_map = {} + def _build_mem_instance_id_map(self): + ''' + Update instance ID for each imported host and define a mapping of + instance IDs to MemHost instances. + ''' + self.mem_instance_id_map = {} if self.instance_id_var: for mem_host in self.all_group.all_hosts.values(): instance_id = mem_host.variables.get(self.instance_id_var, '') @@ -610,80 +623,113 @@ class Command(NoArgsCommand): mem_host.name, self.instance_id_var) continue mem_host.instance_id = instance_id - mem_instance_id_map[instance_id] = mem_host.name - #self.logger.warning('%r', instance_id_map) + self.mem_instance_id_map[instance_id] = mem_host.name - # If overwrite is set, for each host in the database that is NOT in - # the local list, delete it. When importing from a cloud inventory - # source attached to a specific group, only delete hosts beneath that - # group. Delete each host individually so signal handlers will run. - if self.overwrite: - if self.inventory_source.group: - del_hosts = self.inventory_source.group.all_hosts - # FIXME: Also include hosts from inventory_source.managed_hosts? - else: - del_hosts = self.inventory.hosts.filter(active=True) - instance_ids = set(mem_instance_id_map.keys()) - host_pks = set([v for k,v in db_instance_id_map.items() if k in instance_ids]) - host_names = set(mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys()) + def _delete_hosts(self): + ''' + For each host in the database that is NOT in the local list, delete + it. When importing from a cloud inventory source attached to a + specific group, only delete hosts beneath that group. Delete each + host individually so signal handlers will run. + ''' + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + if self.inventory_source.group: + del_hosts = self.inventory_source.group.all_hosts + # FIXME: Also include hosts from inventory_source.managed_hosts? + else: + del_hosts = self.inventory.hosts.filter(active=True) + if self.instance_id_var: + instance_ids = set(self.mem_instance_id_map.keys()) + host_pks = set([v for k,v in self.db_instance_id_map.items() if k in instance_ids]) + host_names = set(self.mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys()) del_hosts = del_hosts.exclude(Q(name__in=host_names) | Q(instance_id__in=instance_ids) | Q(pk__in=host_pks)) - for host in del_hosts: - host_name = host.name - host.mark_inactive() - self.logger.info('Deleted host "%s"', host_name) + else: + del_hosts = del_hosts.exclude(name__in=self.all_group.all_hosts.keys()) + for host in del_hosts: + host_name = host.name + host.mark_inactive()#from_inventory_import=True) + self.logger.info('Deleted host "%s"', host_name) + if settings.SQL_DEBUG: + self.logger.warning('host deletions took %d queries for %d hosts', + len(connection.queries) - queries_before, + del_hosts.count()) + def _delete_groups(self): + ''' # If overwrite is set, for each group in the database that is NOT in # the local list, delete it. When importing from a cloud inventory # source attached to a specific group, only delete children of that # group. Delete each group individually so signal handlers will run. - if self.overwrite: - if self.inventory_source.group: - del_groups = self.inventory_source.group.all_children - # FIXME: Also include groups from inventory_source.managed_groups? - else: - del_groups = self.inventory.groups.filter(active=True) - group_names = set(self.all_group.all_groups.keys()) - del_groups = del_groups.exclude(name__in=group_names) - for group in del_groups: - group_name = group.name - group.mark_inactive(recompute=False) - self.logger.info('Group "%s" deleted', group_name) + ''' + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + if self.inventory_source.group: + del_groups = self.inventory_source.group.all_children + # FIXME: Also include groups from inventory_source.managed_groups? + else: + del_groups = self.inventory.groups.filter(active=True) + group_names = set(self.all_group.all_groups.keys()) + del_groups = del_groups.exclude(name__in=group_names) + for group in del_groups: + group_name = group.name + group.mark_inactive(recompute=False)#from_inventory_import=True) + self.logger.info('Group "%s" deleted', group_name) + if settings.SQL_DEBUG: + self.logger.warning('group deletions took %d queries for %d groups', + len(connection.queries) - queries_before, + del_groups.count()) - # If overwrite is set, clear all invalid child relationships for groups - # and all invalid host memberships. When importing from a cloud - # inventory source attached to a specific group, only clear - # relationships for hosts and groups that are beneath the inventory - # source group. - if self.overwrite: - if self.inventory_source.group: - db_groups = self.inventory_source.group.all_children - else: - db_groups = self.inventory.groups.filter(active=True) - for db_group in db_groups: - db_children = db_group.children.filter(active=True) - mem_children = self.all_group.all_groups[db_group.name].children - mem_children_names = [g.name for g in mem_children] - for db_child in db_children.exclude(name__in=mem_children_names): - if db_child not in db_group.children.filter(active=True): - continue - db_group.children.remove(db_child) - self.logger.info('Group "%s" removed from group "%s"', - db_child.name, db_group.name) - db_hosts = db_group.hosts.filter(active=True) - mem_hosts = self.all_group.all_groups[db_group.name].hosts - mem_host_names = set([h.name for h in mem_hosts if not h.instance_id]) - mem_instance_ids = set([h.instance_id for h in mem_hosts if h.instance_id]) - db_host_pks = set([v for k,v in db_instance_id_map.items() if k in mem_instance_ids]) - for db_host in db_hosts.exclude(Q(name__in=mem_host_names) | Q(instance_id__in=mem_instance_ids) | Q(pk__in=db_host_pks)): - if db_host not in db_group.hosts.filter(active=True): - continue - db_group.hosts.remove(db_host) - self.logger.info('Host "%s" removed from group "%s"', - db_host.name, db_group.name) + def _delete_group_children_and_hosts(self): + ''' + Clear all invalid child relationships for groups and all invalid host + memberships. When importing from a cloud inventory source attached to + a specific group, only clear relationships for hosts and groups that + are beneath the inventory source group. + ''' + # FIXME: Optimize performance! + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + group_group_count = 0 + group_host_count = 0 + if self.inventory_source.group: + db_groups = self.inventory_source.group.all_children + else: + db_groups = self.inventory.groups.filter(active=True) + for db_group in db_groups: + db_children = db_group.children.filter(active=True) + mem_children = self.all_group.all_groups[db_group.name].children + mem_children_names = [g.name for g in mem_children] + for db_child in db_children.exclude(name__in=mem_children_names): + group_group_count += 1 + if db_child not in db_group.children.filter(active=True): + continue + db_group.children.remove(db_child) + self.logger.info('Group "%s" removed from group "%s"', + db_child.name, db_group.name) + db_hosts = db_group.hosts.filter(active=True) + mem_hosts = self.all_group.all_groups[db_group.name].hosts + mem_host_names = set([h.name for h in mem_hosts if not h.instance_id]) + mem_instance_ids = set([h.instance_id for h in mem_hosts if h.instance_id]) + db_host_pks = set([v for k,v in self.db_instance_id_map.items() if k in mem_instance_ids]) + for db_host in db_hosts.exclude(Q(name__in=mem_host_names) | Q(instance_id__in=mem_instance_ids) | Q(pk__in=db_host_pks)): + group_host_count += 1 + if db_host not in db_group.hosts.filter(active=True): + continue + db_group.hosts.remove(db_host) + self.logger.info('Host "%s" removed from group "%s"', + db_host.name, db_group.name) + if settings.SQL_DEBUG: + self.logger.warning('group-group and group-host deletions took %d queries for %d relationships', + len(connection.queries) - queries_before, + group_group_count + group_host_count) - # Update/overwrite variables from "all" group. If importing from a - # cloud source attached to a specific group, variables will be set on - # the base group, otherwise they will be set on the whole inventory. + def _update_inventory(self): + ''' + Update/overwrite variables from "all" group. If importing from a + cloud source attached to a specific group, variables will be set on + the base group, otherwise they will be set on the whole inventory. + ''' if self.inventory_source.group: all_obj = self.inventory_source.group all_obj.inventory_sources.add(self.inventory_source) @@ -706,151 +752,262 @@ class Command(NoArgsCommand): else: self.logger.info('%s variables unmodified', all_name.capitalize()) - # FIXME: Attribute changes to superuser? - - # For each group in the local list, create it if it doesn't exist in - # the database. Otherwise, update/replace database variables from the - # imported data. Associate with the inventory source group if - # importing from cloud inventory source. - for k,v in self.all_group.all_groups.iteritems(): - variables = json.dumps(v.variables) - defaults = dict(variables=variables, description='imported') - group, created = self.inventory.groups.get_or_create(name=k, - defaults=defaults) - # Access auto one-to-one attribute to create related object. - group.inventory_source - if created: - self.logger.info('Group "%s" added', k) + def _create_update_groups(self): + ''' + For each group in the local list, create it if it doesn't exist in the + database. Otherwise, update/replace database variables from the + imported data. Associate with the inventory source group if importing + from cloud inventory source. + ''' + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + inv_src_group = self.inventory_source.group + group_names = set(self.all_group.all_groups.keys()) + for group in self.inventory.groups.filter(name__in=group_names): + mem_group = self.all_group.all_groups[group.name] + db_variables = group.variables_dict + if self.overwrite_vars or self.overwrite: + db_variables = mem_group.variables else: - db_variables = group.variables_dict + db_variables.update(mem_group.variables) + if db_variables != group.variables_dict: + group.variables = json.dumps(db_variables) + group.save(update_fields=['variables']) if self.overwrite_vars or self.overwrite: - db_variables = v.variables + self.logger.info('Group "%s" variables replaced', group.name) else: - db_variables.update(v.variables) - if db_variables != group.variables_dict: - group.variables = json.dumps(db_variables) - group.save(update_fields=['variables']) - if self.overwrite_vars or self.overwrite: - self.logger.info('Group "%s" variables replaced', k) - else: - self.logger.info('Group "%s" variables updated', k) - else: - self.logger.info('Group "%s" variables unmodified', k) - if self.inventory_source.group and self.inventory_source.group != group: - self.inventory_source.group.children.add(group) - group.inventory_sources.add(self.inventory_source) + self.logger.info('Group "%s" variables updated', group.name) + else: + self.logger.info('Group "%s" variables unmodified', group.name) + group_names.remove(group.name) + if inv_src_group and inv_src_group != group: + self._batch_add_m2m(inv_src_group.children, group) + self._batch_add_m2m(self.inventory_source.groups, group) + for group_name in group_names: + mem_group = self.all_group.all_groups[group_name] + group = self.inventory.groups.create(name=group_name, variables=json.dumps(mem_group.variables), description='imported') + # Access auto one-to-one attribute to create related object. + #group.inventory_source + InventorySource.objects.create(group=group, inventory=self.inventory, name=('%s (%s)' % (group_name, self.inventory.name))) + self.logger.info('Group "%s" added', group.name) + if inv_src_group: + self._batch_add_m2m(inv_src_group.children, group) + self._batch_add_m2m(self.inventory_source.groups, group) + if inv_src_group: + self._batch_add_m2m(inv_src_group.children, flush=True) + self._batch_add_m2m(self.inventory_source.groups, flush=True) + if settings.SQL_DEBUG: + self.logger.warning('group updates took %d queries for %d groups', + len(connection.queries) - queries_before, + len(self.all_group.all_groups)) - # For each host in the local list, create it if it doesn't exist in - # the database. Otherwise, update/replace database variables from the - # imported data. Associate with the inventory source group if - # importing from cloud inventory source. + def _update_db_host_from_mem_host(self, db_host, mem_host): + # Update host variables. + db_variables = db_host.variables_dict + if self.overwrite_vars or self.overwrite: + db_variables = mem_host.variables + else: + db_variables.update(mem_host.variables) + update_fields = [] + if db_variables != db_host.variables_dict: + db_host.variables = json.dumps(db_variables) + update_fields.append('variables') + # Update host enabled flag. + enabled = None + if self.enabled_var and self.enabled_var in mem_host.variables: + value = mem_host.variables[self.enabled_var] + if self.enabled_value is not None: + enabled = bool(unicode(self.enabled_value) == unicode(value)) + else: + enabled = bool(value) + if enabled is not None and db_host.enabled != enabled: + db_host.enabled = enabled + update_fields.append('enabled') + # Update host name. + if mem_host.name != db_host.name: + old_name = db_host.name + db_host.name = mem_host.name + update_fields.append('name') + # Update host instance_id. + if self.instance_id_var: + instance_id = mem_host.variables.get(self.instance_id_var, '') + else: + instance_id = '' + if instance_id != db_host.instance_id: + old_instance_id = db_host.instance_id + db_host.instance_id = instance_id + update_fields.append('instance_id') + # Update host and display message(s) on what changed. + if update_fields: + db_host.save(update_fields=update_fields) + if 'name' in update_fields: + self.logger.info('Host renamed from "%s" to "%s"', old_name, mem_host.name) + if 'instance_id' in update_fields: + if old_instance_id: + self.logger.info('Host "%s" instance_id updated', mem_host.name) + else: + self.logger.info('Host "%s" instance_id added', mem_host.name) + if 'variables' in update_fields: + if self.overwrite_vars or self.overwrite: + self.logger.info('Host "%s" variables replaced', mem_host.name) + else: + self.logger.info('Host "%s" variables updated', mem_host.name) + else: + self.logger.info('Host "%s" variables unmodified', mem_host.name) + if 'enabled' in update_fields: + if enabled: + self.logger.info('Host "%s" is now enabled', mem_host.name) + else: + self.logger.info('Host "%s" is now disabled', mem_host.name) + if self.inventory_source.group: + self._batch_add_m2m(self.inventory_source.group.hosts, db_host) + self._batch_add_m2m(self.inventory_source.hosts, db_host) + #host.update_computed_fields(False, False) + + def _create_update_hosts(self): + ''' + For each host in the local list, create it if it doesn't exist in the + database. Otherwise, update/replace database variables from the + imported data. Associate with the inventory source group if importing + from cloud inventory source. + ''' + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + host_pks_updated = set() + mem_host_pk_map = {} + mem_host_instance_id_map = {} + mem_host_name_map = {} + mem_host_names_to_update = set(self.all_group.all_hosts.keys()) for k,v in self.all_group.all_hosts.iteritems(): - variables = json.dumps(v.variables) - defaults = dict(variables=variables, name=k, description='imported') + instance_id = '' + if self.instance_id_var: + instance_id = v.variables.get(self.instance_id_var, '') + if instance_id in self.db_instance_id_map: + mem_host_pk_map[self.db_instance_id_map[instance_id]] = v + elif instance_id: + mem_host_instance_id_map[instance_id] = v + else: + mem_host_name_map[k] = v + + # Update all existing hosts where we know the PK based on instance_id. + for db_host in self.inventory.hosts.filter(active=True, pk__in=mem_host_pk_map.keys()): + mem_host = mem_host_pk_map[db_host.pk] + self._update_db_host_from_mem_host(db_host, mem_host) + host_pks_updated.add(db_host.pk) + mem_host_names_to_update.discard(mem_host.name) + + # Update all existing hosts where we know the instance_id. + for db_host in self.inventory.hosts.filter(active=True, instance_id__in=mem_host_instance_id_map.keys()).exclude(pk__in=host_pks_updated): + mem_host = mem_host_instance_id_map[db_host.instance_id] + self._update_db_host_from_mem_host(db_host, mem_host) + host_pks_updated.add(db_host.pk) + mem_host_names_to_update.discard(mem_host.name) + + # Update all existing hosts by name. + for db_host in self.inventory.hosts.filter(active=True, name__in=mem_host_name_map.keys()).exclude(pk__in=host_pks_updated): + mem_host = mem_host_name_map[db_host.name] + self._update_db_host_from_mem_host(db_host, mem_host) + host_pks_updated.add(db_host.pk) + mem_host_names_to_update.discard(mem_host.name) + + # Create any new hosts. + for mem_host_name in mem_host_names_to_update: + mem_host = self.all_group.all_hosts[mem_host_name] + host_attrs = dict(variables=json.dumps(mem_host.variables), + name=mem_host_name, description='imported') enabled = None - if self.enabled_var and self.enabled_var in v.variables: - value = v.variables[self.enabled_var] + if self.enabled_var and self.enabled_var in mem_host.variables: + value = mem_host.variables[self.enabled_var] if self.enabled_value is not None: enabled = bool(unicode(self.enabled_value) == unicode(value)) else: enabled = bool(value) - defaults['enabled'] = enabled - instance_id = '' + host_attrs['enabled'] = enabled if self.instance_id_var: - instance_id = v.variables.get(self.instance_id_var, '') - defaults['instance_id'] = instance_id - if instance_id in db_instance_id_map: - attrs = {'pk': db_instance_id_map[instance_id]} - elif instance_id: - attrs = {'instance_id': instance_id} - defaults.pop('instance_id') + instance_id = mem_host.variables.get(self.instance_id_var, '') + host_attrs['instance_id'] = instance_id + db_host = self.inventory.hosts.create(**host_attrs) + if enabled is False: + self.logger.info('Host "%s" added (disabled)', mem_host_name) else: - attrs = {'name': k} - defaults.pop('name') - attrs['defaults'] = defaults - host, created = self.inventory.hosts.get_or_create(**attrs) - if created: - if enabled is False: - self.logger.info('Host "%s" added (disabled)', k) - else: - self.logger.info('Host "%s" added', k) - #self.logger.info('Host variables: %s', variables) - else: - db_variables = host.variables_dict - if self.overwrite_vars or self.overwrite: - db_variables = v.variables - else: - db_variables.update(v.variables) - update_fields = [] - if db_variables != host.variables_dict: - host.variables = json.dumps(db_variables) - update_fields.append('variables') - if enabled is not None and host.enabled != enabled: - host.enabled = enabled - update_fields.append('enabled') - if k != host.name: - old_name = host.name - host.name = k - update_fields.append('name') - if instance_id != host.instance_id: - old_instance_id = host.instance_id - host.instance_id = instance_id - update_fields.append('instance_id') - if update_fields: - host.save(update_fields=update_fields) - if 'name' in update_fields: - self.logger.info('Host renamed from "%s" to "%s"', old_name, k) - if 'instance_id' in update_fields: - if old_instance_id: - self.logger.info('Host "%s" instance_id updated', k) - else: - self.logger.info('Host "%s" instance_id added', k) - if 'variables' in update_fields: - if self.overwrite_vars or self.overwrite: - self.logger.info('Host "%s" variables replaced', k) - else: - self.logger.info('Host "%s" variables updated', k) - else: - self.logger.info('Host "%s" variables unmodified', k) - if 'enabled' in update_fields: - if enabled: - self.logger.info('Host "%s" is now enabled', k) - else: - self.logger.info('Host "%s" is now disabled', k) + self.logger.info('Host "%s" added', mem_host_name) if self.inventory_source.group: - self.inventory_source.group.hosts.add(host) - host.inventory_sources.add(self.inventory_source) - host.update_computed_fields(False, False) + self._batch_add_m2m(self.inventory_source.group.hosts, db_host) + self._batch_add_m2m(self.inventory_source.hosts, db_host) + #host.update_computed_fields(False, False) + if self.inventory_source.group: + self._batch_add_m2m(self.inventory_source.group.hosts, flush=True) + self._batch_add_m2m(self.inventory_source.hosts, flush=True) + + if settings.SQL_DEBUG: + self.logger.warning('host updates took %d queries for %d hosts', + len(connection.queries) - queries_before, + len(self.all_group.all_hosts)) + + def _create_update_group_children(self): + ''' + For each imported group, create all parent-child group relationships. + ''' + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + group_names = [k for k,v in self.all_group.all_groups.iteritems() if v.children] + group_group_count = 0 + for db_group in self.inventory.groups.filter(name__in=group_names): + mem_group = self.all_group.all_groups[db_group.name] + group_group_count += len(mem_group.children) + child_names = set([g.name for g in mem_group.children]) + db_children_qs = self.inventory.groups.filter(name__in=child_names) + for db_child in db_children_qs.filter(children__id=db_group.id): + self.logger.info('Group "%s" already child of group "%s"', db_child.name, db_group.name) + for db_child in db_children_qs.exclude(children__id=db_group.id): + self._batch_add_m2m(db_group.children, db_child) + self.logger.info('Group "%s" added as child of "%s"', db_child.name, db_group.name) + self._batch_add_m2m(db_group.children, flush=True) + if settings.SQL_DEBUG: + self.logger.warning('Group-group updates took %d queries for %d group-group relationships', + len(connection.queries) - queries_before, group_group_count) + + def _create_update_group_hosts(self): # For each host in a mem group, add it to the parent(s) to which it # belongs. - for k,v in self.all_group.all_groups.iteritems(): - if not v.hosts: - continue - db_group = self.inventory.groups.get(name=k) - for h in v.hosts: - if h.instance_id: - db_host = self.inventory.hosts.get(instance_id=h.instance_id) - else: - db_host = self.inventory.hosts.get(name=h.name) - if db_host not in db_group.hosts.all(): - db_group.hosts.add(db_host) - self.logger.info('Host "%s" added to group "%s"', h.name, k) - else: - self.logger.info('Host "%s" already in group "%s"', h.name, k) - - # for each group, draw in child group arrangements - for k,v in self.all_group.all_groups.iteritems(): - if not v.children: - continue - db_group = self.inventory.groups.get(name=k) - for g in v.children: - db_child = self.inventory.groups.get(name=g.name) - if db_child not in db_group.hosts.all(): - db_group.children.add(db_child) - self.logger.info('Group "%s" added as child of "%s"', g.name, k) - else: - self.logger.info('Group "%s" already child of group "%s"', g.name, k) + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + group_names = [k for k,v in self.all_group.all_groups.iteritems() if v.hosts] + group_host_count = 0 + for db_group in self.inventory.groups.filter(name__in=group_names): + mem_group = self.all_group.all_groups[db_group.name] + group_host_count += len(mem_group.hosts) + host_names = set([h.name for h in mem_group.hosts if not h.instance_id]) + host_instance_ids = set([h.instance_id for h in mem_group.hosts if h.instance_id]) + db_hosts_qs = self.inventory.hosts.filter(Q(name__in=host_names) | Q(instance_id__in=host_instance_ids)) + for db_host in db_hosts_qs.filter(groups__id=db_group.id): + self.logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name) + for db_host in db_hosts_qs.exclude(groups__id=db_group.id): + self._batch_add_m2m(db_group.hosts, db_host) + self.logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name) + self._batch_add_m2m(db_group.hosts, flush=True) + if settings.SQL_DEBUG: + self.logger.warning('Group-host updates took %d queries for %d group-host relationships', + len(connection.queries) - queries_before, group_host_count) + + def load_into_database(self): + ''' + Load inventory from in-memory groups to the database, overwriting or + merging as appropriate. + ''' + # FIXME: Attribute changes to superuser? + self._build_db_instance_id_map() + self._build_mem_instance_id_map() + if self.overwrite: + self._delete_hosts() + self._delete_groups() + self._delete_group_children_and_hosts() + self._update_inventory() + self._create_update_groups() + self._create_update_hosts() + self._create_update_group_children() + self._create_update_group_hosts() def check_license(self): reader = LicenseReader() @@ -914,6 +1071,9 @@ class Command(NoArgsCommand): status, tb, exc = 'error', '', None try: + if settings.SQL_DEBUG: + queries_before = len(connection.queries) + # Update inventory update for this command line invocation. with ignore_inventory_computed_fields(): if self.inventory_update: @@ -935,7 +1095,12 @@ class Command(NoArgsCommand): else: with disable_activity_stream(): self.load_into_database() + if settings.SQL_DEBUG: + queries_before2 = len(connection.queries) self.inventory.update_computed_fields() + if settings.SQL_DEBUG: + self.logger.warning('update computed fields took %d queries', + len(connection.queries) - queries_before2) self.check_license() if self.inventory_source.group: @@ -943,13 +1108,18 @@ class Command(NoArgsCommand): else: inv_name = '"%s" (id=%s)' % (self.inventory.name, self.inventory.id) - self.logger.info('Inventory import completed for %s in %0.1fs', - inv_name, time.time() - begin) + if settings.SQL_DEBUG: + self.logger.warning('Inventory import completed for %s in %0.1fs', + inv_name, time.time() - begin) + else: + self.logger.info('Inventory import completed for %s in %0.1fs', + inv_name, time.time() - begin) status = 'successful' - if settings.DEBUG: - sqltime = sum(float(x['time']) for x in connection.queries) - self.logger.info('Inventory import required %d queries ' - 'taking %0.3fs', len(connection.queries), + if settings.SQL_DEBUG: + queries_this_import = connection.queries[queries_before:] + sqltime = sum(float(x['time']) for x in queries_this_import) + self.logger.warning('Inventory import required %d queries ' + 'taking %0.3fs', len(queries_this_import), sqltime) except Exception, e: if isinstance(e, KeyboardInterrupt): diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index 90f67b2d8c..59e7e8ca18 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -21,7 +21,7 @@ import zmq # Django from django.conf import settings from django.db import models -from django.db.models import CASCADE, SET_NULL, PROTECT +from django.db.models import Q from django.utils.translation import ugettext_lazy as _ from django.core.exceptions import ValidationError, NON_FIELD_ERRORS from django.core.urlresolvers import reverse @@ -40,6 +40,7 @@ __all__ = ['Inventory', 'Host', 'Group', 'InventorySource', 'InventoryUpdate'] logger = logging.getLogger('awx.main.models.inventory') + class Inventory(CommonModel): ''' an inventory source contains lists and hosts. @@ -120,24 +121,174 @@ class Inventory(CommonModel): variables_dict = VarsDictProperty('variables') + def get_group_hosts_map(self, active=None): + ''' + Return dictionary mapping group_id to set of child host_id's. + ''' + # FIXME: Cache this mapping? + group_hosts_kw = dict(group__inventory_id=self.pk, host__inventory_id=self.pk) + if active is not None: + group_hosts_kw['group__active'] = active + group_hosts_kw['host__active'] = active + group_hosts_qs = Group.hosts.through.objects.filter(**group_hosts_kw) + group_hosts_qs = group_hosts_qs.values_list('group_id', 'host_id') + group_hosts_map = {} + for group_id, host_id in group_hosts_qs: + group_host_ids = group_hosts_map.setdefault(group_id, set()) + group_host_ids.add(host_id) + return group_hosts_map + + def get_group_parents_map(self, active=None): + ''' + Return dictionary mapping group_id to set of parent group_id's. + ''' + # FIXME: Cache this mapping? + group_parents_kw = dict(from_group__inventory_id=self.pk, to_group__inventory_id=self.pk) + if active is not None: + group_parents_kw['from_group__active'] = active + group_parents_kw['to_group__active'] = active + group_parents_qs = Group.parents.through.objects.filter(**group_parents_kw) + group_parents_qs = group_parents_qs.values_list('from_group_id', 'to_group_id') + group_parents_map = {} + for from_group_id, to_group_id in group_parents_qs: + group_parents = group_parents_map.setdefault(from_group_id, set()) + group_parents.add(to_group_id) + return group_parents_map + + def get_group_children_map(self, active=None): + ''' + Return dictionary mapping group_id to set of child group_id's. + ''' + # FIXME: Cache this mapping? + group_parents_kw = dict(from_group__inventory_id=self.pk, to_group__inventory_id=self.pk) + if active is not None: + group_parents_kw['from_group__active'] = active + group_parents_kw['to_group__active'] = active + group_parents_qs = Group.parents.through.objects.filter(**group_parents_kw) + group_parents_qs = group_parents_qs.values_list('from_group_id', 'to_group_id') + group_children_map = {} + for from_group_id, to_group_id in group_parents_qs: + group_children = group_children_map.setdefault(to_group_id, set()) + group_children.add(from_group_id) + return group_children_map + + def update_host_computed_fields(self): + ''' + Update computed fields for all active hosts in this inventory. + ''' + hosts_to_update = {} + hosts_qs = self.hosts.filter(active=True) + # Define queryset of all hosts with active failures. + hosts_with_active_failures = hosts_qs.filter(last_job_host_summary__isnull=False, last_job_host_summary__job__active=True, last_job_host_summary__failed=True).values_list('pk', flat=True) + # Find all hosts that need the has_active_failures flag set. + hosts_to_set = hosts_qs.filter(has_active_failures=False, pk__in=hosts_with_active_failures) + for host_pk in hosts_to_set.values_list('pk', flat=True): + host_updates = hosts_to_update.setdefault(host_pk, {}) + host_updates['has_active_failures'] = True + # Find all hosts that need the has_active_failures flag cleared. + hosts_to_clear = hosts_qs.filter(has_active_failures=True).exclude(pk__in=hosts_with_active_failures) + for host_pk in hosts_to_clear.values_list('pk', flat=True): + host_updates = hosts_to_update.setdefault(host_pk, {}) + host_updates['has_active_failures'] = False + # Define queryset of all hosts with cloud inventory sources. + hosts_with_cloud_inventory = hosts_qs.filter(inventory_sources__active=True, inventory_sources__source__in=CLOUD_INVENTORY_SOURCES).values_list('pk', flat=True) + # Find all hosts that need the has_inventory_sources flag set. + hosts_to_set = hosts_qs.filter(has_inventory_sources=False, pk__in=hosts_with_cloud_inventory) + for host_pk in hosts_to_set.values_list('pk', flat=True): + host_updates = hosts_to_update.setdefault(host_pk, {}) + host_updates['has_inventory_sources'] = True + # Find all hosts that need the has_inventory_sources flag cleared. + hosts_to_clear = hosts_qs.filter(has_inventory_sources=True).exclude(pk__in=hosts_with_cloud_inventory) + for host_pk in hosts_to_clear.values_list('pk', flat=True): + host_updates = hosts_to_updates.setdefault(host_pk, {}) + host_updates['has_inventory_sources'] = False + # Now apply updates to hosts where needed. + for host in hosts_qs.filter(pk__in=hosts_to_update.keys()): + host_updates = hosts_to_update[host.pk] + for field, value in host_updates.items(): + setattr(host, field, value) + host.save(update_fields=host_updates.keys()) + + def update_group_computed_fields(self): + ''' + Update computed fields for all active groups in this inventory. + ''' + group_children_map = self.get_group_children_map(active=True) + group_hosts_map = self.get_group_hosts_map(active=True) + active_host_pks = set(self.hosts.filter(active=True).values_list('pk', flat=True)) + failed_host_pks = set(self.hosts.filter(active=True, last_job_host_summary__job__active=True, last_job_host_summary__failed=True).values_list('pk', flat=True)) + active_group_pks = set(self.groups.filter(active=True).values_list('pk', flat=True)) + failed_group_pks = set() # Update below as we check each group. + groups_with_cloud_pks = set(self.groups.filter(active=True, inventory_sources__active=True, inventory_sources__source__in=CLOUD_INVENTORY_SOURCES).values_list('pk', flat=True)) + groups_to_update = {} + + # Build list of group pks to check, starting with the groups at the + # deepest level within the tree. + root_group_pks = set(self.root_groups.values_list('pk', flat=True)) + group_depths = {} # pk: max_depth + def update_group_depths(group_pk, current_depth=0): + max_depth = group_depths.get(group_pk, 0) + if current_depth > max_depth: + group_depths[group_pk] = current_depth + for child_pk in group_children_map.get(group_pk, set()): + update_group_depths(child_pk, current_depth + 1) + for group_pk in root_group_pks: + update_group_depths(group_pk) + group_pks_to_check = [x[1] for x in sorted([(v,k) for k,v in group_depths.items()], reverse=True)] + + for group_pk in group_pks_to_check: + # Get all children and host pks for this group. + parent_pks_to_check = set([group_pk]) + parent_pks_checked = set() + child_pks = set() + host_pks = set() + while parent_pks_to_check: + for parent_pk in list(parent_pks_to_check): + c_ids = group_children_map.get(parent_pk, set()) + child_pks.update(c_ids) + parent_pks_to_check.remove(parent_pk) + parent_pks_checked.add(parent_pk) + parent_pks_to_check.update(c_ids - parent_pks_checked) + h_ids = group_hosts_map.get(parent_pk, set()) + host_pks.update(h_ids) + # Define updates needed for this group. + group_updates = groups_to_update.setdefault(group_pk, {}) + group_updates.update({ + 'total_hosts': len(active_host_pks & host_pks), + 'has_active_failures': bool(failed_host_pks & host_pks), + 'hosts_with_active_failures': len(failed_host_pks & host_pks), + 'total_groups': len(child_pks), + 'groups_with_active_failures': len(failed_group_pks & child_pks), + 'has_inventory_sources': bool(group_pk in groups_with_cloud_pks), + }) + if group_updates['has_active_failures']: + failed_group_pks.add(group_pk) + + # Now apply updates to each group as needed. + for group in self.groups.filter(pk__in=groups_to_update.keys()): + group_updates = groups_to_update[group.pk] + for field, value in group_updates.items(): + if getattr(group, field) != value: + setattr(group, field, value) + else: + group_updates.pop(field) + if group_updates: + group.save(update_fields=group_updates.keys()) + def update_computed_fields(self, update_groups=True, update_hosts=True): ''' Update model fields that are computed from database relationships. ''' logger.debug("Going to update inventory computed fields") if update_hosts: - for host in self.hosts.filter(active=True): - host.update_computed_fields(update_inventory=False, - update_groups=False) + self.update_host_computed_fields() if update_groups: - for group in self.groups.filter(active=True): - group.update_computed_fields() + self.update_group_computed_fields() active_hosts = self.hosts.filter(active=True) failed_hosts = active_hosts.filter(has_active_failures=True) active_groups = self.groups.filter(active=True) failed_groups = active_groups.filter(has_active_failures=True) active_inventory_sources = self.inventory_sources.filter(active=True, source__in=CLOUD_INVENTORY_SOURCES) - #failed_inventory_sources = active_inventory_sources.filter(last_update_failed=True) failed_inventory_sources = active_inventory_sources.filter(last_job_failed=True) computed_fields = { 'has_active_failures': bool(failed_hosts.count()), @@ -232,14 +383,15 @@ class Host(CommonModelNameNotUnique): def get_absolute_url(self): return reverse('api:host_detail', args=(self.pk,)) - def mark_inactive(self, save=True): + def mark_inactive(self, save=True, from_inventory_import=False): ''' When marking hosts inactive, remove all associations to related inventory sources. ''' super(Host, self).mark_inactive(save=save) - self.inventory_sources.clear() - self.clear_cached_values() + if not from_inventory_import: + self.inventory_sources.clear() + self.clear_cached_values() def update_computed_fields(self, update_inventory=True, update_groups=True): ''' @@ -280,10 +432,19 @@ class Host(CommonModelNameNotUnique): Return all groups of which this host is a member, avoiding infinite recursion in the case of cyclical group relations. ''' - qs = self.groups.distinct() - for group in self.groups.all(): - qs = qs | group.all_parents - return qs + group_parents_map = self.inventory.get_group_parents_map() + group_pks = set(self.groups.values_list('pk', flat=True)) + child_pks_to_check = set() + child_pks_to_check.update(group_pks) + child_pks_checked = set() + while child_pks_to_check: + for child_pk in list(child_pks_to_check): + p_ids = group_parents_map.get(child_pk, set()) + group_pks.update(p_ids) + child_pks_to_check.remove(child_pk) + child_pks_checked.add(child_pk) + child_pks_to_check.update(p_ids - child_pks_checked) + return Group.objects.filter(pk__in=group_pks).distinct() def update_cached_values(self): cacheable_data = {"%s_all_groups" % self.id: [{'id': g.id, 'name': g.name} for g in self.all_groups.all()], @@ -422,7 +583,7 @@ class Group(CommonModelNameNotUnique): mark_actual() update_inventory_computed_fields.delay(self.id, True) - def mark_inactive(self, save=True, recompute=True): + def mark_inactive(self, save=True, recompute=True, from_inventory_import=False): ''' When marking groups inactive, remove all associations to related groups/hosts/inventory_sources. @@ -436,7 +597,9 @@ class Group(CommonModelNameNotUnique): self.hosts.clear() i = self.inventory - if recompute: + if from_inventory_import: + super(Group, self).mark_inactive(save=save) + elif recompute: with ignore_inventory_computed_fields(): mark_actual() i.update_computed_fields() @@ -475,16 +638,21 @@ class Group(CommonModelNameNotUnique): def get_all_parents(self, except_pks=None): ''' - Return all parents of this group recursively, avoiding infinite - recursion in the case of cyclical relations. The group itself will be - excluded unless there is a cycle leading back to it. + Return all parents of this group recursively. The group itself will + be excluded unless there is a cycle leading back to it. ''' - except_pks = except_pks or set() - except_pks.add(self.pk) - qs = self.parents.distinct() - for group in self.parents.exclude(pk__in=except_pks): - qs = qs | group.get_all_parents(except_pks) - return qs + group_parents_map = self.inventory.get_group_parents_map() + child_pks_to_check = set([self.pk]) + child_pks_checked = set() + parent_pks = set() + while child_pks_to_check: + for child_pk in list(child_pks_to_check): + p_ids = group_parents_map.get(child_pk, set()) + parent_pks.update(p_ids) + child_pks_to_check.remove(child_pk) + child_pks_checked.add(child_pk) + child_pks_to_check.update(p_ids - child_pks_checked) + return Group.objects.filter(pk__in=parent_pks).distinct() @property def all_parents(self): @@ -492,16 +660,21 @@ class Group(CommonModelNameNotUnique): def get_all_children(self, except_pks=None): ''' - Return all children of this group recursively, avoiding infinite - recursion in the case of cyclical relations. The group itself will be - excluded unless there is a cycle leading back to it. + Return all children of this group recursively. The group itself will + be excluded unless there is a cycle leading back to it. ''' - except_pks = except_pks or set() - except_pks.add(self.pk) - qs = self.children.distinct() - for group in self.children.exclude(pk__in=except_pks): - qs = qs | group.get_all_children(except_pks) - return qs + group_children_map = self.inventory.get_group_children_map() + parent_pks_to_check = set([self.pk]) + parent_pks_checked = set() + child_pks = set() + while parent_pks_to_check: + for parent_pk in list(parent_pks_to_check): + c_ids = group_children_map.get(parent_pk, set()) + child_pks.update(c_ids) + parent_pks_to_check.remove(parent_pk) + parent_pks_checked.add(parent_pk) + parent_pks_to_check.update(c_ids - parent_pks_checked) + return Group.objects.filter(pk__in=child_pks).distinct() @property def all_children(self): @@ -509,15 +682,22 @@ class Group(CommonModelNameNotUnique): def get_all_hosts(self, except_group_pks=None): ''' - Return all hosts associated with this group or any of its children, - avoiding infinite recursion in the case of cyclical group relations. + Return all hosts associated with this group or any of its children. ''' - except_group_pks = except_group_pks or set() - except_group_pks.add(self.pk) - qs = self.hosts.distinct() - for group in self.children.exclude(pk__in=except_group_pks): - qs = qs | group.get_all_hosts(except_group_pks) - return qs + group_children_map = self.inventory.get_group_children_map() + group_hosts_map = self.inventory.get_group_hosts_map() + parent_pks_to_check = set([self.pk]) + parent_pks_checked = set() + host_pks = set() + while parent_pks_to_check: + for parent_pk in list(parent_pks_to_check): + c_ids = group_children_map.get(parent_pk, set()) + parent_pks_to_check.remove(parent_pk) + parent_pks_checked.add(parent_pk) + parent_pks_to_check.update(c_ids - parent_pks_checked) + h_ids = group_hosts_map.get(parent_pk, set()) + host_pks.update(h_ids) + return Host.objects.filter(pk__in=host_pks).distinct() @property def all_hosts(self): diff --git a/awx/main/tests/commands.py b/awx/main/tests/commands.py index e05ec81388..99c6339b32 100644 --- a/awx/main/tests/commands.py +++ b/awx/main/tests/commands.py @@ -445,7 +445,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): contact_name='AWX Admin', contact_email='awx@example.com', license_date=int(time.time() + 3600), - instance_count=500, + instance_count=10000, ) handle, license_path = tempfile.mkstemp(suffix='.json') os.close(handle) @@ -565,7 +565,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): result, stdout, stderr = self.run_command('inventory_import', inventory_id=new_inv.pk, source=inv_src) - self.assertEqual(result, None) + self.assertEqual(result, None, stdout + stderr) # Check that inventory is populated as expected. new_inv = Inventory.objects.get(pk=new_inv.pk) expected_group_names = set(['servers', 'dbservers', 'webservers']) @@ -637,7 +637,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): source=self.ini_path, overwrite=overwrite, overwrite_vars=overwrite_vars) - self.assertEqual(result, None) + self.assertEqual(result, None, stdout + stderr) # Check that inventory is populated as expected. new_inv = Inventory.objects.get(pk=new_inv.pk) expected_group_names = set(['servers', 'dbservers', 'webservers', @@ -828,7 +828,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): result, stdout, stderr = self.run_command('inventory_import', inventory_id=new_inv.pk, source=source) - self.assertEqual(result, None) + self.assertEqual(result, None, stdout + stderr) # Check that inventory is populated as expected. new_inv = Inventory.objects.get(pk=new_inv.pk) self.assertEqual(old_inv.variables_dict, new_inv.variables_dict) @@ -860,14 +860,13 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): new_inv = self.organizations[0].inventories.create(name='newec2') self.assertEqual(new_inv.hosts.count(), 0) self.assertEqual(new_inv.groups.count(), 0) - #inv_file = os.path.join(os.path.dirname(__file__), 'data', - # 'large_ec2_inventory.py') os.chdir(os.path.join(os.path.dirname(__file__), 'data')) inv_file = 'large_ec2_inventory.py' + settings.DEBUG = True result, stdout, stderr = self.run_command('inventory_import', inventory_id=new_inv.pk, source=inv_file) - self.assertEqual(result, None, stdout+stderr) + self.assertEqual(result, None, stdout + stderr) # Check that inventory is populated as expected within a reasonable # amount of time. Computed fields should also be updated. new_inv = Inventory.objects.get(pk=new_inv.pk) @@ -875,5 +874,45 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): self.assertNotEqual(new_inv.groups.count(), 0) self.assertNotEqual(new_inv.total_hosts, 0) self.assertNotEqual(new_inv.total_groups, 0) - self.assertElapsedLessThan(60) + self.assertElapsedLessThan(30) + def _get_ngroups_for_nhosts(self, n): + if n > 0: + return min(n, 10) + ((n - 1) / 10 + 1) + ((n - 1) / 100 + 1) + ((n - 1) / 1000 + 1) + else: + return 0 + + def _check_largeinv_import(self, new_inv, nhosts, nhosts_inactive=0): + self._start_time = time.time() + inv_file = os.path.join(os.path.dirname(__file__), 'data', 'largeinv.py') + ngroups = self._get_ngroups_for_nhosts(nhosts) + os.environ['NHOSTS'] = str(nhosts) + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=new_inv.pk, + source=inv_file, + overwrite=True, verbosity=0) + self.assertEqual(result, None, stdout + stderr) + # Check that inventory is populated as expected within a reasonable + # amount of time. Computed fields should also be updated. + new_inv = Inventory.objects.get(pk=new_inv.pk) + self.assertEqual(new_inv.hosts.filter(active=True).count(), nhosts) + self.assertEqual(new_inv.groups.filter(active=True).count(), ngroups) + self.assertEqual(new_inv.hosts.filter(active=False).count(), nhosts_inactive) + self.assertEqual(new_inv.total_hosts, nhosts) + self.assertEqual(new_inv.total_groups, ngroups) + self.assertElapsedLessThan(30) + + def test_large_inventory_file(self): + new_inv = self.organizations[0].inventories.create(name='largeinv') + self.assertEqual(new_inv.hosts.count(), 0) + self.assertEqual(new_inv.groups.count(), 0) + settings.DEBUG = True + nhosts = 2000 + # Test initial import into empty inventory. + self._check_largeinv_import(new_inv, nhosts, 0) + # Test re-importing and overwriting. + self._check_largeinv_import(new_inv, nhosts, 0) + # Test re-importing with only half as many hosts. + self._check_largeinv_import(new_inv, nhosts / 2, nhosts / 2) + # Test re-importing that clears all hosts. + self._check_largeinv_import(new_inv, 0, nhosts) diff --git a/awx/main/tests/data/largeinv.py b/awx/main/tests/data/largeinv.py new file mode 100755 index 0000000000..178dca6881 --- /dev/null +++ b/awx/main/tests/data/largeinv.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Python +import json +import optparse +import os + +nhosts = int(os.environ.get('NHOSTS', 100)) + +inv_list = { + '_meta': { + 'hostvars': {}, + }, +} + +for n in xrange(nhosts): + hostname = 'host-%08d.example.com' % n + group_evens_odds = 'evens.example.com' if n % 2 == 0 else 'odds.example.com' + group_threes = 'threes.example.com' if n % 3 == 0 else '' + group_fours = 'fours.example.com' if n % 4 == 0 else '' + group_fives = 'fives.example.com' if n % 5 == 0 else '' + group_sixes = 'sixes.example.com' if n % 6 == 0 else '' + group_sevens = 'sevens.example.com' if n % 7 == 0 else '' + group_eights = 'eights.example.com' if n % 8 == 0 else '' + group_nines = 'nines.example.com' if n % 9 == 0 else '' + group_tens = 'tens.example.com' if n % 10 == 0 else '' + group_by_10s = 'group-%07dX.example.com' % (n / 10) + group_by_100s = 'group-%06dXX.example.com' % (n / 100) + group_by_1000s = 'group-%05dXXX.example.com' % (n / 1000) + for group in [group_evens_odds, group_threes, group_fours, group_fives, group_sixes, group_sevens, group_eights, group_nines, group_tens, group_by_10s]: + if not group: + continue + if group in inv_list: + inv_list[group]['hosts'].append(hostname) + else: + inv_list[group] = {'hosts': [hostname], 'children': [], 'vars': {'group_prefix': group.split('.')[0]}} + if group_by_1000s not in inv_list: + inv_list[group_by_1000s] = {'hosts': [], 'children': [], 'vars': {'group_prefix': group_by_1000s.split('.')[0]}} + if group_by_100s not in inv_list: + inv_list[group_by_100s] = {'hosts': [], 'children': [], 'vars': {'group_prefix': group_by_100s.split('.')[0]}} + if group_by_100s not in inv_list[group_by_1000s]['children']: + inv_list[group_by_1000s]['children'].append(group_by_100s) + if group_by_10s not in inv_list[group_by_100s]['children']: + inv_list[group_by_100s]['children'].append(group_by_10s) + inv_list['_meta']['hostvars'][hostname] = { + 'ansible_ssh_user': 'example', + 'ansible_connection': 'local', + 'host_prefix': hostname.split('.')[0], + 'host_id': n, + } + +if __name__ == '__main__': + parser = optparse.OptionParser() + parser.add_option('--list', action='store_true', dest='list') + parser.add_option('--host', dest='hostname', default='') + options, args = parser.parse_args() + if options.list: + print json.dumps(inv_list, indent=4) + elif options.hostname: + print json.dumps(inv_list['_meta']['hostvars'][options.hostname], indent=4) + else: + print json.dumps({}, indent=4) + diff --git a/awx/main/tests/inventory.py b/awx/main/tests/inventory.py index 79f9a2c5c0..f5da1663be 100644 --- a/awx/main/tests/inventory.py +++ b/awx/main/tests/inventory.py @@ -614,7 +614,7 @@ class InventoryTest(BaseTest): # data used for testing listing all hosts that are transitive members of a group g2 = Group.objects.get(name='web4') - nh = Host.objects.create(name='newhost.example.com', inventory=inva, + nh = Host.objects.create(name='newhost.example.com', inventory=g2.inventory, created_by=self.super_django_user) g2.hosts.add(nh) g2.save()