Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/pickleable-downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicBboy authored Jan 4, 2025
2 parents 2060962 + 0ad84f3 commit 6d17e5c
Show file tree
Hide file tree
Showing 12 changed files with 476 additions and 7 deletions.
11 changes: 9 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,15 @@ def to_click_option(
If no custom logic exists, fall back to json.dumps.
"""
with FlyteContextManager.with_context(flyte_ctx.new_builder()):
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if hasattr(default_val, "model_dump_json"):
# pydantic v2
default_val = default_val.model_dump_json()
elif hasattr(default_val, "json"):
# pydantic v1
default_val = default_val.json()
else:
encoder = JSONEncoder(python_type)
default_val = encoder.encode(default_val)
if literal_var.type.metadata:
description_extra = f": {json.dumps(literal_var.type.metadata)}"

Expand Down
61 changes: 57 additions & 4 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,53 @@ def __init__(
self._downloader = downloader
self._downloaded = False
self._remote_path = remote_path
self._remote_source: typing.Optional[str] = None
self._remote_source: typing.Optional[typing.Union[str, os.PathLike]] = None

# Setup local path and downloader for delayed downloading
# We introduce another attribute self._local_path to avoid overriding user-defined self.path
self._local_path = self.path

ctx = FlyteContextManager.current_context()
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = lambda: FlyteFilePathTransformer.downloader(
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
)

def __fspath__(self):
# This is where a delayed downloading of the file will happen
"""
Define the file path protocol for opening FlyteFile with the context manager,
following show two common use cases:
1. Directly open a FlyteFile with a local path:
ff = FlyteFile(path=local_path)
with open(ff, "r") as f:
# Read your local file here
# ...
There's no need to handle downloading of the file because it's on the local file system.
In this case, a dummy downloading will be done.
2. Directly open a FlyteFile with a remote path:
ff = FlyteFile(path=remote_path)
with open(ff, "r") as f:
# Read your remote file here
# ...
We now support directly opening a FlyteFile with a file from the remote data storage.
In this case, a delayed downloading of the remote file will be done.
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/6090.
"""
if not self._downloaded:
# Download data from remote to local or run dummy downloading for input local path
self._downloader()
self._downloaded = True
return self.path
return self._local_path

def __eq__(self, other):
if isinstance(other, FlyteFile):
Expand Down Expand Up @@ -699,10 +738,24 @@ async def async_to_python_value(
_downloader = partial(_flyte_file_downloader, ctx.file_access, uri, local_path)

Check warning on line 738 in flytekit/types/file/file.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/file/file.py#L738

Added line #L738 was not covered by tests

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(local_path, _downloader)
ff = FlyteFile.__class_getitem__(expected_format)(
path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path)
)
ff._remote_source = uri
return ff

@staticmethod
def downloader(
ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike]
) -> None:
"""
Download data from remote_path to local_path.
We design the downloader as a static method because its behavior is logically
related to this class but don't need to interact with class or instance data.
"""
ctx.file_access.get_data(remote_path, local_path, is_multipart=False)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if (
literal_type.blob is not None
Expand Down
63 changes: 63 additions & 0 deletions plugins/flytekit-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,66 @@ def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]:

return responses
```

## vLLM

The vLLM plugin allows you to serve an LLM hosted on HuggingFace.

```python
import flytekit as fl
from openai import OpenAI

model_name = "google/gemma-2b-it"
hf_token_key = "vllm_hf_token"

vllm_args = {
"model": model_name,
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_FSEC_",
hf_token_key=hf_token_key
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args
)

image = fl.ImageSpec(
name="vllm_serve",
registry="...",
packages=["flytekitplugins-inference"],
)


@fl.task(
pod_template=vllm_instance.pod_template,
container_image=image,
secret_requests=[
fl.Secret(
key=hf_token_key, mount_requirement=fl.Secret.MountType.ENV_VAR # must be mounted as an env var
)
],
)
def model_serving() -> str:
client = OpenAI(
base_url=f"{vllm_instance.base_url}/v1", api_key="vllm" # api key required but ignored
)

completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": "Compose a haiku about the power of AI.",
}
],
temperature=0.5,
top_p=1,
max_tokens=1024,
)
return completion.choices[0].message.content
```
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .nim.serve import NIM, NIMSecrets
from .ollama.serve import Model, Ollama
from .vllm.serve import VLLM, HFSecret
Empty file.
85 changes: 85 additions & 0 deletions plugins/flytekit-inference/flytekitplugins/inference/vllm/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional

from ..sidecar_template import ModelInferenceTemplate


@dataclass
class HFSecret:
"""
:param secrets_prefix: The secrets prefix that Flyte appends to all mounted secrets.
:param hf_token_group: The group name for the HuggingFace token.
:param hf_token_key: The key name for the HuggingFace token.
"""

secrets_prefix: str # _UNION_ or _FSEC_
hf_token_key: str
hf_token_group: Optional[str] = None


class VLLM(ModelInferenceTemplate):
def __init__(
self,
hf_secret: HFSecret,
arg_dict: Optional[dict] = None,
image: str = "vllm/vllm-openai",
health_endpoint: str = "/health",
port: int = 8000,
cpu: int = 2,
gpu: int = 1,
mem: str = "10Gi",
):
"""
Initialize NIM class for managing a Kubernetes pod template.
:param hf_secret: Instance of HFSecret for managing hugging face secrets.
:param arg_dict: A dictionary of arguments for the VLLM model server (https://docs.vllm.ai/en/stable/models/engine_args.html).
:param image: The Docker image to be used for the model server container. Default is "vllm/vllm-openai".
:param health_endpoint: The health endpoint for the model server container. Default is "/health".
:param port: The port number for the model server container. Default is 8000.
:param cpu: The number of CPU cores requested for the model server container. Default is 2.
:param gpu: The number of GPU cores requested for the model server container. Default is 1.
:param mem: The amount of memory requested for the model server container. Default is "10Gi".
"""
if hf_secret.hf_token_key is None:
raise ValueError("HuggingFace token key must be provided.")
if hf_secret.secrets_prefix is None:
raise ValueError("Secrets prefix must be provided.")

self._hf_secret = hf_secret
self._arg_dict = arg_dict

super().__init__(
image=image,
health_endpoint=health_endpoint,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
)

self.setup_vllm_pod_template()

def setup_vllm_pod_template(self):
from kubernetes.client.models import V1EnvVar

model_server_container = self.pod_template.pod_spec.init_containers[0]

if self._hf_secret.hf_token_group:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_group}_{self._hf_secret.hf_token_key})".upper()
else:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_key})".upper()

model_server_container.env = [
V1EnvVar(name="HUGGING_FACE_HUB_TOKEN", value=hf_key),
]
model_server_container.args = self.build_vllm_args()

def build_vllm_args(self) -> list:
args = []
if self._arg_dict:
for key, value in self._arg_dict.items():
args.append(f"--{key}")
if value is not None:
args.append(str(value))
return args
1 change: 1 addition & 0 deletions plugins/flytekit-inference/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
f"flytekitplugins.{PLUGIN_NAME}",
f"flytekitplugins.{PLUGIN_NAME}.nim",
f"flytekitplugins.{PLUGIN_NAME}.ollama",
f"flytekitplugins.{PLUGIN_NAME}.vllm",
],
install_requires=plugin_requires,
license="apache2",
Expand Down
60 changes: 60 additions & 0 deletions plugins/flytekit-inference/tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from flytekitplugins.inference import VLLM, HFSecret


def test_vllm_init_valid_params():
vllm_args = {
"model": "google/gemma-2b-it",
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_UNION_",
hf_token_key="vllm_hf_token"
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args,
image='vllm/vllm-openai:my-tag',
cpu='10',
gpu='2',
mem='50Gi',
port=8080,
)

assert len(vllm_instance.pod_template.pod_spec.init_containers) == 1
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].image
== 'vllm/vllm-openai:my-tag'
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].resources.requests[
"memory"
]
== "50Gi"
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].ports[0].container_port
== 8080
)
assert vllm_instance.pod_template.pod_spec.init_containers[0].args == ['--model', 'google/gemma-2b-it', '--dtype', 'half', '--max-model-len', '2000']
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].name == 'HUGGING_FACE_HUB_TOKEN'
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].value == '$(_UNION_VLLM_HF_TOKEN)'



def test_vllm_default_params():
vllm_instance = VLLM(hf_secret=HFSecret(secrets_prefix="_FSEC_", hf_token_key="test_token"))

assert vllm_instance.base_url == "http://localhost:8000"
assert vllm_instance._image == 'vllm/vllm-openai'
assert vllm_instance._port == 8000
assert vllm_instance._cpu == 2
assert vllm_instance._gpu == 1
assert vllm_instance._health_endpoint == "/health"
assert vllm_instance._mem == "10Gi"
assert vllm_instance._arg_dict == None
assert vllm_instance._hf_secret.secrets_prefix == '_FSEC_'
assert vllm_instance._hf_secret.hf_token_key == 'test_token'
assert vllm_instance._hf_secret.hf_token_group == None
31 changes: 30 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from urllib.parse import urlparse
import uuid
import pytest
from mock import mock, patch
import mock

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -29,6 +29,9 @@
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

from tests.flytekit.integration.remote.utils import SimpleFileTransfer


MODULE_PATH = pathlib.Path(__file__).parent / "workflows/basic"
CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml"))
# Run `make build-dev` to build and push the image to the local registry.
Expand Down Expand Up @@ -111,6 +114,14 @@ def test_remote_eager_run():
# child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2.
run("eager_example.py", "simple_eager_workflow", "--x", "3")

def test_pydantic_default_input_with_map_task():
execution_id = run("pydantic_wf.py", "wf")
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
print("Execution Error:", execution.error)
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"


def test_generic_idl_flytetypes():
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true"
Expand Down Expand Up @@ -804,3 +815,21 @@ def test_get_control_plane_version():
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("localhost:30080", True))
version = client.get_control_plane_version()
assert version == "unknown" or version.startswith("v")


def test_open_ff():
"""Test opening FlyteFile from a remote path."""
# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type="json")

execution_id = run("flytefile.py", "wf", "--remote_file_path", remote_file_path)
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)
Loading

0 comments on commit 6d17e5c

Please sign in to comment.