Skip to content

Commit

Permalink
Remove request arg from chat completion response processing (#240)
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Tang <[email protected]>
  • Loading branch information
terrytangyuan authored Oct 15, 2024
1 parent 209cd3d commit 80ada04
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def _nonstream_chat_completion(
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
Expand All @@ -105,7 +105,7 @@ async def _to_async_generator():

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def _nonstream_chat_completion(
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
Expand All @@ -103,7 +103,7 @@ async def _stream_chat_completion(

stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/adapters/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def _nonstream_chat_completion(
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest
Expand All @@ -163,7 +163,7 @@ async def _generate_and_convert_to_openai_compat():

stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/adapters/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def _nonstream_chat_completion(
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest
Expand All @@ -135,7 +135,7 @@ async def _generate_and_convert_to_openai_compat():

stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/adapters/inference/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _nonstream_chat_completion(
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together
Expand All @@ -123,7 +123,7 @@ async def _to_async_generator():

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/impls/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def _nonstream_chat_completion(
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
Expand All @@ -229,7 +229,7 @@ async def _generate_and_convert_to_openai_compat():

stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

Expand Down
8 changes: 2 additions & 6 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def text_from_choice(choice) -> str:


def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse:
choice = response.choices[0]

Expand All @@ -78,9 +76,7 @@ def process_chat_completion_response(


async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
Expand Down

0 comments on commit 80ada04

Please sign in to comment.