diff --git a/sdk/python/kfp/compiler/compiler_utils.py b/sdk/python/kfp/compiler/compiler_utils.py index 59567a5851e..a52361c97cf 100644 --- a/sdk/python/kfp/compiler/compiler_utils.py +++ b/sdk/python/kfp/compiler/compiler_utils.py @@ -19,12 +19,7 @@ from typing import DefaultDict, Dict, List, Mapping, Set, Tuple, Union from kfp import dsl -from kfp.dsl import constants -from kfp.dsl import for_loop -from kfp.dsl import pipeline_channel -from kfp.dsl import pipeline_context -from kfp.dsl import pipeline_task -from kfp.dsl import tasks_group +from kfp.dsl import constants, for_loop, pipeline_channel, pipeline_context, pipeline_task, tasks_group GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask] @@ -462,7 +457,7 @@ def get_outputs_for_all_groups( group_name_to_group = {group.name: group for group in all_groups} group_name_to_children = { group.name: [group.name for group in group.groups] + - [task.name for task in group.tasks] for group in all_groups + [task.name for task in group.tasks] for group in all_groups } outputs = collections.defaultdict(dict) @@ -762,7 +757,9 @@ def get_dependencies( # then make this validation dsl.Collected-aware elif isinstance(upstream_parent_group, tasks_group.ParallelFor): upstream_tasks_that_downstream_consumers_from = [ - channel.task.name for channel in task._channel_inputs if channel.task + channel.task.name + for channel in task._channel_inputs + if channel.task ] has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from # don't raise for .after