Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlyteFile and FlyteDirectory pickleable #3030

Merged
merged 11 commits into from
Jan 8, 2025
11 changes: 4 additions & 7 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import typing
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
from uuid import UUID
Expand Down Expand Up @@ -374,23 +375,20 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir
paths.append(FlyteDirectory(joined_path))
return paths

def create_downloader(_remote_path: str, _local_path: str, is_multipart: bool):
return lambda: file_access.get_data(_remote_path, _local_path, is_multipart=is_multipart)

fs = file_access.get_filesystem_for_path(final_path)
for key in fs.listdir(final_path):
remote_path = os.path.join(final_path, key["name"].split(os.sep)[-1])
if key["type"] == "file":
local_path = file_access.get_random_local_path()
os.makedirs(pathlib.Path(local_path).parent, exist_ok=True)
downloader = create_downloader(remote_path, local_path, is_multipart=False)
downloader = partial(file_access.get_data, remote_path, local_path, is_multipart=False)

flyte_file: FlyteFile = FlyteFile(local_path, downloader=downloader)
flyte_file._remote_source = remote_path
paths.append(flyte_file)
else:
local_folder = file_access.get_random_local_directory()
downloader = create_downloader(remote_path, local_folder, is_multipart=True)
downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True)

flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader)
flyte_directory._remote_source = remote_path
Expand Down Expand Up @@ -665,8 +663,7 @@ async def async_to_python_value(

batch_size = get_batch_size(expected_python_type)

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
_downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)

Expand Down
24 changes: 7 additions & 17 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, cast
from urllib.parse import unquote

Expand Down Expand Up @@ -307,7 +308,8 @@ def __init__(
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = lambda: FlyteFilePathTransformer.downloader(
self._downloader = partial(
ctx.file_access.get_data,
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
Expand Down Expand Up @@ -732,26 +734,14 @@ async def async_to_python_value(

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(
path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path)
)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader)
ff._remote_source = uri

return ff

@staticmethod
def downloader(
ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike]
) -> None:
"""
Download data from remote_path to local_path.

We design the downloader as a static method because its behavior is logically
related to this class but don't need to interact with class or instance data.
"""
ctx.file_access.get_data(remote_path, local_path, is_multipart=False)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if (
literal_type.blob is not None
Expand Down
30 changes: 27 additions & 3 deletions tests/flytekit/unit/core/test_flyte_directory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import pickle
import shutil
import tempfile
import typing
Expand All @@ -20,7 +21,7 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -407,8 +408,7 @@ def my_wf(path: SvgDirectory) -> DC:
assert dc1 == dc2


def test_input_from_flyte_console_attribute_access_flytefile(
local_dummy_directory):
def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory):
# Flyte Console will send the input data as protobuf Struct

dict_obj = {"path": local_dummy_directory}
Expand All @@ -422,3 +422,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory)
assert isinstance(downstream_input, FlyteDirectory)
assert downstream_input == FlyteDirectory(local_dummy_directory)


def test_flyte_directory_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/directory",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.MULTIPART,
format=""
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
Comment on lines +445 to +448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more pickle test assertions

Consider adding more assertions to verify the pickled/unpickled FlyteDirectory object's properties like uri and other attributes are preserved correctly after deserialization.

Code suggestion
Check the AI-generated fix before applying
Suggested change
# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
assert unpickled_input.uri == "s3://sample-path/directory"

Code Review Run #639caa


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

27 changes: 26 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pathlib
import pickle
import tempfile
import typing
from unittest.mock import MagicMock, patch
Expand All @@ -19,7 +20,7 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -782,3 +783,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file):
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile)
assert downstream_input == FlyteFile(local_dummy_file)


def test_flyte_file_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/file",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.SINGLE,
format="txt"
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
Loading