diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 7dda4f5588..d0405aa0d8 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -162,44 +162,6 @@ jobs: fail_ci_if_error: false files: coverage.xml - test-hypothesis: - needs: - - detect-python-versions - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest] - python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}} - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Cache pip - uses: actions/cache@v3 - with: - # This path is specific to Ubuntu - path: ~/.cache/pip - # Look to see if there is a cache hit for the corresponding requirements files - key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }} - - name: Install dependencies - run: | - pip install uv - make setup-global-uv - uv pip freeze - - name: Test with coverage - env: - FLYTEKIT_HYPOTHESIS_PROFILE: ci - run: | - make unit_test_hypothesis - - name: Codecov - uses: codecov/codecov-action@v3.1.4 - with: - fail_ci_if_error: false - files: coverage.xml - test-serialization: needs: - detect-python-versions @@ -299,6 +261,9 @@ jobs: FLYTEKIT_IMAGE: localhost:30000/flytekit:dev FLYTEKIT_CI: 1 PYTEST_OPTS: -n2 + AWS_ENDPOINT_URL: 'http://localhost:30002' + AWS_ACCESS_KEY_ID: minio + AWS_SECRET_ACCESS_KEY: miniostorage run: | make ${{ matrix.makefile-cmd }} - name: Codecov diff --git a/Makefile b/Makefile index 0ff0246f72..08039e5ccc 100644 --- a/Makefile +++ b/Makefile @@ -72,10 +72,6 @@ unit_test: # Run serial tests without any parallelism $(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models --ignore=tests/flytekit/unit/extend ${CODECOV_OPTS} -.PHONY: unit_test_hypothesis -unit_test_hypothesis: - $(PYTEST_AND_OPTS) -m "hypothesis" tests/flytekit/unit/experimental ${CODECOV_OPTS} - .PHONY: unit_test_extras unit_test_extras: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras tests/flytekit/unit/extend ${CODECOV_OPTS} diff --git a/dev-requirements.in b/dev-requirements.in index 20aba11e9d..5241f02605 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,4 @@ -e file:. -flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index f1376be846..479490493e 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -10,6 +10,3 @@ Experimental Features .. autosummary:: :nosignatures: :toctree: generated/ - - ~experimental.eager - ~experimental.EagerException diff --git a/flytekit/__init__.py b/flytekit/__init__.py index a6eb70004b..6cd2b85564 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -239,7 +239,7 @@ from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference from flytekit.core.resources import Resources from flytekit.core.schedule import CronSchedule, FixedRate -from flytekit.core.task import Secret, reference_task, task +from flytekit.core.task import Secret, eager, reference_task, task from flytekit.core.type_engine import BatchSize from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 084e8f733b..49103319d0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,7 +1,6 @@ import asyncio import contextlib import datetime -import inspect import os import pathlib import signal @@ -177,10 +176,6 @@ def _dispatch_execute( # Step2 # Invoke task - dispatch_execute outputs = task_def.dispatch_execute(ctx, idl_input_literals) - if inspect.iscoroutine(outputs): - # Handle eager-mode (async) tasks - logger.info("Output is a coroutine") - outputs = _get_working_loop().run_until_complete(outputs) # Step3a if isinstance(outputs, VoidPromise): diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 3d43844d39..0c0e36c280 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -18,6 +18,7 @@ ``` """ +import os from typing import Optional, Protocol, runtime_checkable from click import Group @@ -59,10 +60,22 @@ def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None ) -> FlyteRemote: """Get FlyteRemote object for CLI session.""" + cfg_file = get_config_file(config) + + # The assumption here (if there's no config file that means we want sandbox) is too broad. + # todo: can improve this in the future, rather than just checking one env var, auto() with + # nothing configured should probably not return sandbox but can consider if cfg_file is None: - cfg_obj = Config.for_sandbox() - logger.info("No config files found, creating remote with sandbox config") + # We really are just looking for endpoint, client_id, and client_secret. These correspond to the env vars + # FLYTE_PLATFORM_URL, FLYTE_CREDENTIALS_CLIENT_ID, FLYTE_CREDENTIALS_CLIENT_SECRET + # auto() should pick these up. + if "FLYTE_PLATFORM_URL" in os.environ: + cfg_obj = Config.auto(None) + logger.warning(f"Auto-created config object to pick up env vars {cfg_obj}") + else: + cfg_obj = Config.for_sandbox() + logger.info("No config files found, creating remote with sandbox config") else: # pragma: no cover cfg_obj = Config.auto(config) logger.debug(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else "")) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 607518f0fb..6430aa9eac 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,7 +21,6 @@ import asyncio import collections import datetime -import inspect import warnings from abc import abstractmethod from base64 import b64encode @@ -142,6 +141,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None + is_eager: bool = False def __post_init__(self): if self.timeout: @@ -181,6 +181,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: cache_serializable=self.cache_serialize, pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, + is_eager=self.is_eager, ) @@ -340,9 +341,6 @@ def local_execute( # if one is changed and not the other. outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) - if inspect.iscoroutine(outputs_literal_map): - return outputs_literal_map - outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -759,29 +757,6 @@ def dispatch_execute( raise raise FlyteUserRuntimeException(e) from e - if inspect.iscoroutine(native_outputs): - # If native outputs is a coroutine, then this is an eager workflow. - if exec_ctx.execution_state: - if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION: - # Just return task outputs as a coroutine if the eager workflow is being executed locally, - # outside of a workflow. This preserves the expectation that the eager workflow is an async - # function. - return native_outputs - elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: - # If executed inside of a workflow being executed locally, then run the coroutine to get the - # actual results. - return asyncio.run( - self._async_execute( - native_inputs, - native_outputs, - ctx, - exec_ctx, - new_user_params, - ) - ) - - return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params) - # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is # bubbled up to be handled at the callee layer. native_outputs = self.post_execute(new_user_params, native_outputs) diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index ffedfedfe5..903e5d5ced 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -22,3 +22,19 @@ # Set this environment variable to true to force the task to return non-zero exit code on failure. FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR" + +# Executions launched by the current eager task will be tagged with this key:current_eager_exec_name +EAGER_TAG_KEY = "eager-exec" + +# Executions launched by the current eager task will be tagged with this key:root_eager_exec_name, only relevant +# for nested eager tasks. This is how you identify the root execution. +EAGER_TAG_ROOT_KEY = "eager-root-exec" + +# The environment variable that will be set to the root eager task execution name. This is how you pass down the +# root eager execution. +EAGER_ROOT_ENV_NAME = "_F_EE_ROOT" + +# This is a special key used to store metadata about the cache key in a literal type. +CACHE_KEY_METADATA = "cache-key-metadata" + +SERIALIZATION_FORMAT = "serialization-format" diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 59dfc91a94..d804cbddc8 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -16,6 +16,7 @@ import logging as _logging import os import pathlib +import signal import tempfile import traceback import typing @@ -24,6 +25,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum +from types import FrameType from typing import Generator, List, Optional, Union from flytekit.configuration import Config, SecretsConfig, SerializationSettings @@ -37,8 +39,10 @@ from flytekit.models.core import identifier as _identifier if typing.TYPE_CHECKING: - from flytekit import Deck from flytekit.clients import friendly as friendly_client # noqa + from flytekit.clients.friendly import SynchronousFlyteClient + from flytekit.core.worker_queue import Controller + from flytekit.deck.deck import Deck # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -526,6 +530,10 @@ class Mode(Enum): # This is the mode that is used to indicate a dynamic task DYNAMIC_TASK_EXECUTION = 4 + EAGER_EXECUTION = 5 + + EAGER_LOCAL_EXECUTION = 6 + mode: Optional[ExecutionState.Mode] working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] @@ -586,6 +594,7 @@ def is_local_execution(self) -> bool: return ( self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION + or self.mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION ) @@ -663,6 +672,7 @@ class FlyteContext(object): in_a_condition: bool = False origin_stackframe: Optional[traceback.FrameSummary] = None output_metadata_tracker: Optional[OutputMetadataTracker] = None + worker_queue: Optional[Controller] = None @property def user_space_params(self) -> Optional[ExecutionParameters]: @@ -689,6 +699,7 @@ def new_builder(self) -> Builder: execution_state=self.execution_state, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, + worker_queue=self.worker_queue, ) def enter_conditional_section(self) -> Builder: @@ -713,6 +724,12 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder: def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder: return self.new_builder().with_output_metadata_tracker(t) + def with_worker_queue(self, wq: Controller) -> Builder: + return self.new_builder().with_worker_queue(wq) + + def with_client(self, c: SynchronousFlyteClient) -> Builder: + return self.new_builder().with_client(c) + def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -774,6 +791,7 @@ class Builder(object): serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False output_metadata_tracker: Optional[OutputMetadataTracker] = None + worker_queue: Optional[Controller] = None def build(self) -> FlyteContext: return FlyteContext( @@ -785,6 +803,7 @@ def build(self) -> FlyteContext: serialization_settings=self.serialization_settings, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, + worker_queue=self.worker_queue, ) def enter_conditional_section(self) -> FlyteContext.Builder: @@ -833,6 +852,14 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext self.output_metadata_tracker = t return self + def with_worker_queue(self, wq: Controller) -> FlyteContext.Builder: + self.worker_queue = wq + return self + + def with_client(self, c: SynchronousFlyteClient) -> FlyteContext.Builder: + self.flyte_client = c + return self + def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -871,6 +898,12 @@ class FlyteContextManager(object): FlyteContextManager.pop_context() """ + signal_handlers: typing.List[typing.Callable[[int, FrameType], typing.Any]] = [] + + @staticmethod + def add_signal_handler(handler: typing.Callable[[int, FrameType], typing.Any]): + FlyteContextManager.signal_handlers.append(handler) + @staticmethod def get_origin_stackframe(limit=2) -> traceback.FrameSummary: ss = traceback.extract_stack(limit=limit + 1) @@ -954,6 +987,13 @@ def initialize(): user_space_path = os.path.join(cfg.local_sandbox_path, "user_space") pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True) + def main_signal_handler(signum: int, frame: FrameType): + for handler in FlyteContextManager.signal_handlers: + handler(signum, frame) + exit(1) + + signal.signal(signal.SIGINT, main_signal_handler) + # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_context = FlyteContext(file_access=default_local_file_access_provider) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 7035147016..0640bc2eb5 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -423,47 +423,34 @@ async def async_put_raw_data( r = await self._put(from_path, to_path, **kwargs) return r or to_path + # See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to + # support effectively async open(). For now these use-cases below will revert to sync calls. # raw bytes if isinstance(lpath, bytes): - fs = await self.get_async_filesystem_for_path(to_path) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - s.write(lpath) - else: - with fs.open(to_path, "wb", **kwargs) as s: - s.write(lpath) - + fs = self.get_filesystem_for_path(to_path) + with fs.open(to_path, "wb", **kwargs) as s: + s.write(lpath) return to_path # If lpath is a buffered reader of some kind if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) + with fs.open(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) return to_path if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) + with fs.open(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) return to_path raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") diff --git a/flytekit/core/options.py b/flytekit/core/options.py index ad35bc3ea1..79d46c2039 100644 --- a/flytekit/core/options.py +++ b/flytekit/core/options.py @@ -1,6 +1,5 @@ import typing from dataclasses import dataclass -from typing import Callable, Optional from flytekit.models import common as common_models from flytekit.models import security @@ -35,9 +34,6 @@ class Options(object): notifications: typing.Optional[typing.List[common_models.Notification]] = None disable_notifications: typing.Optional[bool] = None overwrite_cache: typing.Optional[bool] = None - file_uploader: Optional[Callable] = ( - None # This is used by the translator to upload task files, like pickled code etc - ) @classmethod def default_from( diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index f5eea7b161..3db02175b5 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,13 +1,13 @@ from __future__ import annotations +import asyncio import collections import datetime -import inspect import typing from collections.abc import Iterable from copy import deepcopy from enum import Enum -from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from google.protobuf import struct_pb2 as _struct from typing_extensions import Annotated, Protocol, get_args, get_origin @@ -109,6 +109,27 @@ def my_wf(in1: int, in2: int) -> int: translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals) +async def _translate_inputs_to_native( + ctx: FlyteContext, + incoming_values: Dict[str, Any], + flyte_interface_types: Dict[str, _interface_models.Variable], +) -> Dict[str, _literals_models.Literal]: + if incoming_values is None: + raise AssertionError("Incoming values cannot be None, must be a dict") + + result = {} # So as to not overwrite the input_kwargs + for k, v in incoming_values.items(): + if k not in flyte_interface_types: + raise AssertionError(f"Received unexpected keyword argument {k}") + v = await resolve_attr_path_recursively(v) + result[k] = v + + return result + + +translate_inputs_to_native = loop_manager.synced(_translate_inputs_to_native) + + async def resolve_attr_path_recursively(v: Any) -> Any: """ This function resolves the attribute path in a nested structure recursively. @@ -1386,9 +1407,47 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr def local_execution_mode(self) -> ExecutionState.Mode: ... +async def async_flyte_entity_call_handler( + entity: SupportsNodeCreation, *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + """ + This is a limited async version of the main call handler. + """ + # Make sure arguments are part of interface + for k, v in kwargs.items(): + if k not in entity.python_interface.inputs: + raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'") + + # Check if we have more arguments than expected + if len(args) > len(entity.python_interface.inputs): + raise AssertionError( + f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}" + ) + + # Convert args to kwargs + for arg, input_name in zip(args, entity.python_interface.inputs.keys()): + if input_name in kwargs: + raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'") + kwargs[input_name] = arg + + ctx = FlyteContextManager.current_context() + # This handles the case where we call other entities from within a running eager task. + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.EAGER_EXECUTION: + # for both nested eager, async, and sync tasks, submit to the informer. + if not ctx.worker_queue: + raise AssertionError("Worker queue missing, must be set when trying to execute tasks in an eager workflow") + loop = asyncio.get_running_loop() + fut = ctx.worker_queue.add(loop, entity, input_kwargs=kwargs) + result = await fut + return result + + # eager local execution, and all other call patterns are handled by the sync version + return flyte_entity_call_handler(entity, **kwargs) + + def flyte_entity_call_handler( entity: SupportsNodeCreation, *args, **kwargs -) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]: +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1421,6 +1480,13 @@ def flyte_entity_call_handler( kwargs[input_name] = arg ctx = FlyteContextManager.current_context() + # This handles the case where we call other entities from within a running eager task. + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.EAGER_EXECUTION: + # call the blocking version of the async call handler + # This is a recursive call, the async handler also calls this function, so this conditional must match + # the one in the async function perfectly, otherwise you'll get infinite recursion. + loop_manager.run_sync(async_flyte_entity_call_handler, entity, **kwargs) + if ctx.execution_state and ( ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION @@ -1430,7 +1496,10 @@ def flyte_entity_call_handler( ) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return create_and_link_node(ctx, entity=entity, **kwargs) + + # This handles the case for when we're already in a local execution state if ctx.execution_state and ctx.execution_state.is_local_execution(): + original_mode = ctx.execution_state.mode mode = cast(LocallyExecutable, entity).local_execution_mode() omt = OutputMetadataTracker() with FlyteContextManager.with_context( @@ -1448,7 +1517,15 @@ def flyte_entity_call_handler( return create_task_output(vals, entity.python_interface) else: return None - return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) + if original_mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION: + # When calling a local task/eager task/launch plan/subworkflow, we want the results to be Python + # native values, not wrapped in Promises. + local_execute_results = cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) + return create_native_named_tuple(ctx, local_execute_results, entity.python_interface) + else: + return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) + + # This condition kicks off a new local execution. else: mode = cast(LocallyExecutable, entity).local_execution_mode() omt = OutputMetadataTracker() @@ -1465,9 +1542,6 @@ def flyte_entity_call_handler( else: raise ValueError(f"Received an output when workflow local execution expected None. Received: {result}") - if inspect.iscoroutine(result): - return result - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION: return result diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index c63199ad04..5c7ad290aa 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -18,20 +18,32 @@ from __future__ import annotations import inspect +import os +import signal from abc import ABC from collections import OrderedDict from contextlib import suppress from enum import Enum -from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union, cast +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast +from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import launch_plan as _annotated_launch_plan -from flytekit.core.base_task import Task, TaskResolverMixin +from flytekit.core.base_task import Task, TaskMetadata, TaskResolverMixin +from flytekit.core.constants import EAGER_ROOT_ENV_NAME from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring from flytekit.core.interface import transform_function_to_interface -from flytekit.core.promise import VoidPromise, translate_inputs_to_literals +from flytekit.core.promise import ( + Promise, + VoidPromise, + async_flyte_entity_call_handler, + translate_inputs_to_literals, +) from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver +from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import extract_task_module, is_functools_wrapped_module_level, isnested, istestfunction +from flytekit.core.utils import _dnsify +from flytekit.core.worker_queue import Controller from flytekit.core.workflow import ( PythonFunctionWorkflow, WorkflowBase, @@ -39,12 +51,16 @@ WorkflowMetadata, WorkflowMetadataDefaults, ) +from flytekit.deck.deck import Deck +from flytekit.exceptions.eager import EagerException +from flytekit.exceptions.system import FlyteNonRecoverableSystemException from flytekit.exceptions.user import FlyteValueException from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.utils.asyn import loop_manager T = TypeVar("T") @@ -118,12 +134,14 @@ def __init__( """ :param T task_config: Configuration object for Task. Should be a unique type for that specific Task :param Callable task_function: Python function that has type annotations and works for the task - :param Optional[List[str]] ignore_input_vars: When supplied, these input variables will be removed from the interface. This + :param Optional[List[str]] ignore_input_vars: When supplied, these input variables will be removed from the + interface. This can be used to inject some client side variables only. Prefer using ExecutionParams :param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example executing normally or specially handling a dynamic case. :param str task_type: String task type to be associated with this Task - :param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints: + :param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] + node_dependency_hints: A list of tasks, launchplans, or workflows that this task depends on. This is only for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime. :param bool pickle_untyped: If set to True, the task will pickle untyped outputs. This is just a convenience @@ -202,11 +220,6 @@ def execute(self, **kwargs) -> Any: """ if self.execution_mode == self.ExecutionBehavior.DEFAULT: return self._task_function(**kwargs) - elif self.execution_mode == self.ExecutionBehavior.EAGER: - # if the task is a coroutine function, inject the context object so that the async_entity - # has access to the FlyteContext. - kwargs["async_ctx"] = FlyteContextManager.current_context() - return self._task_function(**kwargs) elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) @@ -286,7 +299,8 @@ def compile_into_workflow( if not isinstance(model, task_models.TaskSpec): raise TypeError( - f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}" + f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, " + f"but got {type(model)}" ) # Store the valid task template so that we can pass it to the @@ -389,3 +403,207 @@ def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_param python_dependencies_deck.append(renderer.to_html()) return super()._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) + + +class AsyncPythonFunctionTask(PythonFunctionTask[T], metaclass=FlyteTrackedABC): + """ + This is the base task for eager tasks, as well as normal async tasks + Really only need to override the call function. + """ + + async def __call__( # type: ignore[override] + self, *args: object, **kwargs: object + ) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return await async_flyte_entity_call_handler(self, *args, **kwargs) # type: ignore + + async def async_execute(self, *args, **kwargs) -> Any: + """ + Overrides the base execute function. This function does not handle dynamic at all. Eager and dynamic don't mix. + """ + # Args is present because the asyn helper function passes it, but everything should be in kwargs by this point + assert not args + if self.execution_mode == self.ExecutionBehavior.DEFAULT: + return await self._task_function(**kwargs) + elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: + raise NotImplementedError + + execute = loop_manager.synced(async_execute) + + +class EagerAsyncPythonFunctionTask(AsyncPythonFunctionTask[T], metaclass=FlyteTrackedABC): + """ + This is the base eager task (aka eager workflow) type. It replaces the previous experiment eager task type circa + Q4 2024. Users unfamiliar with this concept should refer to the documentation for more information. + But basically, Python becomes propeller, and every task invocation, creates a stack frame on the Flyte cluster in + the form of an execution rather than on the actual memory stack. + + """ + + def __init__( + self, + task_config: T, + task_function: Callable, + task_type="python-task", + ignore_input_vars: Optional[List[str]] = None, + task_resolver: Optional[TaskResolverMixin] = None, + node_dependency_hints: Optional[ + Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]] + ] = None, + **kwargs, + ): + # delete execution mode from kwargs + if "execution_mode" in kwargs: + del kwargs["execution_mode"] + + if "metadata" in kwargs: + kwargs["metadata"].is_eager = True + else: + kwargs["metadata"] = TaskMetadata(is_eager=True) + + super().__init__( + task_config, + task_function, + task_type, + ignore_input_vars, + PythonFunctionTask.ExecutionBehavior.EAGER, + task_resolver, + node_dependency_hints, + **kwargs, + ) + + def local_execution_mode(self) -> ExecutionState.Mode: + return ExecutionState.Mode.EAGER_LOCAL_EXECUTION + + async def async_execute(self, *args, **kwargs) -> Any: + """ + Overrides the base execute function. This function does not handle dynamic at all. Eager and dynamic don't mix. + + Some notes on the different call scenarios since it's a little different than other tasks. + a) starting local execution - eager_task() + -> last condition of call handler, + -> set execution mode and self.local_execute() + -> self.execute(native_vals) + -> 1) -> task function() or 2) -> self.run_with_backend() # fn name will be changed. + b) inside an eager task local execution - calling normal_task() + -> call handler detects in eager local execution (middle part of call handler) + -> call normal_task's local_execute() + c) inside an eager task local execution - calling async_normal_task() + -> produces a coro, which when awaited/run + -> call handler detects in eager local execution (middle part of call handler) + -> call async_normal_task's local_execute() + -> call AsyncPythonFunctionTask's async_execute(), which awaits the task function + d) inside an eager task local execution - calling another_eager_task() + -> produces a coro, which when awaited/run + -> call handler detects in eager local execution (middle part of call handler) + -> call another_eager_task's local_execute() + -> results are returned instead of being passed to create_native_named_tuple + d) eager_task, starting backend execution from entrypoint.py + -> eager_task.dispatch_execute(literals) + -> eager_task.execute(native values) + -> awaits eager_task.run_with_backend() # fn name will be changed + e) in an eager task during backend execution, calling any flyte_entity() + -> add the entity to the worker queue and await the result. + """ + # Args is present because the asyn helper function passes it, but everything should be in kwargs by this point + assert len(args) == 1 + ctx = FlyteContextManager.current_context() + is_local_execution = cast(ExecutionState, ctx.execution_state).is_local_execution() + if not is_local_execution: + # a real execution + return await self.run_with_backend(**kwargs) + else: + # set local mode and proceed with running the function. This makes the + mode = self.local_execution_mode() + with FlyteContextManager.with_context( + ctx.with_execution_state(cast(ExecutionState, ctx.execution_state).with_params(mode=mode)) + ): + return await self._task_function(**kwargs) + + def execute(self, **kwargs) -> Any: + ctx = FlyteContextManager.current_context() + is_local_execution = cast(ExecutionState, ctx.execution_state).is_local_execution() + builder = ctx.new_builder() + if not is_local_execution: + # ensure that the worker queue is in context + if not ctx.worker_queue: + from flytekit.configuration.plugin import get_plugin + + # This should be read from transport at real runtime if available, but if not, we should either run + # remote in interactive mode, or let users configure the version to use. + ss = ctx.serialization_settings + if not ss: + ss = SerializationSettings( + image_config=ImageConfig.auto_default_image(), + ) + + # In order to build the controller, we really just need a remote. + project = ( + ctx.user_space_params.execution_id.project + if ctx.user_space_params and ctx.user_space_params.execution_id + else "flytesnacks" + ) + domain = ( + ctx.user_space_params.execution_id.domain + if ctx.user_space_params and ctx.user_space_params.execution_id + else "development" + ) + raw_output = ctx.user_space_params.raw_output_prefix if ctx.user_space_params else None + logger.info(f"Constructing default remote with no config and {project}, {domain}, {raw_output}") + remote = get_plugin().get_remote( + config=None, project=project, domain=domain, data_upload_location=raw_output + ) + + # tag is the current execution id + # root tag is read from the environment variable if it exists, if not, it's the current execution id + if not ctx.user_space_params or not ctx.user_space_params.execution_id: + raise AssertionError( + "User facing context and execution ID should be" " present when not running locally" + ) + tag = ctx.user_space_params.execution_id.name + root_tag = os.environ.get(EAGER_ROOT_ENV_NAME, tag) + + # Prefix is a combination of the name of this eager workflow, and the current execution id. + prefix = self.name.split(".")[-1][:8] + prefix = f"e-{prefix}-{tag[:5]}" + prefix = _dnsify(prefix) + # Note: The construction of this object is in this function because this function should be on the + # main thread of pyflyte-execute. It needs to be on the main thread because signal handlers can only + # be installed on the main thread. + c = Controller(remote=remote, ss=ss, tag=tag, root_tag=root_tag, exec_prefix=prefix) + handler = c.get_signal_handler() + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + builder = ctx.with_worker_queue(c) + else: + raise AssertionError("Worker queue should not be already present in the context for eager execution") + with FlyteContextManager.with_context(builder): + return loop_manager.run_sync(self.async_execute, self, **kwargs) + + # easier to use explicit input kwargs + async def run_with_backend(self, **kwargs): + """ + This is the main entry point to kick off a live run. Like if you're running locally, but want to use a + Flyte backend, or running for real on a Flyte backend. + """ + + # set up context + ctx = FlyteContextManager.current_context() + mode = ExecutionState.Mode.EAGER_EXECUTION + builder = ctx.with_execution_state(cast(ExecutionState, ctx.execution_state).with_params(mode=mode)) + + with FlyteContextManager.with_context(builder) as ctx: + base_error = None + try: + result = await self._task_function(**kwargs) + except EagerException as ee: + # Catch and re-raise a different exception to render Deck even in case of failure. + logger.error(f"Leaving eager execution because of {ee}") + base_error = ee + + html = cast(Controller, ctx.worker_queue).render_html() + Deck("eager workflow", html) + + if base_error: + # now have to fail this eager task, because we don't want it to show up as succeeded. + raise FlyteNonRecoverableSystemException(base_error) + return result diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 8a99dbf2ea..9a334f98f6 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from dataclasses import dataclass, fields +from typing import Any, List, Optional, Union +from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements from mashumaro.mixins.json import DataClassJSONMixin from flytekit.models import task as task_models @@ -73,7 +74,10 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu))) if resources.ephemeral_storage is not None: resource_entries.append( - _ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage)) + _ResourceEntry( + name=_ResourceName.EPHEMERAL_STORAGE, + value=str(resources.ephemeral_storage), + ) ) return resource_entries @@ -96,3 +100,49 @@ def convert_resources_to_resource_model( if limits is not None: limit_entries = _convert_resources_to_resource_entries(limits) return task_models.Resources(requests=request_entries, limits=limit_entries) + + +def pod_spec_from_resources( + k8s_pod_name: str, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + k8s_gpu_resource_key: str = "nvidia.com/gpu", +) -> dict[str, Any]: + def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str): + if resources is None: + return None + + resources_map = { + "cpu": "cpu", + "mem": "memory", + "gpu": k8s_gpu_resource_key, + "ephemeral_storage": "ephemeral-storage", + } + + k8s_pod_resources = {} + + for resource in fields(resources): + resource_value = getattr(resources, resource.name) + if resource_value is not None: + k8s_pod_resources[resources_map[resource.name]] = resource_value + + return k8s_pod_resources + + requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key) + limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key) + requests = requests or limits + limits = limits or requests + + k8s_pod = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests=requests, + limits=limits, + ), + ) + ] + ) + + return k8s_pod.to_dict() diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 1196fd95c7..6451e742c5 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,8 +1,9 @@ from __future__ import annotations import datetime +import inspect import os -from functools import update_wrapper +from functools import partial, update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload from typing_extensions import ParamSpec # type: ignore @@ -12,7 +13,7 @@ from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface from flytekit.core.pod_template import PodTemplate -from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.python_function_task import AsyncPythonFunctionTask, EagerAsyncPythonFunctionTask, PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources from flytekit.core.utils import str2bool @@ -354,9 +355,22 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: timeout=timeout, ) - decorated_fn = decorate_function(fn) - - task_instance = TaskPlugins.find_pythontask_plugin(type(task_config))( + if inspect.iscoroutinefunction(fn): + # TODO: figure out vscode decoration for async tasks, wait to do this until this vscode pattern has + # stabilized. + # https://github.com/flyteorg/flyte/issues/6071 + decorated_fn = fn + else: + decorated_fn = decorate_function(fn) + task_plugin = TaskPlugins.find_pythontask_plugin(type(task_config)) + if inspect.iscoroutinefunction(fn): + if task_plugin is PythonFunctionTask: + task_plugin = AsyncPythonFunctionTask + else: + if not issubclass(task_plugin, AsyncPythonFunctionTask): + raise AssertionError(f"Task plugin {task_plugin} is not compatible with async functions") + + task_instance = task_plugin( task_config, decorated_fn, metadata=_metadata, @@ -489,3 +503,96 @@ def execute(self, **kwargs) -> Any: return values[0] else: return tuple(values) + + +def eager( + _fn=None, + *args, + **kwargs, +) -> Union[EagerAsyncPythonFunctionTask, partial]: + """Eager workflow decorator. + + This type of task will execute all Flyte entities within it eagerly, meaning that all python constructs can be + used inside of an ``@eager``-decorated function. This is because eager workflows use a + :py:class:`~flytekit.remote.remote.FlyteRemote` object to kick off executions when a flyte entity needs to produce a + value. Basically think about it as: every Flyte entity that is called(), the stack frame is an execution with its + own Flyte URL. Results (or the error) are fetched when the execution is finished. + + For example: + + .. code-block:: python + + from flytekit import task, eager + + @task + def add_one(x: int) -> int: + return x + 1 + + @task + def double(x: int) -> int: + return x * 2 + + @eager + async def eager_workflow(x: int) -> int: + out = add_one(x=x) + return double(x=out) + + # run locally with asyncio + if __name__ == "__main__": + import asyncio + + result = asyncio.run(eager_workflow(x=1)) + print(f"Result: {result}") # "Result: 4" + + Unlike :py:func:`dynamic workflows `, eager workflows are not compiled into a workflow spec, but + uses python's `async `__ capabilities to execute flyte entities. + + .. note:: + + Eager workflows only support `@task`, `@workflow`, and `@eager` entities. Conditionals are not supported, use a + plain Python if statement instead. + + .. important:: + + A ``client_secret_group`` and ``client_secret_key`` is needed for authenticating via + :py:class:`~flytekit.remote.remote.FlyteRemote` using the ``client_credentials`` authentication, which is + configured via :py:class:`~flytekit.configuration.PlatformConfig`. + + .. code-block:: python + + from flytekit.remote import FlyteRemote + from flytekit.configuration import Config + + @eager( + remote=FlyteRemote(config=Config.auto(config_file="config.yaml")), + client_secret_group="my_client_secret_group", + client_secret_key="my_client_secret_key", + ) + async def eager_workflow(x: int) -> int: + out = await add_one(x) + return await double(one) + + Where ``config.yaml`` contains is a flytectl-compatible config file. + For more details, see `here `__. + + When using a sandbox cluster started with ``flytectl demo start``, however, the ``client_secret_group`` + and ``client_secret_key`` are not needed, : + + .. code-block:: python + + @eager + async def eager_workflow(x: int) -> int: + ... + """ + + if _fn is None: + return partial( + eager, + **kwargs, + ) + + if "enable_deck" in kwargs: + del kwargs["enable_deck"] + + et = EagerAsyncPythonFunctionTask(task_config=None, task_function=_fn, enable_deck=True, **kwargs) + return et diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index b0a6525ecd..da8be53de6 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -68,7 +68,7 @@ def _find_instance_module(): # Try to find the module and filename in the case that we're in the __main__ module # This is useful in cases that use FlyteRemote to load tasks/workflows that are defined # in the same file as where FlyteRemote is being invoked to register and execute Flyte - # entities. One such case is with the `eager` decorator in the flytekit.experimental module. + # entities. One such case is with the `eager` decorator. mod = InstanceTrackingMeta._get_module_from_main(frame.f_globals) if mod is None: return None, None diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d22e0fcfc5..29092035b7 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -36,7 +36,7 @@ from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation -from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK +from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag @@ -53,6 +53,9 @@ from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.utils.asyn import loop_manager +if typing.TYPE_CHECKING: + from flytekit.core.interface import Interface + T = typing.TypeVar("T") DEFINITIONS = "definitions" TITLE = "title" @@ -662,7 +665,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # This is for attribute access in FlytePropeller. ts = TypeStructure(tag="", dataclass_type=literal_type) - return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) + return _type_models.LiteralType( + simple=_type_models.SimpleType.STRUCT, + metadata=schema, + structure=ts, + annotation=TypeAnnotationModel({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}}), + ) def to_generic_literal( self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType @@ -1141,7 +1149,6 @@ class TypeEngine(typing.Generic[T]): _RESTRICTED_TYPES: typing.List[type] = [] _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore - has_lazy_import = False lazy_import_lock = threading.Lock() @classmethod @@ -1250,16 +1257,7 @@ def lazy_import_transformers(cls): # Avoid a race condition where concurrent threads may exit lazy_import_transformers before the transformers # have been imported. This could be implemented without a lock if you assume python assignments are atomic # and re-registering transformers is acceptable, but I decided to play it safe. - if cls.has_lazy_import: - return - cls.has_lazy_import = True - from flytekit.types.structured import ( - register_arrow_handlers, - register_bigquery_handlers, - register_pandas_handlers, - register_snowflake_handlers, - ) - from flytekit.types.structured.structured_dataset import DuplicateHandlerError + from flytekit.types.structured import lazy_import_structured_dataset_handler if is_imported("tensorflow"): from flytekit.extras import tensorflow # noqa: F401 @@ -1274,29 +1272,11 @@ def lazy_import_transformers(cls): from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401 except ValueError: logger.debug("Transformer for pandas is already registered.") - try: - register_pandas_handlers() - except DuplicateHandlerError: - logger.debug("Transformer for pandas is already registered.") - if is_imported("pyarrow"): - try: - register_arrow_handlers() - except DuplicateHandlerError: - logger.debug("Transformer for arrow is already registered.") - if is_imported("google.cloud.bigquery"): - try: - register_bigquery_handlers() - except DuplicateHandlerError: - logger.debug("Transformer for bigquery is already registered.") if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 if is_imported("PIL"): from flytekit.types.file import image # noqa: F401 - if is_imported("snowflake.connector"): - try: - register_snowflake_handlers() - except DuplicateHandlerError: - logger.debug("Transformer for snowflake is already registered.") + lazy_import_structured_dataset_handler() @classmethod def to_literal_type(cls, python_type: Type[T]) -> LiteralType: @@ -1316,7 +1296,12 @@ def to_literal_type(cls, python_type: Type[T]) -> LiteralType: ) data = x.data if data is not None: + # Double-check that `data` does not contain a key called `cache-key-metadata` + if CACHE_KEY_METADATA in data: + raise AssertionError(f"FlyteAnnotation cannot contain `{CACHE_KEY_METADATA}`.") idl_type_annotation = TypeAnnotationModel(annotations=data) + if res.annotation: + idl_type_annotation = TypeAnnotationModel.merge_annotations(idl_type_annotation, res.annotation) res = LiteralType.from_flyte_idl(res.to_flyte_idl()) res._annotation = idl_type_annotation return res @@ -2128,7 +2113,10 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") - return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT) + return _type_models.LiteralType( + simple=_type_models.SimpleType.STRUCT, + annotation=TypeAnnotationModel({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}}), + ) async def async_to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[dict], expected: LiteralType @@ -2636,6 +2624,38 @@ def get_literal(self, key: str) -> Literal: return self._literals[key] + def as_python_native(self, python_interface: Interface) -> typing.Any: + """ + This should return the native Python representation, compatible with unpacking. + This function relies on Python interface outputs being ordered correctly. + + :param python_interface: Only outputs are used but easier to pass the whole interface. + """ + if len(self.literals) == 0: + return None + + if self.variable_map is None: + raise AssertionError(f"Variable map is empty in literals resolver with {self.literals}") + + # Trigger get() on everything to make sure native values are present using the python interface as type hint + for lit_key, lit in self.literals.items(): + self.get(lit_key, as_type=python_interface.outputs.get(lit_key)) + + # if 1 item, then return 1 item + if len(self.native_values) == 1: + return next(iter(self.native_values.values())) + + # if more than 1 item, then return a tuple - can ignore naming the tuple unless it becomes a problem + # This relies on python_interface.outputs being ordered correctly. + res = cast(typing.Tuple[typing.Any, ...], ()) + for var_name, _ in python_interface.outputs.items(): + if var_name not in self.native_values: + raise ValueError(f"Key {var_name} is not in the native values") + + res += (self.native_values[var_name],) + + return res + def __getitem__(self, key: str): # First check to see if it's even in the literal map. if key not in self._literals: diff --git a/flytekit/core/worker_queue.py b/flytekit/core/worker_queue.py new file mode 100644 index 0000000000..df35e473ab --- /dev/null +++ b/flytekit/core/worker_queue.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import asyncio +import atexit +import hashlib +import re +import threading +import typing +from concurrent.futures import Future +from dataclasses import dataclass + +from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.constants import EAGER_ROOT_ENV_NAME, EAGER_TAG_KEY, EAGER_TAG_ROOT_KEY +from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.options import Options +from flytekit.core.reference_entity import ReferenceEntity +from flytekit.core.utils import _dnsify +from flytekit.core.workflow import WorkflowBase +from flytekit.exceptions.system import FlyteSystemException +from flytekit.loggers import developer_logger, logger +from flytekit.models.common import Labels +from flytekit.models.core.execution import WorkflowExecutionPhase + +if typing.TYPE_CHECKING: + from flytekit.remote.remote_callable import RemoteEntity + + RunnableEntity = typing.Union[WorkflowBase, LaunchPlan, PythonTask, ReferenceEntity, RemoteEntity] + from flytekit.remote import FlyteRemote, FlyteWorkflowExecution + + +standard_output_format = re.compile(r"^o\d+$") + +NODE_HTML_TEMPLATE = """ + + + + +

