diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 4c2aebcc36..4c1da5ccf3 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -475,8 +475,15 @@ def to_click_option( If no custom logic exists, fall back to json.dumps. """ with FlyteContextManager.with_context(flyte_ctx.new_builder()): - encoder = JSONEncoder(python_type) - default_val = encoder.encode(default_val) + if hasattr(default_val, "model_dump_json"): + # pydantic v2 + default_val = default_val.model_dump_json() + elif hasattr(default_val, "json"): + # pydantic v1 + default_val = default_val.json() + else: + encoder = JSONEncoder(python_type) + default_val = encoder.encode(default_val) if literal_var.type.metadata: description_extra = f": {json.dumps(literal_var.type.metadata)}" diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index b5891b2155..877bdf01ea 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -299,14 +299,53 @@ def __init__( self._downloader = downloader self._downloaded = False self._remote_path = remote_path - self._remote_source: typing.Optional[str] = None + self._remote_source: typing.Optional[typing.Union[str, os.PathLike]] = None + + # Setup local path and downloader for delayed downloading + # We introduce another attribute self._local_path to avoid overriding user-defined self.path + self._local_path = self.path + + ctx = FlyteContextManager.current_context() + if ctx.file_access.is_remote(self.path): + self._remote_source = self.path + self._local_path = ctx.file_access.get_random_local_path(self._remote_source) + self._downloader = lambda: FlyteFilePathTransformer.downloader( + ctx=ctx, + remote_path=self._remote_source, # type: ignore + local_path=self._local_path, + ) def __fspath__(self): - # This is where a delayed downloading of the file will happen + """ + Define the file path protocol for opening FlyteFile with the context manager, + following show two common use cases: + + 1. Directly open a FlyteFile with a local path: + + ff = FlyteFile(path=local_path) + with open(ff, "r") as f: + # Read your local file here + # ... + + There's no need to handle downloading of the file because it's on the local file system. + In this case, a dummy downloading will be done. + + 2. Directly open a FlyteFile with a remote path: + + ff = FlyteFile(path=remote_path) + with open(ff, "r") as f: + # Read your remote file here + # ... + + We now support directly opening a FlyteFile with a file from the remote data storage. + In this case, a delayed downloading of the remote file will be done. + For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/6090. + """ if not self._downloaded: + # Download data from remote to local or run dummy downloading for input local path self._downloader() self._downloaded = True - return self.path + return self._local_path def __eq__(self, other): if isinstance(other, FlyteFile): @@ -699,10 +738,24 @@ async def async_to_python_value( _downloader = partial(_flyte_file_downloader, ctx.file_access, uri, local_path) expected_format = FlyteFilePathTransformer.get_format(expected_python_type) - ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader) + ff = FlyteFile.__class_getitem__(expected_format)( + path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path) + ) ff._remote_source = uri return ff + @staticmethod + def downloader( + ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike] + ) -> None: + """ + Download data from remote_path to local_path. + + We design the downloader as a static method because its behavior is logically + related to this class but don't need to interact with class or instance data. + """ + ctx.file_access.get_data(remote_path, local_path, is_multipart=False) + def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: if ( literal_type.blob is not None diff --git a/plugins/flytekit-inference/README.md b/plugins/flytekit-inference/README.md index 1bc5c8475e..646200c111 100644 --- a/plugins/flytekit-inference/README.md +++ b/plugins/flytekit-inference/README.md @@ -126,3 +126,66 @@ def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]: return responses ``` + +## vLLM + +The vLLM plugin allows you to serve an LLM hosted on HuggingFace. + +```python +import flytekit as fl +from openai import OpenAI + +model_name = "google/gemma-2b-it" +hf_token_key = "vllm_hf_token" + +vllm_args = { + "model": model_name, + "dtype": "half", + "max-model-len": 2000, +} + +hf_secrets = HFSecret( + secrets_prefix="_FSEC_", + hf_token_key=hf_token_key +) + +vllm_instance = VLLM( + hf_secret=hf_secrets, + arg_dict=vllm_args +) + +image = fl.ImageSpec( + name="vllm_serve", + registry="...", + packages=["flytekitplugins-inference"], +) + + +@fl.task( + pod_template=vllm_instance.pod_template, + container_image=image, + secret_requests=[ + fl.Secret( + key=hf_token_key, mount_requirement=fl.Secret.MountType.ENV_VAR # must be mounted as an env var + ) + ], +) +def model_serving() -> str: + client = OpenAI( + base_url=f"{vllm_instance.base_url}/v1", api_key="vllm" # api key required but ignored + ) + + completion = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": "Compose a haiku about the power of AI.", + } + ], + temperature=0.5, + top_p=1, + max_tokens=1024, + ) + return completion.choices[0].message.content +``` diff --git a/plugins/flytekit-inference/flytekitplugins/inference/__init__.py b/plugins/flytekit-inference/flytekitplugins/inference/__init__.py index cfd14b09a8..8b43dd16a8 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/__init__.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/__init__.py @@ -14,3 +14,4 @@ from .nim.serve import NIM, NIMSecrets from .ollama.serve import Model, Ollama +from .vllm.serve import VLLM, HFSecret diff --git a/plugins/flytekit-inference/flytekitplugins/inference/vllm/__init__.py b/plugins/flytekit-inference/flytekitplugins/inference/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-inference/flytekitplugins/inference/vllm/serve.py b/plugins/flytekit-inference/flytekitplugins/inference/vllm/serve.py new file mode 100644 index 0000000000..f353aabda4 --- /dev/null +++ b/plugins/flytekit-inference/flytekitplugins/inference/vllm/serve.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from typing import Optional + +from ..sidecar_template import ModelInferenceTemplate + + +@dataclass +class HFSecret: + """ + :param secrets_prefix: The secrets prefix that Flyte appends to all mounted secrets. + :param hf_token_group: The group name for the HuggingFace token. + :param hf_token_key: The key name for the HuggingFace token. + """ + + secrets_prefix: str # _UNION_ or _FSEC_ + hf_token_key: str + hf_token_group: Optional[str] = None + + +class VLLM(ModelInferenceTemplate): + def __init__( + self, + hf_secret: HFSecret, + arg_dict: Optional[dict] = None, + image: str = "vllm/vllm-openai", + health_endpoint: str = "/health", + port: int = 8000, + cpu: int = 2, + gpu: int = 1, + mem: str = "10Gi", + ): + """ + Initialize NIM class for managing a Kubernetes pod template. + + :param hf_secret: Instance of HFSecret for managing hugging face secrets. + :param arg_dict: A dictionary of arguments for the VLLM model server (https://docs.vllm.ai/en/stable/models/engine_args.html). + :param image: The Docker image to be used for the model server container. Default is "vllm/vllm-openai". + :param health_endpoint: The health endpoint for the model server container. Default is "/health". + :param port: The port number for the model server container. Default is 8000. + :param cpu: The number of CPU cores requested for the model server container. Default is 2. + :param gpu: The number of GPU cores requested for the model server container. Default is 1. + :param mem: The amount of memory requested for the model server container. Default is "10Gi". + """ + if hf_secret.hf_token_key is None: + raise ValueError("HuggingFace token key must be provided.") + if hf_secret.secrets_prefix is None: + raise ValueError("Secrets prefix must be provided.") + + self._hf_secret = hf_secret + self._arg_dict = arg_dict + + super().__init__( + image=image, + health_endpoint=health_endpoint, + port=port, + cpu=cpu, + gpu=gpu, + mem=mem, + ) + + self.setup_vllm_pod_template() + + def setup_vllm_pod_template(self): + from kubernetes.client.models import V1EnvVar + + model_server_container = self.pod_template.pod_spec.init_containers[0] + + if self._hf_secret.hf_token_group: + hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_group}_{self._hf_secret.hf_token_key})".upper() + else: + hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_key})".upper() + + model_server_container.env = [ + V1EnvVar(name="HUGGING_FACE_HUB_TOKEN", value=hf_key), + ] + model_server_container.args = self.build_vllm_args() + + def build_vllm_args(self) -> list: + args = [] + if self._arg_dict: + for key, value in self._arg_dict.items(): + args.append(f"--{key}") + if value is not None: + args.append(str(value)) + return args diff --git a/plugins/flytekit-inference/setup.py b/plugins/flytekit-inference/setup.py index c0f42a2e41..ef46849726 100644 --- a/plugins/flytekit-inference/setup.py +++ b/plugins/flytekit-inference/setup.py @@ -19,6 +19,7 @@ f"flytekitplugins.{PLUGIN_NAME}", f"flytekitplugins.{PLUGIN_NAME}.nim", f"flytekitplugins.{PLUGIN_NAME}.ollama", + f"flytekitplugins.{PLUGIN_NAME}.vllm", ], install_requires=plugin_requires, license="apache2", diff --git a/plugins/flytekit-inference/tests/test_vllm.py b/plugins/flytekit-inference/tests/test_vllm.py new file mode 100644 index 0000000000..e1a7901de5 --- /dev/null +++ b/plugins/flytekit-inference/tests/test_vllm.py @@ -0,0 +1,60 @@ +from flytekitplugins.inference import VLLM, HFSecret + + +def test_vllm_init_valid_params(): + vllm_args = { + "model": "google/gemma-2b-it", + "dtype": "half", + "max-model-len": 2000, + } + + hf_secrets = HFSecret( + secrets_prefix="_UNION_", + hf_token_key="vllm_hf_token" + ) + + vllm_instance = VLLM( + hf_secret=hf_secrets, + arg_dict=vllm_args, + image='vllm/vllm-openai:my-tag', + cpu='10', + gpu='2', + mem='50Gi', + port=8080, + ) + + assert len(vllm_instance.pod_template.pod_spec.init_containers) == 1 + assert ( + vllm_instance.pod_template.pod_spec.init_containers[0].image + == 'vllm/vllm-openai:my-tag' + ) + assert ( + vllm_instance.pod_template.pod_spec.init_containers[0].resources.requests[ + "memory" + ] + == "50Gi" + ) + assert ( + vllm_instance.pod_template.pod_spec.init_containers[0].ports[0].container_port + == 8080 + ) + assert vllm_instance.pod_template.pod_spec.init_containers[0].args == ['--model', 'google/gemma-2b-it', '--dtype', 'half', '--max-model-len', '2000'] + assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].name == 'HUGGING_FACE_HUB_TOKEN' + assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].value == '$(_UNION_VLLM_HF_TOKEN)' + + + +def test_vllm_default_params(): + vllm_instance = VLLM(hf_secret=HFSecret(secrets_prefix="_FSEC_", hf_token_key="test_token")) + + assert vllm_instance.base_url == "http://localhost:8000" + assert vllm_instance._image == 'vllm/vllm-openai' + assert vllm_instance._port == 8000 + assert vllm_instance._cpu == 2 + assert vllm_instance._gpu == 1 + assert vllm_instance._health_endpoint == "/health" + assert vllm_instance._mem == "10Gi" + assert vllm_instance._arg_dict == None + assert vllm_instance._hf_secret.secrets_prefix == '_FSEC_' + assert vllm_instance._hf_secret.hf_token_key == 'test_token' + assert vllm_instance._hf_secret.hf_token_group == None diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 5d953350a0..f20a33aea8 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -14,7 +14,7 @@ from urllib.parse import urlparse import uuid import pytest -from mock import mock, patch +import mock from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase from flytekit.configuration import Config, ImageConfig, SerializationSettings @@ -29,6 +29,9 @@ from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient from flytekit.configuration import PlatformConfig +from tests.flytekit.integration.remote.utils import SimpleFileTransfer + + MODULE_PATH = pathlib.Path(__file__).parent / "workflows/basic" CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml")) # Run `make build-dev` to build and push the image to the local registry. @@ -111,6 +114,14 @@ def test_remote_eager_run(): # child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. run("eager_example.py", "simple_eager_workflow", "--x", "3") +def test_pydantic_default_input_with_map_task(): + execution_id = run("pydantic_wf.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + def test_generic_idl_flytetypes(): os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true" @@ -804,3 +815,21 @@ def test_get_control_plane_version(): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("localhost:30080", True)) version = client.get_control_plane_version() assert version == "unknown" or version.startswith("v") + + +def test_open_ff(): + """Test opening FlyteFile from a remote path.""" + # Upload a file to minio s3 bucket + file_transfer = SimpleFileTransfer() + remote_file_path = file_transfer.upload_file(file_type="json") + + execution_id = run("flytefile.py", "wf", "--remote_file_path", remote_file_path) + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + + # Delete the remote file to free the space + url = urlparse(remote_file_path) + bucket, key = url.netloc, url.path.lstrip("/") + file_transfer.delete_file(bucket=bucket, key=key) diff --git a/tests/flytekit/integration/remote/utils.py b/tests/flytekit/integration/remote/utils.py new file mode 100644 index 0000000000..dadc8c6530 --- /dev/null +++ b/tests/flytekit/integration/remote/utils.py @@ -0,0 +1,98 @@ +""" +Common utilities for flyte remote runs in integration tests. +""" +import os +import json +import tempfile +import pathlib + +import botocore.session +from botocore.client import BaseClient +from flytekit.configuration import Config +from flytekit.remote.remote import FlyteRemote + + +# Define constants +CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml")) +PROJECT = "flytesnacks" +DOMAIN = "development" + + +class SimpleFileTransfer: + """Utilities for file transfer to minio s3 bucket. + + Mainly support single file uploading and automatic teardown. + """ + + def __init__(self) -> None: + self._remote = FlyteRemote( + config=Config.auto(config_file=CONFIG), + default_project=PROJECT, + default_domain=DOMAIN + ) + self._s3_client = self._get_minio_s3_client(self._remote) + + def _get_minio_s3_client(self, remote: FlyteRemote) -> BaseClient: + """Creat a botocore client.""" + minio_s3_config = remote.file_access.data_config.s3 + sess = botocore.session.get_session() + + return sess.create_client( + "s3", + endpoint_url=minio_s3_config.endpoint, + aws_access_key_id=minio_s3_config.access_key_id, + aws_secret_access_key=minio_s3_config.secret_access_key, + ) + + def upload_file(self, file_type: str) -> str: + """Upload a single file to minio s3 bucket. + + Args: + file_type: File type. Support "txt" and "json". + + Returns: + remote_file_path: Remote file path. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + local_file_path = self._dump_tmp_file(file_type, tmp_dir) + + # Upload to minio s3 bucket + _, remote_file_path = self._remote.upload_file( + to_upload=local_file_path, + project=PROJECT, + domain=DOMAIN, + ) + + return remote_file_path + + def _dump_tmp_file(self, file_type: str, tmp_dir: str) -> str: + """Generate and dump a temporary file locally. + + Args: + file_type: File type. + tmp_dir: Temporary directory. + + Returns: + tmp_file_path: Temporary local file path. + """ + if file_type == "txt": + tmp_file_path = pathlib.Path(tmp_dir) / "test.txt" + with open(tmp_file_path, "w") as f: + f.write("Hello World!") + elif file_type == "json": + d = {"name": "john", "height": 190} + tmp_file_path = pathlib.Path(tmp_dir) / "test.json" + with open(tmp_file_path, "w") as f: + json.dump(d, f) + + return tmp_file_path + + def delete_file(self, bucket: str, key: str) -> None: + """Delete the remote file from minio s3 bucket to free the space. + + Args: + bucket: s3 bucket name. + key: Key name of the object. + """ + res = self._s3_client.delete_object(Bucket=bucket, Key=key) + assert res["ResponseMetadata"]["HTTPStatusCode"] == 204 diff --git a/tests/flytekit/integration/remote/workflows/basic/flytefile.py b/tests/flytekit/integration/remote/workflows/basic/flytefile.py new file mode 100644 index 0000000000..f25b77d907 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/flytefile.py @@ -0,0 +1,52 @@ +from flytekit import task, workflow +from flytekit.types.file import FlyteFile + + +@task +def create_ff(file_path: str) -> FlyteFile: + """Create a FlyteFile.""" + return FlyteFile(path=file_path) + + +@task +def read_ff(ff: FlyteFile) -> None: + """Read input FlyteFile. + + This can be used in the case in which a FlyteFile is created + in another task pod and read in this task pod. + """ + with open(ff, "r") as f: + content = f.read() + print(f"FILE CONTENT | {content}") + + +@task +def create_and_read_ff(file_path: str) -> FlyteFile: + """Create a FlyteFile and read it. + + Both FlyteFile creation and reading are done in this task pod. + + Args: + file_path: File path. + + Returns: + ff: FlyteFile object. + """ + ff = FlyteFile(path=file_path) + with open(ff, "r") as f: + content = f.read() + print(f"FILE CONTENT | {content}") + + return ff + + +@workflow +def wf(remote_file_path: str) -> None: + ff_1 = create_ff(file_path=remote_file_path) + read_ff(ff=ff_1) + ff_2 = create_and_read_ff(file_path=remote_file_path) + read_ff(ff=ff_2) + + +if __name__ == "__main__": + wf() diff --git a/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py new file mode 100644 index 0000000000..d5e9c32170 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +from flytekit import map_task +from typing import List +from flytekit import task, workflow + + +class MyBaseModel(BaseModel): + my_floats: List[float] = [1.0, 2.0, 5.0, 10.0] + +@task +def print_float(my_float: float): + print(f"my_float: {my_float}") + +@workflow +def wf(bm: MyBaseModel = MyBaseModel()): + map_task(print_float)(my_float=bm.my_floats) + +if __name__ == "__main__": + wf()