Skip to content

Commit

Permalink
Inline vLLM inference provider (#181)
Browse files Browse the repository at this point in the history
This is just like `local` using `meta-reference` for everything except
it uses `vllm` for inference.

Docker works, but So far, `conda` is a bit easier to use with the vllm
provider. The default container base image does not include all the
necessary libraries for all vllm features. More cuda dependencies are
necessary.

I started changing this base image used in this template, but it also
required changes to the Dockerfile, so it was getting too involved to
include in the first PR.

Working so far:

* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream True`
* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False`

Example:

```
$ python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False
User>hello world, write me a 2 sentence poem about the moon
Assistant>
The moon glows bright in the midnight sky
A beacon of light,
```

I have only tested these models:

* `Llama3.1-8B-Instruct` - across 4 GPUs (tensor_parallel_size = 4)
* `Llama3.2-1B-Instruct` - on a single GPU (tensor_parallel_size = 1)
  • Loading branch information
russellb authored Oct 6, 2024
1 parent 29138a5 commit f73e247
Show file tree
Hide file tree
Showing 5 changed files with 421 additions and 0 deletions.
10 changes: 10 additions & 0 deletions llama_stack/distribution/templates/local-vllm-build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: local-vllm
distribution_spec:
description: Like local, but use vLLM for running LLM inference
providers:
inference: vllm
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda
11 changes: 11 additions & 0 deletions llama_stack/providers/impls/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any

from .config import VLLMConfig


async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
from .vllm import VLLMInferenceImpl

impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl
35 changes: 35 additions & 0 deletions llama_stack/providers/impls/vllm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator

from llama_stack.providers.utils.inference import supported_inference_models


@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""

model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)

@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
Loading

0 comments on commit f73e247

Please sign in to comment.