{entity_type}: {entity_name}

+ +

+ Execution: + {execution_name} +

+ +
+Inputs +
{inputs}
+
+ +
+Outputs +
{outputs}
+
+ +
+""" + + +@dataclass +class WorkItem: + entity: RunnableEntity + input_kwargs: dict[str, typing.Any] + fut: asyncio.Future + result: typing.Any = None + error: typing.Optional[BaseException] = None + + wf_exec: typing.Optional[FlyteWorkflowExecution] = None + + def set_result(self, result: typing.Any): + assert self.wf_exec is not None + developer_logger.debug(f"Setting result for {self.wf_exec.id.name} on thread {threading.current_thread().name}") + self.result = result + # need to convert from literals resolver to literals and then to python native. + self.fut._loop.call_soon_threadsafe(self.fut.set_result, result) + + def set_error(self, e: BaseException): + developer_logger.debug( + f"Setting error for {self.wf_exec.id.name if self.wf_exec else 'unstarted execution'}" + f" on thread {threading.current_thread().name} to {e}" + ) + self.error = e + self.fut._loop.call_soon_threadsafe(self.fut.set_exception, e) + + def set_exec(self, wf_exec: FlyteWorkflowExecution): + self.wf_exec = wf_exec + + @property + def ready(self) -> bool: + return self.wf_exec is not None and (self.result is not None or self.error is not None) + + +class Informer(typing.Protocol): + def watch(self, work: WorkItem): ... + + +class PollingInformer: + def __init__(self, remote: FlyteRemote, loop: asyncio.AbstractEventLoop): + self.remote: FlyteRemote = remote + self.__loop = loop + + async def watch_one(self, work: WorkItem): + assert work.wf_exec is not None + logger.debug(f"Starting watching execution {work.wf_exec.id} on {threading.current_thread().name}") + while True: + # not really async but pretend it is for now, change to to_thread in the future. + developer_logger.debug(f"Looping on {work.wf_exec.id.name}") + self.remote.sync_execution(work.wf_exec) + if work.wf_exec.is_done: + developer_logger.debug(f"Execution {work.wf_exec.id.name} is done.") + break + await asyncio.sleep(2) + + # set results + # but first need to convert from literals resolver to literals and then to python native. + if work.wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED: + from flytekit.core.interface import Interface + from flytekit.core.type_engine import TypeEngine + + if not work.entity.python_interface: + for k, _ in work.entity.interface.outputs.items(): + if not re.match(standard_output_format, k): + raise AssertionError( + f"Entity without python interface found, and output name {k} does not match standard format o[0-9]+" + ) + + num_outputs = len(work.entity.interface.outputs) + python_outputs_interface: typing.Dict[str, typing.Type] = {} + # Iterate in order so that we add to the interface in the correct order + for i in range(num_outputs): + key = f"o{i}" + if key not in work.entity.interface.outputs: + raise AssertionError( + f"Output name {key} not found in outputs {[k for k in work.entity.interface.outputs.keys()]}" + ) + var_type = work.entity.interface.outputs[key].type + python_outputs_interface[key] = TypeEngine.guess_python_type(var_type) + py_iface = Interface(inputs=typing.cast(dict[str, typing.Type], {}), outputs=python_outputs_interface) + else: + py_iface = work.entity.python_interface + + results = work.wf_exec.outputs.as_python_native(py_iface) + + work.set_result(results) + elif work.wf_exec.closure.phase == WorkflowExecutionPhase.FAILED: + from flytekit.exceptions.eager import EagerException + + exc = EagerException(f"Error executing {work.entity.name} with error: {work.wf_exec.closure.error}") + work.set_error(exc) + + def watch(self, work: WorkItem): + coro = self.watch_one(work) + # both run_coroutine_threadsafe and self.__loop.create_task seem to work, but the first makes + # more sense in case *this* thread is ever different than the thread that self.__loop is running on. + f = asyncio.run_coroutine_threadsafe(coro, self.__loop) + + developer_logger.debug(f"Started watch with future {f}") + + def cb(fut: Future): + """ + This cb takes care of any exceptions that might be thrown in the watch_one coroutine + Note: This is a concurrent Future not an asyncio Future + """ + e = fut.exception() + if e: + logger.error(f"Error in watch for {work.entity.name} with {work.input_kwargs}: {e}") + work.set_error(e) + + f.add_done_callback(cb) + + +# A flag to ensure the handler runs only once +handling_signal = 0 + + +class Controller: + def __init__(self, remote: FlyteRemote, ss: SerializationSettings, tag: str, root_tag: str, exec_prefix: str): + logger.debug( + f"Creating Controller for eager execution with {remote.config.platform.endpoint}," + f" {tag=}, {root_tag=}, {exec_prefix=} and ss: {ss}" + ) + # Set up things for this controller to operate + from flytekit.utils.asyn import _selector_policy + + with _selector_policy(): + self.__loop = asyncio.new_event_loop() + self.__loop.set_exception_handler(self.exc_handler) + self.__runner_thread: threading.Thread = threading.Thread( + target=self._execute, daemon=True, name="controller-loop-runner" + ) + self.__runner_thread.start() + atexit.register(self._close) + + # Things for actually kicking off and monitoring + self.entries: typing.Dict[str, typing.List[WorkItem]] = {} + self.informer = PollingInformer(remote=remote, loop=self.__loop) + self.remote = remote + self.ss = ss + self.exec_prefix = exec_prefix + + # Executions should be tracked in the following way: + # a) you should be able to list by label, all executions generated by the current eager task, + # b) in the case of nested eager, you should be able to list by label, all executions from the root eager task + # c) within a given eager task, the execution names should be deterministic and unique + + # To achieve this, this Controller will: + # a) set a label to track the root eager task execution + # b) set an environment variable to represent the root eager task execution for downstream tasks to read + # b) set a label to track the current eager task exec + # c) create deterministic execution names by combining: + # i) the current eager execution name (aka the tag) + # ii) the entity type being run + # iii) the entity name being run + # iv) the order in which it's called + # v) the input_kwargs + # hash the above, and then prepend it with a prefix. + self.tag = tag + self.root_tag = root_tag + + def _close(self) -> None: + if self.__loop: + self.__loop.stop() + + @staticmethod + def exc_handler(loop, context): + logger.error(f"Caught exception in loop {loop} with context {context}") + + def _execute(self) -> None: + loop = self.__loop + try: + loop.run_forever() + finally: + logger.error("Controller event loop stopped.") + + def get_labels(self) -> Labels: + """ + These labels keep track of the current and root (in case of nested) eager execution, that is responsible for + kicking off this execution. + """ + l = {EAGER_TAG_KEY: self.tag, EAGER_TAG_ROOT_KEY: self.root_tag} + return Labels(l) + + def get_env(self) -> typing.Dict[str, str]: + """ + In order for downstream tasks to correctly set the root label, this needs to pass down that information. + """ + return {EAGER_ROOT_ENV_NAME: self.root_tag} + + def get_execution_name(self, entity: RunnableEntity, idx: int, input_kwargs: dict[str, typing.Any]) -> str: + """Make a deterministic name""" + # todo: Move the transform of python native values to literals up to the controller, and use pb hashing/user + # provided hashmethods to comprise the input_kwargs part. Merely printing input_kwargs is not strictly correct + # https://github.com/flyteorg/flyte/issues/6069 + components = f"{self.tag}-{type(entity)}-{entity.name}-{idx}-{input_kwargs}" + + # has the components into something deterministic + hex = hashlib.md5(components.encode()).hexdigest() + # just take the first 16 chars. + hex = hex[:16] + name = entity.name.split(".")[-1] + name = name[:8] # just take the first 8 chars + exec_name = f"{self.exec_prefix}-{name}-{hex}" + exec_name = _dnsify(exec_name) + return exec_name + + def launch_and_start_watch(self, wi: WorkItem, idx: int): + """This function launches executions. This is called via the loop, so it needs exception handling""" + try: + if wi.result is None and wi.error is None: + l = self.get_labels() + e = self.get_env() + options = Options(labels=l) + exec_name = self.get_execution_name(wi.entity, idx, wi.input_kwargs) + logger.info(f"Generated execution name {exec_name} for {idx}th call of {wi.entity.name}") + from flytekit.remote.remote_callable import RemoteEntity + + if isinstance(wi.entity, RemoteEntity): + version = wi.entity.id.version + else: + version = self.ss.version + + # todo: if the execution already exists, remote.execute will return that execution. in the future + # we can add input checking to make sure the inputs are indeed a match. + wf_exec = self.remote.execute( + entity=wi.entity, + execution_name=exec_name, + inputs=wi.input_kwargs, + version=version, + image_config=self.ss.image_config, + options=options, + envs=e, + ) + logger.info(f"Successfully started execution {wf_exec.id.name}") + wi.set_exec(wf_exec) + + # if successful then start watch on the execution + self.informer.watch(wi) + else: + raise AssertionError( + "This launch function should not be invoked for work items already" " with result or error" + ) + except Exception as e: + # all exceptions get registered onto the future. + logger.error(f"Error launching execution for {wi.entity.name} with {wi.input_kwargs}") + wi.set_error(e) + + def add( + self, task_loop: asyncio.AbstractEventLoop, entity: RunnableEntity, input_kwargs: dict[str, typing.Any] + ) -> asyncio.Future: + """ + Add an entity along with the requested inputs to be submitted to Admin for running and return a future + """ + # need to also check to see if the entity has already been registered, and if not, register it. + fut = task_loop.create_future() + i = WorkItem(entity=entity, input_kwargs=input_kwargs, fut=fut) + + # For purposes of awaiting an execution, we don't need to keep track of anything, but doing so for Deck + if entity.name not in self.entries: + self.entries[entity.name] = [] + self.entries[entity.name].append(i) + idx = len(self.entries[entity.name]) - 1 + + # trigger a run of the launching function. + self.__loop.call_soon_threadsafe(self.launch_and_start_watch, i, idx) + return fut + + def render_html(self) -> str: + """Render the callstack as a deck presentation to be shown after eager workflow execution.""" + + from flytekit.core.base_task import PythonTask + from flytekit.core.python_function_task import AsyncPythonFunctionTask, EagerAsyncPythonFunctionTask + from flytekit.core.workflow import WorkflowBase + + output = "

Nodes


" + + def _entity_type(entity) -> str: + if isinstance(entity, EagerAsyncPythonFunctionTask): + return "Eager Workflow" + elif isinstance(entity, AsyncPythonFunctionTask): + return "Async Task" + elif isinstance(entity, PythonTask): + return "Task" + elif isinstance(entity, WorkflowBase): + return "Workflow" + return str(type(entity)) + + for entity_name, items_list in self.entries.items(): + for item in items_list: + if not item.ready: + logger.warning( + f"Item for {item.entity.name} with inputs {item.input_kwargs}" + f" isn't ready, skipping for deck rendering..." + ) + continue + kind = _entity_type(item.entity) + output = f"{output}\n" + NODE_HTML_TEMPLATE.format( + entity_type=kind, + entity_name=item.entity.name, + execution_name=item.wf_exec.id.name, # type: ignore[union-attr] + url=self.remote.generate_console_url(item.wf_exec), + inputs=item.input_kwargs, + outputs=item.result if item.result else item.error, + ) + + return output + + def get_signal_handler(self): + """ + TODO: At some point, this loop would be ideally managed by the loop manager, and the signal handler should + gracefully initiate shutdown of all loops, calling .cancel() on all tasks, allowing each loop to clean up, + starting with the deepest loop/thread first and working up. + https://github.com/flyteorg/flyte/issues/6068 + """ + + def signal_handler(signum, frame): + logger.warning(f"Received signal {signum}, initiating signal handler") + global handling_signal + if handling_signal: + if handling_signal > 2: + exit(1) + logger.warning("Signal already being handled. Please wait...") + handling_signal += 1 + return + + handling_signal += 1 + self.__loop.stop() # stop the loop + for _, work_list in self.entries.items(): + for work in work_list: + if work.wf_exec and not work.wf_exec.is_done: + try: + self.remote.terminate(work.wf_exec, "eager execution cancelled") + logger.warning(f"Cancelled {work.wf_exec.id.name}") + except FlyteSystemException as e: + logger.info(f"Error cancelling {work.wf_exec.id.name}, may already be cancelled: {e}") + exit(1) + + return signal_handler + + @classmethod + def for_sandbox(cls, exec_prefix: typing.Optional[str] = None) -> Controller: + from flytekit.core.context_manager import FlyteContextManager + from flytekit.remote import FlyteRemote + + ctx = FlyteContextManager.current_context() + remote = FlyteRemote.for_sandbox(default_project="flytesnacks", default_domain="development") + rand = ctx.file_access.get_random_string() + ss = ctx.serialization_settings + if not ss: + ss = SerializationSettings( + image_config=ImageConfig.auto_default_image(), + version=f"v{rand[:8]}", + ) + root_tag = tag = f"eager-local-{rand}" + exec_prefix = exec_prefix or f"e-{rand[:16]}" + + c = Controller(remote=remote, ss=ss, tag=tag, root_tag=root_tag, exec_prefix=exec_prefix) + return c diff --git a/flytekit/exceptions/eager.py b/flytekit/exceptions/eager.py new file mode 100644 index 0000000000..806c66884d --- /dev/null +++ b/flytekit/exceptions/eager.py @@ -0,0 +1,31 @@ +class EagerException(Exception): + """Raised when a node in an eager workflow encounters an error. + + This exception should be used in an :py:func:`@eager ` workflow function to + catch exceptions that are raised by tasks or subworkflows. + + .. code-block:: python + + from flytekit import task + from flytekit.exceptions.eager import EagerException + + @task + def add_one(x: int) -> int: + if x < 0: + raise ValueError("x must be positive") + return x + 1 + + @task + def double(x: int) -> int: + return x * 2 + + @eager + async def eager_workflow(x: int) -> int: + try: + out = await add_one(x=x) + except EagerException: + # The ValueError error is caught + # and raised as an EagerException + raise + return await double(x=out) + """ diff --git a/flytekit/experimental/__init__.py b/flytekit/experimental/__init__.py index af7e5c5971..e08e5676c7 100644 --- a/flytekit/experimental/__init__.py +++ b/flytekit/experimental/__init__.py @@ -2,4 +2,3 @@ # TODO(eapolinario): Remove this once a new flytekit release is out and # references are updated in the monodocs build. -from flytekit.experimental.eager_function import EagerException, eager diff --git a/flytekit/experimental/eager_function.py b/flytekit/experimental/eager_function.py index f5c0051de2..e35a1597d1 100644 --- a/flytekit/experimental/eager_function.py +++ b/flytekit/experimental/eager_function.py @@ -1,560 +1,15 @@ -import asyncio -import inspect import os -import signal -from contextlib import asynccontextmanager -from datetime import datetime, timedelta, timezone -from functools import partial, wraps -from typing import List, Optional +from typing import Optional -from flytekit import Deck, Secret, current_context +from flytekit import current_context from flytekit.configuration import DataConfig, PlatformConfig, S3Config -from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager -from flytekit.core.python_function_task import PythonFunctionTask -from flytekit.core.task import task -from flytekit.core.workflow import WorkflowBase from flytekit.loggers import logger -from flytekit.models.core.execution import WorkflowExecutionPhase from flytekit.remote import FlyteRemote FLYTE_SANDBOX_INTERNAL_ENDPOINT = "flyte-sandbox-grpc.flyte:8089" FLYTE_SANDBOX_MINIO_ENDPOINT = "http://flyte-sandbox-minio.flyte:9000" -NODE_HTML_TEMPLATE = """ - - - - -

{entity_type}: {entity_name}

- -

- Execution: - {execution_name} -

- -
-Inputs -
{inputs}
-
- -
-Outputs -
{outputs}
-
- -
-""" - - -class EagerException(Exception): - """Raised when a node in an eager workflow encounters an error. - - This exception should be used in an :py:func:`@eager ` workflow function to - catch exceptions that are raised by tasks or subworkflows. - - .. code-block:: python - - from flytekit import task - from flytekit.experimental import eager, EagerException - - @task - def add_one(x: int) -> int: - if x < 0: - raise ValueError("x must be positive") - return x + 1 - - @task - def double(x: int) -> int: - return x * 2 - - @eager - async def eager_workflow(x: int) -> int: - try: - out = await add_one(x=x) - except EagerException: - # The ValueError error is caught - # and raised as an EagerException - raise - return await double(x=out) - """ - - -class AsyncEntity: - """A wrapper around a Flyte entity (task, workflow, launch plan) that allows it to be executed asynchronously.""" - - def __init__( - self, - entity, - remote: Optional[FlyteRemote], - ctx: FlyteContext, - async_stack: "AsyncStack", - timeout: Optional[timedelta] = None, - poll_interval: Optional[timedelta] = None, - local_entrypoint: bool = False, - ): - self.entity = entity - self.ctx = ctx - self.async_stack = async_stack - self.execution_state = self.ctx.execution_state.mode - self.remote = remote - self.local_entrypoint = local_entrypoint - if self.remote is not None: - logger.debug(f"Using remote config: {self.remote.config}") - else: - logger.debug("Not using remote, executing locally") - self._timeout = timeout - self._poll_interval = poll_interval - self._execution = None - - async def __call__(self, **kwargs): - logger.debug(f"Calling {self.entity}: {self.entity.name}") - - # ensure async context is provided - if "async_ctx" in kwargs: - kwargs.pop("async_ctx") - - if getattr(self.entity, "execution_mode", None) == PythonFunctionTask.ExecutionBehavior.DYNAMIC: - raise EagerException( - "Eager workflows currently do not work with dynamic workflows. " - "If you need to use a subworkflow, use a static @workflow or nested @eager workflow." - ) - - if not self.local_entrypoint and self.ctx.execution_state.is_local_execution(): - # If running as a local workflow execution, just execute the python function - try: - if isinstance(self.entity, WorkflowBase): - out = self.entity._workflow_function(**kwargs) - if inspect.iscoroutine(out): - # need to handle invocation of AsyncEntity tasks within the workflow - out = await out - return out - elif isinstance(self.entity, PythonTask): - # invoke the task-decorated entity - out = self.entity(**kwargs) - if inspect.iscoroutine(out): - out = await out - return out - else: - raise ValueError(f"Entity type {type(self.entity)} not supported for local execution") - except Exception as exc: - raise EagerException( - f"Error executing {type(self.entity)} {self.entity.name} with {type(exc)}: {exc}" - ) from exc - - # this is a hack to handle the case when the task.name doesn't contain the fully - # qualified module name - entity_name = ( - f"{self.entity._instantiated_in}.{self.entity.name}" - if self.entity._instantiated_in not in self.entity.name - else self.entity.name - ) - - if isinstance(self.entity, WorkflowBase): - remote_entity = self.remote.fetch_workflow(name=entity_name) - elif isinstance(self.entity, PythonTask): - remote_entity = self.remote.fetch_task(name=entity_name) - else: - raise ValueError(f"Entity type {type(self.entity)} not supported for local execution") - - execution = self.remote.execute(remote_entity, inputs=kwargs, type_hints=self.entity.python_interface.inputs) - self._execution = execution - - url = self.remote.generate_console_url(execution) - msg = f"Running flyte {type(self.entity)} {entity_name} on remote cluster: {url}" - if self.local_entrypoint: - logger.info(msg) - else: - logger.debug(msg) - - node = AsyncNode(self, entity_name, execution, url) - self.async_stack.set_node(node) - - poll_interval = self._poll_interval or timedelta(seconds=30) - time_to_give_up = ( - (datetime.max.replace(tzinfo=timezone.utc)) - if self._timeout is None - else datetime.now(timezone.utc) + self._timeout - ) - - while datetime.now(timezone.utc) < time_to_give_up: - execution = self.remote.sync(execution) - if execution.closure.phase in {WorkflowExecutionPhase.FAILED}: - raise EagerException(f"Error executing {self.entity.name} with error: {execution.closure.error}") - elif execution.is_done: - break - await asyncio.sleep(poll_interval.total_seconds()) - - outputs = {} - for key, type_ in self.entity.python_interface.outputs.items(): - outputs[key] = execution.outputs.get(key, as_type=type_) - - if len(outputs) == 1: - out, *_ = outputs.values() - return out - return outputs - - async def terminate(self): - execution = self.remote.sync(self._execution) - logger.debug(f"Cleaning up execution: {execution}") - if not execution.is_done: - self.remote.terminate( - execution, - f"Execution terminated by eager workflow execution {self.async_stack.parent_execution_id}.", - ) - - poll_interval = self._poll_interval or timedelta(seconds=6) - time_to_give_up = ( - (datetime.max.replace(tzinfo=timezone.utc)) - if self._timeout is None - else datetime.now(timezone.utc) + self._timeout - ) - - while datetime.now(timezone.utc) < time_to_give_up: - execution = self.remote.sync(execution) - if execution.is_done: - break - await asyncio.sleep(poll_interval.total_seconds()) - - return True - - -class AsyncNode: - """A node in the async callstack.""" - - def __init__(self, async_entity, entity_name, execution=None, url=None): - self.entity_name = entity_name - self.async_entity = async_entity - self.execution = execution - self._url = url - - @property - def url(self) -> str: - # make sure that internal flyte sandbox endpoint is replaced with localhost endpoint when rendering the urls - # for flyte decks - endpoint_root = FLYTE_SANDBOX_INTERNAL_ENDPOINT.replace("http://", "") - if endpoint_root in self._url: - return self._url.replace(endpoint_root, "localhost:30080") - return self._url - - @property - def entity_type(self) -> str: - if ( - isinstance(self.async_entity.entity, PythonTask) - and getattr(self.async_entity.entity, "execution_mode", None) == PythonFunctionTask.ExecutionBehavior.EAGER - ): - return "Eager Workflow" - elif isinstance(self.async_entity.entity, PythonTask): - return "Task" - elif isinstance(self.async_entity.entity, WorkflowBase): - return "Workflow" - return str(type(self.async_entity.entity)) - - def __repr__(self): - ex_id = self.execution.id - execution_id = None if self.execution is None else f"{ex_id.project}:{ex_id.domain}:{ex_id.name}" - return ( - "" - - @property - def call_stack(self) -> List[AsyncNode]: - return self._call_stack - - def set_node(self, node: AsyncNode): - self._call_stack.append(node) - - -async def render_deck(async_stack): - """Render the callstack as a deck presentation to be shown after eager workflow execution.""" - - def get_io(dict_like): - try: - return {k: dict_like.get(k) for k in dict_like} - except Exception: - return dict_like - - output = "

Nodes


" - for node in async_stack.call_stack: - node_inputs = get_io(node.execution.inputs) - if node.execution.closure.phase in {WorkflowExecutionPhase.FAILED}: - node_outputs = None - else: - node_outputs = get_io(node.execution.outputs) - - output = f"{output}\n" + NODE_HTML_TEMPLATE.format( - entity_type=node.entity_type, - entity_name=node.entity_name, - execution_name=node.execution.id.name, - url=node.url, - inputs=node_inputs, - outputs=node_outputs, - ) - - Deck("eager workflow", output) - - -@asynccontextmanager -async def eager_context( - fn, - remote: Optional[FlyteRemote], - ctx: FlyteContext, - async_stack: AsyncStack, - timeout: Optional[timedelta] = None, - poll_interval: Optional[timedelta] = None, - local_entrypoint: bool = False, -): - """This context manager overrides all tasks in the global namespace with async versions.""" - - _original_cache = {} - - # override tasks with async version - for k, v in fn.__globals__.items(): - if isinstance(v, (PythonTask, WorkflowBase)): - _original_cache[k] = v - fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval, local_entrypoint) - - try: - yield - finally: - # restore old tasks - for k, v in _original_cache.items(): - fn.__globals__[k] = v - - -async def node_cleanup_async(sig, loop, async_stack: AsyncStack): - """Clean up subtasks when eager workflow parent is done. - - This applies either if the eager workflow completes successfully, fails, or is cancelled by the user. - """ - logger.debug(f"Cleaning up async nodes on signal: {sig}") - terminations = [] - for node in async_stack.call_stack: - terminations.append(node.async_entity.terminate()) - results = await asyncio.gather(*terminations) - logger.debug(f"Successfully terminated subtasks {results}") - - -def node_cleanup(sig, frame, loop, async_stack: AsyncStack): - """Clean up subtasks when eager workflow parent is done. - - This applies either if the eager workflow completes successfully, fails, or is cancelled by the user. - """ - logger.debug(f"Cleaning up async nodes on signal: {sig}") - terminations = [] - for node in async_stack.call_stack: - terminations.append(node.async_entity.terminate()) - results = asyncio.gather(*terminations) - results = asyncio.run(results) - logger.debug(f"Successfully terminated subtasks {results}") - loop.close() - - -def eager( - _fn=None, - *, - remote: Optional[FlyteRemote] = None, - client_secret_group: Optional[str] = None, - client_secret_key: Optional[str] = None, - timeout: Optional[timedelta] = None, - poll_interval: Optional[timedelta] = None, - local_entrypoint: bool = False, - client_secret_env_var: Optional[str] = None, - **kwargs, -): - """Eager workflow decorator. - - :param remote: A :py:class:`~flytekit.remote.FlyteRemote` object to use for executing Flyte entities. - :param client_secret_group: The client secret group to use for this workflow. - :param client_secret_key: The client secret key to use for this workflow. - :param timeout: The timeout duration specifying how long to wait for a task/workflow execution within the eager - workflow to complete or terminate. By default, the eager workflow will wait indefinitely until complete. - :param poll_interval: The poll interval for checking if a task/workflow execution within the eager workflow has - finished. If not specified, the default poll interval is 6 seconds. - :param local_entrypoint: If True, the eager workflow will can be executed locally but use the provided - :py:func:`~flytekit.remote.FlyteRemote` object to create task/workflow executions. This is useful for local - testing against a remote Flyte cluster. - :param client_secret_env_var: if specified, binds the client secret to the specified environment variable for - remote authentication. - :param kwargs: keyword-arguments forwarded to :py:func:`~flytekit.task`. - - This type of workflow will execute all flyte entities within it eagerly, meaning that all python constructs can be - used inside of an ``@eager``-decorated function. This is because eager workflows use a - :py:class:`~flytekit.remote.remote.FlyteRemote` object to kick off executions when a flyte entity needs to produce a - value. - - For example: - - .. code-block:: python - - from flytekit import task - from flytekit.experimental import eager - - @task - def add_one(x: int) -> int: - return x + 1 - - @task - def double(x: int) -> int: - return x * 2 - - @eager - async def eager_workflow(x: int) -> int: - out = await add_one(x=x) - return await double(x=out) - - # run locally with asyncio - if __name__ == "__main__": - import asyncio - - result = asyncio.run(eager_workflow(x=1)) - print(f"Result: {result}") # "Result: 4" - - Unlike :py:func:`dynamic workflows `, eager workflows are not compiled into a workflow spec, but - uses python's `async `__ capabilities to execute flyte entities. - - .. note:: - - Eager workflows only support `@task`, `@workflow`, and `@eager` entities. Dynamic workflows and launchplans are - currently not supported. - - Note that for the ``@eager`` function is an ``async`` function. Under the hood, tasks and workflows called inside - an ``@eager`` workflow are executed asynchronously. This means that task and workflow calls will return an awaitable, - which need to be awaited. - - .. important:: - - A ``client_secret_group`` and ``client_secret_key`` is needed for authenticating via - :py:class:`~flytekit.remote.remote.FlyteRemote` using the ``client_credentials`` authentication, which is - configured via :py:class:`~flytekit.configuration.PlatformConfig`. - - .. code-block:: python - - from flytekit.remote import FlyteRemote - from flytekit.configuration import Config - - @eager( - remote=FlyteRemote(config=Config.auto(config_file="config.yaml")), - client_secret_group="my_client_secret_group", - client_secret_key="my_client_secret_key", - ) - async def eager_workflow(x: int) -> int: - out = await add_one(x) - return await double(one) - - Where ``config.yaml`` contains is a flytectl-compatible config file. - For more details, see `here `__. - - When using a sandbox cluster started with ``flytectl demo start``, however, the ``client_secret_group`` - and ``client_secret_key`` are not needed, : - - .. code-block:: python - - @eager(remote=FlyteRemote(config=Config.for_sandbox())) - async def eager_workflow(x: int) -> int: - ... - - .. important:: - - When using ``local_entrypoint=True`` you also need to specify the ``remote`` argument. In this case, the eager - workflow runtime will be local, but all task/subworkflow invocations will occur on the specified Flyte cluster. - This argument is primarily used for testing and debugging eager workflow logic locally. - - """ - - if _fn is None: - return partial( - eager, - remote=remote, - client_secret_group=client_secret_group, - client_secret_key=client_secret_key, - timeout=timeout, - poll_interval=poll_interval, - local_entrypoint=local_entrypoint, - client_secret_env_var=client_secret_env_var, - **kwargs, - ) - - if local_entrypoint and remote is None: - raise ValueError("Must specify remote argument if local_entrypoint is True") - - @wraps(_fn) - async def wrapper(*args, **kws): - # grab the "async_ctx" argument injected by PythonFunctionTask.execute - logger.debug("Starting") - _remote = remote - - # locally executed nested eager workflows won't have async_ctx injected into the **kws input - ctx = kws.pop("async_ctx", None) - task_id, execution_id = None, None - if ctx: - exec_params = ctx.user_space_params - task_id = exec_params.task_id - execution_id = exec_params.execution_id - - async_stack = AsyncStack(task_id, execution_id) - _remote = _prepare_remote( - _remote, ctx, client_secret_group, client_secret_key, local_entrypoint, client_secret_env_var - ) - - # make sure sub-nodes as cleaned up on termination signal - loop = asyncio.get_event_loop() - node_cleanup_partial = partial(node_cleanup_async, async_stack=async_stack) - cleanup_fn = partial(asyncio.ensure_future, node_cleanup_partial(signal.SIGTERM, loop)) - signal.signal(signal.SIGTERM, partial(node_cleanup, loop=loop, async_stack=async_stack)) - - async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval, local_entrypoint): - try: - if _remote is not None: - with _remote.remote_context(): - out = await _fn(*args, **kws) - else: - out = await _fn(*args, **kws) - # need to await for _fn to complete, then invoke the deck - await render_deck(async_stack) - return out - finally: - # in case the cleanup function hasn't been called yet, call it at the end of the eager workflow - await cleanup_fn() - - secret_requests = kwargs.pop("secret_requests", None) or [] - try: - secret_requests.append(Secret(group=client_secret_group, key=client_secret_key)) - except ValueError: - pass - - return task( - wrapper, - secret_requests=secret_requests, - enable_deck=True, - execution_mode=PythonFunctionTask.ExecutionBehavior.EAGER, - **kwargs, - ) - def _prepare_remote( remote: Optional[FlyteRemote], @@ -599,8 +54,6 @@ def _internal_demo_remote(remote: FlyteRemote) -> FlyteRemote: platform=PlatformConfig( endpoint=FLYTE_SANDBOX_INTERNAL_ENDPOINT, insecure=True, - auth_mode="Pkce", - client_id=remote.config.platform.client_id, ), data_config=DataConfig( s3=S3Config( diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index c8f046f5fa..16235a68ec 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,6 +1,6 @@ import asyncio +import inspect import json -import signal import sys import time import typing @@ -8,7 +8,7 @@ from collections import OrderedDict from dataclasses import asdict, dataclass from functools import partial -from types import FrameType, coroutine +from types import FrameType from typing import Any, Dict, List, Optional, Union from flyteidl.admin.agent_pb2 import Agent @@ -285,7 +285,6 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: output_prefix = ctx.file_access.get_random_remote_directory() agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource = asyncio.run( self._do(agent=agent, template=task_template, output_prefix=output_prefix, inputs=kwargs) ) @@ -322,13 +321,14 @@ class AsyncAgentExecutorMixin: Asynchronous tasks are tasks that take a long time to complete, such as running a query. """ - _clean_up_task: coroutine = None + _clean_up_task: bool = False _agent: AsyncAgentBase = None def execute(self: PythonTask, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() + self.resource_meta = None from flytekit.tools.translator import get_serializable @@ -380,7 +380,8 @@ async def _create( output_prefix=output_prefix, ) - signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore + FlyteContextManager.add_signal_handler(partial(self.agent_signal_handler, resource_meta)) + self.resource_meta = resource_meta return resource_meta async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: @@ -397,7 +398,6 @@ async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: time.sleep(1) resource = await mirror_async_methods(self._agent.get, resource_meta=resource_meta) if self._clean_up_task: - await self._clean_up_task sys.exit(1) phase = resource.phase @@ -415,7 +415,11 @@ async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: return resource - def signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameType) -> Any: - if self._clean_up_task is None: - co = mirror_async_methods(self._agent.delete, resource_meta=resource_meta) - self._clean_up_task = asyncio.create_task(co) + def agent_signal_handler(self, resource_meta: ResourceMeta, signum: int, frame: FrameType) -> Any: + if inspect.iscoroutinefunction(self._agent.delete): + # Use asyncio.run to run the async function in the main thread since the loop manager is killed when the + # signal is received. + asyncio.run(self._agent.delete(resource_meta)) + else: + self._agent.delete(resource_meta) + self._clean_up_task = True diff --git a/flytekit/extras/pydantic_transformer/transformer.py b/flytekit/extras/pydantic_transformer/transformer.py index dc6751218b..e9048d8880 100644 --- a/flytekit/extras/pydantic_transformer/transformer.py +++ b/flytekit/extras/pydantic_transformer/transformer.py @@ -8,11 +8,12 @@ from pydantic import BaseModel from flytekit import FlyteContext -from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK +from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.core.utils import str2bool from flytekit.loggers import logger from flytekit.models import types +from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.literals import Binary, Literal, Scalar from flytekit.models.types import LiteralType, TypeStructure @@ -37,7 +38,12 @@ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: # This is for attribute access in FlytePropeller. ts = TypeStructure(tag="", dataclass_type=literal_type) - return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts) + return types.LiteralType( + simple=types.SimpleType.STRUCT, + metadata=schema, + structure=ts, + annotation=TypeAnnotationModel({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}}), + ) def to_generic_literal( self, diff --git a/flytekit/lazy_import/lazy_module.py b/flytekit/lazy_import/lazy_module.py index 993d38a149..3112682d9d 100644 --- a/flytekit/lazy_import/lazy_module.py +++ b/flytekit/lazy_import/lazy_module.py @@ -2,10 +2,12 @@ import sys import types -LAZY_MODULES = [] +class _LazyModule(types.ModuleType): + """ + `lazy_module` returns an instance of this class if the module is not found in the python environment. + """ -class LazyModule(types.ModuleType): def __init__(self, module_name: str): super().__init__(module_name) self._module_name = module_name @@ -17,8 +19,12 @@ def __getattribute__(self, attr): def is_imported(module_name): """ This function is used to check if a module has been imported by the regular import. + Return false if module is lazy imported and not used yet. """ - return module_name in sys.modules and module_name not in LAZY_MODULES + return ( + module_name in sys.modules + and object.__getattribute__(lazy_module(module_name), "__class__").__name__ != "_LazyModule" + ) def lazy_module(fullname): @@ -37,11 +43,12 @@ def lazy_module(fullname): if spec is None or spec.loader is None: # Return a lazy module if the module is not found in the python environment, # so that we can raise a proper error when the user tries to access an attribute in the module. - return LazyModule(fullname) + # The reason to do this is because importlib.util.LazyLoader still requires + # the module to be installed even if you don't use it. + return _LazyModule(fullname) loader = importlib.util.LazyLoader(spec.loader) spec.loader = loader module = importlib.util.module_from_spec(spec) sys.modules[fullname] = module - LAZY_MODULES.append(module) loader.exec_module(module) return module diff --git a/flytekit/models/annotation.py b/flytekit/models/annotation.py index 1c17aabc5e..eef5504cce 100644 --- a/flytekit/models/annotation.py +++ b/flytekit/models/annotation.py @@ -42,6 +42,19 @@ def from_flyte_idl(cls, proto): return cls(annotations=_json_format.MessageToDict(proto.annotations)) + @classmethod + def merge_annotations(cls, annotation: "TypeAnnotation", other_annotation: "TypeAnnotation") -> "TypeAnnotation": + """ + Merges two annotations together. If the same key exists in both annotations, the value in the other annotation + will be used. + :param TypeAnnotation annotation: The first annotation + :param TypeAnnotation other_annotation: The second annotation + :rtype: TypeAnnotation + """ + merged_annotations = annotation.annotations.copy() + merged_annotations.update(other_annotation.annotations) + return cls(annotations=merged_annotations) + def __eq__(self, x: object) -> bool: if not isinstance(x, self.__class__): return False diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 755744ce27..960555fd9b 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -179,6 +179,7 @@ def __init__( cache_serializable, pod_template_name, cache_ignore_input_vars, + is_eager: bool = False, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -200,6 +201,7 @@ def __init__( single instance over identical inputs is executed, other concurrent executions wait for the cached results. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. + :param is_eager: """ self._discoverable = discoverable self._runtime = runtime @@ -211,6 +213,11 @@ def __init__( self._cache_serializable = cache_serializable self._pod_template_name = pod_template_name self._cache_ignore_input_vars = cache_ignore_input_vars + self._is_eager = is_eager + + @property + def is_eager(self): + return self._is_eager @property def discoverable(self): @@ -310,13 +317,14 @@ def to_flyte_idl(self): cache_serializable=self.cache_serializable, pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, + is_eager=self.is_eager, ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) return tm @classmethod - def from_flyte_idl(cls, pb2_object): + def from_flyte_idl(cls, pb2_object: _core_task.TaskMetadata): """ :param flyteidl.core.task_pb2.TaskMetadata pb2_object: :rtype: TaskMetadata @@ -332,6 +340,7 @@ def from_flyte_idl(cls, pb2_object): cache_serializable=pb2_object.cache_serializable, pod_template_name=pb2_object.pod_template_name, cache_ignore_input_vars=pb2_object.cache_ignore_input_vars, + is_eager=pb2_object.is_eager, ) diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 5095504784..bbad66a8ea 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -150,7 +150,7 @@ def is_done(self) -> bool: } @property - def outputs(self): + def outputs(self) -> Optional[LiteralsResolver]: outputs = super().outputs if outputs and self._type_hints: outputs.update_type_hints(self._type_hints) diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py index 1df2197329..21664adf9a 100644 --- a/flytekit/remote/lazy_entity.py +++ b/flytekit/remote/lazy_entity.py @@ -2,6 +2,7 @@ from threading import Lock from flytekit import FlyteContext +from flytekit.models.core.identifier import Identifier from flytekit.remote.remote_callable import RemoteEntity T = typing.TypeVar("T", bound=RemoteEntity) @@ -26,6 +27,10 @@ def __init__(self, name: str, getter: typing.Callable[[], T], *args, **kwargs): def name(self) -> str: return self._name + @property + def id(self) -> Identifier: + return getattr(self.entity, "id") + def entity_fetched(self) -> bool: with self._mutex: return self._entity is not None diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 549f0045d3..08bb677171 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -524,7 +524,9 @@ def get_launch_plan_from_then_node( if node.branch_node: get_launch_plan_from_branch(node.branch_node, node_launch_plans) - return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + flyte_workflow = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + flyte_workflow.template._id = workflow_id + return flyte_workflow def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchPlan: """ @@ -863,13 +865,17 @@ async def _serialize_and_register( cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items())) tasks = [] loop = asyncio.get_running_loop() - for entity, cp_entity in cp_task_entity_map.items(): + for task_entity, cp_entity in cp_task_entity_map.items(): tasks.append( loop.run_in_executor( None, - functools.partial(self.raw_register, cp_entity, serialization_settings, version, og_entity=entity), + functools.partial( + self.raw_register, cp_entity, serialization_settings, version, og_entity=task_entity + ), ) ) + if task_entity == entity: + registered_entity = await tasks[-1] identifiers_or_exceptions = [] identifiers_or_exceptions.extend(await asyncio.gather(*tasks, return_exceptions=True)) @@ -882,15 +888,17 @@ async def _serialize_and_register( raise ie # serial register cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items())) - for entity, cp_entity in cp_other_entities.items(): + for non_task_entity, cp_entity in cp_other_entities.items(): try: identifiers_or_exceptions.append( - self.raw_register(cp_entity, serialization_settings, version, og_entity=entity) + self.raw_register(cp_entity, serialization_settings, version, og_entity=non_task_entity) ) except RegistrationSkipped as e: logger.info(f"Skipping registration... {e}") continue - return identifiers_or_exceptions[-1] + if non_task_entity == entity: + registered_entity = identifiers_or_exceptions[-1] + return registered_entity def register_task( self, @@ -1334,6 +1342,7 @@ def _execute( type_hints = type_hints or {} literal_map = {} + with self.remote_context() as ctx: input_flyte_type_map = entity.interface.inputs @@ -1361,9 +1370,7 @@ def _execute( ) lit = TypeEngine.to_literal(ctx, v, hint, variable.type) literal_map[k] = lit - literal_inputs = literal_models.LiteralMap(literals=literal_map) - try: # Currently, this will only execute the flyte entity referenced by # flyte_id in the same project and domain. However, it is possible to execute it in a different project @@ -1606,6 +1613,7 @@ def execute( tags=tags, cluster_pool=cluster_pool, execution_cluster_label=execution_cluster_label, + options=options, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -1903,6 +1911,7 @@ def execute_local_task( tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, execution_cluster_label: typing.Optional[str] = None, + options: typing.Optional[Options] = None, ) -> FlyteWorkflowExecution: """ Execute a @task-decorated function or TaskTemplate task. @@ -1921,6 +1930,8 @@ def execute_local_task( :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. + :param options: Options to customize the execution. + :return: FlyteWorkflowExecution object. """ ss = SerializationSettings( @@ -1945,8 +1956,10 @@ def execute_local_task( if self.interactive_mode_enabled: ss.fast_serialization_settings = self._pickle_and_upload_entity(entity, pickled_target_dict) + # TODO: If this is being registered from eager, it will not reflect the full serialization settings + # object (look into the function, the passed in ss is basically ignored). How should it be piped in? + # https://github.com/flyteorg/flyte/issues/6070 flyte_task: FlyteTask = self.register_task(entity, ss, version) - return self.execute( flyte_task, inputs, @@ -1957,6 +1970,7 @@ def execute_local_task( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + options=options, envs=envs, tags=tags, cluster_pool=cluster_pool, diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index 4dbf83bdbb..f71b784066 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -5,6 +5,7 @@ from flytekit.core.promise import Promise, VoidPromise, create_and_link_node_from_remote, extract_obj_name from flytekit.exceptions import user as user_exceptions from flytekit.loggers import logger +from flytekit.models.core.identifier import Identifier from flytekit.models.core.workflow import NodeMetadata @@ -20,6 +21,10 @@ def __init__(self, *args, **kwargs): @abstractmethod def name(self) -> str: ... + @property + @abstractmethod + def id(self) -> Identifier: ... + def construct_node_metadata(self) -> NodeMetadata: """ Used when constructing the node that encapsulates this task as part of a broader workflow definition. diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index 3bf209700c..a3698102bf 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -6,6 +6,7 @@ import pathlib import posixpath import shutil +import stat import subprocess import tarfile import tempfile @@ -114,9 +115,6 @@ def fast_package( ignores = default_ignores ignore = IgnoreGroup(source, ignores) - # Remove this after original tar command is removed. - digest = compute_digest(source, ignore.is_ignored) - # This function is temporarily split into two, to support the creation of the tar file in both the old way, # copying the underlying items in the source dir by doing a listdir, and the new way, relying on a list of files. if options and ( @@ -151,6 +149,9 @@ def fast_package( # Original tar command - This condition to be removed in the future after serialize is removed. else: + # Remove this after original tar command is removed. + digest = compute_digest(source, ignore.is_ignored) + # Compute where the archive should be written archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" if output_dir is None: @@ -168,7 +169,6 @@ def fast_package( arcname=ws_file, filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x)), ) - # tar.list(verbose=True) compress_tarball(tar_path, archive_fname) @@ -190,6 +190,11 @@ def compute_digest_for_file(path: os.PathLike, rel_path: os.PathLike) -> None: logger.info(f"Skipping non-existent file {path}") return + # Skip socket files + if stat.S_ISSOCK(os.stat(path).st_mode): + logger.info(f"Skip socket file {path}") + return + if filter: if filter(rel_path): return diff --git a/flytekit/tools/ignore.py b/flytekit/tools/ignore.py index e41daf0904..e2aefef596 100644 --- a/flytekit/tools/ignore.py +++ b/flytekit/tools/ignore.py @@ -48,7 +48,7 @@ def _list_ignored(self) -> Dict: out = subprocess.run(["git", "ls-files", "-io", "--exclude-standard"], cwd=self.root, capture_output=True) if out.returncode == 0: return dict.fromkeys(out.stdout.decode("utf-8").split("\n")[:-1]) - logger.warning(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters") + logger.info(f"Could not determine ignored files due to:\n{out.stderr}\nNot applying any filters") return {} logger.info("No git executable found, not applying any filters") return {} diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index fa8634361d..6580fa6462 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -5,6 +5,7 @@ import os import shutil import site +import stat import sys import tarfile import tempfile @@ -169,6 +170,10 @@ def list_all_files(source_path: str, deref_symlinks, ignore_group: Optional[Igno if not os.path.exists(abspath): logger.info(f"Skipping non-existent file {abspath}") continue + # Skip socket files + if stat.S_ISSOCK(os.stat(abspath).st_mode): + logger.info(f"Skip socket file {abspath}") + continue if ignore_group: if ignore_group.is_ignored(abspath): continue @@ -183,23 +188,39 @@ def list_all_files(source_path: str, deref_symlinks, ignore_group: Optional[Igno return all_files +def _file_is_in_directory(file: str, directory: str) -> bool: + """Return True if file is in directory and in its children.""" + try: + return os.path.commonpath([file, directory]) == directory + except ValueError as e: + # ValueError is raised by windows if the paths are not from the same drive + logger.debug(f"{file} and {directory} are not in the same drive: {str(e)}") + return False + + def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) -> List[str]: """Copies modules into destination that are in modules. The module files are copied only if: 1. Not a site-packages. These are installed packages and not user files. - 2. Not in the bin. These are also installed and not user files. + 2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files. 3. Does not share a common path with the source_path. """ # source path is the folder holding the main script. # but in register/package case, there are multiple folders. # identify a common root amongst the packages listed? - site_packages = site.getsitepackages() - site_packages_set = set(site_packages) - bin_directory = os.path.dirname(sys.executable) files = [] flytekit_root = os.path.dirname(flytekit.__file__) + # These directories contain installed packages or modules from the Python standard library. + # If a module is from these directories, then they are not user files. + invalid_directories = [ + flytekit_root, + sys.prefix, + sys.base_prefix, + site.getusersitepackages(), + ] + site.getsitepackages() + for mod in modules: try: mod_file = mod.__file__ @@ -209,37 +230,11 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) if mod_file is None: continue - # Check to see if mod_file is in site_packages or bin_directory, which are - # installed packages & libraries that are not user files. This happens when - # there is a virtualenv like `.venv` in the working directory. - try: - # Do not upload code if it is from the flytekit library - if os.path.commonpath([flytekit_root, mod_file]) == flytekit_root: - continue - - if os.path.commonpath(site_packages + [mod_file]) in site_packages_set: - # Do not upload files from site-packages - continue - - if os.path.commonpath([bin_directory, mod_file]) == bin_directory: - # Do not upload from the bin directory - continue - - except ValueError: - # ValueError is raised by windows if the paths are not from the same drive - # If the files are not in the same drive, then mod_file is not - # in the site-packages or bin directory. - pass + if any(_file_is_in_directory(mod_file, directory) for directory in invalid_directories): + continue - try: - common_path = os.path.commonpath([mod_file, source_path]) - if common_path != source_path: - # Do not upload files that do not share a common directory with the source - continue - except ValueError: - # ValueError is raised by windows if the paths are not from the same drive - # If they are not in the same directory, then they do not share a common path, - # so we do not upload the file. + if not _file_is_in_directory(mod_file, source_path): + # Only upload files where the module file in the source directory continue files.append(mod_file) @@ -251,7 +246,7 @@ def add_imported_modules_from_source(source_path: str, destination: str, modules """Copies modules into destination that are in modules. The module files are copied only if: 1. Not a site-packages. These are installed packages and not user files. - 2. Not in the bin. These are also installed and not user files. + 2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files. 3. Does not share a common path with the source_path. """ # source path is the folder holding the main script. diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 868f657610..5c7a6d5eb4 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -22,6 +22,7 @@ from flytekit.core.python_auto_container import ( PythonAutoContainerTask, ) +from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask from flytekit.core.utils import ClassDecorator, _dnsify @@ -155,6 +156,9 @@ def get_serializable_task( for entity_hint in entity.node_dependency_hints: get_serializable(entity_mapping, settings, entity_hint, options) + if isinstance(entity, EagerAsyncPythonFunctionTask): + settings = settings.with_serialized_context() + container = entity.get_container(settings) # This pod will be incorrect when doing fast serialize pod = entity.get_k8s_pod(settings) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 254ff16721..ce1affd01e 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -13,15 +13,15 @@ """ from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer +from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from .structured_dataset import ( + DuplicateHandlerError, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, - StructuredDatasetMetadata, StructuredDatasetTransformerEngine, - StructuredDatasetType, ) @@ -84,3 +84,27 @@ def register_snowflake_handlers(): "We won't register snowflake handler for structured dataset because " "we can't find package snowflake-connector-python" ) + + +def lazy_import_structured_dataset_handler(): + if is_imported("pandas"): + try: + register_pandas_handlers() + register_csv_handlers() + except DuplicateHandlerError: + logger.debug("Transformer for pandas is already registered.") + if is_imported("pyarrow"): + try: + register_arrow_handlers() + except DuplicateHandlerError: + logger.debug("Transformer for arrow is already registered.") + if is_imported("google.cloud.bigquery"): + try: + register_bigquery_handlers() + except DuplicateHandlerError: + logger.debug("Transformer for bigquery is already registered.") + if is_imported("snowflake.connector"): + try: + register_snowflake_handlers() + except DuplicateHandlerError: + logger.debug("Transformer for snowflake is already registered.") diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 19afd6733e..da9cc79753 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -170,6 +170,18 @@ def literal(self) -> Optional[literals.StructuredDataset]: return self._literal_sd def open(self, dataframe_type: Type[DF]): + from flytekit.types.structured import lazy_import_structured_dataset_handler + + """ + Load the handler if needed. For the use case like: + @task + def t1(sd: StructuredDataset): + import pandas as pd + sd.open(pd.DataFrame).all() + + pandas is imported inside the task, so pandnas handler won't be loaded during deserialization in type engine. + """ + lazy_import_structured_dataset_handler() self._dataframe_type = dataframe_type return self diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index d1edb67436..0ae25e9bb7 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -59,6 +59,15 @@ def _execute(self) -> None: finally: loop.close() + def get_exc_handler(self): + def exc_handler(loop, context): + logger.error( + f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught" + f" exception in {loop}: {context}" + ) + + return exc_handler + def run(self, coro: Any) -> Any: """Synchronously run a coroutine on a background thread.""" name = f"{threading.current_thread().name} : loop-runner" @@ -66,9 +75,13 @@ def run(self, coro: Any) -> Any: if self.__loop is None: with _selector_policy(): self.__loop = asyncio.new_event_loop() + + exc_handler = self.get_exc_handler() + self.__loop.set_exception_handler(exc_handler) self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name) self.__runner_thread.start() fut = asyncio.run_coroutine_threadsafe(coro, self.__loop) + res = fut.result(None) return res diff --git a/plugins/flytekit-dbt/setup.py b/plugins/flytekit-dbt/setup.py index 08899d42ce..f683449fea 100644 --- a/plugins/flytekit-dbt/setup.py +++ b/plugins/flytekit-dbt/setup.py @@ -4,10 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = [ - "flytekit>=1.3.0b2", - "dbt-core<1.8.0", -] +plugin_requires = ["flytekit>=1.3.0b2", "dbt-core>=1.6.0,<1.8.0", "networkx>=2.5"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py index 81e68618ca..c8f93c585e 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py @@ -164,7 +164,7 @@ def __init__(self, *args, **kwargs): name=container_name, image="python:3.11-slim", command=["/bin/sh", "-c"], - args=[f"pip install requests && pip install ollama && {command}"], + args=[f"pip install requests && pip install ollama==0.3.3 && {command}"], resources=V1ResourceRequirements( requests={ "cpu": self._model_cpu, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 474901544d..e6359641ca 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -69,9 +69,10 @@ def encode( df.to_parquet(output_bytes) if structured_dataset.uri is not None: + output_bytes.seek(0) fs = ctx.file_access.get_filesystem_for_path(path=structured_dataset.uri) with fs.open(structured_dataset.uri, "wb") as s: - s.write(output_bytes) + s.write(output_bytes.read()) output_uri = structured_dataset.uri else: remote_fn = "00000" # 00000 is our default unnamed parquet filename diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index c2d4a39be7..9acae1c274 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -5,7 +5,7 @@ import pytest from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer from typing_extensions import Annotated -from packaging import version +import numpy as np from polars.testing import assert_frame_equal from flytekit import kwtypes, task, workflow @@ -134,3 +134,28 @@ def consume_sd_return_sd(sd: StructuredDataset) -> StructuredDataset: opened_sd = opened_sd.collect() assert_frame_equal(opened_sd, polars_df) + + +def test_with_uri(): + temp_file = tempfile.mktemp() + + @task + def random_dataframe(num_rows: int) -> StructuredDataset: + feature_1_list = np.random.randint(low=100, high=999, size=(num_rows,)) + feature_2_list = np.random.normal(loc=0, scale=1, size=(num_rows, )) + pl_df = pl.DataFrame({'protein_length': feature_1_list, + 'protein_feature': feature_2_list}) + sd = StructuredDataset(dataframe=pl_df, uri=temp_file) + return sd + + @task + def consume(df: pd.DataFrame): + print(df.head(5)) + print(df.describe()) + + @workflow + def my_wf(num_rows: int): + pl = random_dataframe(num_rows=num_rows) + consume(pl) + + my_wf(num_rows=100) diff --git a/plugins/flytekit-ray/README.md b/plugins/flytekit-ray/README.md index f7db403a6c..250321fd05 100644 --- a/plugins/flytekit-ray/README.md +++ b/plugins/flytekit-ray/README.md @@ -7,3 +7,5 @@ To install the plugin, run the following command: ```bash pip install flytekitplugins-ray ``` + +All [examples](https://docs.flyte.org/en/latest/flytesnacks/examples/ray_plugin/index.html) showcasing execution of Ray jobs using the plugin can be found in the documentation. diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 3620a0494c..98a653a990 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -98,12 +98,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ), worker_group_spec=[ WorkerGroupSpec( - c.group_name, - c.replicas, - c.min_replicas, - c.max_replicas, - c.ray_start_params, - c.k8s_pod, + c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, c.k8s_pod ) for c in cfg.worker_node_config ], diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index 737cdf6f4a..c943067013 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -4,16 +4,29 @@ import ray import yaml from flytekitplugins.ray import HeadNodeConfig -from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec +from flytekitplugins.ray.models import ( + HeadGroupSpec, + RayCluster, + RayJob, + WorkerGroupSpec, +) from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig from google.protobuf.json_format import MessageToDict -from flytekit.models.task import K8sPod from flytekit import PythonFunctionTask, task from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.models.task import K8sPod config = RayJobConfig( - worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], + worker_node_config=[ + WorkerNodeConfig( + group_name="test_group", + replicas=3, + min_replicas=0, + max_replicas=10, + k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}), + ) + ], head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), runtime_env={"pip": ["numpy"]}, enable_autoscaling=True, @@ -44,7 +57,19 @@ def t1(a: int) -> str: ) ray_job_pb = RayJob( - ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True), + ray_cluster=RayCluster( + worker_group_spec=[ + WorkerGroupSpec( + group_name="test_group", + replicas=3, + min_replicas=0, + max_replicas=10, + k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}), + ) + ], + head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), + enable_autoscaling=True, + ), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), runtime_env_yaml=yaml.dump({"pip": ["numpy"]}), shutdown_after_job_finishes=True, diff --git a/pyproject.toml b/pyproject.toml index 6b50981faf..58c107cdc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "jsonlines", "jsonpickle", "keyring>=18.0.1", + "kubernetes>=12.0.1", "markdown-it-py", "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", diff --git a/tests/flytekit/integration/experimental/eager_workflows.py b/tests/flytekit/integration/experimental/eager_workflows.py deleted file mode 100644 index 2dbc28a640..0000000000 --- a/tests/flytekit/integration/experimental/eager_workflows.py +++ /dev/null @@ -1,154 +0,0 @@ -import asyncio -import os -import typing -from functools import partial -from pathlib import Path - -import pandas as pd - -from flytekit import task, workflow -from flytekit.configuration import Config -from flytekit.experimental import EagerException, eager -from flytekit.remote import FlyteRemote -from flytekit.types.directory import FlyteDirectory -from flytekit.types.file import FlyteFile -from flytekit.types.structured import StructuredDataset - -remote = FlyteRemote( - config=Config.for_sandbox(), - default_project="flytesnacks", - default_domain="development", -) - - -eager_partial = partial(eager, remote=remote) - - -@task -def add_one(x: int) -> int: - return x + 1 - - -@task -def double(x: int) -> int: - return x * 2 - - -@task -def gt_0(x: int) -> bool: - return x > 0 - - -@task -def raises_exc(x: int) -> int: - if x == 0: - raise TypeError - return x - - -@task -def create_structured_dataset() -> StructuredDataset: - df = pd.DataFrame({"a": [1, 2, 3]}) - return StructuredDataset(dataframe=df) - - -@task -def create_file() -> FlyteFile: - fname = "/tmp/flytekit_test_file" - with open(fname, "w") as fh: - fh.write("some data\n") - return FlyteFile(path=fname) - - -@task -def create_directory() -> FlyteDirectory: - dirname = "/tmp/flytekit_test_dir" - Path(dirname).mkdir(exist_ok=True, parents=True) - with open(os.path.join(dirname, "file"), "w") as tmp: - tmp.write("some data\n") - return FlyteDirectory(path=dirname) - - -@eager_partial -async def simple_eager_wf(x: int) -> int: - out = await add_one(x=x) - return await double(x=out) - - -@eager_partial -async def conditional_eager_wf(x: int) -> int: - if await gt_0(x=x): - return -1 - return 1 - - -@eager_partial -async def try_except_eager_wf(x: int) -> int: - try: - return await raises_exc(x=x) - except EagerException: - return -1 - - -@eager_partial -async def gather_eager_wf(x: int) -> typing.List[int]: - results = await asyncio.gather(*[add_one(x=x) for _ in range(10)]) - return results - - -@eager_partial -async def nested_eager_wf(x: int) -> int: - out = await simple_eager_wf(x=x) - return await double(x=out) - - -@workflow -def wf_with_eager_wf(x: int) -> int: - out = simple_eager_wf(x=x) - return double(x=out) - - -@workflow -def subworkflow(x: int) -> int: - return add_one(x=x) - - -@eager_partial -async def eager_wf_with_subworkflow(x: int) -> int: - out = await subworkflow(x=x) - return await double(x=out) - - -@eager_partial -async def eager_wf_structured_dataset() -> int: - dataset = await create_structured_dataset() - df = dataset.open(pd.DataFrame).all() - return int(df["a"].sum()) - - -@eager_partial -async def eager_wf_flyte_file() -> str: - file = await create_file() - file.download() - with open(file.path) as f: - data = f.read().strip() - return data - - -@eager_partial -async def eager_wf_flyte_directory() -> str: - directory = await create_directory() - directory.download() - with open(os.path.join(directory.path, "file")) as f: - data = f.read().strip() - return data - - -@eager(remote=remote, local_entrypoint=True) -async def eager_wf_local_entrypoint(x: int) -> int: - out = await add_one(x=x) - return await double(x=out) - - -if __name__ == "__main__": - print(asyncio.run(simple_eager_wf(x=1))) diff --git a/tests/flytekit/integration/experimental/test_eager_workflows.py b/tests/flytekit/integration/experimental/test_eager_workflows.py deleted file mode 100644 index ad1bc44112..0000000000 --- a/tests/flytekit/integration/experimental/test_eager_workflows.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Eager workflow integration tests. - -These tests are currently not run in CI. In order to run this locally you'll need to start a -local flyte cluster, and build and push a flytekit development image: - -``` - -# if you already have a local cluster running, tear it down and start fresh -flytectl demo teardown -v - -# start a local flyte cluster -flytectl demo start - -# build and push the image -docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=3.9 -docker push localhost:30000/flytekit:dev - -# run the tests -pytest tests/flytekit/integration/experimental/test_eager_workflows.py -``` -""" - -import asyncio -import os -import subprocess -import time -from pathlib import Path - -import pytest - -from flytekit.configuration import Config -from flytekit.remote import FlyteRemote - -from .eager_workflows import eager_wf_local_entrypoint - -MODULE = "eager_workflows" -MODULE_PATH = Path(__file__).parent / f"{MODULE}.py" -CONFIG = os.environ.get("FLYTECTL_CONFIG", str(Path.home() / ".flyte" / "config-sandbox.yaml")) -IMAGE = os.environ.get("FLYTEKIT_IMAGE", "localhost:30000/flytekit:dev") - - -@pytest.fixture(scope="session") -def register(): - subprocess.run( - [ - "pyflyte", - "-c", - CONFIG, - "register", - "--image", - IMAGE, - "--project", - "flytesnacks", - "--domain", - "development", - MODULE_PATH, - ] - ) - - -@pytest.mark.skipif( - os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure" -) -@pytest.mark.parametrize( - "entity_type, entity_name, input, output", - [ - ("eager", "simple_eager_wf", 1, 4), - ("eager", "conditional_eager_wf", 1, -1), - ("eager", "conditional_eager_wf", -10, 1), - ("eager", "try_except_eager_wf", 1, 1), - ("eager", "try_except_eager_wf", 0, -1), - ("eager", "gather_eager_wf", 1, [2] * 10), - ("eager", "nested_eager_wf", 1, 8), - ("eager", "eager_wf_with_subworkflow", 1, 4), - ("eager", "eager_wf_structured_dataset", None, 6), - ("eager", "eager_wf_flyte_file", None, "some data"), - ("eager", "eager_wf_flyte_directory", None, "some data"), - ("workflow", "wf_with_eager_wf", 1, 8), - ], -) -def test_eager_workflows(register, entity_type, entity_name, input, output): - remote = FlyteRemote( - config=Config.auto(config_file=CONFIG), - default_project="flytesnacks", - default_domain="development", - ) - - fetch_method = { - "eager": remote.fetch_task, - "workflow": remote.fetch_workflow, - }[entity_type] - - entity = None - for i in range(100): - try: - entity = fetch_method(name=f"{MODULE}.{entity_name}") - break - except Exception: - print(f"retry {i}") - time.sleep(6) - continue - - if entity is None: - raise RuntimeError("failed to fetch entity") - - inputs = {} if input is None else {"x": input} - execution = remote.execute(entity, inputs=inputs, wait=True) - assert execution.outputs["o0"] == output - - -@pytest.mark.skipif( - os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure" -) -def test_eager_workflow_local_entrypoint(register): - result = asyncio.run(eager_wf_local_entrypoint(x=1)) - assert result == 4 diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 67523a4dd0..5d953350a0 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -107,6 +107,11 @@ def test_remote_run(): run("default_lp.py", "my_wf") +def test_remote_eager_run(): + # child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. + run("eager_example.py", "simple_eager_workflow", "--x", "3") + + def test_generic_idl_flytetypes(): os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true" # default inputs for flyte types in dataclass diff --git a/tests/flytekit/integration/remote/workflows/basic/eager_example.py b/tests/flytekit/integration/remote/workflows/basic/eager_example.py new file mode 100644 index 0000000000..b72ff3cd37 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/eager_example.py @@ -0,0 +1,16 @@ +from flytekit.core.task import task, eager +from flytekit.configuration import ImageConfig, SerializationSettings, Image +from flytekit.core.task import task, eager + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@eager +async def simple_eager_workflow(x: int) -> int: + # This is the normal way of calling tasks. Call normal tasks in an effectively async way by hanging and waiting for + # the result. + out = add_one(x=x) + return out diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index d4281227db..ed1fc7fdd0 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -9,13 +9,12 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow, PythonFunctionTask +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine -from flytekit.experimental.eager_function import eager from flytekit.extras.accelerators import GPUAccelerator from flytekit.models.literals import ( Literal, diff --git a/tests/flytekit/unit/core/test_async.py b/tests/flytekit/unit/core/test_async.py new file mode 100644 index 0000000000..c706833ab9 --- /dev/null +++ b/tests/flytekit/unit/core/test_async.py @@ -0,0 +1,60 @@ +import pytest +from flytekit.core.task import task, eager +from flytekit.core.worker_queue import Controller +from flytekit.utils.asyn import loop_manager +from flytekit.core.context_manager import FlyteContextManager +from flytekit.configuration import Config, DataConfig, S3Config, FastSerializationSettings, ImageConfig, SerializationSettings, Image +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.tools.translator import get_serializable +from collections import OrderedDict + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@eager(environment={"a": "b"}) +async def simple_eager_workflow(x: int) -> int: + # This is the normal way of calling tasks. Call normal tasks in an effectively async way by hanging and waiting for + # the result. + out = add_one(x=x) + return out + + +@pytest.mark.asyncio +async def test_easy_1(): + res = await simple_eager_workflow(x=1) + print(res) + assert res == 2 + + +@pytest.mark.skip +def test_easy_2(): + ctx = FlyteContextManager.current_context() + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/testing/async_test/raw_output/" + print(f"Using raw output location: {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + c = Controller.for_sandbox() + c.remote._interactive_mode_enabled = True + with FlyteContextManager.with_context( + ctx.with_file_access(provider).with_client(c.remote.client).with_worker_queue(c) + ): + res = loop_manager.run_sync(simple_eager_workflow.run_with_backend, x=1) + assert res == 2 + + +def test_serialization(): + se_spec = get_serializable(OrderedDict(), serialization_settings, simple_eager_workflow) + assert se_spec.template.metadata.is_eager + assert len(se_spec.template.container.env) == 2 diff --git a/tests/flytekit/unit/core/test_async_more_semantics.py b/tests/flytekit/unit/core/test_async_more_semantics.py new file mode 100644 index 0000000000..05321f2065 --- /dev/null +++ b/tests/flytekit/unit/core/test_async_more_semantics.py @@ -0,0 +1,119 @@ +import asyncio +import typing +import pytest + +from flytekit.core.task import task, eager +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.remote.remote import FlyteRemote +from flytekit.utils.asyn import loop_manager + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@task +async def a_double(x: int) -> int: + return x * 2 + + +@task +def double(x: int) -> int: + return x * 2 + + +@eager +async def base_wf(x: int) -> int: + out = add_one(x=x) + doubled = a_double(x=x) + if out - await doubled < 0: + return -1 + final = double(x=out) + return final + + +@eager +async def parent_wf(a: int, b: int) -> typing.Tuple[int, int]: + print("hi") + t1 = asyncio.create_task(base_wf(x=a)) + t2 = asyncio.create_task(base_wf(x=b)) + # Test this again in the future + # currently behaving as python does. + # print("hi2", flush=True) + # await asyncio.sleep(0.01) + # time.sleep(5) + # Since eager workflows are also async tasks, we can use the general async pattern with them. + i1, i2 = await asyncio.gather(t1, t2) + return i1, i2 + + +@pytest.mark.asyncio +async def test_nested_all_local(): + res = await parent_wf(a=1, b=2) + print(res) + assert res == (4, -1) + + +@pytest.mark.skip +def test_nested_local_backend(): + ctx = FlyteContextManager.current_context() + remote = FlyteRemote(Config.for_sandbox()) + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/testing/async_test/raw_output/" + print(f"Using raw output location: {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + with FlyteContextManager.with_context(ctx.with_file_access(provider).with_client(remote.client)): + res = loop_manager.run_sync(parent_wf.run_with_backend, a=1, b=100) + print(res) + # Nested eagers just run against the backend like any other task. + assert res == (42, 44) + + +@eager +async def level_3(x: int) -> int: + out = add_one(x=x) + return out + + +@eager +async def level_2(x: int) -> int: + out = add_one(x=x) + level_3_res = await level_3(x=out) + final_res = double(x=level_3_res) + return final_res + + +@eager +async def level_1() -> typing.Tuple[int, int]: + i1 = add_one(x=5) + t2 = asyncio.create_task(level_2(x=1)) + + # don't forget the comma + i2, = await asyncio.gather(t2) + return i1, i2 + + +@pytest.mark.asyncio +async def test_nested_level_1_local(): + res = await level_1() + print(res) + assert res == (6, 6) + + +@pytest.mark.skip +def test_nested_local_backend_level(): + ctx = FlyteContextManager.current_context() + remote = FlyteRemote(Config.for_sandbox()) + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/testing/async_test/raw_output/" + print(f"Using raw output location: {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + with FlyteContextManager.with_context(ctx.with_file_access(provider).with_client(remote.client)): + res = loop_manager.run_sync(level_1.run_with_backend) + print(res) + assert res == (42, 42) diff --git a/tests/flytekit/unit/core/test_async_with_dynamic.py b/tests/flytekit/unit/core/test_async_with_dynamic.py new file mode 100644 index 0000000000..f1c8c34303 --- /dev/null +++ b/tests/flytekit/unit/core/test_async_with_dynamic.py @@ -0,0 +1,75 @@ +import asyncio +import time +import typing +import pytest + +from flytekit.core.task import task, eager +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.remote.remote import FlyteRemote +from flytekit.utils.asyn import loop_manager +from flytekit.core.worker_queue import Controller + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@task +async def a_double(x: int) -> int: + return x * 2 + + +@task +def double(x: int) -> int: + return x * 2 + + +@dynamic +def level_3_dt(x: int) -> int: + out = add_one(x=x) + return out + + +@eager +async def level_2(x: int) -> int: + out = add_one(x=x) + level_3_res = level_3_dt(x=out) + final_res = double(x=level_3_res) + return final_res + + +@eager +async def level_1() -> typing.Tuple[int, int]: + i1 = add_one(x=10) + t2 = asyncio.create_task(level_2(x=1)) + + # don't forget the comma + i2, = await asyncio.gather(t2) + return i1, i2 + + +@pytest.mark.asyncio +async def test_nested_level_1_local(): + res = await level_1() + print(res) + assert res == (11, 6) + + +@pytest.mark.skip +def test_nested_local_backend(): + ctx = FlyteContextManager.current_context() + remote = FlyteRemote(Config.for_sandbox()) + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/testing/async_test/raw_output/" + print(f"Using raw output location: {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + c = Controller.for_sandbox() + with FlyteContextManager.with_context(ctx.with_file_access(provider).with_client(remote.client).with_worker_queue(c)) as ctx: + res = loop_manager.run_sync(level_2.run_with_backend, x=1000) + print(res) + assert res == 43 diff --git a/tests/flytekit/unit/core/test_async_with_wflp.py b/tests/flytekit/unit/core/test_async_with_wflp.py new file mode 100644 index 0000000000..2d14c2c6c4 --- /dev/null +++ b/tests/flytekit/unit/core/test_async_with_wflp.py @@ -0,0 +1,78 @@ +import asyncio +import typing +import pytest + +from flytekit.core.task import task, eager +from flytekit.core.workflow import workflow +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.remote.remote import FlyteRemote +from flytekit.utils.asyn import loop_manager + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@task +async def a_double(x: int) -> int: + return x * 2 + + +@task +def double(x: int) -> int: + return x * 2 + + +@workflow +def level_3_subwf(x: int) -> int: + out = add_one(x=x) + return out + + +ctx = FlyteContextManager.current_context() +level_3_lp = LaunchPlan.get_or_create(level_3_subwf, "level_3_lp") + + +@eager +async def level_2(x: int) -> int: + out = add_one(x=x) + level_3_res = level_3_subwf(x=out) + level_3_lp_res = level_3_lp(x=level_3_res) + final_res = double(x=level_3_lp_res) + return final_res + + +@eager +async def level_1() -> typing.Tuple[int, int]: + i1 = add_one(x=5) + t2 = asyncio.create_task(level_2(x=1)) + + # don't forget the comma + i2, = await asyncio.gather(t2) + return i1, i2 + + +@pytest.mark.asyncio +async def test_nested_level_1_local(): + res = await level_1() + print(res) + assert res == (6, 8) + + +@pytest.mark.skip +def test_nested_local_backend(): + ctx = FlyteContextManager.current_context() + remote = FlyteRemote(Config.for_sandbox()) + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/testing/async_test/raw_output/" + print(f"Using raw output location: {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + with FlyteContextManager.with_context(ctx.with_file_access(provider).with_client(remote.client)): + res = loop_manager.run_sync(level_1.run_with_backend) + print(res) + assert res == (42, 42) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index d992ed1fa5..116717b92d 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,16 +1,17 @@ import io import os -import fsspec import pathlib import random import string import sys import tempfile +import fsspec import mock import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from flytekit.configuration import Config from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.local_fsspec import FlyteLocalFileSystem @@ -207,3 +208,18 @@ def __init__(self, *args, **kwargs): fp = FileAccessProvider("/tmp", "s3://my-bucket") fp.get_filesystem("testgetfs", test_arg="test_arg") + + +@pytest.mark.sandbox_test +def test_put_raw_data_bytes(): + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + prefix = provider.get_random_string() + provider.put_raw_data(lpath=b"hello", upload_prefix=prefix, file_name="hello_bytes") + provider.put_raw_data(lpath=io.BytesIO(b"hello"), upload_prefix=prefix, file_name="hello_bytes_io") + provider.put_raw_data(lpath=io.StringIO("hello"), upload_prefix=prefix, file_name="hello_string_io") + + fs = provider.get_filesystem("s3") + listing = fs.ls(f"{raw_output}{prefix}/") + assert len(listing) == 3 diff --git a/tests/flytekit/unit/core/test_generice_idl_type_engine.py b/tests/flytekit/unit/core/test_generice_idl_type_engine.py index f790d03bd4..be3b735e64 100644 --- a/tests/flytekit/unit/core/test_generice_idl_type_engine.py +++ b/tests/flytekit/unit/core/test_generice_idl_type_engine.py @@ -496,7 +496,10 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, dict]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), @@ -505,7 +508,10 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=2, ) recursive_assert( @@ -515,12 +521,18 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, dict]]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=3, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[int, dict]]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=2, ) @@ -2352,10 +2364,6 @@ def _check_annotation(t, annotation): typing_extensions.Annotated[int, FlyteAnnotation({"d": {"test": "data"}, "l": ["nested", ["list"]]})], {"d": {"test": "data"}, "l": ["nested", ["list"]]}, ) - _check_annotation( - typing_extensions.Annotated[int, FlyteAnnotation(InnerStruct(a=1, b="fizz", c=[1]))], - InnerStruct(a=1, b="fizz", c=[1]), - ) def test_annotated_list(): @@ -2441,112 +2449,223 @@ class AnnotatedDataclassTest(DataClassJsonMixin): @pytest.mark.parametrize( "t,expected_type", [ - (dict, LiteralType(simple=SimpleType.STRUCT)), - # Annotations are not being copied over to the LiteralType ( - typing_extensions.Annotated[dict, "a-tag"], - LiteralType(simple=SimpleType.STRUCT), + dict, + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), - (typing.Dict[int, str], LiteralType(simple=SimpleType.STRUCT)), + # Only the special `cache-key-metadata` TypeAnnotations is copied over to the LiteralType ( - typing.Dict[str, int], - LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)), + typing_extensions.Annotated[dict, "a-tag"], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), ( - typing.Dict[str, str], - LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING)), + typing.Dict[int, str], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + + ), ), ( - typing.Dict[str, typing.List[int]], - LiteralType(map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER))), + typing.Dict[str, int], + LiteralType( + map_value_type=LiteralType(simple=SimpleType.INTEGER), + ), ), - (typing.Dict[int, typing.List[int]], LiteralType(simple=SimpleType.STRUCT)), ( - typing.Dict[int, typing.Dict[int, int]], - LiteralType(simple=SimpleType.STRUCT), + typing.Dict[str, str], + LiteralType( + map_value_type=LiteralType(simple=SimpleType.STRING), + ), ), ( - typing.Dict[str, typing.Dict[int, int]], - LiteralType(map_value_type=LiteralType(simple=SimpleType.STRUCT)), + typing.Dict[str, typing.List[int]], + LiteralType( + map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER)), + ), ), ( - typing.Dict[str, typing.Dict[str, int]], - LiteralType(map_value_type=LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER))), + typing.Dict[int, typing.List[int]], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), ( - DataclassTest, - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "DataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"} + typing.Dict[int, typing.Dict[int, int]], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - "additionalProperties": False, - "required": ["a", "b"] }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + ), + ), + ( + typing.Dict[str, typing.Dict[int, int]], + LiteralType( + map_value_type=LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, }, ), ), + ), ), - # Similar to the dict[int, str] case, the annotation is not being copied over to the LiteralType ( - Annotated[DataclassTest, "another-tag"], - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "DataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"} + typing.Dict[str, typing.Dict[str, int]], + LiteralType( + map_value_type=LiteralType( + map_value_type=LiteralType( + simple=SimpleType.INTEGER, + ), + ), + ), + ), + ( + DataclassTest, + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "DataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"} + }, + "additionalProperties": False, + "required": ["a", "b"] + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - "additionalProperties": False, - "required": ["a", "b"] }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + ), + ), + # Similar to the dict[int, str] case, the annotation is not being copied over to the LiteralType + ( + Annotated[DataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "DataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"} + }, + "additionalProperties": False, + "required": ["a", "b"] + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - ), + }, ), + ), ), # Notice how the annotation in the field is not carried over either ( - Annotated[AnnotatedDataclassTest, "another-tag"], - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "AnnotatedDataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"} - }, - "additionalProperties": False, - "required": ["a", "b"] + Annotated[AnnotatedDataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "AnnotatedDataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + }, + "additionalProperties": False, + "required": ["a", "b"], + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - ), + }, + ), + ), + ), + # Notice how the `FlyteAnnotation` is carried over + ( + Annotated[int, FlyteAnnotation({"some-annotation": "a value"})], + LiteralType( + simple=SimpleType.INTEGER, + annotation=TypeAnnotation( + { + "some-annotation": "a value", + }, ), + ), ), ], ) -def test_annotated_dicts(t, expected_type): +def test_annotated(t, expected_type): assert TypeEngine.to_literal_type(t) == expected_type @@ -3320,35 +3439,6 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: assert none_value_output is None, f"None value was {none_value_output}, not None as expected" -@pytest.mark.serial -def test_lazy_import_transformers_concurrently(): - # Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure - # this achieves what we expect. - TypeEngine.has_lazy_import = False - - # Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order - after_import_mock, mock_register = mock.Mock(), mock.Mock() - mock_wrapper = mock.Mock() - mock_wrapper.mock_register = mock_register - mock_wrapper.after_import_mock = after_import_mock - - with mock.patch.object(StructuredDatasetTransformerEngine, "register", new=mock_register): - def run(): - TypeEngine.lazy_import_transformers() - after_import_mock() - - N = 5 - with ThreadPoolExecutor(max_workers=N) as executor: - futures = [executor.submit(run) for _ in range(N)] - [f.result() for f in futures] - - # Assert that all the register calls come before anything else. - assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()] * N - expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N - assert all([mock_call[0] == "mock_register" for mock_call in - mock_wrapper.mock_calls[:expected_number_of_register_calls]]) - - @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") def test_option_list_with_pipe(): pt = list[int] | None diff --git a/tests/flytekit/unit/core/test_literals_resolver.py b/tests/flytekit/unit/core/test_literals_resolver.py index 39b6c9c6ea..31260c7aae 100644 --- a/tests/flytekit/unit/core/test_literals_resolver.py +++ b/tests/flytekit/unit/core/test_literals_resolver.py @@ -10,6 +10,7 @@ from flytekit.models import interface as interface_models from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar from flytekit.types.structured.structured_dataset import StructuredDataset +from flytekit.core.interface import Interface @pytest.mark.parametrize( @@ -134,3 +135,60 @@ def test_interface(): guessed_df = lr.get("my_df") # Using the user specified type, so number of columns is correct. assert len(guessed_df.metadata.structured_dataset_type.columns) == 2 + + +def get_simple_lr() -> LiteralsResolver: + lm = { + "my_map": Literal( + map=LiteralMap( + literals={ + "k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))), + "k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))), + }, + ) + ), + "my_list": Literal( + collection=LiteralCollection( + literals=[ + Literal(scalar=Scalar(primitive=Primitive(integer=1))), + Literal(scalar=Scalar(primitive=Primitive(integer=2))), + Literal(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + ), + "val_a": Literal(scalar=Scalar(primitive=Primitive(integer=21828))), + } + + variable_map = { + "my_map": interface_models.Variable(type=TypeEngine.to_literal_type(typing.Dict[str, str]), description=""), + "my_list": interface_models.Variable(type=TypeEngine.to_literal_type(typing.List[int]), description=""), + "val_a": interface_models.Variable(type=TypeEngine.to_literal_type(int), description=""), + } + + lr = LiteralsResolver(literals=lm, variable_map=variable_map) + + return lr + + +def test_get_python_native_vm(): + lr = get_simple_lr() + lr._variable_map = None + pif = Interface(inputs={}, outputs={ + "my_map": typing.Dict[str, str], + "my_list": typing.List[int], + "val_a": int, + }) + with pytest.raises(AssertionError): + lr.as_python_native(pif) + + +def test_get_python_native(): + lr = get_simple_lr() + pif = Interface(inputs={}, outputs={ + "my_map": typing.Dict[str, str], + "my_list": typing.List[int], + "val_a": int, + }) + + out = lr.as_python_native(pif) + assert out == ({'k1': 'v1', 'k2': '2'}, [1, 2, 3], 21828) diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 5dd9926039..1c09a111e3 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -1,10 +1,14 @@ from typing import Dict import pytest +from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements import flytekit.models.task as _task_models from flytekit import Resources -from flytekit.core.resources import convert_resources_to_resource_model +from flytekit.core.resources import ( + pod_spec_from_resources, + convert_resources_to_resource_model, +) _ResourceName = _task_models.Resources.ResourceName @@ -101,3 +105,53 @@ def test_resources_round_trip(): json_str = original.to_json() result = Resources.from_json(json_str) assert original == result + + +def test_pod_spec_from_resources_requests_limits_set(): + requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi") + limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi") + k8s_pod_name = "foo" + + expected_pod_spec = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests={ + "cpu": "1", + "memory": "1Gi", + "nvidia.com/gpu": "1", + "ephemeral-storage": "1Gi", + }, + limits={ + "cpu": "4", + "memory": "2Gi", + "nvidia.com/gpu": "1", + "ephemeral-storage": "1Gi", + }, + ), + ) + ] + ) + pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) + assert expected_pod_spec == V1PodSpec(**pod_spec) + + +def test_pod_spec_from_resources_requests_set(): + requests = Resources(cpu="1", mem="1Gi") + limits = None + k8s_pod_name = "foo" + + expected_pod_spec = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests={"cpu": "1", "memory": "1Gi"}, + limits={"cpu": "1", "memory": "1Gi"}, + ), + ) + ] + ) + pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) + assert expected_pod_spec == V1PodSpec(**pod_spec) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 379af3cc93..82b69859bd 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -16,6 +16,7 @@ from flytekit.exceptions.user import FlyteAssertion, FlyteMissingTypeException from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.annotation import TypeAnnotation from flytekit.models.literals import ( BindingData, BindingDataCollection, @@ -829,9 +830,7 @@ def wf_with_input() -> typing.Dict[str, int]: } ) - assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( - map_value_type=LiteralType(simple=SimpleType.INTEGER) - ) + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)) assert wf_with_input() == input_val diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 109c647ae5..48f7e4b959 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -486,7 +486,10 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, dict]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), @@ -495,7 +498,10 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=2, ) recursive_assert( @@ -505,12 +511,18 @@ def recursive_assert( ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, dict]]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=3, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[int, dict]]]), - LiteralType(simple=SimpleType.STRUCT), + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation({"cache-key-metadata": {"serialization-format": "msgpack"}}) + ), expected_depth=2, ) @@ -2357,10 +2369,6 @@ def _check_annotation(t, annotation): typing_extensions.Annotated[int, FlyteAnnotation({"d": {"test": "data"}, "l": ["nested", ["list"]]})], {"d": {"test": "data"}, "l": ["nested", ["list"]]}, ) - _check_annotation( - typing_extensions.Annotated[int, FlyteAnnotation(InnerStruct(a=1, b="fizz", c=[1]))], - InnerStruct(a=1, b="fizz", c=[1]), - ) def test_annotated_list(): @@ -2446,112 +2454,223 @@ class AnnotatedDataclassTest(DataClassJsonMixin): @pytest.mark.parametrize( "t,expected_type", [ - (dict, LiteralType(simple=SimpleType.STRUCT)), - # Annotations are not being copied over to the LiteralType ( - typing_extensions.Annotated[dict, "a-tag"], - LiteralType(simple=SimpleType.STRUCT), + dict, + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), - (typing.Dict[int, str], LiteralType(simple=SimpleType.STRUCT)), + # Only the special `cache-key-metadata` TypeAnnotations is copied over to the LiteralType ( - typing.Dict[str, int], - LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)), + typing_extensions.Annotated[dict, "a-tag"], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), ( - typing.Dict[str, str], - LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING)), + typing.Dict[int, str], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + + ), ), ( - typing.Dict[str, typing.List[int]], - LiteralType(map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER))), + typing.Dict[str, int], + LiteralType( + map_value_type=LiteralType(simple=SimpleType.INTEGER), + ), ), - (typing.Dict[int, typing.List[int]], LiteralType(simple=SimpleType.STRUCT)), ( - typing.Dict[int, typing.Dict[int, int]], - LiteralType(simple=SimpleType.STRUCT), + typing.Dict[str, str], + LiteralType( + map_value_type=LiteralType(simple=SimpleType.STRING), + ), ), ( - typing.Dict[str, typing.Dict[int, int]], - LiteralType(map_value_type=LiteralType(simple=SimpleType.STRUCT)), + typing.Dict[str, typing.List[int]], + LiteralType( + map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER)), + ), ), ( - typing.Dict[str, typing.Dict[str, int]], - LiteralType(map_value_type=LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER))), + typing.Dict[int, typing.List[int]], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, + }, + ), + ), ), ( - DataclassTest, - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "DataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"} + typing.Dict[int, typing.Dict[int, int]], + LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - "additionalProperties": False, - "required": ["a", "b"] }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + ), + ), + ( + typing.Dict[str, typing.Dict[int, int]], + LiteralType( + map_value_type=LiteralType( + simple=SimpleType.STRUCT, + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", + }, }, ), ), + ), ), - # Similar to the dict[int, str] case, the annotation is not being copied over to the LiteralType ( - Annotated[DataclassTest, "another-tag"], - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "DataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"} + typing.Dict[str, typing.Dict[str, int]], + LiteralType( + map_value_type=LiteralType( + map_value_type=LiteralType( + simple=SimpleType.INTEGER, + ), + ), + ), + ), + ( + DataclassTest, + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "DataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"} + }, + "additionalProperties": False, + "required": ["a", "b"] + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - "additionalProperties": False, - "required": ["a", "b"] }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + ), + ), + # Similar to the dict[int, str] case, the annotation is not being copied over to the LiteralType + ( + Annotated[DataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "DataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"} + }, + "additionalProperties": False, + "required": ["a", "b"] + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), + }, + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - ), + }, ), + ), ), # Notice how the annotation in the field is not carried over either ( - Annotated[AnnotatedDataclassTest, "another-tag"], - LiteralType( - simple=SimpleType.STRUCT, - metadata={ - "type": "object", - "title": "AnnotatedDataclassTest", - "properties": { - "a": {"type": "integer"}, - "b": {"type": "string"}, - }, - "additionalProperties": False, - "required": ["a", "b"], + Annotated[AnnotatedDataclassTest, "another-tag"], + LiteralType( + simple=SimpleType.STRUCT, + metadata={ + "type": "object", + "title": "AnnotatedDataclassTest", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + }, + "additionalProperties": False, + "required": ["a", "b"], + }, + structure=TypeStructure( + tag="", + dataclass_type={ + "a": LiteralType(simple=SimpleType.INTEGER), + "b": LiteralType(simple=SimpleType.STRING), }, - structure=TypeStructure( - tag="", - dataclass_type={ - "a": LiteralType(simple=SimpleType.INTEGER), - "b": LiteralType(simple=SimpleType.STRING), + ), + annotation=TypeAnnotation( + { + "cache-key-metadata": { + "serialization-format": "msgpack", }, - ), + }, + ), + ), + ), + # Notice how the `FlyteAnnotation` is carried over + ( + Annotated[int, FlyteAnnotation({"some-annotation": "a value"})], + LiteralType( + simple=SimpleType.INTEGER, + annotation=TypeAnnotation( + { + "some-annotation": "a value", + }, ), + ), ), ], ) -def test_annotated_dicts(t, expected_type): +def test_annotated(t, expected_type): assert TypeEngine.to_literal_type(t) == expected_type @@ -3332,10 +3451,6 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: @pytest.mark.serial def test_lazy_import_transformers_concurrently(): - # Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure - # this achieves what we expect. - TypeEngine.has_lazy_import = False - # Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order after_import_mock, mock_register = mock.Mock(), mock.Mock() mock_wrapper = mock.Mock() @@ -3352,11 +3467,11 @@ def run(): futures = [executor.submit(run) for _ in range(N)] [f.result() for f in futures] - # Assert that all the register calls come before anything else. - assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()] * N + assert mock_wrapper.mock_calls[-1] == mock.call.after_import_mock() expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N + assert sum([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls]) == expected_number_of_register_calls assert all([mock_call[0] == "mock_register" for mock_call in - mock_wrapper.mock_calls[:expected_number_of_register_calls]]) + mock_wrapper.mock_calls[:int(len(mock_wrapper.mock_calls)/N)-1]]) @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9") diff --git a/tests/flytekit/unit/core/test_worker_queue.py b/tests/flytekit/unit/core/test_worker_queue.py new file mode 100644 index 0000000000..1c1cfb1446 --- /dev/null +++ b/tests/flytekit/unit/core/test_worker_queue.py @@ -0,0 +1,109 @@ +import mock +import pytest +import asyncio + +from flytekit.core.task import task +from flytekit.remote.remote import FlyteRemote +from flytekit.core.worker_queue import Controller, WorkItem +from flytekit.configuration import ImageConfig, LocalConfig, SerializationSettings +from flytekit.utils.asyn import loop_manager + + +@mock.patch("flytekit.core.worker_queue.Controller.launch_and_start_watch") +def test_controller(mock_start): + @task + def t1() -> str: + return "hello" + + remote = FlyteRemote.for_sandbox() + ss = SerializationSettings( + image_config=ImageConfig.auto_default_image(), + ) + c = Controller(remote, ss, tag="exec-id", root_tag="exec-id", exec_prefix="e-unit-test") + + def _mock_start(wi: WorkItem, idx: int): + assert c.entries[wi.entity.name][idx] is wi + wi.wf_exec = mock.MagicMock() # just to pass the assert + wi.set_result("hello") + + mock_start.side_effect = _mock_start + + async def fake_eager(): + loop = asyncio.get_running_loop() + f = c.add(loop, entity=t1, input_kwargs={}) + res = await f + assert res == "hello" + + loop_manager.run_sync(fake_eager) + + +@pytest.mark.asyncio +@mock.patch("flytekit.core.worker_queue.Controller") +async def test_controller_launch(mock_controller): + @task + def t2() -> str: + return "hello" + + def _mock_execute( + entity, + execution_name: str, + inputs, + version, + image_config, + options, + envs, + ): + assert entity is t2 + assert execution_name.startswith("e-unit-test-t2-") + assert envs == {'_F_EE_ROOT': 'exec-id'} + print(entity, execution_name, inputs, version, image_config, options, envs) + wf_exec = mock.MagicMock() + return wf_exec + + remote = mock.MagicMock() + remote.execute.side_effect = _mock_execute + mock_controller.informer.watch.return_value = True + + loop = asyncio.get_running_loop() + fut = loop.create_future() + wi = WorkItem(t2, input_kwargs={}, fut=fut) + + ss = SerializationSettings( + image_config=ImageConfig.auto_default_image(), + ) + c = Controller(remote, ss, tag="exec-id", root_tag="exec-id", exec_prefix="e-unit-test") + + c.launch_and_start_watch(wi, 0) + assert wi.error is None + + wi.result = 5 + c.launch_and_start_watch(wi, 0) + # Function shouldn't be called if item already has a result + with pytest.raises(AssertionError): + await fut + + +@pytest.mark.asyncio +async def test_wi(): + @task + def t1() -> str: + return "hello" + + loop = asyncio.get_running_loop() + fut = loop.create_future() + wi = WorkItem(t1, input_kwargs={}, fut=fut) + + with pytest.raises(AssertionError): + wi.set_result("hello") + + assert not wi.ready + + wi.wf_exec = mock.MagicMock() + wi.set_result("hello") + assert wi.ready + + fut2 = loop.create_future() + wi = WorkItem(t1, input_kwargs={}, fut=fut2) + wi.set_error(ValueError("hello")) + with pytest.raises(ValueError): + await fut2 diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py deleted file mode 100644 index 898d11a5ba..0000000000 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ /dev/null @@ -1,306 +0,0 @@ -import asyncio -import mock -import os -import sys -import typing -from pathlib import Path - -import hypothesis.strategies as st -import pytest -from hypothesis import given - -from flytekit import dynamic, task, workflow - -from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute -from flytekit.core import context_manager -from flytekit.core.promise import VoidPromise -from flytekit.exceptions.user import FlyteValidationException -from flytekit.experimental import EagerException, eager -from flytekit.models import literals as _literal_models -from flytekit.types.directory import FlyteDirectory -from flytekit.types.file import FlyteFile -from flytekit.types.structured import StructuredDataset - -INTEGER_ST = st.integers(min_value=-10_000_000, max_value=10_000_000) - - -@task -def add_one(x: int) -> int: - return x + 1 - - -@task -def double(x: int) -> int: - return x * 2 - - -@task -def gt_0(x: int) -> bool: - return x > 0 - - -@task -def raises_exc(x: int) -> int: - if x == 0: - raise TypeError - return x - - -@dynamic -def dynamic_wf(x: int) -> int: - out = add_one(x=x) - return double(x=out) - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_simple_eager_workflow(x_input: int): - """Testing simple eager workflow with just tasks.""" - - @eager - async def eager_wf(x: int) -> int: - out = await add_one(x=x) - return await double(x=out) - - result = asyncio.run(eager_wf(x=x_input)) - assert result == (x_input + 1) * 2 - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_conditional_eager_workflow(x_input: int): - """Test eager workflow with conditional logic.""" - - @eager - async def eager_wf(x: int) -> int: - if await gt_0(x=x): - return -1 - return 1 - - result = asyncio.run(eager_wf(x=x_input)) - if x_input > 0: - assert result == -1 - else: - assert result == 1 - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_try_except_eager_workflow(x_input: int): - """Test eager workflow with try/except logic.""" - - @eager - async def eager_wf(x: int) -> int: - try: - return await raises_exc(x=x) - except EagerException: - return -1 - - result = asyncio.run(eager_wf(x=x_input)) - if x_input == 0: - assert result == -1 - else: - assert result == x_input - - -@given(x_input=INTEGER_ST, n_input=st.integers(min_value=1, max_value=20)) -@pytest.mark.hypothesis -def test_gather_eager_workflow(x_input: int, n_input: int): - """Test eager workflow with asyncio gather.""" - - @eager - async def eager_wf(x: int, n: int) -> typing.List[int]: - results = await asyncio.gather(*[add_one(x=x) for _ in range(n)]) - return results - - results = asyncio.run(eager_wf(x=x_input, n=n_input)) - assert results == [x_input + 1 for _ in range(n_input)] - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_eager_workflow_with_dynamic_exception(x_input: int): - """Test eager workflow with dynamic workflow is not supported.""" - - @eager - async def eager_wf(x: int) -> typing.List[int]: - return await dynamic_wf(x=x) - - with pytest.raises(EagerException, match="Eager workflows currently do not work with dynamic workflows"): - asyncio.run(eager_wf(x=x_input)) - - -@eager -async def nested_eager_wf(x: int) -> int: - return await add_one(x=x) - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_nested_eager_workflow(x_input: int): - """Testing running nested eager workflows.""" - - @eager - async def eager_wf(x: int) -> int: - out = await nested_eager_wf(x=x) - return await double(x=out) - - result = asyncio.run(eager_wf(x=x_input)) - assert result == (x_input + 1) * 2 - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_eager_workflow_within_workflow(x_input: int): - """Testing running eager workflow within a static workflow.""" - - @eager - async def eager_wf(x: int) -> int: - return await add_one(x=x) - - @workflow - def wf(x: int) -> int: - out = eager_wf(x=x) - return double(x=out) - - result = wf(x=x_input) - assert result == (x_input + 1) * 2 - - -@workflow -def subworkflow(x: int) -> int: - return add_one(x=x) - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_workflow_within_eager_workflow(x_input: int): - """Testing running a static workflow within an eager workflow.""" - - @eager - async def eager_wf(x: int) -> int: - out = await subworkflow(x=x) - return await double(x=out) - - result = asyncio.run(eager_wf(x=x_input)) - assert result == (x_input + 1) * 2 - - -@given(x_input=INTEGER_ST) -@pytest.mark.hypothesis -def test_local_task_eager_workflow_exception(x_input: int): - """Testing simple eager workflow with a local function task doesn't work.""" - - @task - def local_task(x: int) -> int: - return x - - @eager - async def eager_wf_with_local(x: int) -> int: - return await local_task(x=x) - - with pytest.raises(TypeError): - asyncio.run(eager_wf_with_local(x=x_input)) - - -@given(x_input=INTEGER_ST) -@pytest.mark.filterwarnings("ignore:coroutine 'AsyncEntity.__call__' was never awaited") -@pytest.mark.hypothesis -def test_local_workflow_within_eager_workflow_exception(x_input: int): - """Cannot call a locally-defined workflow within an eager workflow""" - - @workflow - def local_wf(x: int) -> int: - return add_one(x=x) - - @eager - async def eager_wf(x: int) -> int: - out = await local_wf(x=x) - return await double(x=out) - - with pytest.raises(FlyteValidationException): - asyncio.run(eager_wf(x=x_input)) - - -@task -def create_structured_dataset() -> StructuredDataset: - import pandas as pd - - df = pd.DataFrame({"a": [1, 2, 3]}) - return StructuredDataset(dataframe=df) - - -@task -def create_file() -> FlyteFile: - fname = "/tmp/flytekit_test_file" - with open(fname, "w") as fh: - fh.write("some data\n") - return FlyteFile(path=fname) - - -@task -def create_directory() -> FlyteDirectory: - dirname = "/tmp/flytekit_test_dir" - Path(dirname).mkdir(exist_ok=True, parents=True) - with open(os.path.join(dirname, "file"), "w") as tmp: - tmp.write("some data\n") - return FlyteDirectory(path=dirname) - - -@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") -@pytest.mark.hypothesis -def test_eager_workflow_with_offloaded_types(): - """Test eager workflow that eager workflows work with offloaded types.""" - import pandas as pd - - @eager - async def eager_wf_structured_dataset() -> int: - dataset = await create_structured_dataset() - df = dataset.open(pd.DataFrame).all() - return df["a"].sum() - - @eager - async def eager_wf_flyte_file() -> str: - file = await create_file() - with open(file.path) as f: - data = f.read().strip() - return data - - @eager - async def eager_wf_flyte_directory() -> str: - directory = await create_directory() - with open(os.path.join(directory.path, "file")) as f: - data = f.read().strip() - return data - - result = asyncio.run(eager_wf_structured_dataset()) - assert result == 6 - - result = asyncio.run(eager_wf_flyte_file()) - assert result == "some data" - - result = asyncio.run(eager_wf_flyte_directory()) - assert result == "some data" - - -@mock.patch("flytekit.core.utils.load_proto_from_file") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.core.utils.write_proto_to_file") -def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop): - """Test that event loop is preserved after executing eager workflow via dispatch.""" - - @eager - async def eager_wf(): - await asyncio.sleep(0.1) - return - - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) - ) - ) as ctx: - _dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix") - loop_after_execute = asyncio.get_event_loop_policy().get_event_loop() - assert event_loop == loop_after_execute diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py index 05ba54903f..d929e4d4fa 100644 --- a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py @@ -10,9 +10,12 @@ from pydantic import BaseModel, Field from flytekit import task, workflow +from flytekit.core.constants import CACHE_KEY_METADATA, MESSAGEPACK, SERIALIZATION_FORMAT from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.models.annotation import TypeAnnotation from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType, SimpleType from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema @@ -980,3 +983,12 @@ def wf_return_bm(bm: bm) -> bm: return bm assert wf_return_bm(bm=bm(a=1, b=2)) == bm(a=1, b=2) + + +def test_basemodel_literal_type_annotation(): + class BM(BaseModel): + a: int = -1 + b: float = 2.1 + c: str = "Hello, Flyte" + + assert TypeEngine.to_literal_type(BM).annotation == TypeAnnotation({CACHE_KEY_METADATA: {SERIALIZATION_FORMAT: MESSAGEPACK}}) diff --git a/tests/flytekit/unit/lazy_module/test_lazy_module.py b/tests/flytekit/unit/lazy_module/test_lazy_module.py index 83c0fb86a7..169954a1c5 100644 --- a/tests/flytekit/unit/lazy_module/test_lazy_module.py +++ b/tests/flytekit/unit/lazy_module/test_lazy_module.py @@ -1,12 +1,24 @@ -import pytest +import sys +from unittest.mock import Mock -from flytekit.lazy_import.lazy_module import LazyModule, lazy_module +import pytest +from flytekit.lazy_import.lazy_module import _LazyModule, lazy_module, is_imported def test_lazy_module(): mod = lazy_module("click") assert mod.__name__ == "click" mod = lazy_module("fake_module") - assert isinstance(mod, LazyModule) + + sys.modules["fake_module"] = mod + assert not is_imported("fake_module") + assert isinstance(mod, _LazyModule) with pytest.raises(ImportError, match="Module fake_module is not yet installed."): print(mod.attr) + + non_lazy_module = Mock() + non_lazy_module.__name__ = 'NonLazyModule' + sys.modules["fake_module"] = non_lazy_module + assert is_imported("fake_module") + + assert is_imported("dataclasses") diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 85b51ca2fd..df54eda6af 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -17,15 +17,13 @@ from mock import ANY, MagicMock, patch import flytekit.configuration -from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, \ - map_task, dynamic +from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map_task, dynamic, eager from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteAssertion -from flytekit.experimental.eager_function import eager from flytekit.models import common as common_models from flytekit.models import security from flytekit.models.admin.workflow import Workflow, WorkflowClosure @@ -35,7 +33,7 @@ from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution from flytekit.models.task import Task -from flytekit.remote import FlyteTask +from flytekit.remote import FlyteTask, FlyteWorkflow from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan @@ -773,17 +771,17 @@ def t1(a: int) -> int: return a + 1 @task - def t2(a: int) -> int: + async def t2(a: int) -> int: return a * 2 @eager async def eager_wf(a: int) -> int: - out = await t1(a=a) + out = t1(a=a) if out < 0: return -1 return await t2(a=out) - with pytest.raises(FlyteAssertion): + with pytest.raises(FlyteAssertion, match="Eager tasks are not supported in interactive mode"): _get_pickled_target_dict(eager_wf) @@ -811,3 +809,46 @@ def wf() -> int: # the second one should rr.register_launch_plan(lp2, version="1", serialization_settings=ss) mock_client.update_launch_plan.assert_called() + + +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_register_task_with_node_dependency_hints(mock_client): + @task + def task0(): + return None + + @workflow + def workflow0(): + return task0() + + @dynamic(node_dependency_hints=[workflow0]) + def dynamic0(): + return workflow0() + + @workflow + def workflow1(): + return dynamic0() + + rr = FlyteRemote( + Config.for_sandbox(), + default_project="flytesnacks", + default_domain="development", + ) + + ss = SerializationSettings( + image_config=ImageConfig.from_images("docker.io/abc:latest"), + version="dummy_version", + ) + + registered_task = rr.register_task(dynamic0, ss) + assert isinstance(registered_task, FlyteTask) + assert registered_task.id.resource_type == ResourceType.TASK + assert registered_task.id.project == "flytesnacks" + assert registered_task.id.domain == "development" + # When running via `make unit_test` there is a `__-channelexec__` prefix added to the name. + assert registered_task.id.name.endswith("tests.flytekit.unit.remote.test_remote.dynamic0") + assert registered_task.id.version == "dummy_version" + + registered_workflow = rr.register_workflow(workflow1, ss) + assert isinstance(registered_workflow, FlyteWorkflow) + assert registered_workflow.id == Identifier(ResourceType.WORKFLOW, "flytesnacks", "development", "tests.flytekit.unit.remote.test_remote.workflow1", "dummy_version") diff --git a/tests/flytekit/unit/tools/test_fast_registration.py b/tests/flytekit/unit/tools/test_fast_registration.py index 0888f678eb..f99ecb82ee 100644 --- a/tests/flytekit/unit/tools/test_fast_registration.py +++ b/tests/flytekit/unit/tools/test_fast_registration.py @@ -1,6 +1,9 @@ import os +import socket import subprocess +import sys import tarfile +import tempfile import time from hashlib import md5 from pathlib import Path @@ -48,6 +51,33 @@ def flyte_project(tmp_path): return tmp_path +@pytest.mark.skipif( + sys.platform == "win32", + reason="Skip if running on windows since Unix Domain Sockets do not exist in that OS", +) +def test_skip_socket_file(): + tmp_dir = tempfile.mkdtemp() + + tree = { + "data": {"large.file": "", "more.files": ""}, + "src": { + "workflows": { + "hello_world.py": "print('Hello World!')", + }, + }, + } + + # Add a socket file + socket_path = tmp_dir + "/test.sock" + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_socket.bind(socket_path) + + subprocess.run(["git", "init", str(tmp_dir)]) + + # Assert that this runs successfully + compute_digest(str(tmp_dir)) + + def test_package(flyte_project, tmp_path): archive_fname = fast_package(source=flyte_project, output_dir=tmp_path) with tarfile.open(archive_fname) as tar: diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index b40cd9dc78..6d14cac107 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -1,9 +1,21 @@ import os +import socket import subprocess import sys +import tempfile +from pathlib import Path +from types import ModuleType +from unittest.mock import patch -from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules +import pytest + +import flytekit from flytekit.core.tracker import import_module_from_file +from flytekit.tools.script_mode import compress_scripts, hash_file, add_imported_modules_from_source, get_all_modules, \ + list_all_files +from flytekit.tools.script_mode import ( + list_imported_modules_as_files, +) MAIN_WORKFLOW = """ from flytekit import task, workflow @@ -237,3 +249,90 @@ def test_get_all_modules(tmp_path): # Workflow exists, so it is imported workflow_file.write_text(WORKFLOW_CONTENT) assert n_sys_modules + 1 == len(get_all_modules(os.fspath(source_dir), "my_workflows.main")) + + +@patch("flytekit.tools.script_mode.sys") +@patch("site.getsitepackages") +def test_list_imported_modules_as_files(mock_getsitepackage, mock_sys, tmp_path): + + bin_directory = Path(os.path.dirname(sys.executable)) + flytekit_root = Path(os.path.dirname(flytekit.__file__)) + source_path = tmp_path / "project" + + # Site packages should be executed + site_packages = [ + str(source_path / ".venv" / "lib" / "python3.10" / "site-packages"), + str(source_path / ".venv" / "local" / "lib" / "python3.10" / "dist-packages"), + str(source_path / ".venv" / "lib" / "python3" / "dist-packages"), + str(source_path / ".venv" / "lib" / "python3.10" / "dist-packages"), + ] + mock_getsitepackage.return_value = site_packages + + # lib module that should be excluded, even if it is in the same roto as source_path + lib_path = source_path / "micromamba" / "envs" / "my-env" + lib_modules = [ + (ModuleType("lib_module"), str(lib_path / "module.py")) + ] + # mock the sys prefix to be in the source path + mock_sys.prefix = str(lib_path) + + # bin module that should be excluded + bin_modules = [ + (ModuleType("bin_module"), str(bin_directory / "bin" / "module.py")) + ] + # site modules that should be excluded + site_modules = [ + (ModuleType("site_module_1"), str(Path(site_packages[0]) / "package" / "module_1.py")), + (ModuleType("site_module_2"), str(Path(site_packages[1]) / "package" / "module_2.py")), + (ModuleType("site_module_3"), str(Path(site_packages[2]) / "package" / "module_3.py")), + (ModuleType("site_module_4"), str(Path(site_packages[3]) / "package" / "module_4.py")), + ] + + # local modules that should be included + local_modules = [ + (ModuleType("local_module_1"), str(source_path / "package_a" / "module_1.py")), + (ModuleType("local_module_2"), str(source_path / "package_a" / "module_2.py")), + (ModuleType("local_module_3"), str(source_path / "package_b" / "module_3.py")), + (ModuleType("local_module_4"), str(source_path / "package_b" / "module_4.py")), + ] + flyte_modules = [ + (ModuleType("flyte_module"), str(flytekit_root / "package" / "module.py")) + ] + + module_path_pairs = local_modules + flyte_modules + bin_modules + lib_modules + site_modules + + for m, p in module_path_pairs: + m.__file__ = p + + modules = [m for m, _ in module_path_pairs] + + file_list = list_imported_modules_as_files(str(source_path), modules) + + assert sorted(file_list) == sorted([p for _, p in local_modules]) + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="Skip if running on windows since Unix Domain Sockets do not exist in that OS", +) +def test_list_all_files_skip_socket_files(): + tmp_dir = Path(tempfile.mkdtemp()) + + source_dir = tmp_dir / "source" + source_dir.mkdir() + + file1 = source_dir / "file1.py" + file1.write_text("") + + socket_file = source_dir / "test.socket" + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_socket.bind(os.fspath(socket_file)) + + files = list(list_all_files(os.fspath(source_dir), False)) + + # Ensure that the socket file is not in the list of files + assert str(socket_file) not in files + + # Ensure that the regular file is the only file in the list + assert len(files) == 1 + assert str(file1) in files diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index 5433b79a9c..c8efbe272d 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -22,7 +22,6 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType from flytekit.tools.translator import get_serializable -from flytekit.types.file import FlyteFile from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, diff --git a/tests/flytekit/unit/utils/test_asyn.py b/tests/flytekit/unit/utils/test_asyn.py index b8ce75b2a7..69d27f827a 100644 --- a/tests/flytekit/unit/utils/test_asyn.py +++ b/tests/flytekit/unit/utils/test_asyn.py @@ -1,5 +1,8 @@ import os import threading +import sys + +import mock import pytest import asyncio from typing import List, Dict, Optional @@ -74,7 +77,7 @@ async def async_function(n: int, orig: int) -> str: print(f"Async[{n}] Started! CTX id {id(ctx)} @ depth {ctx.vals['depth']} Thread: {threading.current_thread().name}") if n > 0: - await asyncio.sleep(0.5) # Simulate some async work + await asyncio.sleep(0.01) # Simulate some async work result = sync_function(n - 1, orig) # Call the synchronous function return f"Async[{n}]: {result}" else: @@ -122,3 +125,46 @@ def test_recursive_calling(): # things like pytorch elastic. for k in loop_manager._runner_map.keys(): assert str(os.getpid()) in k + + +@mock.patch("flytekit.utils.asyn._TaskRunner.get_exc_handler") +def test_error_two_ways(mock_getter): + + # First reset everything so that the _TaskRunners get recreated + keys = [k for k in loop_manager._runner_map.keys()] + for k in keys: + l = loop_manager._runner_map[k] + l._close() + del loop_manager._runner_map[k] + + # Test exception handling two ways + mock_handler = mock.MagicMock() + mock_getter.return_value = mock_handler + + async def runner_1(): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_exception(ValueError("Future failed!")) + + # this should trigger the exception handler because there's an uncaught exception on a future. + + loop_manager.run_sync(runner_1) + + def sync_error(): + raise ValueError("This is a test2") + + async def get_exc(): + raise ValueError("This is a test") + + async def runner_2(): + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(sync_error) + t = loop.create_task(get_exc()) + return await t + + # This should trigger the handler because the ss call raises a ValueError as the first step, so when await t + # is run, the ss function sync_error function will raise + with pytest.raises(ValueError): + loop_manager.run_sync(runner_2) + + assert mock_handler.call_count == 2