diff --git a/app/grandchallenge/algorithms/forms.py b/app/grandchallenge/algorithms/forms.py index 4154a48c6..fd7a29675 100644 --- a/app/grandchallenge/algorithms/forms.py +++ b/app/grandchallenge/algorithms/forms.py @@ -56,6 +56,7 @@ AlgorithmModel, AlgorithmPermissionRequest, Job, + annotate_input_output_counts, ) from grandchallenge.algorithms.serializers import ( AlgorithmImageSerializer, @@ -462,37 +463,32 @@ def user_algorithms_for_phase(self): desired_model_subquery = AlgorithmModel.objects.filter( algorithm=OuterRef("pk"), is_desired_version=True ) - return ( - get_objects_for_user(self._user, "algorithms.change_algorithm") - .annotate( - total_input_count=Count("inputs", distinct=True), - total_output_count=Count("outputs", distinct=True), - relevant_input_count=Count( - "inputs", filter=Q(inputs__in=inputs), distinct=True - ), - relevant_output_count=Count( - "outputs", filter=Q(outputs__in=outputs), distinct=True - ), - has_active_image=Exists(desired_image_subquery), - active_image_pk=desired_image_subquery.values_list( - "pk", flat=True - ), - active_model_pk=desired_model_subquery.values_list( - "pk", flat=True - ), - active_image_comment=desired_image_subquery.values_list( - "comment", flat=True - ), - active_model_comment=desired_model_subquery.values_list( - "comment", flat=True - ), - ) - .filter( - total_input_count=len(inputs), - total_output_count=len(outputs), - relevant_input_count=len(inputs), - relevant_output_count=len(outputs), - ) + annotated_qs = annotate_input_output_counts( + queryset=get_objects_for_user( + self._user, "algorithms.change_algorithm" + ), + inputs=inputs, + outputs=outputs, + ) + return annotated_qs.annotate( + has_active_image=Exists(desired_image_subquery), + active_image_pk=desired_image_subquery.values_list( + "pk", flat=True + ), + active_model_pk=desired_model_subquery.values_list( + "pk", flat=True + ), + active_image_comment=desired_image_subquery.values_list( + "comment", flat=True + ), + active_model_comment=desired_model_subquery.values_list( + "comment", flat=True + ), + ).filter( + input_count=len(inputs), + output_count=len(outputs), + relevant_input_count=len(inputs), + relevant_output_count=len(outputs), ) @cached_property diff --git a/app/grandchallenge/algorithms/models.py b/app/grandchallenge/algorithms/models.py index 0392b369d..2c2958a93 100644 --- a/app/grandchallenge/algorithms/models.py +++ b/app/grandchallenge/algorithms/models.py @@ -69,6 +69,23 @@ JINJA_ENGINE = sandbox.ImmutableSandboxedEnvironment() +def annotate_input_output_counts(queryset, inputs=None, outputs=None): + return queryset.annotate( + input_count=Count("inputs", distinct=True), + output_count=Count("outputs", distinct=True), + relevant_input_count=Count( + "inputs", + filter=Q(inputs__in=inputs) if inputs is not None else Q(), + distinct=True, + ), + relevant_output_count=Count( + "outputs", + filter=Q(outputs__in=outputs) if outputs is not None else Q(), + distinct=True, + ), + ) + + class AlgorithmInterfaceManager(models.Manager): def create( @@ -96,22 +113,6 @@ def create( def delete(self): raise NotImplementedError("Bulk delete is not allowed.") - def with_input_output_counts(self, inputs=None, outputs=None): - return self.annotate( - input_count=Count("inputs", distinct=True), - output_count=Count("outputs", distinct=True), - relevant_input_count=Count( - "inputs", - filter=Q(inputs__in=inputs) if inputs is not None else Q(), - distinct=True, - ), - relevant_output_count=Count( - "outputs", - filter=Q(outputs__in=outputs) if outputs is not None else Q(), - distinct=True, - ), - ) - class AlgorithmInterface(UUIDModel): inputs = models.ManyToManyField( @@ -144,10 +145,11 @@ class AlgorithmInterfaceOutput(models.Model): def get_existing_interface_for_inputs_and_outputs( *, inputs, outputs, model=AlgorithmInterface ): + annotated_qs = annotate_input_output_counts( + model.objects.all(), inputs=inputs, outputs=outputs + ) try: - return model.objects.with_input_output_counts( - inputs=inputs, outputs=outputs - ).get( + return annotated_qs.get( relevant_input_count=len(inputs), relevant_output_count=len(outputs), input_count=len(inputs), @@ -998,18 +1000,12 @@ def get_jobs_with_same_inputs( # the existing civs and filter on both counts so as to not include jobs # with partially overlapping inputs # or jobs with more inputs than the existing civs - existing_jobs = ( - Job.objects.filter(**unique_kwargs) - .annotate( - input_count=Count("inputs", distinct=True), - input_match_count=Count( - "inputs", filter=Q(inputs__in=existing_civs), distinct=True - ), - ) - .filter( - input_count=input_interface_count, - input_match_count=input_interface_count, - ) + annotated_qs = annotate_input_output_counts( + queryset=Job.objects.filter(**unique_kwargs), inputs=existing_civs + ) + existing_jobs = annotated_qs.filter( + input_count=input_interface_count, + relevant_input_count=input_interface_count, ) return existing_jobs diff --git a/app/grandchallenge/algorithms/serializers.py b/app/grandchallenge/algorithms/serializers.py index af9857b42..ac930f13d 100644 --- a/app/grandchallenge/algorithms/serializers.py +++ b/app/grandchallenge/algorithms/serializers.py @@ -19,6 +19,7 @@ AlgorithmInterface, AlgorithmModel, Job, + annotate_input_output_counts, ) from grandchallenge.components.backends.exceptions import ( CIVNotEditableException, @@ -288,10 +289,11 @@ def validate_inputs_and_return_matching_interface(self, *, inputs): the algorithm and returns that AlgorithmInterface """ provided_inputs = {i["interface"] for i in inputs} + annotated_qs = annotate_input_output_counts( + self._algorithm.interfaces, inputs=provided_inputs + ) try: - interface = self._algorithm.interfaces.with_input_output_counts( - inputs=provided_inputs - ).get( + interface = annotated_qs.get( relevant_input_count=len(provided_inputs), input_count=len(provided_inputs), )