Skip to content

Commit

Permalink
Implement (chat_)completion for vllm provider
Browse files Browse the repository at this point in the history
This is the start of an inline inference provider using vllm as a
library.

Issue #142

Working so far:

* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream True`
* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False`

Example:

```
$ python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False
User>hello world, write me a 2 sentence poem about the moon
Assistant>
The moon glows bright in the midnight sky
A beacon of light,
```

I have only tested these models:

* `Llama3.1-8B-Instruct` - across 4 GPUs (tensor_parallel_size = 4)
* `Llama3.2-1B-Instruct` - on a single GPU (tensor_parallel_size = 1)

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Oct 4, 2024
1 parent b261242 commit 2444737
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 15 deletions.
5 changes: 5 additions & 0 deletions llama_stack/providers/adapters/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
10 changes: 9 additions & 1 deletion llama_stack/providers/adapters/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,15 @@ async def completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
messages = [Message(role="user", content=content)]
async for result in self.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
):
yield result

def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
Expand Down
34 changes: 32 additions & 2 deletions llama_stack/providers/impls/vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
from pydantic import BaseModel
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator

from llama_stack.providers.utils.inference import supported_inference_models


@json_schema_type
class VLLMConfig(BaseModel):
pass
"""Configuration for the vLLM inference provider."""

model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)

@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
Loading

0 comments on commit 2444737

Please sign in to comment.