Skip to content

Commit

Permalink
Backport changes up to 2025/01/21 (#3078)
Browse files Browse the repository at this point in the history
* Make FlyteUserRuntimeException to return error_code in Container Error (#3059)

* Make FlyteUserRuntimeException to return error_code in the ContainerError

Signed-off-by: Rafael Ribeiro Raposo <[email protected]>

* [Flytekit] Separate remote signal functions (#2933)

* feat: separate remote signal functions

Signed-off-by: mao3267 <[email protected]>

* refactor: make lint

Signed-off-by: mao3267 <[email protected]>

* test: add integration test for separated signal functions

Signed-off-by: mao3267 <[email protected]>

* fix: register workflow to admin

Signed-off-by: mao3267 <[email protected]>

* fix: integration test and approve function

Signed-off-by: mao3267 <[email protected]>

* fix: remove approve node output

Signed-off-by: mao3267 <[email protected]>

* fix: replace single sleep command to retry statement

Signed-off-by: mao3267 <[email protected]>

* fix: update comments

Signed-off-by: mao3267 <[email protected]>

* fix: simplify duplicate retry operations

Signed-off-by: mao3267 <[email protected]>

---------

Signed-off-by: mao3267 <[email protected]>

* Only copy over cat-certificates.crt if it does not exist in base image  (#3067)

* Do not copy over ca-certifcates.crt if the base image has one

Signed-off-by: Thomas J. Fan <[email protected]>

* Only copy over cat-certificates.crt if it does not exist in base image

Signed-off-by: Thomas J. Fan <[email protected]>

---------

Signed-off-by: Thomas J. Fan <[email protected]>

* Support with_overrides setting metadata for map_task subnode instead of parent node (#2982)

* test

Signed-off-by: Paul Dittamo <[email protected]>

* add support for with_overrides for map tasks

Signed-off-by: Paul Dittamo <[email protected]>

* expand unit test

Signed-off-by: Paul Dittamo <[email protected]>

* cleanup

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>

* fix: remove duplication log when execute (#3052)

Signed-off-by: Vincent <[email protected]>

* Fix: Always propagate pytorch task worker process exception timestamp to task exception (#3057)

* Fix: Always propagate pytorch task worker process exception timestamp to task exception

Signed-off-by: Fabio Grätz <[email protected]>

* Fix exist recoverable error test

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>

* Allow user-defined dataclass type transformer (again) (#3075)

* Allow for user-defined dataclass type tranformers

Signed-off-by: Eduardo Apolinario <[email protected]>

* Finish comment and remote user-defined dataclass transformer from registry

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Rafael Ribeiro Raposo <[email protected]>
Signed-off-by: mao3267 <[email protected]>
Signed-off-by: Thomas J. Fan <[email protected]>
Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Vincent <[email protected]>
Signed-off-by: Fabio Grätz <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Rafael Raposo <[email protected]>
Co-authored-by: Vincent Chen <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Paul Dittamo <[email protected]>
Co-authored-by: V <[email protected]>
Co-authored-by: Fabio M. Graetz, Ph.D. <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
9 people authored Jan 22, 2025
1 parent 17841c8 commit 0633001
Show file tree
Hide file tree
Showing 17 changed files with 422 additions and 58 deletions.
2 changes: 1 addition & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _dispatch_execute(
exc_str = get_traceback_str(e)
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
code="USER",
code=e.error_code,
message=exc_str,
kind=kind,
origin=_execution_models.ExecutionError.ErrorKind.USER,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,9 @@ def run_remote(
if run_level_params.wait_execution:
msg += " Waiting to complete..."
p = Progress(TimeElapsedColumn(), TextColumn(msg), transient=True)
t = p.add_task("exec")
t = p.add_task("exec", visible=False)
with p:
p.update(t, visible=True)
p.start_task(t)
execution = remote.execute(
entity,
Expand Down
12 changes: 6 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
**kwargs,
)

self.sub_node_metadata: NodeMetadata = super().construct_node_metadata()
self.sub_node_metadata._name = self.name

@property
def name(self) -> str:
return self._name
Expand All @@ -138,16 +141,13 @@ def python_interface(self):
return self._collection_interface

def construct_node_metadata(self) -> NodeMetadata:
# TODO: add support for other Flyte entities
"""
This returns metadata for the parent ArrayNode, not the sub-node getting mapped over
"""
return NodeMetadata(
name=self.name,
)

def construct_sub_node_metadata(self) -> NodeMetadata:
nm = super().construct_node_metadata()
nm._name = self.name
return nm

@property
def min_success_ratio(self) -> Optional[float]:
return self._min_success_ratio
Expand Down
84 changes: 52 additions & 32 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,57 @@ def run_entity(self) -> Any:
def metadata(self) -> _workflow_model.NodeMetadata:
return self._metadata

def _override_node_metadata(
self,
name,
timeout: Optional[Union[int, datetime.timedelta]] = None,
retries: Optional[int] = None,
interruptible: typing.Optional[bool] = None,
cache: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serialize: typing.Optional[bool] = None,
):
from flytekit.core.array_node_map_task import ArrayNodeMapTask

if isinstance(self.flyte_entity, ArrayNodeMapTask):
# override the sub-node's metadata
node_metadata = self.flyte_entity.sub_node_metadata
else:
node_metadata = self._metadata

if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
node_metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
node_metadata._interruptible = interruptible

if name is not None:
node_metadata._name = name

if cache is not None:
assert_not_promise(cache, "cache")
node_metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
node_metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
node_metadata._cache_serializable = cache_serialize

def with_overrides(
self,
node_name: Optional[str] = None,
Expand Down Expand Up @@ -174,27 +225,6 @@ def with_overrides(
assert_no_promises_in_resources(resources)
self._resources = resources

if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if retries is not None:
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if interruptible is not None:
assert_not_promise(interruptible, "interruptible")
self._metadata._interruptible = interruptible

if name is not None:
self._metadata._name = name

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
if not isinstance(task_config, type(self.run_entity._task_config)):
Expand All @@ -209,17 +239,7 @@ def with_overrides(
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize
self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

return self

Expand Down
13 changes: 10 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,9 +1219,6 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]

if dataclasses.is_dataclass(python_type):
return cls._DATACLASS_TRANSFORMER

return None

@classmethod
Expand All @@ -1240,6 +1237,16 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
if v is not None:
return v

# flytekit's dataclass type transformer is left for last to give users a chance to register a type transformer
# to handle dataclass-like objects as part of the mro evaluation.
#
# N.B.: keep in mind that there are no compatibility guarantees between these user-defined dataclass transformers
# and the flytekit one. This incompatibility is *not* a new behavior introduced by the recent type engine
# refactor (https://github.com/flyteorg/flytekit/pull/2815), but it is worth calling out explicitly as a known
# limitation nonetheless.
if dataclasses.is_dataclass(python_type):
return cls._DATACLASS_TRANSFORMER

display_pickle_warning(str(python_type))
from flytekit.types.pickle.pickle import FlytePickleTransformer

Expand Down
9 changes: 7 additions & 2 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@ class FlyteUserException(_FlyteException):
class FlyteUserRuntimeException(_FlyteException):
_ERROR_CODE = "USER:RuntimeError"

def __init__(self, exc_value: Exception):
def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None):
"""
FlyteUserRuntimeException is thrown when a user code raises an exception.
:param exc_value: The exception that was raised from user code.
:param timestamp: The timestamp as fractional seconds since epoch when the exception was raised.
"""
self._exc_value = exc_value
super().__init__(str(exc_value))
super().__init__(str(exc_value), timestamp=timestamp)

@property
def value(self):
return self._exc_value

@property
def error_code(self):
return self._ERROR_CODE


class FlyteTypeException(FlyteUserException, TypeError):
_ERROR_CODE = "USER:TypeError"
Expand Down
4 changes: 3 additions & 1 deletion flytekit/image_spec/default_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
USER root
$APT_INSTALL_COMMAND
COPY --from=micromamba /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
RUN --mount=from=micromamba,source=/etc/ssl/certs/ca-certificates.crt,target=/tmp/ca-certificates.crt \
[ -f /etc/ssl/certs/ca-certificates.crt ] || \
mkdir -p /etc/ssl/certs/ && cp /tmp/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
RUN id -u flytekit || useradd --create-home --shell /bin/bash flytekit
RUN chown -R flytekit /root && chown -R flytekit /home
Expand Down
87 changes: 87 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,93 @@ def list_signals(
s = resp.signals
return s

def approve(self, signal_id: str, execution_name: str, project: str = None, domain: str = None):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
"""

wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)

lt = TypeEngine.to_literal_type(bool)
true_literal = TypeEngine.to_literal(self.context, True, bool, lt)

req = SignalSetRequest(
id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=true_literal.to_flyte_idl()
)

# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)

def reject(self, signal_id: str, execution_name: str, project: str = None, domain: str = None):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
"""

wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)

lt = TypeEngine.to_literal_type(bool)
false_literal = TypeEngine.to_literal(self.context, False, bool, lt)

req = SignalSetRequest(
id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=false_literal.to_flyte_idl()
)

# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)

def set_input(
self,
signal_id: str,
execution_name: str,
value: typing.Union[literal_models.Literal, typing.Any],
project=None,
domain=None,
python_type=None,
literal_type=None,
):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to
convert into a Literal. This argument is only value for wait_for_input type signals.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param python_type: Provide a python type to help with conversion if the value you provided is not a Literal.
:param literal_type: Provide a Flyte literal type to help with conversion if the value you provided
is not a Literal
"""

wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
if isinstance(value, Literal):
logger.debug(f"Using provided {value} as existing Literal value")
lit = value
else:
lt = literal_type or (
TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value))
)
lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt)
logger.debug(f"Converted {value} to literal {lit} using literal type {lt}")

req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl())

# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)

def set_signal(
self,
signal_id: str,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def get_serializable_array_node_map_task(
)
node = workflow_model.Node(
id=entity.name,
metadata=entity.construct_sub_node_metadata(),
metadata=entity.sub_node_metadata,
inputs=node.bindings,
upstream_node_ids=[],
output_aliases=[],
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

Expand Down Expand Up @@ -475,7 +475,7 @@ def fn_partial():
# the automatically assigned timestamp based on exception creation time
raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp)
else:
raise RuntimeError(e.format_msg())
raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp)
except SignalException as e:
logger.exception(f"Elastic launch agent process terminating: {e}")
raise IgnoreOutputs()
Expand Down
38 changes: 36 additions & 2 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flytekit import task, workflow
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException

@pytest.fixture(autouse=True, scope="function")
def restore_env():
Expand Down Expand Up @@ -223,7 +223,7 @@ def wf(recoverable: bool):
with pytest.raises(FlyteRecoverableException):
wf(recoverable=recoverable)
else:
with pytest.raises(RuntimeError):
with pytest.raises(FlyteUserRuntimeException):
wf(recoverable=recoverable)


Expand Down Expand Up @@ -276,3 +276,37 @@ def test_task_omp_set():
assert os.environ["OMP_NUM_THREADS"] == "42"

test_task_omp_set()


def test_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise Exception("Test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None


def test_recoverable_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise FlyteRecoverableException("Recoverable test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None
Loading

0 comments on commit 0633001

Please sign in to comment.