Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1025 - Add test_dispatch_to_batch #1042

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions api/scpca_portal/management/commands/dispatch_to_batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from collections import Counter

from django.conf import settings
from django.core.management.base import BaseCommand
from django.template.defaultfilters import pluralize

import boto3

Expand All @@ -20,11 +22,13 @@ class Command(BaseCommand):
help = """
Submits all computed file combinations to the specified AWS Batch job queue
for projects for which computed files have yet to be generated for them.
If a project-id is passed, then computed files are only submitted for that specific project.
If regenerate-all is passed, then presence of existing computed files are ignored.
If a project-id is passed, then all other projects will be ignored.
"""

def add_arguments(self, parser):
parser.add_argument("--project-id", type=str)
parser.add_argument("--regenerate-all", type=bool, default=False)
parser.add_argument("--project-id", type=str, default="")

def handle(self, *args, **kwargs):
self.dispatch_to_batch(**kwargs)
Expand Down Expand Up @@ -62,29 +66,40 @@ def submit_job(

logger.info(f'{job_name} submitted to Batch with jobId {response["jobId"]}')

def dispatch_to_batch(self, project_id: str = "", **kwargs):
def dispatch_to_batch(self, project_id: str, regenerate_all: bool, **kwargs):
"""
Iterate over all projects that don't have computed files and submit each
resource_id and download_config combination to the Batch queue.
If a project id is passed, then computed files are created for all combinations
within that project.
Iterate over all projects that fit the criteria of the passed flags
and submit jobs to Batch accordingly.
"""
projects = (
Project.objects.filter(project_computed_files__isnull=True)
if not project_id
else Project.objects.filter(scpca_id=project_id)
)
projects = Project.objects.all()

if not regenerate_all:
projects = projects.filter(project_computed_files__isnull=True)

if project_id:
projects = projects.filter(scpca_id=project_id)

job_counts = Counter()
for project in projects:
for download_config_name in project.valid_download_config_names:
self.submit_job(
project_id=project.scpca_id,
download_config_name=download_config_name,
)
job_counts["project"] += 1

for sample in project.samples_to_generate:
for download_config_name in sample.valid_download_config_names:
self.submit_job(
sample_id=sample.scpca_id,
download_config_name=download_config_name,
)
job_counts["sample"] += 1

total_job_count = sum(job_counts.values())
logger.info(
"Job submission complete. "
f"{total_job_count} job{pluralize(total_job_count)} were submitted "
f"({job_counts['project']} project job{pluralize(job_counts['project'])}, "
f"{job_counts['sample']} sample job{pluralize(job_counts['sample'])})."
)
22 changes: 19 additions & 3 deletions api/scpca_portal/test/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ class Meta:


class ProjectFactory(LeafProjectFactory):
computed_file1 = factory.RelatedFactory(ProjectComputedFileFactory, "project")
sample1 = factory.RelatedFactory(SampleFactory, "project")
summary1 = factory.RelatedFactory(ProjectSummaryFactory, factory_related_name="project")
computed_file = factory.RelatedFactory(ProjectComputedFileFactory, "project")
sample = factory.RelatedFactory(SampleFactory, "project")
library = factory.RelatedFactory(LibraryFactory, "project")
summary = factory.RelatedFactory(ProjectSummaryFactory, factory_related_name="project")

@factory.post_generation
def add_sample_library_relation(self, create, extracted, **kwargs):
"""
In order for objects to be associated with eachother via a ManyToMany relationship,
both objects must first be created.
This method makes the sample and library association after their creation above.
"""
if not create:
return

sample = self.samples.first()
library = self.libraries.first()

sample.libraries.add(library)
149 changes: 149 additions & 0 deletions api/scpca_portal/test/management/commands/test_dispatch_to_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from functools import partial
from unittest.mock import patch

from django.core.management import call_command
from django.test import TestCase

from scpca_portal.test.factories import ProjectFactory


class TestDispatchToBatch(TestCase):
def setUp(self):
self.dispatch_to_batch = partial(call_command, "dispatch_to_batch")

@patch("scpca_portal.management.commands.dispatch_to_batch.Command.submit_job")
def test_generate_all_missing_files(self, mock_submit_job):
projects_with_files = [ProjectFactory() for _ in range(3)]
for project_with_files in projects_with_files:
self.assertTrue(project_with_files.computed_files.exists())

projects_no_files = [ProjectFactory(computed_file=None) for _ in range(2)]
for project_no_files in projects_no_files:
self.assertFalse(project_no_files.computed_files.exists())

self.dispatch_to_batch()
mock_submit_job.assert_called()

submitted_project_ids = set(
call.kwargs.get("project_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("project_id") is not None
)
self.assertEqual(len(projects_no_files), len(submitted_project_ids))
for project_no_files in projects_no_files:
self.assertIn(project_no_files.scpca_id, submitted_project_ids)
for project_with_files in projects_with_files:
self.assertNotIn(project_with_files.scpca_id, submitted_project_ids)

sample_no_files = projects_no_files[0].samples.first()
submitted_sample_ids = set(
call.kwargs.get("sample_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("sample_id") is not None
)
self.assertIn(sample_no_files.scpca_id, submitted_sample_ids)

@patch("scpca_portal.management.commands.dispatch_to_batch.Command.submit_job")
def test_generate_missing_files_for_passed_project(self, mock_submit_job):
project = ProjectFactory(computed_file=None)
self.assertFalse(project.computed_files.exists())
adtl_projects = [ProjectFactory(computed_file=None) for _ in range(3)]

self.dispatch_to_batch(project_id=project.scpca_id)
mock_submit_job.assert_called()

submitted_project_ids = set(
call.kwargs.get("project_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("project_id") is not None
)
self.assertEqual(len(submitted_project_ids), 1)
self.assertIn(project.scpca_id, submitted_project_ids)
for adtl_project in adtl_projects:
self.assertNotIn(adtl_project.scpca_id, submitted_project_ids)

sample = project.samples.first()
submitted_sample_ids = set(
call.kwargs.get("sample_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("sample_id") is not None
)
self.assertEqual(len(submitted_sample_ids), 1)
self.assertIn(sample.scpca_id, submitted_sample_ids)

@patch("scpca_portal.management.commands.dispatch_to_batch.Command.submit_job")
def test_regenerate_all_files(self, mock_submit_job):
projects_with_files = [ProjectFactory() for _ in range(3)]
for project_with_files in projects_with_files:
self.assertTrue(project_with_files.computed_files.exists())

projects_no_files = [ProjectFactory(computed_file=None) for _ in range(2)]
for project_no_files in projects_no_files:
self.assertFalse(project_no_files.computed_files.exists())

projects = projects_with_files + projects_no_files

self.dispatch_to_batch(regenerate_all=True)
mock_submit_job.assert_called()

submitted_project_ids = set(
call.kwargs.get("project_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("project_id") is not None
)
self.assertEqual(len(projects), len(submitted_project_ids))
for project in projects:
self.assertIn(project.scpca_id, submitted_project_ids)

sample_no_files = projects[0].samples.first()
submitted_sample_ids = set(
call.kwargs.get("sample_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("sample_id") is not None
)
self.assertIn(sample_no_files.scpca_id, submitted_sample_ids)

@patch("scpca_portal.management.commands.dispatch_to_batch.Command.submit_job")
def test_regenerate_files_for_passed_project(self, mock_submit_job):
project = ProjectFactory()
self.assertTrue(project.computed_files.exists())
adtl_projects = [ProjectFactory() for _ in range(3)]
for adtl_project in adtl_projects:
self.assertTrue(adtl_project.computed_files.exists())

self.dispatch_to_batch(project_id=project.scpca_id, regenerate_all=True)
mock_submit_job.assert_called()

submitted_project_ids = set(
call.kwargs.get("project_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("project_id") is not None
)
self.assertEqual(len(submitted_project_ids), 1)
self.assertIn(project.scpca_id, submitted_project_ids)
for adtl_project in adtl_projects:
self.assertNotIn(adtl_project.scpca_id, submitted_project_ids)

sample = project.samples.first()
submitted_sample_ids = set(
call.kwargs.get("sample_id")
for call in mock_submit_job.call_args_list
if call.kwargs.get("sample_id") is not None
)
self.assertEqual(len(submitted_sample_ids), 1)
self.assertIn(sample.scpca_id, submitted_sample_ids)

@patch("scpca_portal.management.commands.dispatch_to_batch.Command.submit_job")
def test_no_missing_computed_files(self, mock_submit_job):
project = ProjectFactory()
self.assertTrue(project.computed_files.exists())

self.dispatch_to_batch()
mock_submit_job.assert_not_called()

def test_project_missing_sample_computed_files(self):
"""
We currently don't support generation of individual samples.
We plan on removing sample file generation in favor of Dataset downloads.
"""
pass
Loading