diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index fe77558621..28b3c24011 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -2,21 +2,20 @@ # All Rights Reserved. # Python -import datetime -import logging -import sys -from optparse import make_option -import subprocess -import traceback import glob -import exceptions +import json +import logging +from optparse import make_option +import os +import shlex +import subprocess +import sys +import traceback # Django -from django.core.management.base import BaseCommand, CommandError +from django.core.management.base import NoArgsCommand, CommandError from django.db import transaction from django.contrib.auth.models import User -from django.utils.dateparse import parse_datetime -from django.utils.timezone import now # AWX from awx.main.models import * @@ -340,7 +339,7 @@ class GenericLoader(object): LOGGER.debug("analyzing type of source") if not os.path.exists(src): LOGGER.debug("source missing") - raise ImportException("source does not exist") + raise CommandError("source does not exist") if os.path.isdir(src): self.memGroup = memGroup = MemGroup('all', src) for f in glob.glob("%s/*" % src): @@ -365,15 +364,14 @@ class GenericLoader(object): def result(self): return self.memGroup -class Command(BaseCommand): +class Command(NoArgsCommand): ''' Management command to import directory, INI, or dynamic inventory ''' help = 'Import or sync external inventory sources' - args = '[, , ...]' - option_list = BaseCommand.option_list + ( + option_list = NoArgsCommand.option_list + ( make_option('--inventory-name', dest='inventory_name', type='str', default=None, metavar='n', help='name of inventory source to sync'), make_option('--inventory-id', dest='inventory_id', type='int', default=None, metavar='i', @@ -400,14 +398,7 @@ class Command(BaseCommand): self.logger.propagate = False @transaction.commit_on_success - def handle(self, *args, **options): - try: - self.main(args, options) - except ImportException, ie: - print ie.msg - - def main(self, args, options): - + def handle_noargs(self, **options): self.verbosity = int(options.get('verbosity', 1)) self.init_logging() @@ -422,17 +413,13 @@ class Command(BaseCommand): LOGGER.debug("id=%s" % id) if name is not None and id is not None: - self.logger.error("--inventory-name and --inventory-id are mutually exclusive") - sys.exit(1) + raise CommandError("--inventory-name and --inventory-id are mutually exclusive") if name is None and id is None: - self.logger.error("--inventory-name or --inventory-id are required") - sys.exit(1) + raise CommandError("--inventory-name or --inventory-id is required") if (overwrite or overwrite_vars) and keep_vars: - self.logger.error("--overwrite/--overwrite-vars and --keep-vars are mutually exclusive") - sys.exit(1) + raise CommandError("--overwrite/--overwrite-vars and --keep-vars are mutually exclusive") if not source: - self.logger.error("--source is required") - sys.exit(1) + raise CommandError("--source is required") LOGGER.debug("preparing loader") @@ -451,7 +438,7 @@ class Command(BaseCommand): inventory = Inventory.objects.filter(name=name) count = inventory.count() if count != 1: - raise ImportException("%d inventory objects matched, expected 1" % count) + raise CommandError("%d inventory objects matched, expected 1" % count) inventory = inventory.all()[0] print "MODIFYING INVENTORY: %s" % inventory.name diff --git a/awx/main/tests/commands.py b/awx/main/tests/commands.py index a07b66a107..78cf7bc915 100644 --- a/awx/main/tests/commands.py +++ b/awx/main/tests/commands.py @@ -7,6 +7,7 @@ import os import StringIO import sys import tempfile +import time # Django from django.conf import settings @@ -16,10 +17,38 @@ from django.core.management.base import CommandError from django.utils.timezone import now # AWX +from awx.main.licenses import LicenseWriter from awx.main.models import * from awx.main.tests.base import BaseTest -__all__ = ['CleanupDeletedTest'] +__all__ = ['CleanupDeletedTest', 'InventoryImportTest'] + +TEST_INVENTORY_INI = '''\ +[webservers] +web1.example.com +web2.example.com +web3.example.com + +[webservers:vars] +webvar=blah + +[dbservers] +db1.example.com +db2.example.com + +[dbservers:vars] +dbvar=ugh + +[servers:children] +webservers +dbservers + +[servers:vars] +varb=B + +[all:vars] +vara=A +''' class BaseCommandTest(BaseTest): ''' @@ -45,44 +74,7 @@ class BaseCommandTest(BaseTest): if os.path.exists(tf): os.remove(tf) - def run_command(self, name, *args, **options): - ''' - Run a management command and capture its stdout/stderr along with any - exceptions. - ''' - command_runner = options.pop('command_runner', call_command) - stdin_fileobj = options.pop('stdin_fileobj', None) - options.setdefault('verbosity', 1) - options.setdefault('interactive', False) - original_stdin = sys.stdin - original_stdout = sys.stdout - original_stderr = sys.stderr - if stdin_fileobj: - sys.stdin = stdin_fileobj - sys.stdout = StringIO.StringIO() - sys.stderr = StringIO.StringIO() - result = None - try: - result = command_runner(name, *args, **options) - except Exception, e: - result = e - except SystemExit, e: - result = e - finally: - captured_stdout = sys.stdout.getvalue() - captured_stderr = sys.stderr.getvalue() - sys.stdin = original_stdin - sys.stdout = original_stdout - sys.stderr = original_stderr - return result, captured_stdout, captured_stderr - -class CleanupDeletedTest(BaseCommandTest): - ''' - Test cases for cleanup_deleted management command. - ''' - - def setUp(self): - super(CleanupDeletedTest, self).setUp() + def create_test_inventories(self): self.setup_users() self.organizations = self.make_organizations(self.super_django_user, 2) self.projects = self.make_projects(self.normal_django_user, 2) @@ -124,6 +116,47 @@ class CleanupDeletedTest(BaseCommandTest): group.parents.add(groups[3]) self.groups.extend(groups) + + def run_command(self, name, *args, **options): + ''' + Run a management command and capture its stdout/stderr along with any + exceptions. + ''' + command_runner = options.pop('command_runner', call_command) + stdin_fileobj = options.pop('stdin_fileobj', None) + options.setdefault('verbosity', 1) + options.setdefault('interactive', False) + original_stdin = sys.stdin + original_stdout = sys.stdout + original_stderr = sys.stderr + if stdin_fileobj: + sys.stdin = stdin_fileobj + sys.stdout = StringIO.StringIO() + sys.stderr = StringIO.StringIO() + result = None + try: + result = command_runner(name, *args, **options) + except Exception, e: + result = e + except SystemExit, e: + result = e + finally: + captured_stdout = sys.stdout.getvalue() + captured_stderr = sys.stderr.getvalue() + sys.stdin = original_stdin + sys.stdout = original_stdout + sys.stderr = original_stderr + return result, captured_stdout, captured_stderr + +class CleanupDeletedTest(BaseCommandTest): + ''' + Test cases for cleanup_deleted management command. + ''' + + def setUp(self): + super(CleanupDeletedTest, self).setUp() + self.create_test_inventories() + def get_model_counts(self): def get_models(m): if not m._meta.abstract: @@ -204,3 +237,108 @@ class CleanupDeletedTest(BaseCommandTest): counts_after = self.get_user_counts() self.assertNotEqual(counts_before, counts_after) self.assertFalse(counts_after[1]) + +class InventoryImportTest(BaseCommandTest): + ''' + Test cases for inventory_import management command. + ''' + + def setUp(self): + super(InventoryImportTest, self).setUp() + self.create_test_inventories() + self.create_test_ini() + self.create_test_license_file() + + def create_test_license_file(self): + writer = LicenseWriter( + company_name='AWX', + contact_name='AWX Admin', + contact_email='awx@example.com', + license_date=int(time.time() + 3600), + instance_count=500, + ) + handle, license_path = tempfile.mkstemp(suffix='.json') + os.close(handle) + writer.write_file(license_path) + self._temp_files.append(license_path) + os.environ['AWX_LICENSE_FILE'] = license_path + + def create_test_ini(self): + handle, self.ini_path = tempfile.mkstemp(suffix='.txt') + ini_file = os.fdopen(handle, 'w') + ini_file.write(TEST_INVENTORY_INI) + ini_file.close() + self._temp_files.append(self.ini_path) + + def test_invalid_options(self): + inventory_id = self.inventories[0].pk + inventory_name = self.inventories[0].name + # No options specified. + result, stdout, stderr = self.run_command('inventory_import') + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('inventory-id' in str(result)) + self.assertTrue('required' in str(result)) + # Both inventory ID and name. + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=inventory_id, + inventory_name=inventory_name) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('inventory-id' in str(result)) + self.assertTrue('exclusive' in str(result)) + # Inventory ID with overwrite and keep_vars. + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=inventory_id, + overwrite=True, keep_vars=True) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('overwrite-vars' in str(result)) + self.assertTrue('exclusive' in str(result)) + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=inventory_id, + overwrite_vars=True, + keep_vars=True) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('overwrite-vars' in str(result)) + self.assertTrue('exclusive' in str(result)) + # Inventory ID, but no source. + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=inventory_id) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('--source' in str(result)) + self.assertTrue('required' in str(result)) + # Inventory ID, with invalid source. + invalid_source = ''.join([os.path.splitext(self.ini_path)[0] + '-invalid', + os.path.splitext(self.ini_path)[1]]) + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=inventory_id, + source=invalid_source) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('not exist' in str(result)) + # Invalid inventory ID. + invalid_id = Inventory.objects.order_by('-pk')[0].pk + 1 + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=invalid_id, + source=self.ini_path) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('matched' in str(result)) + # Invalid inventory name. + invalid_name = 'invalid inventory name' + result, stdout, stderr = self.run_command('inventory_import', + inventory_name=invalid_name, + source=self.ini_path) + self.assertTrue(isinstance(result, CommandError), result) + self.assertTrue('matched' in str(result)) + + def test_ini_file(self): + # New empty inventory. + new_inv = self.organizations[0].inventories.create(name='newb') + self.assertEqual(new_inv.hosts.count(), 0) + self.assertEqual(new_inv.groups.count(), 0) + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=new_inv.pk, + source=self.ini_path) + self.assertEqual(result, None) + # FIXME + + def test_executable_file(self): + pass + # FIXME