Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support with_overrides setting metadata for map_task subnode instead of parent node #2982

Merged
merged 5 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
**kwargs,
)

self.sub_node_metadata: NodeMetadata = super().construct_node_metadata()
self.sub_node_metadata._name = self.name
Comment on lines +131 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using constructor for name setting

Consider using the constructor parameters to set the name property when creating NodeMetadata instead of modifying the protected _name attribute directly. This would follow better encapsulation practices.

Code suggestion
Check the AI-generated fix before applying
Suggested change
self.sub_node_metadata: NodeMetadata = super().construct_node_metadata()
self.sub_node_metadata._name = self.name
self.sub_node_metadata: NodeMetadata = NodeMetadata(name=self.name, timeout=self.metadata.timeout, retries=self.metadata.retry_strategy, interruptible=self.metadata.interruptible)

Code Review Run #a0a338


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


@property
def name(self) -> str:
return self._name
Expand All @@ -137,16 +140,13 @@ def python_interface(self):
return self._collection_interface

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
"""
This returns metadata for the parent ArrayNode, not the sub-node getting mapped over
"""
return NodeMetadata(
name=self.name,
)

def construct_sub_node_metadata(self) -> NodeMetadata:
nm = super().construct_node_metadata()
nm._name = self.name
return nm

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio
Expand Down
84 changes: 52 additions & 32 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,57 @@ def run_entity(self) -> Any:
def metadata(self) -> _workflow_model.NodeMetadata:
return self._metadata

def _override_node_metadata(
self,
name,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: typing.Optional[bool] = None,
cache: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serialize: typing.Optional[bool] = None,
):
from flytekit.core.array_node_map_task import ArrayNodeMapTask

if isinstance(self.flyte_entity, ArrayNodeMapTask):
# override the sub-node's metadata
node_metadata = self.flyte_entity.sub_node_metadata
else:
node_metadata = self._metadata

if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
node_metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
node_metadata._interruptible = interruptible

if name is not None:
node_metadata._name = name

if cache is not None:
assert_not_promise(cache, "cache")
node_metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
node_metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
node_metadata._cache_serializable = cache_serialize

def with_overrides(
self,
node_name: Optional[str] = None,
Expand Down Expand Up @@ -174,27 +225,6 @@ def with_overrides(
assert_no_promises_in_resources(resources)
self._resources = resources

if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if name is not None:
self._metadata._name = name

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
if not isinstance(task_config, type(self.run_entity._task_config)):
Expand All @@ -209,17 +239,7 @@ def with_overrides(
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize
self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

return self

Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def get_serializable_array_node_map_task(
)
node = workflow_model.Node(
id=entity.name,
metadata=entity.construct_sub_node_metadata(),
metadata=entity.sub_node_metadata,
inputs=node.bindings,
upstream_node_ids=[],
output_aliases=[],
Expand Down
58 changes: 51 additions & 7 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask
from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, Resources
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
Expand All @@ -21,6 +21,7 @@
LiteralMap,
LiteralOffloadedMetadata,
)
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable
from flytekit.types.directory import FlyteDirectory

Expand Down Expand Up @@ -349,16 +350,59 @@ def my_wf1() -> typing.List[typing.Optional[int]]:
assert my_wf1() == [1, None, 3, 4]


def test_map_task_override(serialization_settings):
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)
@task
def my_mappable_task(a: int) -> typing.Optional[str]:
return str(a)


@task(
container_image="original-image",
timeout=timedelta(seconds=10),
interruptible=False,
retries=10,
cache=True,
cache_version="original-version",
requests=Resources(cpu=1)
)
def my_mappable_task_1(a: int) -> typing.Optional[str]:
return str(a)


@pytest.mark.parametrize(
"task_func",
[my_mappable_task, my_mappable_task_1]
)
def test_map_task_override(serialization_settings, task_func):
array_node_map_task = map_task(task_func)

@workflow
def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")
array_node_map_task(a=x).with_overrides(
container_image="new-image",
timeout=timedelta(seconds=20),
interruptible=True,
retries=5,
cache=True,
cache_version="new-version",
requests=Resources(cpu=2)
)

assert wf.nodes[0]._container_image == "new-image"

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert wf.nodes[0]._container_image == "random:image"
array_node = wf_spec.template.nodes[0]
assert array_node.metadata.timeout == timedelta()
sub_node_spec = array_node.array_node.node
assert sub_node_spec.metadata.timeout == timedelta(seconds=20)
assert sub_node_spec.metadata.interruptible
assert sub_node_spec.metadata.retries.retries == 5
assert sub_node_spec.metadata.cacheable
assert sub_node_spec.metadata.cache_version == "new-version"
assert sub_node_spec.target.overrides.resources.requests == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2")
]


def test_serialization_metadata(serialization_settings):
Expand Down
Loading