diff --git a/hub/management/commands/import_mps.py b/hub/management/commands/import_mps.py index 148282e1a..9cc1e736d 100644 --- a/hub/management/commands/import_mps.py +++ b/hub/management/commands/import_mps.py @@ -191,7 +191,7 @@ def import_mps(self): ).values_list("data") ) - for (party, count) in Counter(all_parties).most_common(): + for party, count in Counter(all_parties).most_common(): shade = party_shades.get(party[0], "#DCDCDC") parties.append(dict(title=party[0], shader=shade)) @@ -244,7 +244,3 @@ def check_for_duplicate_mps(self): mps_to_delete = list(duplicate_mps) mps_to_delete.remove(least_recent_mp) Person.objects.filter(external_id__in=mps_to_delete).delete() - - print("Rerunning MP import scripts") - for command_name in MP_IMPORT_COMMANDS: - call_command(command_name) diff --git a/hub/management/commands/run_all_import_scripts.py b/hub/management/commands/run_all_import_scripts.py new file mode 100644 index 000000000..f63a40a81 --- /dev/null +++ b/hub/management/commands/run_all_import_scripts.py @@ -0,0 +1,93 @@ +from os import listdir +from os.path import isfile, join + +from django.core.management import call_command +from django.core.management.base import BaseCommand + + +class Command(BaseCommand): + help = "Run all of the import scripts" + + def get_scripts(self, *args, **options): + path = "hub/management/commands" + files = [f for f in listdir(path) if isfile(join(path, f))] + # Remove .py suffix, and search for generate_ and import_ scripts + scripts = [ + f[:-3] + for f in files + if ("generate_" in f or f[:7] == "import_") and f != "generate_csv" + ] + generators = [] + importers = [] + for script in scripts: + if "generate_" in script: + generators.append(script) + else: + importers.append(script) + + return {"generators": generators, "importers": importers} + + def add_arguments(self, parser): + parser.add_argument( + "-g", + "--generate", + action="store_true", + help="Run 'generate_' scripts as well as 'import_' scripts", + ) + + def run_generator_scripts(self, generators, *args, **options): + total = str(len(generators)) + failed_generators = {} + for i, generator in enumerate(generators): + print(f"Running command: {generator} ({str(i+1)}/{total})") + try: + call_command(generator) + except Exception as e: + print(f"Error raised: {e}") + print("Moving to next generator...") + failed_generators[generator] = e + print("\n") + + print("Failed generators:") + for generator, e in failed_generators.items(): + print(f" {generator}: {e}") + + def run_importer_scripts(self, imports, *args, **options): + total = str(len(imports)) + i = 1 + failed_imports = {} + priority_imports = ["import_areas", "import_mps", "import_mps_election_results"] + for importer in priority_imports: + imports.remove(importer) + print(f"Running command: {importer} ({str(i)}/{total})") + try: + call_command(importer) + except Exception as e: + print(f"Error raised: {e}") + print("Moving to next importer...") + failed_imports[importer] = e + print("\n") + i += 1 + if failed_imports != {}: + print("One of the priority importers failed. Please fix, and try again") + exit() + for importer in imports: + print(f"Running command: {importer} ({str(i)}/{total})") + try: + call_command(importer) + except Exception as e: + print(f"Error raised: {e}") + print("Moving to next importer...") + failed_imports[importer] = e + print("\n") + i += 1 + + print("Failed importers:") + for importer, e in failed_imports.items(): + print(f" {importer}: {e}") + + def handle(self, generate=False, *args, **options): + scripts = self.get_scripts() + if generate: + self.run_generator_scripts(generators=scripts["generators"]) + self.run_importer_scripts(imports=scripts["importers"]) diff --git a/hub/tests/test_import_mps.py b/hub/tests/test_import_mps.py index e28730348..be855b196 100644 --- a/hub/tests/test_import_mps.py +++ b/hub/tests/test_import_mps.py @@ -157,18 +157,6 @@ def test_correct_mp_left(self, mock_call_command): Person.objects.get(name="Juliet Replacement"), ) - @patch("hub.management.commands.import_mps.call_command") - @patch("hub.management.commands.import_mps.MP_IMPORT_COMMANDS", ["A", "B", "C"]) - def test_data_reimported(self, mock_call_command): - # Run the duplicate MPs method - import_mps = ImportMpsCommand() - import_mps.check_for_duplicate_mps() - - # Check that the call_command was executed once - # for each import command, and then again for - # the initial data to determine the correct duplicate - self.assertEqual(mock_call_command.call_count, 4) - @patch("hub.management.commands.import_mps.call_command") def test_data_not_reimported_if_no_duplicates(self, mock_call_command): # Remove the duplicated MP