Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpenRouter message caching #101

Merged
merged 5 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 31 additions & 66 deletions src/api_client.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,55 @@
import os
from abc import ABC, abstractmethod

import httpx

from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion
from .chat_completion import ChatCompletion
from .logger import Logger


class APIClient(ABC):
class OpenRouterAPIClient:
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"
TIMEOUT = 30

def __init__(self):
self.api_key = os.environ.get(self.ENV_KEY)
self.logger = Logger("log.yml")

@staticmethod
def create(client_type):
clients = {"openrouter": OpenRouterAPIClient, "anthropic": AnthropicAPIClient}
if client_type not in clients:
raise ValueError(f"Unsupported client type: {client_type}")
return clients[client_type]()

async def request_completion(self, messages, parameters, pricing):
body = self.prepare_body(messages, parameters)
try:
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
response = await client.post(
self.API_URL, headers=self.get_headers(), json=body
)
response.raise_for_status()
completion = self.create_completion(response.json(), pricing)
self.logger.log(parameters, messages, completion.content)
return completion
except httpx.ReadTimeout:
raise Exception("Request timed out")

@abstractmethod
def get_headers(self):
pass

@abstractmethod
def prepare_body(self, messages, parameters):
pass

@abstractmethod
def create_completion(self, response, pricing):
pass


class OpenRouterAPIClient(APIClient):
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"

def get_headers(self):
return {"Authorization": f"Bearer {self.api_key}"}

def prepare_body(self, messages, parameters):
return {"messages": messages, **parameters}

def create_completion(self, response, pricing):
return OpenRouterChatCompletion(response, pricing)


class AnthropicAPIClient(APIClient):
API_URL = "https://api.anthropic.com/v1/messages"
ENV_KEY = "ANTHROPIC_API_KEY"
async def request_completion(self, messages, parameters, pricing):
body = self.prepare_body(messages, parameters)
try:
completion_data = await self.get_completion_data(body)
completion = ChatCompletion(completion_data, pricing)
self.logger.log(parameters, messages, completion.content)
return completion
except httpx.ReadTimeout:
raise Exception("Request timed out")

def get_headers(self):
return {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
}
async def get_completion_data(self, body):
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
completion_response = await client.post(
self.API_URL, headers=self.get_headers(), json=body
)
completion_response.raise_for_status()
completion_data = completion_response.json()
if "error" in completion_data:
raise Exception(completion_data["error"])

def prepare_body(self, messages, parameters):
other_messages, system = self._transform_messages(messages)
return {"messages": other_messages, "system": system, **parameters}
details_data = await self._poll_details(client, completion_data["id"])
return {**completion_data, "details": details_data["data"]}

def create_completion(self, response, pricing):
return AnthropicChatCompletion(response, pricing)
async def _poll_details(self, client, generation_id, max_attempts=10):
details_url = f"https://openrouter.ai/api/v1/generation?id={generation_id}"

def _transform_messages(self, original_messages):
messages = [msg for msg in original_messages if msg["role"] != "system"]
system = []
for msg in original_messages:
if msg["role"] == "system":
system.extend(msg["content"])
for _ in range(max_attempts):
details_response = await client.get(details_url, headers=self.get_headers())
if details_response.status_code == 200:
return details_response.json()

return messages, system
raise TimeoutError("Details not available after maximum attempts")
64 changes: 15 additions & 49 deletions src/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,74 +12,40 @@ def validate(self):
if self.content == "":
raise Exception("Response was empty")

@property
def cost(self):
if self.pricing:
return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + (
self.completion_tokens / 1_000_000 * self.pricing[1]
)
return 0

@property
def error_message(self):
return self.response.get("error", {}).get("message", "")

@property
def cache_creation_input_tokens(self):
return 0

@property
def cache_read_input_tokens(self):
return 0


class AnthropicChatCompletion(ChatCompletion):
@property
def choice(self):
return self.response["content"][0]
return self.response["choices"][0]

@property
def content(self):
return self.choice["text"]
return self.choice["message"]["content"]

@property
def prompt_tokens(self):
return self.response["usage"]["input_tokens"]
return self.response["usage"]["prompt_tokens"]

@property
def completion_tokens(self):
return self.response["usage"]["output_tokens"]

@property
def cache_creation_input_tokens(self):
return self.response["usage"]["cache_creation_input_tokens"]

@property
def cache_read_input_tokens(self):
return self.response["usage"]["cache_read_input_tokens"]
return self.response["usage"]["completion_tokens"]

@property
def finish_reason(self):
return self.response["stop_reason"]


class OpenRouterChatCompletion(ChatCompletion):
@property
def choice(self):
return self.response["choices"][0]
return self.choice.get("finish_reason")

@property
def content(self):
return self.choice["message"]["content"]
def error_message(self):
return self.response.get("error", {}).get("message", "")

@property
def prompt_tokens(self):
return self.response["usage"]["prompt_tokens"]
def cache_discount(self):
return self.response["details"]["cache_discount"]

@property
def completion_tokens(self):
return self.response["usage"]["completion_tokens"]
def cache_discount_string(self):
sign = "-" if self.cache_discount < 0 else ""
amount = f"${abs(self.cache_discount):.2f}"
return f"{sign}{amount}"

@property
def finish_reason(self):
return self.choice.get("finish_reason")
def cost(self):
return self.response["details"]["total_cost"]
6 changes: 3 additions & 3 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jinja2
import yaml

from ..api_client import APIClient
from ..api_client import OpenRouterAPIClient
from ..resolve_vars import resolve_vars


Expand All @@ -12,9 +12,9 @@ def __init__(self, context):
self.context = context

async def execute(self):
client = APIClient.create(self.context.api_provider)
client = OpenRouterAPIClient()

params = {"max_tokens": 1000}
params = {"max_tokens": 1024}
if self.context.model is not None:
params["model"] = self.context.model

Expand Down
4 changes: 4 additions & 0 deletions src/lm_executors/chat_executor_template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
- type: text
text: |-
{{ message.content | indent(8) }}
{% if loop.index in [(messages|length), (messages|length - 2)] %}
cache_control:
type: ephemeral
{% endif %}
{% endfor %}
{% if reinforcement_chat_prompt %}
- role: {{ 'system' if not 'claude' in model else 'assistant' }}
Expand Down
3 changes: 1 addition & 2 deletions src/telegram/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ async def stats_command_handler(self, ctx):
last_message_stats += "\n".join(
[
f"`Cost: ${lc.cost:.2f}`",
f"`Cache discount: {lc.cache_discount_string}`",
f"`Prompt tokens: {lc.prompt_tokens}`",
f"`Completion tokens: {lc.completion_tokens}`",
f"`Cache creation tokens: {lc.cache_creation_input_tokens}`",
f"`Cache read tokens: {lc.cache_read_input_tokens}`",
]
)
else:
Expand Down