Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1384: add cost estimate endpoint #225

Merged
merged 10 commits into from
Oct 23, 2024
96 changes: 96 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
Expand Down Expand Up @@ -81,6 +82,19 @@ class BatchSubmitted(BaseModel):
job_id: str


class CostEstimate(BaseModel):
prompt_cost_usd: Optional[float]
completion_cost_usd: Optional[float]
total_cost_usd: Optional[float]


class CostEstimateRequest(BaseModel):
prompt: str
substitutions: List[Dict]
model: str
output_fields: List[str]


class Status(Enum):
PENDING = "Pending"
INPROGRESS = "InProgress"
Expand Down Expand Up @@ -210,6 +224,88 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


def get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens


def get_completion_tokens(model: str, output_fields: List[str]) -> int:
max_tokens = litellm.get_model_info(model=model, custom_llm_provider="openai").get(
"max_tokens", None
)
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
return min(max_tokens, 4 * len(output_fields))


def _estimate_cost(user_prompt: str, model: str, output_fields: List[str]):
prompt_tokens = get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = get_completion_tokens(model, output_fields)
prompt_cost, completion_cost = litellm.cost_per_token(
model="gpt-3.5-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
):
"""
Submits a batch of data to an existing streaming job.
Will push the batch of data into Kafka in a topic specific to the job ID

Args:
batch (BatchData): The data to push to Kafka queue to be processed by agent.arun()

Returns:
Response: Generic response indicating status of request
"""
prompt = request.prompt
substitutions = request.substitutions
model = request.model
output_fields = request.output_fields
try:
user_prompts = [prompt.format(**substitution) for substitution in substitutions]
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = _estimate_cost(
user_prompt=user_prompt,
model=model,
output_fields=output_fields,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
cumulative_total_cost += total_cost
return Response[CostEstimate](
data=CostEstimate(
prompt_cost_usd=cumulative_prompt_cost,
completion_cost_usd=cumulative_completion_cost,
total_cost_usd=cumulative_total_cost,
)
)

except Exception as e:
logger.error("Failed to estimate cost: %s", e)
return Response[CostEstimate](
data=CostEstimate(
prompt_cost_usd=None,
completion_cost_usd=None,
total_cost_usd=None,
)
)


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand Down
Loading