Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1384: add cost estimate endpoint #225

Merged
merged 10 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
)
return runtime

def get_skills(self) -> List[Skill]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this function? Skills must be always Dict mapping from skill name to Skill (there is a skills_validator function to prevalidate that), so we can always iterate like

for skill_name in agent.skills.get_skill_names():
     skill = agent.skills[skill_name]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type sig for skills is skills: SerializeAsAny[Union[Skill, SkillSet]], so we can't count on this always being a SkillSet (and therefore having get_skill_names defined), yeah?
Plus, it seems simpler to be able to get the list of skills than the list of names, I'd think? Since it doesn't seem that get_skill_names is being used, I think it's simpler to change that to get the skills themselves (if we can, in fact, assume this is a SkillSet)

Copy link
Contributor

@niklub niklub Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we can't count on this always being a SkillSet

we can because of https://github.com/HumanSignal/Adala/blob/master/adala/agents/base.py#L109

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the type sig of skills in Agent then?
And still, like I said in my first comment, it seems simpler to have a SkillSet.get_skills than SkillSet.get_skill_names, esp since SkillSet.get_skill_names doesn't seem to be used anywhere. I'll go ahead and change that and the type sig of Agent.skills

if isinstance(self.skills, SkillSet):
if isinstance(self.skills.skills, Dict):
skills = list(self.skills.skills.values())
else:
skills = self.skills.skills
else:
skills = [self.skills]
return skills

