-
Notifications
You must be signed in to change notification settings - Fork 737
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement (chat_)completion for vllm provider
This is the start of an inline inference provider using vllm as a library. Issue #142 Working so far: * `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream True` * `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False` Example: ``` $ python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False User>hello world, write me a 2 sentence poem about the moon Assistant> The moon glows bright in the midnight sky A beacon of light, ``` I have only tested these models: * `Llama3.1-8B-Instruct` - across 4 GPUs (tensor_parallel_size = 4) * `Llama3.2-1B-Instruct` - on a single GPU (tensor_parallel_size = 1) Signed-off-by: Russell Bryant <[email protected]>
- Loading branch information
Showing
5 changed files
with
374 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,35 @@ | ||
from pydantic import BaseModel | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from llama_models.schema_utils import json_schema_type | ||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from llama_stack.providers.utils.inference import supported_inference_models | ||
|
||
|
||
@json_schema_type | ||
class VLLMConfig(BaseModel): | ||
pass | ||
"""Configuration for the vLLM inference provider.""" | ||
|
||
model: str = Field( | ||
default="Llama3.1-8B-Instruct", | ||
description="Model descriptor from `llama model list`", | ||
) | ||
tensor_parallel_size: int = Field( | ||
default=1, | ||
description="Number of tensor parallel replicas (number of GPUs to use).", | ||
) | ||
|
||
@field_validator("model") | ||
@classmethod | ||
def validate_model(cls, model: str) -> str: | ||
permitted_models = supported_inference_models() | ||
if model not in permitted_models: | ||
model_list = "\n\t".join(permitted_models) | ||
raise ValueError( | ||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" | ||
) | ||
return model |
Oops, something went wrong.