Skip to content

Commit

Permalink
make FlyteFile and Directory pickleable
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Jan 3, 2025
1 parent 8a0804e commit 2060962
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
7 changes: 4 additions & 3 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.exceptions.user import FlyteAssertion
from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator
from flytekit.models import types as _type_models
Expand Down Expand Up @@ -666,7 +667,7 @@ async def async_to_python_value(

batch_size = get_batch_size(expected_python_type)

_downloader = partial(_flyte_directory_downloader, ctx, uri, local_folder, batch_size)
_downloader = partial(_flyte_directory_downloader, ctx.file_access, uri, local_folder, batch_size)

expected_format = self.get_format(expected_python_type)

Expand All @@ -683,8 +684,8 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirec
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


def _flyte_directory_downloader(ctx: FlyteContext, uri: str, local_folder: str, batch_size: int):
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
def _flyte_directory_downloader(file_access_provider: FileAccessProvider, uri: str, local_folder: str, batch_size: int):
return file_access_provider.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)


TypeEngine.register(FlyteDirToMultipartBlobTransformer())
8 changes: 4 additions & 4 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import (
AsyncTypeTransformer,
TypeEngine,
Expand Down Expand Up @@ -695,12 +696,11 @@ async def async_to_python_value(
# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(_flyte_file_downloader, ctx, uri, local_path)
_downloader = partial(_flyte_file_downloader, ctx.file_access, uri, local_path)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
ff._remote_source = uri

return ff

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
Expand All @@ -714,8 +714,8 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


def _flyte_file_downloader(ctx: FlyteContext, uri: str, local_path: str):
return ctx.file_access.get_data(uri, local_path, is_multipart=False)
def _flyte_file_downloader(file_access_provider: FileAccessProvider, uri: str, local_path: str):
return file_access_provider.get_data(uri, local_path, is_multipart=False)


TypeEngine.register(FlyteFilePathTransformer(), additional_types=[os.PathLike])
30 changes: 27 additions & 3 deletions tests/flytekit/unit/core/test_flyte_directory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import pickle
import shutil
import tempfile
import typing
Expand All @@ -20,7 +21,7 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -407,8 +408,7 @@ def my_wf(path: SvgDirectory) -> DC:
assert dc1 == dc2


def test_input_from_flyte_console_attribute_access_flytefile(
local_dummy_directory):
def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory):
# Flyte Console will send the input data as protobuf Struct

dict_obj = {"path": local_dummy_directory}
Expand All @@ -422,3 +422,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory)
assert isinstance(downstream_input, FlyteDirectory)
assert downstream_input == FlyteDirectory(local_dummy_directory)


def test_flyte_directory_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/directory",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.MULTIPART,
format=""
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
27 changes: 26 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pathlib
import pickle
import tempfile
import typing
from unittest.mock import MagicMock, patch
Expand All @@ -19,7 +20,7 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -782,3 +783,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file):
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile)
assert downstream_input == FlyteFile(local_dummy_file)


def test_flyte_file_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/file",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.SINGLE,
format="txt"
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input

0 comments on commit 2060962

Please sign in to comment.