def run(
self, input: InternalDataFrame = None, runtime: Optional[str] = None, **kwargs
) -> InternalDataFrame:
Expand Down
68 changes: 68 additions & 0 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import instructor
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
import traceback
from adala.runtimes.base import CostEstimate
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.internal_data import InternalDataFrame
from adala.utils.parse import (
Expand Down Expand Up @@ -527,6 +528,73 @@ async def record_to_record(
# Extract the single row from the output DataFrame and convert it to a dictionary
return output_df.iloc[0].to_dict()

@staticmethod
def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
max_tokens = litellm.get_model_info(
model=model, custom_llm_provider="openai"
).get("max_tokens", None)
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
n_outputs = len(output_fields) if output_fields else 1
return min(max_tokens, 4 * n_outputs)

@classmethod
def _estimate_cost(
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields)
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
try:
user_prompts = [
prompt.format(**substitution) for substitution in substitutions
]
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = self._estimate_cost(
user_prompt=user_prompt,
model=self.model,
output_fields=output_fields,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
cumulative_total_cost += total_cost
return CostEstimate(
prompt_cost_usd=cumulative_prompt_cost,
completion_cost_usd=cumulative_completion_cost,
total_cost_usd=cumulative_total_cost,
)

except Exception as e:
logger.error("Failed to estimate cost: %s", e)
return CostEstimate(
prompt_cost_usd=None,
completion_cost_usd=None,
total_cost_usd=None,
)


class LiteLLMVisionRuntime(LiteLLMChatRuntime):
"""
Expand Down
30 changes: 30 additions & 0 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
tqdm.pandas()


class CostEstimate(BaseModel):
prompt_cost_usd: Optional[float]
completion_cost_usd: Optional[float]
total_cost_usd: Optional[float]

def __add__(self, other: "CostEstimate") -> "CostEstimate":
def _safe_add(lhs: Optional[float], rhs: Optional[float]) -> Optional[float]:
if lhs is None and rhs is None:
return None
_lhs = lhs or 0.0
_rhs = rhs or 0.0
return _lhs + _rhs

prompt_cost_usd = _safe_add(self.prompt_cost_usd, other.prompt_cost_usd)
completion_cost_usd = _safe_add(
self.completion_cost_usd, other.completion_cost_usd
)
total_cost_usd = _safe_add(self.total_cost_usd, other.total_cost_usd)
return CostEstimate(
prompt_cost_usd=prompt_cost_usd,
completion_cost_usd=completion_cost_usd,
total_cost_usd=total_cost_usd,
)


class Runtime(BaseModelInRegistry):
"""
Base class representing a generic runtime environment.
Expand Down Expand Up @@ -191,6 +216,11 @@ def record_to_batch(
response_model=response_model,
)

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
raise NotImplementedError("This runtime does not support cost estimates")


class AsyncRuntime(Runtime):
"""Async version of runtime that uses asyncio to process batch of records."""
Expand Down
59 changes: 59 additions & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aiokafka.errors import UnknownTopicOrPartitionError
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
Expand All @@ -24,6 +25,7 @@
from server.handlers.result_handlers import ResultHandler
from server.log_middleware import LogMiddleware
from adala.skills.collection.prompt_improvement import ImprovedPromptResponse
from adala.runtimes.base import CostEstimate
from server.tasks.stream_inference import streaming_parent_task
from server.utils import (
Settings,
Expand Down Expand Up @@ -81,6 +83,12 @@ class BatchSubmitted(BaseModel):
job_id: str


class CostEstimateRequest(BaseModel):
agent: Agent
prompt: str
substitutions: List[Dict]


class Status(Enum):
PENDING = "Pending"
INPROGRESS = "InProgress"
Expand Down Expand Up @@ -210,6 +218,57 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
):
"""
Estimates what it would cost to run inference on the batch of data in
`request` (using the run params from `request`)

Args:
request (CostEstimateRequest): Specification for the inference run to
make an estimate for, includes:
agent (adala.agent.Agent): The agent definition, used to get the model
and any other params necessary to estimate cost
prompt (str): The prompt template that will be used for each task
substitutions (List[Dict]): Mappings to substitute (simply using str.format)

Returns:
Response[CostEstimate]: The cost estimate, including the prompt/completion/total costs (in USD)
"""
prompt = request.prompt
substitutions = request.substitutions
agent = request.agent
runtime = agent.get_runtime()

try:
cost_estimates = []
for skill in agent.get_skills():
output_fields = (
list(skill.field_schema.keys()) if skill.field_schema else None
)
cost_estimate = runtime.get_cost_estimate(
prompt=prompt, substitutions=substitutions, output_fields=output_fields
)
cost_estimates.append(cost_estimate)
total_cost_estimate = sum(
cost_estimates,
CostEstimate(
prompt_cost_usd=None, completion_cost_usd=None, total_cost_usd=None
),
)
except NotImplementedError:
return Response[CostEstimate](
data=CostEstimate(
prompt_cost_usd=None,
completion_cost_usd=None,
total_cost_usd=None,
)
)
return Response[CostEstimate](data=total_cost_estimate)


@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
def get_status(job_id):
"""
Expand Down
78 changes: 78 additions & 0 deletions tests/test_cost_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
from adala.runtimes._litellm import AsyncLiteLLMChatRuntime
from adala.runtimes.base import CostEstimate
from adala.agents import Agent
from adala.skills import ClassificationSkill
import numpy as np
import os
from fastapi.testclient import TestClient
from server.app import app, CostEstimateRequest

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


def test_simple_estimate_cost():
runtime = AsyncLiteLLMChatRuntime(model="gpt-4o-mini", api_key=OPENAI_API_KEY)

cost_estimate = runtime.get_cost_estimate(
prompt="testing, {text}",
substitutions=[{"text": "knock knock, who's there"}],
output_fields=["text"],
)

assert isinstance(cost_estimate, CostEstimate)
assert isinstance(cost_estimate.prompt_cost_usd, float)
assert isinstance(cost_estimate.completion_cost_usd, float)
assert isinstance(cost_estimate.total_cost_usd, float)
assert np.isclose(
cost_estimate.total_cost_usd,
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
)


def test_estimate_cost_endpoint():
test_client = TestClient(app)
req = {
"agent": {
"skills": [
{
"type": "ClassificationSkill",
"name": "text_classifier",
"instructions": "Always return the answer 'Feature Lack'.",
"input_template": "{text}",
"output_template": "{output}",
"labels": [
"Feature Lack",
"Price",
"Integration Issues",
"Usability Concerns",
"Competitor Advantage",
],
}
],
"runtimes": {
"default": {
"type": "AsyncLiteLLMChatRuntime",
"model": "gpt-4o-mini",
"api_key": OPENAI_API_KEY,
}
},
},
"prompt": "test {text}",
"substitutions": [{"text": "test"}],
}
resp = test_client.post(
"/estimate-cost",
json=req,
)
resp_data = resp.json()["data"]
cost_estimate = CostEstimate(**resp_data)

assert isinstance(cost_estimate, CostEstimate)
assert isinstance(cost_estimate.prompt_cost_usd, float)
assert isinstance(cost_estimate.completion_cost_usd, float)
assert isinstance(cost_estimate.total_cost_usd, float)
assert np.isclose(
cost_estimate.total_cost_usd,
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
)
Loading