From 80ada04f768071aa2ffc8630b15a822d395c07db Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 15 Oct 2024 16:03:17 -0400 Subject: [PATCH] Remove request arg from chat completion response processing (#240) Signed-off-by: Yuan Tang --- .../providers/adapters/inference/databricks/databricks.py | 4 ++-- .../providers/adapters/inference/fireworks/fireworks.py | 4 ++-- llama_stack/providers/adapters/inference/ollama/ollama.py | 4 ++-- llama_stack/providers/adapters/inference/tgi/tgi.py | 4 ++-- .../providers/adapters/inference/together/together.py | 4 ++-- llama_stack/providers/impls/vllm/vllm.py | 4 ++-- llama_stack/providers/utils/inference/openai_compat.py | 8 ++------ 7 files changed, 14 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 7e8263dbf0..1410511864 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -91,7 +91,7 @@ async def _nonstream_chat_completion( ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI @@ -105,7 +105,7 @@ async def _to_async_generator(): stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index c85ee00f9f..c82012cba6 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -94,7 +94,7 @@ async def _nonstream_chat_completion( ) -> ChatCompletionResponse: params = self._get_params(request) r = await client.completion.acreate(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks @@ -103,7 +103,7 @@ async def _stream_chat_completion( stream = client.completion.acreate(**params) async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index acf1546272..c50c869fd5 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -143,7 +143,7 @@ async def _nonstream_chat_completion( response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(request, response, self.formatter) + return process_chat_completion_response(response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -163,7 +163,7 @@ async def _generate_and_convert_to_openai_compat(): stream = _generate_and_convert_to_openai_compat() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 835649d942..cd0afad0cc 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -116,7 +116,7 @@ async def _nonstream_chat_completion( response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(request, response, self.formatter) + return process_chat_completion_response(response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -135,7 +135,7 @@ async def _generate_and_convert_to_openai_compat(): stream = _generate_and_convert_to_openai_compat() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 3231f4657a..750ca126e2 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -108,7 +108,7 @@ async def _nonstream_chat_completion( ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: Together @@ -123,7 +123,7 @@ async def _to_async_generator(): stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index e0b063ac97..5cdb1a2ab5 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -207,7 +207,7 @@ async def _nonstream_chat_completion( response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(request, response, self.formatter) + return process_chat_completion_response(response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator @@ -229,7 +229,7 @@ async def _generate_and_convert_to_openai_compat(): stream = _generate_and_convert_to_openai_compat() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 118880b29b..72db7b18cf 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -50,9 +50,7 @@ def text_from_choice(choice) -> str: def process_chat_completion_response( - request: ChatCompletionRequest, - response: OpenAICompatCompletionResponse, - formatter: ChatFormat, + response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> ChatCompletionResponse: choice = response.choices[0] @@ -78,9 +76,7 @@ def process_chat_completion_response( async def process_chat_completion_stream_response( - request: ChatCompletionRequest, - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - formatter: ChatFormat, + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent(