Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce cache_key for cache key customization in SDK #11465

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def run_pipeline(
version_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1Run:
"""Runs a specified pipeline.
Expand All @@ -709,6 +710,8 @@ def run_pipeline(
is ``True`` for all tasks by default. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.

Expand All @@ -721,6 +724,7 @@ def run_pipeline(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -806,6 +810,7 @@ def create_recurring_run(
enabled: bool = True,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1RecurringRun:
"""Creates a recurring run.
Expand Down Expand Up @@ -850,6 +855,8 @@ def create_recurring_run(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account this recurring run uses.
Returns:
Expand All @@ -862,6 +869,7 @@ def create_recurring_run(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -908,6 +916,7 @@ def _create_job_config(
pipeline_id: Optional[str],
version_id: Optional[str],
enable_caching: Optional[bool],
cache_key: Optional[str],
pipeline_root: Optional[str],
) -> _JobConfig:
"""Creates a JobConfig with spec and resource_references.
Expand All @@ -928,6 +937,8 @@ def _create_job_config(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
pipeline_root: Root path of the pipeline outputs.

Returns:
Expand Down Expand Up @@ -956,7 +967,8 @@ def _create_job_config(
# settings.
if enable_caching is not None:
_override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
enable_caching,
cache_key)
pipeline_spec = pipeline_doc.to_dict()

pipeline_version_reference = None
Expand All @@ -983,6 +995,7 @@ def create_run_from_pipeline_func(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1004,6 +1017,8 @@ def create_run_from_pipeline_func(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1032,6 +1047,7 @@ def create_run_from_pipeline_func(
namespace=namespace,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)

Expand All @@ -1044,6 +1060,7 @@ def create_run_from_pipeline_package(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1065,6 +1082,8 @@ def create_run_from_pipeline_package(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1105,6 +1124,7 @@ def create_run_from_pipeline_package(
params=arguments,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)
return RunPipelineResult(self, run_info)
Expand Down Expand Up @@ -1681,6 +1701,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
def _override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
cache_key: str="",
) -> None:
"""Overrides caching options.

Expand All @@ -1690,3 +1711,4 @@ def _override_caching_options(
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching
task_spec.caching_options.cache_key = cache_key
2 changes: 2 additions & 0 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def build_task_spec_for_task(
utils.sanitize_component_name(task.name))
pipeline_task_spec.caching_options.enable_cache = (
task._task_spec.enable_caching)
pipeline_task_spec.caching_options.cache_key = (
task._task_spec.cache_key)

if task._task_spec.retry_policy is not None:
pipeline_task_spec.retry_policy.CopyFrom(
Expand Down
8 changes: 6 additions & 2 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
args: Dict[str, Any],
execute_locally: bool = False,
execution_caching_default: bool = True,
execution_cache_key: str = "",
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
Expand Down Expand Up @@ -131,7 +132,8 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=execution_caching_default)
enable_caching=execution_caching_default,
cache_key=execution_cache_key)
self._run_after: List[str] = []

self.importer_spec = None
Expand Down Expand Up @@ -301,16 +303,18 @@ def _extract_container_spec_and_convert_placeholders(
return container_spec

@block_if_final()
def set_caching_options(self, enable_caching: bool) -> 'PipelineTask':
def set_caching_options(self, enable_caching: bool, cache_key: str = "") -> 'PipelineTask':
"""Sets caching options for the task.

Args:
enable_caching: Whether to enable caching.
cache_key: Customized cache key for this task.

Returns:
Self return to allow chained setting calls.
"""
self._task_spec.enable_caching = enable_caching
self._task_spec.cache_key = cache_key
return self

def _ensure_container_spec_exists(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class TaskSpec:
from the [items][] collection.
enable_caching (optional): whether or not to enable caching for the task.
Default is True.
cache_key (optional): Customized cache key for this task.
Default is empty string.
display_name (optional): the display name of the task. If not specified,
the task name will be used as the display name.
"""
Expand All @@ -421,6 +423,7 @@ class TaskSpec:
iterator_items: Optional[Any] = None
iterator_item_input: Optional[str] = None
enable_caching: bool = True
cache_key: str = ""
display_name: Optional[str] = None
retry_policy: Optional[RetryPolicy] = None

Expand Down
Loading