diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index b66c48443c..ef08ef64de 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -308,8 +308,7 @@ def __init__( if ctx.file_access.is_remote(self.path): self._remote_source = self.path self._local_path = ctx.file_access.get_random_local_path(self._remote_source) - self._downloader = partial( - ctx.file_access.get_data, + self._downloader = lambda: FlyteFilePathTransformer.downloader( ctx=ctx, remote_path=self._remote_source, # type: ignore local_path=self._local_path, @@ -742,6 +741,18 @@ async def async_to_python_value( ff._remote_source = uri return ff + @staticmethod + def downloader( + ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike] + ) -> None: + """ + Download data from remote_path to local_path. + + We design the downloader as a static method because its behavior is logically + related to this class but don't need to interact with class or instance data. + """ + ctx.file_access.get_data(remote_path, local_path, is_multipart=False) + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: if ( literal_type.blob is not None diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index de1a1c8821..be17e26174 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -14,7 +14,7 @@ from urllib.parse import urlparse import uuid import pytest -from unittest import mock +import mock from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase from flytekit.configuration import Config, ImageConfig, SerializationSettings