diff --git a/src/api_client.py b/src/api_client.py index 043c68b..1185647 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -1,90 +1,55 @@ import os -from abc import ABC, abstractmethod import httpx -from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion +from .chat_completion import ChatCompletion from .logger import Logger -class APIClient(ABC): +class OpenRouterAPIClient: + API_URL = "https://openrouter.ai/api/v1/chat/completions" + ENV_KEY = "OPENROUTER_API_KEY" TIMEOUT = 30 def __init__(self): self.api_key = os.environ.get(self.ENV_KEY) self.logger = Logger("log.yml") - @staticmethod - def create(client_type): - clients = {"openrouter": OpenRouterAPIClient, "anthropic": AnthropicAPIClient} - if client_type not in clients: - raise ValueError(f"Unsupported client type: {client_type}") - return clients[client_type]() - - async def request_completion(self, messages, parameters, pricing): - body = self.prepare_body(messages, parameters) - try: - async with httpx.AsyncClient(timeout=self.TIMEOUT) as client: - response = await client.post( - self.API_URL, headers=self.get_headers(), json=body - ) - response.raise_for_status() - completion = self.create_completion(response.json(), pricing) - self.logger.log(parameters, messages, completion.content) - return completion - except httpx.ReadTimeout: - raise Exception("Request timed out") - - @abstractmethod - def get_headers(self): - pass - - @abstractmethod - def prepare_body(self, messages, parameters): - pass - - @abstractmethod - def create_completion(self, response, pricing): - pass - - -class OpenRouterAPIClient(APIClient): - API_URL = "https://openrouter.ai/api/v1/chat/completions" - ENV_KEY = "OPENROUTER_API_KEY" - def get_headers(self): return {"Authorization": f"Bearer {self.api_key}"} def prepare_body(self, messages, parameters): return {"messages": messages, **parameters} - def create_completion(self, response, pricing): - return OpenRouterChatCompletion(response, pricing) - - -class AnthropicAPIClient(APIClient): - API_URL = "https://api.anthropic.com/v1/messages" - ENV_KEY = "ANTHROPIC_API_KEY" + async def request_completion(self, messages, parameters, pricing): + body = self.prepare_body(messages, parameters) + try: + completion_data = await self.get_completion_data(body) + completion = ChatCompletion(completion_data, pricing) + self.logger.log(parameters, messages, completion.content) + return completion + except httpx.ReadTimeout: + raise Exception("Request timed out") - def get_headers(self): - return { - "x-api-key": self.api_key, - "anthropic-version": "2023-06-01", - "anthropic-beta": "prompt-caching-2024-07-31", - } + async def get_completion_data(self, body): + async with httpx.AsyncClient(timeout=self.TIMEOUT) as client: + completion_response = await client.post( + self.API_URL, headers=self.get_headers(), json=body + ) + completion_response.raise_for_status() + completion_data = completion_response.json() + if "error" in completion_data: + raise Exception(completion_data["error"]) - def prepare_body(self, messages, parameters): - other_messages, system = self._transform_messages(messages) - return {"messages": other_messages, "system": system, **parameters} + details_data = await self._poll_details(client, completion_data["id"]) + return {**completion_data, "details": details_data["data"]} - def create_completion(self, response, pricing): - return AnthropicChatCompletion(response, pricing) + async def _poll_details(self, client, generation_id, max_attempts=10): + details_url = f"https://openrouter.ai/api/v1/generation?id={generation_id}" - def _transform_messages(self, original_messages): - messages = [msg for msg in original_messages if msg["role"] != "system"] - system = [] - for msg in original_messages: - if msg["role"] == "system": - system.extend(msg["content"]) + for _ in range(max_attempts): + details_response = await client.get(details_url, headers=self.get_headers()) + if details_response.status_code == 200: + return details_response.json() - return messages, system + raise TimeoutError("Details not available after maximum attempts") diff --git a/src/chat_completion.py b/src/chat_completion.py index f647b27..4237e9a 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -12,74 +12,40 @@ def validate(self): if self.content == "": raise Exception("Response was empty") - @property - def cost(self): - if self.pricing: - return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + ( - self.completion_tokens / 1_000_000 * self.pricing[1] - ) - return 0 - - @property - def error_message(self): - return self.response.get("error", {}).get("message", "") - - @property - def cache_creation_input_tokens(self): - return 0 - - @property - def cache_read_input_tokens(self): - return 0 - - -class AnthropicChatCompletion(ChatCompletion): @property def choice(self): - return self.response["content"][0] + return self.response["choices"][0] @property def content(self): - return self.choice["text"] + return self.choice["message"]["content"] @property def prompt_tokens(self): - return self.response["usage"]["input_tokens"] + return self.response["usage"]["prompt_tokens"] @property def completion_tokens(self): - return self.response["usage"]["output_tokens"] - - @property - def cache_creation_input_tokens(self): - return self.response["usage"]["cache_creation_input_tokens"] - - @property - def cache_read_input_tokens(self): - return self.response["usage"]["cache_read_input_tokens"] + return self.response["usage"]["completion_tokens"] @property def finish_reason(self): - return self.response["stop_reason"] - - -class OpenRouterChatCompletion(ChatCompletion): - @property - def choice(self): - return self.response["choices"][0] + return self.choice.get("finish_reason") @property - def content(self): - return self.choice["message"]["content"] + def error_message(self): + return self.response.get("error", {}).get("message", "") @property - def prompt_tokens(self): - return self.response["usage"]["prompt_tokens"] + def cache_discount(self): + return self.response["details"]["cache_discount"] @property - def completion_tokens(self): - return self.response["usage"]["completion_tokens"] + def cache_discount_string(self): + sign = "-" if self.cache_discount < 0 else "" + amount = f"${abs(self.cache_discount):.2f}" + return f"{sign}{amount}" @property - def finish_reason(self): - return self.choice.get("finish_reason") + def cost(self): + return self.response["details"]["total_cost"] diff --git a/src/lm_executors/chat_executor.py b/src/lm_executors/chat_executor.py index e50a5ff..ca16759 100644 --- a/src/lm_executors/chat_executor.py +++ b/src/lm_executors/chat_executor.py @@ -1,7 +1,7 @@ import jinja2 import yaml -from ..api_client import APIClient +from ..api_client import OpenRouterAPIClient from ..resolve_vars import resolve_vars @@ -12,9 +12,9 @@ def __init__(self, context): self.context = context async def execute(self): - client = APIClient.create(self.context.api_provider) + client = OpenRouterAPIClient() - params = {"max_tokens": 1000} + params = {"max_tokens": 1024} if self.context.model is not None: params["model"] = self.context.model diff --git a/src/lm_executors/chat_executor_template.j2 b/src/lm_executors/chat_executor_template.j2 index 16c30c3..daa7680 100644 --- a/src/lm_executors/chat_executor_template.j2 +++ b/src/lm_executors/chat_executor_template.j2 @@ -25,6 +25,10 @@ - type: text text: |- {{ message.content | indent(8) }} + {% if loop.index in [(messages|length), (messages|length - 2)] %} + cache_control: + type: ephemeral + {% endif %} {% endfor %} {% if reinforcement_chat_prompt %} - role: {{ 'system' if not 'claude' in model else 'assistant' }} diff --git a/src/telegram/telegram_bot.py b/src/telegram/telegram_bot.py index cb8f2a3..29faa5e 100644 --- a/src/telegram/telegram_bot.py +++ b/src/telegram/telegram_bot.py @@ -102,10 +102,9 @@ async def stats_command_handler(self, ctx): last_message_stats += "\n".join( [ f"`Cost: ${lc.cost:.2f}`", + f"`Cache discount: {lc.cache_discount_string}`", f"`Prompt tokens: {lc.prompt_tokens}`", f"`Completion tokens: {lc.completion_tokens}`", - f"`Cache creation tokens: {lc.cache_creation_input_tokens}`", - f"`Cache read tokens: {lc.cache_read_input_tokens}`", ] ) else: