From 7a54e965747111e9a1c03e3a5fb010371d2b659c Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Thu, 5 Dec 2024 12:21:58 -0800 Subject: [PATCH 1/4] test Signed-off-by: Paul Dittamo --- .../unit/core/test_array_node_map_task.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index d4281227db..05e7c8c213 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -351,16 +351,28 @@ def my_wf1() -> typing.List[typing.Optional[int]]: def test_map_task_override(serialization_settings): - @task + @task( + timeout=timedelta(seconds=10) + ) def my_mappable_task(a: int) -> typing.Optional[str]: return str(a) + arraynode_maptask = map_task(my_mappable_task) + @workflow def wf(x: typing.List[int]): - map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") + arraynode_maptask(a=x).with_overrides(container_image="random:image", timeout=timedelta(seconds=20)) assert wf.nodes[0]._container_image == "random:image" + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf) + + array_node = wf_spec.template.nodes[0] + assert array_node.metadata.timeout == timedelta() + task_spec = od[arraynode_maptask] + assert task_spec.template.metadata.timeout == timedelta(seconds=20) + def test_serialization_metadata(serialization_settings): @task(interruptible=True) From ba2f578ddfa9b9569ec072f2bb24fb162d6ec824 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Wed, 8 Jan 2025 20:17:06 -0800 Subject: [PATCH 2/4] add support for with_overrides for map tasks Signed-off-by: Paul Dittamo --- flytekit/core/array_node_map_task.py | 13 +-- flytekit/core/node.py | 84 ++++++++++++------- flytekit/tools/translator.py | 2 +- .../unit/core/test_array_node_map_task.py | 17 +++- 4 files changed, 74 insertions(+), 42 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 87150c47c1..dee83d6911 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -128,6 +128,9 @@ def __init__( **kwargs, ) + self.sub_node_metadata: NodeMetadata = super().construct_node_metadata() + self.sub_node_metadata._name = self.name + @property def name(self) -> str: return self._name @@ -137,15 +140,15 @@ def python_interface(self): return self._collection_interface def construct_node_metadata(self) -> NodeMetadata: - # TODO: add support for other Flyte entities + """ + This returns metadata for the parent ArrayNode, not the sub-node getting mapped over + """ return NodeMetadata( name=self.name, ) - def construct_sub_node_metadata(self) -> NodeMetadata: - nm = super().construct_node_metadata() - nm._name = self.name - return nm + def get_sub_node_metadata(self) -> NodeMetadata: + return self.sub_node_metadata @property def min_success_ratio(self) -> Optional[float]: diff --git a/flytekit/core/node.py b/flytekit/core/node.py index ea089c6fd3..61ae41c060 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -124,6 +124,57 @@ def run_entity(self) -> Any: def metadata(self) -> _workflow_model.NodeMetadata: return self._metadata + def _override_node_metadata( + self, + name, + timeout: Optional[Union[int, datetime.timedelta]] = None, + retries: Optional[int] = None, + interruptible: typing.Optional[bool] = None, + cache: typing.Optional[bool] = None, + cache_version: typing.Optional[str] = None, + cache_serialize: typing.Optional[bool] = None, + ): + from flytekit.core.array_node_map_task import ArrayNodeMapTask + + if isinstance(self.flyte_entity, ArrayNodeMapTask): + # override the sub-node's metadata + node_metadata = self.flyte_entity.sub_node_metadata + else: + node_metadata = self._metadata + + if timeout is None: + node_metadata._timeout = datetime.timedelta() + elif isinstance(timeout, int): + node_metadata._timeout = datetime.timedelta(seconds=timeout) + elif isinstance(timeout, datetime.timedelta): + node_metadata._timeout = timeout + else: + raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") + if retries is not None: + assert_not_promise(retries, "retries") + node_metadata._retries = ( + _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) + ) + + if interruptible is not None: + assert_not_promise(interruptible, "interruptible") + node_metadata._interruptible = interruptible + + if name is not None: + node_metadata._name = name + + if cache is not None: + assert_not_promise(cache, "cache") + node_metadata._cacheable = cache + + if cache_version is not None: + assert_not_promise(cache_version, "cache_version") + node_metadata._cache_version = cache_version + + if cache_serialize is not None: + assert_not_promise(cache_serialize, "cache_serialize") + node_metadata._cache_serializable = cache_serialize + def with_overrides( self, node_name: Optional[str] = None, @@ -174,27 +225,6 @@ def with_overrides( assert_no_promises_in_resources(resources) self._resources = resources - if timeout is None: - self._metadata._timeout = datetime.timedelta() - elif isinstance(timeout, int): - self._metadata._timeout = datetime.timedelta(seconds=timeout) - elif isinstance(timeout, datetime.timedelta): - self._metadata._timeout = timeout - else: - raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") - if retries is not None: - assert_not_promise(retries, "retries") - self._metadata._retries = ( - _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) - ) - - if interruptible is not None: - assert_not_promise(interruptible, "interruptible") - self._metadata._interruptible = interruptible - - if name is not None: - self._metadata._name = name - if task_config is not None: logger.warning("This override is beta. We may want to revisit this in the future.") if not isinstance(task_config, type(self.run_entity._task_config)): @@ -209,17 +239,7 @@ def with_overrides( assert_not_promise(accelerator, "accelerator") self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl()) - if cache is not None: - assert_not_promise(cache, "cache") - self._metadata._cacheable = cache - - if cache_version is not None: - assert_not_promise(cache_version, "cache_version") - self._metadata._cache_version = cache_version - - if cache_serialize is not None: - assert_not_promise(cache_serialize, "cache_serialize") - self._metadata._cache_serializable = cache_serialize + self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize) return self diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index ee905a4218..a2f1e8c7a1 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -624,7 +624,7 @@ def get_serializable_array_node_map_task( ) node = workflow_model.Node( id=entity.name, - metadata=entity.construct_sub_node_metadata(), + metadata=entity.get_sub_node_metadata(), inputs=node.bindings, upstream_node_ids=[], output_aliases=[], diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 808059d1ac..d7fb5998ec 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -351,7 +351,9 @@ def my_wf1() -> typing.List[typing.Optional[int]]: def test_map_task_override(serialization_settings): @task( - timeout=timedelta(seconds=10) + timeout=timedelta(seconds=10), + interruptible=True, + retries=10, ) def my_mappable_task(a: int) -> typing.Optional[str]: return str(a) @@ -360,7 +362,12 @@ def my_mappable_task(a: int) -> typing.Optional[str]: @workflow def wf(x: typing.List[int]): - arraynode_maptask(a=x).with_overrides(container_image="random:image", timeout=timedelta(seconds=20)) + arraynode_maptask(a=x).with_overrides( + container_image="random:image", + timeout=timedelta(seconds=20), + cache=True, + cache_version="v1", + ) assert wf.nodes[0]._container_image == "random:image" @@ -369,8 +376,10 @@ def wf(x: typing.List[int]): array_node = wf_spec.template.nodes[0] assert array_node.metadata.timeout == timedelta() - task_spec = od[arraynode_maptask] - assert task_spec.template.metadata.timeout == timedelta(seconds=20) + sub_node_spec = array_node.array_node.node + assert sub_node_spec.metadata.timeout == timedelta(seconds=20) + assert sub_node_spec.metadata.retries.retries == 10 + assert sub_node_spec.metadata.interruptible def test_serialization_metadata(serialization_settings): From fe58977abb5b51bc67e5c0b9e5aad1c69e9de041 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Wed, 8 Jan 2025 21:40:04 -0800 Subject: [PATCH 3/4] expand unit test Signed-off-by: Paul Dittamo --- .../unit/core/test_array_node_map_task.py | 53 +++++++++++++------ 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index d7fb5998ec..97693940e0 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -9,7 +9,7 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, Resources from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver @@ -21,6 +21,7 @@ LiteralMap, LiteralOffloadedMetadata, ) +from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable from flytekit.types.directory import FlyteDirectory @@ -349,27 +350,44 @@ def my_wf1() -> typing.List[typing.Optional[int]]: assert my_wf1() == [1, None, 3, 4] -def test_map_task_override(serialization_settings): - @task( - timeout=timedelta(seconds=10), - interruptible=True, - retries=10, - ) - def my_mappable_task(a: int) -> typing.Optional[str]: - return str(a) +@task +def my_mappable_task(a: int) -> typing.Optional[str]: + return str(a) + + +@task( + container_image="original-image", + timeout=timedelta(seconds=10), + interruptible=False, + retries=10, + cache=True, + cache_version="original-version", + requests=Resources(cpu=1) +) +def my_mappable_task_1(a: int) -> typing.Optional[str]: + return str(a) - arraynode_maptask = map_task(my_mappable_task) + +@pytest.mark.parametrize( + "task_func", + [my_mappable_task, my_mappable_task_1] +) +def test_map_task_override(serialization_settings, task_func): + array_node_map_task = map_task(task_func) @workflow def wf(x: typing.List[int]): - arraynode_maptask(a=x).with_overrides( - container_image="random:image", + array_node_map_task(a=x).with_overrides( + container_image="new-image", timeout=timedelta(seconds=20), + interruptible=True, + retries=5, cache=True, - cache_version="v1", + cache_version="new-version", + requests=Resources(cpu=2) ) - assert wf.nodes[0]._container_image == "random:image" + assert wf.nodes[0]._container_image == "new-image" od = OrderedDict() wf_spec = get_serializable(od, serialization_settings, wf) @@ -378,8 +396,13 @@ def wf(x: typing.List[int]): assert array_node.metadata.timeout == timedelta() sub_node_spec = array_node.array_node.node assert sub_node_spec.metadata.timeout == timedelta(seconds=20) - assert sub_node_spec.metadata.retries.retries == 10 assert sub_node_spec.metadata.interruptible + assert sub_node_spec.metadata.retries.retries == 5 + assert sub_node_spec.metadata.cacheable + assert sub_node_spec.metadata.cache_version == "new-version" + assert sub_node_spec.target.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2") + ] def test_serialization_metadata(serialization_settings): From 550ca24c3cbc3cbcd42b85bf84fae8f60c5688e4 Mon Sep 17 00:00:00 2001 From: Paul Dittamo Date: Sun, 12 Jan 2025 22:55:53 -0800 Subject: [PATCH 4/4] cleanup Signed-off-by: Paul Dittamo --- flytekit/core/array_node_map_task.py | 3 --- flytekit/tools/translator.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index dee83d6911..05690e175b 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -147,9 +147,6 @@ def construct_node_metadata(self) -> NodeMetadata: name=self.name, ) - def get_sub_node_metadata(self) -> NodeMetadata: - return self.sub_node_metadata - @property def min_success_ratio(self) -> Optional[float]: return self._min_success_ratio diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index a2f1e8c7a1..e74f4c1c71 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -624,7 +624,7 @@ def get_serializable_array_node_map_task( ) node = workflow_model.Node( id=entity.name, - metadata=entity.get_sub_node_metadata(), + metadata=entity.sub_node_metadata, inputs=node.bindings, upstream_node_ids=[], output_aliases=[],