-
Notifications
You must be signed in to change notification settings - Fork 737
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yuan Tang <[email protected]>
- Loading branch information
1 parent
c02a90e
commit 6e18080
Showing
6 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ xcuserdata/ | |
Package.resolved | ||
*.pte | ||
*.ipynb_checkpoints* | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
name: local-vllm | ||
distribution_spec: | ||
description: Use vLLM for running LLM inference | ||
providers: | ||
inference: remote::vllm | ||
memory: meta-reference | ||
safety: meta-reference | ||
agents: meta-reference | ||
telemetry: meta-reference | ||
image_type: conda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from .config import DatabricksImplConfig | ||
from .vllm import InferenceEndpointAdapter, VLLMAdapter | ||
|
||
|
||
async def get_adapter_impl(config: DatabricksImplConfig, _deps): | ||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}" | ||
|
||
if config.url is not None: | ||
impl = VLLMAdapter(config) | ||
elif config.is_inference_endpoint(): | ||
impl = InferenceEndpointAdapter(config) | ||
else: | ||
raise ValueError( | ||
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." | ||
) | ||
|
||
await impl.initialize() | ||
return impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
from llama_models.schema_utils import json_schema_type | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
# TODO: Any other engine configs | ||
@json_schema_type | ||
class VLLMImplConfig(BaseModel): | ||
url: Optional[str] = Field( | ||
default=None, | ||
description="The URL for the vLLM model serving endpoint", | ||
) | ||
api_token: Optional[str] = Field( | ||
default=None, | ||
description="The API token", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the terms described in the LICENSE file in | ||
# the root directory of this source tree. | ||
|
||
from typing import AsyncGenerator | ||
|
||
from llama_models.llama3.api.chat_format import ChatFormat | ||
|
||
from llama_models.llama3.api.datatypes import Message, StopReason | ||
from llama_models.llama3.api.tokenizer import Tokenizer | ||
from llama_models.sku_list import resolve_model | ||
|
||
from openai import OpenAI | ||
|
||
from llama_stack.apis.inference import * # noqa: F403 | ||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages | ||
|
||
from .config import VLLMImplConfig | ||
|
||
# TODO | ||
VLLM_SUPPORTED_MODELS = {} | ||
|
||
|
||
class VLLMInferenceAdapter(Inference): | ||
def __init__(self, config: VLLMImplConfig) -> None: | ||
self.config = config | ||
tokenizer = Tokenizer.get_instance() | ||
self.formatter = ChatFormat(tokenizer) | ||
|
||
@property | ||
def client(self) -> OpenAI: | ||
return OpenAI( | ||
api_key=self.config.api_token, | ||
base_url=self.config.url | ||
) | ||
|
||
async def initialize(self) -> None: | ||
return | ||
|
||
async def shutdown(self) -> None: | ||
pass | ||
|
||
async def completion(self, request: CompletionRequest) -> AsyncGenerator: | ||
raise NotImplementedError() | ||
|
||
def _messages_to_vllm_messages(self, messages: list[Message]) -> list: | ||
vllm_messages = [] | ||
for message in messages: | ||
if message.role == "ipython": | ||
role = "tool" | ||
else: | ||
role = message.role | ||
vllm_messages.append({"role": role, "content": message.content}) | ||
|
||
return vllm_messages | ||
|
||
def resolve_vllm_model(self, model_name: str) -> str: | ||
model = resolve_model(model_name) | ||
assert ( | ||
model is not None | ||
and model.descriptor(shorten_default_variant=True) | ||
in VLLM_SUPPORTED_MODELS | ||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(VLLM_SUPPORTED_MODELS.keys())}" | ||
|
||
return VLLM_SUPPORTED_MODELS.get( | ||
model.descriptor(shorten_default_variant=True) | ||
) | ||
|
||
def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict: | ||
options = {} | ||
# TODO | ||
return options | ||
|
||
async def chat_completion( | ||
self, | ||
model: str, | ||
messages: List[Message], | ||
sampling_params: Optional[SamplingParams] = SamplingParams(), | ||
tools: Optional[List[ToolDefinition]] = None, | ||
tool_choice: Optional[ToolChoice] = ToolChoice.auto, | ||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, | ||
stream: Optional[bool] = False, | ||
logprobs: Optional[LogProbConfig] = None, | ||
) -> AsyncGenerator: | ||
# wrapper request to make it easier to pass around (internal only, not exposed to API) | ||
request = ChatCompletionRequest( | ||
model=model, | ||
messages=messages, | ||
sampling_params=sampling_params, | ||
tools=tools or [], | ||
tool_choice=tool_choice, | ||
tool_prompt_format=tool_prompt_format, | ||
stream=stream, | ||
logprobs=logprobs, | ||
) | ||
|
||
# accumulate sampling params and other options to pass to vLLM | ||
options = self.get_vllm_chat_options(request) | ||
vllm_model = self.resolve_vllm_model(request.model) | ||
messages = prepare_messages(request) | ||
model_input = self.formatter.encode_dialog_prompt(messages) | ||
|
||
input_tokens = len(model_input.tokens) | ||
max_new_tokens = min( | ||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens), | ||
self.max_tokens - input_tokens - 1, | ||
) | ||
|
||
print(f"Calculated max_new_tokens: {max_new_tokens}") | ||
|
||
assert ( | ||
request.model == self.model_name | ||
), f"Model mismatch, expected {self.model_name}, got {request.model}" | ||
|
||
if not request.stream: | ||
r = self.client.chat.completions.create( | ||
model=vllm_model, | ||
messages=self._messages_to_vllm_messages(messages), | ||
max_tokens=max_new_tokens, | ||
stream=False, | ||
**options, | ||
) | ||
stop_reason = None | ||
if r.choices[0].finish_reason: | ||
if ( | ||
r.choices[0].finish_reason == "stop" | ||
or r.choices[0].finish_reason == "eos" | ||
): | ||
stop_reason = StopReason.end_of_turn | ||
elif r.choices[0].finish_reason == "length": | ||
stop_reason = StopReason.out_of_tokens | ||
|
||
completion_message = self.formatter.decode_assistant_message_from_content( | ||
r.choices[0].message.content, stop_reason | ||
) | ||
yield ChatCompletionResponse( | ||
completion_message=completion_message, | ||
logprobs=None, | ||
) | ||
else: | ||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.start, | ||
delta="", | ||
) | ||
) | ||
|
||
buffer = "" | ||
ipython = False | ||
stop_reason = None | ||
|
||
for chunk in self.client.chat.completions.create( | ||
model=vllm_model, | ||
messages=self._messages_to_vllm_messages(messages), | ||
max_tokens=max_new_tokens, | ||
stream=True, | ||
**options, | ||
): | ||
if chunk.choices[0].finish_reason: | ||
if ( | ||
stop_reason is None and chunk.choices[0].finish_reason == "stop" | ||
) or ( | ||
stop_reason is None and chunk.choices[0].finish_reason == "eos" | ||
): | ||
stop_reason = StopReason.end_of_turn | ||
elif ( | ||
stop_reason is None | ||
and chunk.choices[0].finish_reason == "length" | ||
): | ||
stop_reason = StopReason.out_of_tokens | ||
break | ||
|
||
text = chunk.choices[0].message.content | ||
if text is None: | ||
continue | ||
|
||
# check if it's a tool call ( aka starts with <|python_tag|> ) | ||
if not ipython and text.startswith("<|python_tag|>"): | ||
ipython = True | ||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.progress, | ||
delta=ToolCallDelta( | ||
content="", | ||
parse_status=ToolCallParseStatus.started, | ||
), | ||
) | ||
) | ||
buffer += text | ||
continue | ||
|
||
if ipython: | ||
if text == "<|eot_id|>": | ||
stop_reason = StopReason.end_of_turn | ||
text = "" | ||
continue | ||
elif text == "<|eom_id|>": | ||
stop_reason = StopReason.end_of_message | ||
text = "" | ||
continue | ||
|
||
buffer += text | ||
delta = ToolCallDelta( | ||
content=text, | ||
parse_status=ToolCallParseStatus.in_progress, | ||
) | ||
|
||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.progress, | ||
delta=delta, | ||
stop_reason=stop_reason, | ||
) | ||
) | ||
else: | ||
buffer += text | ||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.progress, | ||
delta=text, | ||
stop_reason=stop_reason, | ||
) | ||
) | ||
|
||
# parse tool calls and report errors | ||
message = self.formatter.decode_assistant_message_from_content( | ||
buffer, stop_reason | ||
) | ||
parsed_tool_calls = len(message.tool_calls) > 0 | ||
if ipython and not parsed_tool_calls: | ||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.progress, | ||
delta=ToolCallDelta( | ||
content="", | ||
parse_status=ToolCallParseStatus.failure, | ||
), | ||
stop_reason=stop_reason, | ||
) | ||
) | ||
|
||
for tool_call in message.tool_calls: | ||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.progress, | ||
delta=ToolCallDelta( | ||
content=tool_call, | ||
parse_status=ToolCallParseStatus.success, | ||
), | ||
stop_reason=stop_reason, | ||
) | ||
) | ||
|
||
yield ChatCompletionResponseStreamChunk( | ||
event=ChatCompletionResponseEvent( | ||
event_type=ChatCompletionResponseEventType.complete, | ||
delta="", | ||
stop_reason=stop_reason, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters