Skip to content

Commit

Permalink
make _downloader function in FlyteFile/Directory pickleable
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Jan 2, 2025
1 parent c95cc63 commit 8a0804e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 6 additions & 2 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import typing
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
from uuid import UUID
Expand Down Expand Up @@ -665,8 +666,7 @@ async def async_to_python_value(

batch_size = get_batch_size(expected_python_type)

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
_downloader = partial(_flyte_directory_downloader, ctx, uri, local_folder, batch_size)

expected_format = self.get_format(expected_python_type)

Expand All @@ -683,4 +683,8 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirec
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


def _flyte_directory_downloader(ctx: FlyteContext, uri: str, local_folder: str, batch_size: int):
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)


TypeEngine.register(FlyteDirToMultipartBlobTransformer())
8 changes: 6 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, cast
from urllib.parse import unquote

Expand Down Expand Up @@ -694,8 +695,7 @@ async def async_to_python_value(
# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

def _downloader():
return ctx.file_access.get_data(uri, local_path, is_multipart=False)
_downloader = partial(_flyte_file_downloader, ctx, uri, local_path)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
Expand All @@ -714,4 +714,8 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


def _flyte_file_downloader(ctx: FlyteContext, uri: str, local_path: str):
return ctx.file_access.get_data(uri, local_path, is_multipart=False)


TypeEngine.register(FlyteFilePathTransformer(), additional_types=[os.PathLike])

0 comments on commit 8a0804e

Please sign in to comment.