From e0fd54b2ee0e5381dbc4243edf0b309ec613c794 Mon Sep 17 00:00:00 2001 From: David Farrington Date: Thu, 19 Dec 2024 18:40:00 +0000 Subject: [PATCH] Fix issue where tasks preceeding parallel for loops that recieve pipeline parameters, are wrongly expected to have task attributes Signed-off-by: David Farrington --- sdk/python/kfp/compiler/compiler_test.py | 47 +++++++++++++++-------- sdk/python/kfp/compiler/compiler_utils.py | 4 +- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 898187e36c3..90778bb0cae 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -419,7 +419,7 @@ def test_set_description_through_pipeline_decorator(self): @dsl.pipeline(description='Prefer me.') def my_pipeline(): - """Don't prefer me""" + """Don't prefer me.""" VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input') self.assertEqual(my_pipeline.pipeline_spec.pipeline_info.description, @@ -441,7 +441,8 @@ def test_set_description_through_pipeline_docstring_long(self): def my_pipeline(): """Docstring-specified description. - More information about this pipeline.""" + More information about this pipeline. + """ VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input') self.assertEqual( @@ -2429,6 +2430,7 @@ def pipeline_with_multiline_definition( sample_input1: bool = True, sample_input2: str = 'string') -> str: """docstring short description. + docstring long description. docstring long description. """ op1 = my_comp(string=sample_input2, model=sample_input1) @@ -2455,10 +2457,9 @@ def pipeline_with_multiline_definition( def pipeline_with_multiline_definition( sample_input1: bool = True, sample_input2: str = 'string') -> str: - """ - docstring long description. - docstring long description. - docstring long description. + """docstring long description. + + docstring long description. docstring long description. """ op1 = my_comp(string=sample_input2, model=sample_input1) result = op1.output @@ -2487,8 +2488,8 @@ def test_idempotency_on_comment_with_multiline_docstring(self): def my_pipeline(sample_input1: bool = True, sample_input2: str = 'string') -> str: """docstring short description. - docstring long description. - docstring long description. + + docstring long description. docstring long description. """ op1 = my_comp(string=sample_input2, model=sample_input1) result = op1.output @@ -4144,7 +4145,7 @@ def my_pipeline( string: str, in_artifact: Input[Artifact], ) -> Outputs: - """Pipeline description. Returns + """Pipeline description. Returns. Args: string: Return Pipeline input string. Returns @@ -4607,7 +4608,9 @@ class TestDslOneOf(unittest.TestCase): # To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on the dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer. def test_if_else_returned(self): - """Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels.""" + """Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf + returned from a pipeline, and different output keys on dsl.OneOf + channels.""" @dsl.pipeline def roll_die_pipeline() -> str: @@ -4668,7 +4671,9 @@ def roll_die_pipeline() -> str: ) def test_if_elif_else_returned(self): - """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels.""" + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, + dsl.OneOf returned from a pipeline, and different output keys on + dsl.OneOf channels.""" @dsl.pipeline def roll_die_pipeline() -> str: @@ -4743,7 +4748,9 @@ def roll_die_pipeline() -> str: ) def test_if_elif_else_consumed(self): - """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf passed to a consumer task, and different output keys on dsl.OneOf channels.""" + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, + dsl.OneOf passed to a consumer task, and different output keys on + dsl.OneOf channels.""" @dsl.pipeline def roll_die_pipeline(): @@ -4820,7 +4827,9 @@ def roll_die_pipeline(): ) def test_if_else_consumed_and_returned(self): - """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline.""" + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, + and dsl.OneOf passed to a consumer task and returned from the + pipeline.""" @dsl.pipeline def flip_coin_pipeline() -> str: @@ -4893,7 +4902,8 @@ def flip_coin_pipeline() -> str: ) def test_if_else_consumed_and_returned_artifacts(self): - """Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline.""" + """Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and + dsl.OneOf passed to a consumer task and returned from the pipeline.""" @dsl.pipeline def flip_coin_pipeline() -> Artifact: @@ -5060,7 +5070,8 @@ def flip_coin_pipeline(execute_pipeline: bool): print_task_2.outputs['a']) def test_deeply_nested_consumed(self): - """Uses If, Elif, Else, and OneOf deeply nested within multiple dub-DAGs.""" + """Uses If, Elif, Else, and OneOf deeply nested within multiple dub- + DAGs.""" @dsl.pipeline def flip_coin_pipeline(execute_pipeline: bool): @@ -5159,7 +5170,8 @@ def flip_coin_pipeline(execute_pipeline: bool): print_task_2.outputs['a']) def test_oneof_in_condition(self): - """Tests that dsl.OneOf's channel can be consumed in a downstream group nested one level""" + """Tests that dsl.OneOf's channel can be consumed in a downstream group + nested one level.""" @dsl.pipeline def roll_die_pipeline(repeat_on: str = 'Got heads!'): @@ -5212,7 +5224,8 @@ def roll_die_pipeline(repeat_on: str = 'Got heads!'): ) def test_consumed_in_nested_groups(self): - """Tests that dsl.OneOf's channel can be consumed in a downstream group nested multiple levels""" + """Tests that dsl.OneOf's channel can be consumed in a downstream group + nested multiple levels.""" @dsl.pipeline def roll_die_pipeline( diff --git a/sdk/python/kfp/compiler/compiler_utils.py b/sdk/python/kfp/compiler/compiler_utils.py index dc10665944f..5ab4bb1f694 100644 --- a/sdk/python/kfp/compiler/compiler_utils.py +++ b/sdk/python/kfp/compiler/compiler_utils.py @@ -762,7 +762,9 @@ def get_dependencies( # then make this validation dsl.Collected-aware elif isinstance(upstream_parent_group, tasks_group.ParallelFor): upstream_tasks_that_downstream_consumers_from = [ - channel.task.name for channel in task._channel_inputs + channel.task.name + for channel in task._channel_inputs + if channel.task ] has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from # don't raise for .after