Skip to content

Commit

Permalink
Make all methods async def again; add completion() for meta-referen…
Browse files Browse the repository at this point in the history
…ce (#270)

PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def".

The rationale was that this allowed the user (within llama-stack) of this to use it as:

```
async for chunk in api.chat_completion(params)
```

However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like:

```
async for chunk in await api.chat_completion(params)
```

Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
  • Loading branch information
ashwinb authored Oct 19, 2024
1 parent 95a96af commit 2089427
Show file tree
Hide file tree
Showing 23 changed files with 307 additions and 190 deletions.
40 changes: 22 additions & 18 deletions docs/resources/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
},
"servers": [
{
Expand Down Expand Up @@ -2830,8 +2830,11 @@
"CompletionResponse": {
"type": "object",
"properties": {
"completion_message": {
"$ref": "#/components/schemas/CompletionMessage"
"content": {
"type": "string"
},
"stop_reason": {
"$ref": "#/components/schemas/StopReason"
},
"logprobs": {
"type": "array",
Expand All @@ -2842,7 +2845,8 @@
},
"additionalProperties": false,
"required": [
"completion_message"
"content",
"stop_reason"
],
"title": "Completion response."
},
Expand Down Expand Up @@ -6075,49 +6079,49 @@
],
"tags": [
{
"name": "Evaluations"
},
{
"name": "Inspect"
"name": "Models"
},
{
"name": "RewardScoring"
},
{
"name": "Datasets"
"name": "MemoryBanks"
},
{
"name": "Models"
"name": "Shields"
},
{
"name": "Telemetry"
"name": "SyntheticDataGeneration"
},
{
"name": "PostTraining"
"name": "Inference"
},
{
"name": "SyntheticDataGeneration"
"name": "Inspect"
},
{
"name": "BatchInference"
},
{
"name": "Inference"
"name": "Memory"
},
{
"name": "Datasets"
},
{
"name": "Agents"
},
{
"name": "Memory"
"name": "PostTraining"
},
{
"name": "Safety"
"name": "Telemetry"
},
{
"name": "Shields"
"name": "Safety"
},
{
"name": "MemoryBanks"
"name": "Evaluations"
},
{
"name": "BuiltinTool",
Expand Down
31 changes: 17 additions & 14 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,17 @@ components:
CompletionResponse:
additionalProperties: false
properties:
completion_message:
$ref: '#/components/schemas/CompletionMessage'
content:
type: string
logprobs:
items:
$ref: '#/components/schemas/TokenLogProbs'
type: array
stop_reason:
$ref: '#/components/schemas/StopReason'
required:
- completion_message
- content
- stop_reason
title: Completion response.
type: object
CompletionResponseStreamChunk:
Expand Down Expand Up @@ -2507,7 +2510,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
\ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
Expand Down Expand Up @@ -3712,21 +3715,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Evaluations
- name: Inspect
- name: RewardScoring
- name: Datasets
- name: Models
- name: Telemetry
- name: PostTraining
- name: RewardScoring
- name: MemoryBanks
- name: Shields
- name: SyntheticDataGeneration
- name: BatchInference
- name: Inference
- name: Agents
- name: Inspect
- name: BatchInference
- name: Memory
- name: Datasets
- name: Agents
- name: PostTraining
- name: Telemetry
- name: Safety
- name: Shields
- name: MemoryBanks
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
Expand Down
4 changes: 1 addition & 3 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,8 @@ async def create_agent(
agent_config: AgentConfig,
) -> AgentCreateResponse: ...

# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/apis/agents/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ async def create_agent_session(
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())

def create_agent_turn(
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
return await self._nonstream_agent_turn(request)

async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
Expand Down Expand Up @@ -126,7 +126,7 @@ async def _run_agent(

for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
Expand Down
7 changes: 4 additions & 3 deletions llama_stack/apis/inference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down Expand Up @@ -139,7 +139,8 @@ async def run_main(
else:
logprobs_config = None

iterator = client.chat_completion(
assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
Expand Down
13 changes: 6 additions & 7 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""

completion_message: CompletionMessage
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None


Expand All @@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""

completion_message_batch: List[CompletionMessage]
batch: List[CompletionResponse]


@json_schema_type
Expand Down Expand Up @@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):

@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
batch: List[ChatCompletionResponse]


@json_schema_type
Expand All @@ -181,10 +182,8 @@ def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore

# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -196,7 +195,7 @@ def completion(
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down
12 changes: 6 additions & 6 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def shutdown(self) -> None:
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -93,11 +93,11 @@ def chat_completion(
)
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
return await provider.chat_completion(**params)

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -114,9 +114,9 @@ def completion(
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
return (chunk async for chunk in await provider.completion(**params))
else:
return provider.completion(**params)
return await provider.completion(**params)

async def embeddings(
self,
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/adapters/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
self.client.close()

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand Down Expand Up @@ -283,7 +283,7 @@ def _tools_to_tool_config(
)
return tool_config

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -58,7 +58,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -84,7 +84,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -61,7 +61,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -87,7 +87,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
Expand Down
Loading

0 comments on commit 2089427

Please sign in to comment.