From be1dc74d7d4d816133e617b6a61bdc52e99fbc11 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Dec 2024 04:28:56 +0000 Subject: [PATCH] feat: Introduce cache_key to sdk Signed-off-by: Ze Mao --- sdk/python/kfp/client/client.py | 24 ++++++++++++++++++- .../kfp/compiler/pipeline_spec_builder.py | 2 ++ sdk/python/kfp/dsl/pipeline_task.py | 8 +++++-- sdk/python/kfp/dsl/structures.py | 3 +++ 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sdk/python/kfp/client/client.py b/sdk/python/kfp/client/client.py index f8897236343..926c3c6eb4a 100644 --- a/sdk/python/kfp/client/client.py +++ b/sdk/python/kfp/client/client.py @@ -686,6 +686,7 @@ def run_pipeline( version_id: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = "", service_account: Optional[str] = None, ) -> kfp_server_api.V2beta1Run: """Runs a specified pipeline. @@ -709,6 +710,8 @@ def run_pipeline( is ``True`` for all tasks by default. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. @@ -721,6 +724,7 @@ def run_pipeline( pipeline_id=pipeline_id, version_id=version_id, enable_caching=enable_caching, + cache_key=cache_key, pipeline_root=pipeline_root, ) @@ -806,6 +810,7 @@ def create_recurring_run( enabled: bool = True, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = "", service_account: Optional[str] = None, ) -> kfp_server_api.V2beta1RecurringRun: """Creates a recurring run. @@ -850,6 +855,8 @@ def create_recurring_run( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account this recurring run uses. Returns: @@ -862,6 +869,7 @@ def create_recurring_run( pipeline_id=pipeline_id, version_id=version_id, enable_caching=enable_caching, + cache_key=cache_key, pipeline_root=pipeline_root, ) @@ -908,6 +916,7 @@ def _create_job_config( pipeline_id: Optional[str], version_id: Optional[str], enable_caching: Optional[bool], + cache_key: Optional[str], pipeline_root: Optional[str], ) -> _JobConfig: """Creates a JobConfig with spec and resource_references. @@ -928,6 +937,8 @@ def _create_job_config( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. pipeline_root: Root path of the pipeline outputs. Returns: @@ -956,7 +967,8 @@ def _create_job_config( # settings. if enable_caching is not None: _override_caching_options(pipeline_doc.pipeline_spec, - enable_caching) + enable_caching, + cache_key) pipeline_spec = pipeline_doc.to_dict() pipeline_version_reference = None @@ -983,6 +995,7 @@ def create_run_from_pipeline_func( namespace: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = "", service_account: Optional[str] = None, experiment_id: Optional[str] = None, ) -> RunPipelineResult: @@ -1004,6 +1017,8 @@ def create_run_from_pipeline_func( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name. @@ -1032,6 +1047,7 @@ def create_run_from_pipeline_func( namespace=namespace, pipeline_root=pipeline_root, enable_caching=enable_caching, + cache_key=cache_key, service_account=service_account, ) @@ -1044,6 +1060,7 @@ def create_run_from_pipeline_package( namespace: Optional[str] = None, pipeline_root: Optional[str] = None, enable_caching: Optional[bool] = None, + cache_key: Optional[str] = "", service_account: Optional[str] = None, experiment_id: Optional[str] = None, ) -> RunPipelineResult: @@ -1065,6 +1082,8 @@ def create_run_from_pipeline_package( different caching options for individual tasks. If set, the setting applies to all tasks in the pipeline (overrides the compile time settings). + cache_key (optional): Customized cache key for this task. + If set, the cache_key will be used as the key for the task's cache. service_account: Specifies which Kubernetes service account to use for this run. experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name. @@ -1105,6 +1124,7 @@ def create_run_from_pipeline_package( params=arguments, pipeline_root=pipeline_root, enable_caching=enable_caching, + cache_key=cache_key, service_account=service_account, ) return RunPipelineResult(self, run_info) @@ -1681,6 +1701,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc: def _override_caching_options( pipeline_spec: pipeline_spec_pb2.PipelineSpec, enable_caching: bool, + cache_key: str="", ) -> None: """Overrides caching options. @@ -1690,3 +1711,4 @@ def _override_caching_options( """ for _, task_spec in pipeline_spec.root.dag.tasks.items(): task_spec.caching_options.enable_cache = enable_caching + task_spec.caching_options.cache_key = cache_key diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 3061faab5e4..d95d3c52d95 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -124,6 +124,8 @@ def build_task_spec_for_task( utils.sanitize_component_name(task.name)) pipeline_task_spec.caching_options.enable_cache = ( task._task_spec.enable_caching) + pipeline_task_spec.caching_options.cache_key = ( + task._task_spec.cache_key) if task._task_spec.retry_policy is not None: pipeline_task_spec.retry_policy.CopyFrom( diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index b41a14ef82d..082b696a7b8 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -99,6 +99,7 @@ def __init__( args: Dict[str, Any], execute_locally: bool = False, execution_caching_default: bool = True, + execution_cache_key: str = "", ) -> None: """Initilizes a PipelineTask instance.""" # import within __init__ to avoid circular import @@ -131,7 +132,8 @@ def __init__( inputs=dict(args.items()), dependent_tasks=[], component_ref=component_spec.name, - enable_caching=execution_caching_default) + enable_caching=execution_caching_default, + cache_key=execution_cache_key) self._run_after: List[str] = [] self.importer_spec = None @@ -301,16 +303,18 @@ def _extract_container_spec_and_convert_placeholders( return container_spec @block_if_final() - def set_caching_options(self, enable_caching: bool) -> 'PipelineTask': + def set_caching_options(self, enable_caching: bool, cache_key: str = "") -> 'PipelineTask': """Sets caching options for the task. Args: enable_caching: Whether to enable caching. + cache_key: Customized cache key for this task. Returns: Self return to allow chained setting calls. """ self._task_spec.enable_caching = enable_caching + self._task_spec.cache_key = cache_key return self def _ensure_container_spec_exists(self) -> None: diff --git a/sdk/python/kfp/dsl/structures.py b/sdk/python/kfp/dsl/structures.py index 5a73d93b35c..fa161cffcbe 100644 --- a/sdk/python/kfp/dsl/structures.py +++ b/sdk/python/kfp/dsl/structures.py @@ -409,6 +409,8 @@ class TaskSpec: from the [items][] collection. enable_caching (optional): whether or not to enable caching for the task. Default is True. + cache_key (optional): Customized cache key for this task. + Default is empty string. display_name (optional): the display name of the task. If not specified, the task name will be used as the display name. """ @@ -421,6 +423,7 @@ class TaskSpec: iterator_items: Optional[Any] = None iterator_item_input: Optional[str] = None enable_caching: bool = True + cache_key: str = "" display_name: Optional[str] = None retry_policy: Optional[RetryPolicy] = None