Skip to content

Commit

Permalink
Support with_overrides setting metadata for map_task subnode instead …
Browse files Browse the repository at this point in the history
…of parent node (#2982)

* test

Signed-off-by: Paul Dittamo <[email protected]>

* add support for with_overrides for map tasks

Signed-off-by: Paul Dittamo <[email protected]>

* expand unit test

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt authored and eapolinario committed Jan 22, 2025
1 parent 6e29fe1 commit 41e77e6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 46 deletions.
12 changes: 6 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
**kwargs,
)

self.sub_node_metadata: NodeMetadata = super().construct_node_metadata()
self.sub_node_metadata._name = self.name

@property
def name(self) -> str:
return self._name
Expand All @@ -138,16 +141,13 @@ def python_interface(self):
return self._collection_interface

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
"""
This returns metadata for the parent ArrayNode, not the sub-node getting mapped over
"""
return NodeMetadata(
name=self.name,
)

def construct_sub_node_metadata(self) -> NodeMetadata:
nm = super().construct_node_metadata()
nm._name = self.name
return nm

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio
Expand Down
84 changes: 52 additions & 32 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,57 @@ def run_entity(self) -> Any:
def metadata(self) -> _workflow_model.NodeMetadata:
return self._metadata

def _override_node_metadata(
self,
name,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: typing.Optional[bool] = None,
cache: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serialize: typing.Optional[bool] = None,
):
from flytekit.core.array_node_map_task import ArrayNodeMapTask

if isinstance(self.flyte_entity, ArrayNodeMapTask):
# override the sub-node's metadata
node_metadata = self.flyte_entity.sub_node_metadata
else:
node_metadata = self._metadata

if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
node_metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
node_metadata._interruptible = interruptible

if name is not None:
node_metadata._name = name

if cache is not None:
assert_not_promise(cache, "cache")
node_metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
node_metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
node_metadata._cache_serializable = cache_serialize

def with_overrides(
self,
node_name: Optional[str] = None,
Expand Down Expand Up @@ -174,27 +225,6 @@ def with_overrides(
assert_no_promises_in_resources(resources)
self._resources = resources

if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if name is not None:
self._metadata._name = name

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
if not isinstance(task_config, type(self.run_entity._task_config)):
Expand All @@ -209,17 +239,7 @@ def with_overrides(
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize
self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

return self

Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def get_serializable_array_node_map_task(
)
node = workflow_model.Node(
id=entity.name,
metadata=entity.construct_sub_node_metadata(),
metadata=entity.sub_node_metadata,
inputs=node.bindings,
upstream_node_ids=[],
output_aliases=[],
Expand Down
58 changes: 51 additions & 7 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit import dynamic, map_task, task, workflow, PythonFunctionTask
from flytekit import dynamic, map_task, task, workflow, PythonFunctionTask, Resources
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
Expand All @@ -22,6 +22,7 @@
LiteralMap,
LiteralOffloadedMetadata,
)
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable
from flytekit.types.directory import FlyteDirectory

Expand Down Expand Up @@ -350,16 +351,59 @@ def my_wf1() -> typing.List[typing.Optional[int]]:
assert my_wf1() == [1, None, 3, 4]


def test_map_task_override(serialization_settings):
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)


@task(
container_image="original-image",
timeout=timedelta(seconds=10),
interruptible=False,
retries=10,
cache=True,
cache_version="original-version",
requests=Resources(cpu=1)
)
def my_mappable_task_1(a: int) -> typing.Optional[str]:
return str(a)


@pytest.mark.parametrize(
"task_func",
[my_mappable_task, my_mappable_task_1]
)
def test_map_task_override(serialization_settings, task_func):
array_node_map_task = map_task(task_func)

@workflow
def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")
array_node_map_task(a=x).with_overrides(
container_image="new-image",
timeout=timedelta(seconds=20),
interruptible=True,
retries=5,
cache=True,
cache_version="new-version",
requests=Resources(cpu=2)
)

assert wf.nodes[0]._container_image == "new-image"

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert wf.nodes[0]._container_image == "random:image"
array_node = wf_spec.template.nodes[0]
assert array_node.metadata.timeout == timedelta()
sub_node_spec = array_node.array_node.node
assert sub_node_spec.metadata.timeout == timedelta(seconds=20)
assert sub_node_spec.metadata.interruptible
assert sub_node_spec.metadata.retries.retries == 5
assert sub_node_spec.metadata.cacheable
assert sub_node_spec.metadata.cache_version == "new-version"
assert sub_node_spec.target.overrides.resources.requests == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2")
]


def test_serialization_metadata(serialization_settings):
Expand Down

0 comments on commit 41e77e6

Please sign in to comment.