diff --git a/src/api_client.py b/src/api_client.py index 2f3f3af..af3b55f 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -1,40 +1,73 @@ import os +from abc import ABC, abstractmethod import httpx +from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion -class ApiClient: - TIMEOUT = 120 - CONFIG = { - "openai": { - "api_url": "https://api.openai.com", - "api_key_env": "OPENAI_API_KEY", - }, - "openrouter": { - "api_url": "https://openrouter.ai/api", - "api_key_env": "OPENROUTER_API_KEY", - }, - } - def __init__(self, provider): - self.provider = provider +class APIClient(ABC): + TIMEOUT = 60 - async def call_api(self, messages, parameters): - body = {"messages": messages, **parameters} - - headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + def __init__(self): + self.api_key = os.environ.get(self.ENV_KEY) + async def call_api(self, body): async with httpx.AsyncClient(timeout=self.TIMEOUT) as client: - url = f"{self.api_url}/v1/chat/completions" - response = await client.post(url, headers=headers, json=body) + response = await client.post( + self.API_URL, headers=self.get_headers(), json=body + ) response.raise_for_status() return response.json() - @property - def api_url(self): - return self.CONFIG[self.provider]["api_url"] + @abstractmethod + def get_headers(self): + pass + + @staticmethod + def create(client_type): + if client_type == "openrouter": + return OpenRouterAPIClient() + elif client_type == "anthropic": + return AnthropicAPIClient() + else: + raise ValueError(f"Unsupported client type: {client_type}") + + +class OpenRouterAPIClient(APIClient): + API_URL = "https://openrouter.ai/api/v1/chat/completions" + ENV_KEY = "OPENROUTER_API_KEY" + + async def call_api(self, messages, parameters, pricing): + body = {"messages": messages, **parameters} + response = await super().call_api(body) + return OpenRouterChatCompletion(response, pricing) + + def get_headers(self): + return {"Authorization": f"Bearer {self.api_key}"} + + +class AnthropicAPIClient(APIClient): + API_URL = "https://api.anthropic.com/v1/messages" + ENV_KEY = "ANTHROPIC_API_KEY" + + async def call_api(self, messages, parameters, pricing): + messages, system = self._transform_messages(messages) + body = {"messages": messages, "system": system, **parameters} + response = await super().call_api(body) + return AnthropicChatCompletion(response, pricing) + + def get_headers(self): + return {"x-api-key": self.api_key, "anthropic-version": "2023-06-01"} + + def _transform_messages(self, original_messages): + messages = [] + system = [] + + for message in original_messages: + if message["role"] == "system": + system.append({"type": "text", "text": message["content"]}) + else: + messages.append(message) - @property - def api_key(self): - var = self.CONFIG[self.provider]["api_key_env"] - return os.environ.get(var) + return messages, system diff --git a/src/chat_completion.py b/src/chat_completion.py index a0be1f8..2f37672 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -9,8 +9,7 @@ def __init__(self, response, pricing): @classmethod async def generate(cls, client, content, parameters, pricing=None): - response = await client.call_api(content, parameters) - completion = cls(response, pricing) + completion = await client.call_api(content, parameters, pricing) completion.validate() completion.logger.log(parameters, content, completion.content) return completion @@ -23,13 +22,27 @@ def validate(self): if self.content == "": raise Exception("Response was empty") + @property + def cost(self): + if self.pricing: + return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + ( + self.completion_tokens / 1_000_000 * self.pricing[1] + ) + return 0 + + @property + def error_message(self): + return self.response.get("error", {}).get("message", "") + + +class AnthropicChatCompletion(ChatCompletion): @property def choice(self): - return self.response["choices"][0] + return self.response["content"][0] @property def content(self): - return self.choice["message"]["content"] + return self.choice["text"] @property def prompt_tokens(self): @@ -37,20 +50,30 @@ def prompt_tokens(self): @property def completion_tokens(self): - return self.response["usage"]["completion_tokens"] + return self.response["usage"]["output_tokens"] @property def finish_reason(self): - return self.choice["finish_reason"] + return self.response["stop_reason"] + +class OpenRouterChatCompletion(ChatCompletion): @property - def cost(self): - if self.pricing: - return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + ( - self.completion_tokens / 1_000_000 * self.pricing[1] - ) - return 0 + def choice(self): + return self.response["choices"][0] @property - def error_message(self): - return self.response.get("error", {}).get("message", "") + def content(self): + return self.choice["message"]["content"] + + @property + def prompt_tokens(self): + return self.response["usage"]["prompt_tokens"] + + @property + def completion_tokens(self): + return self.response["usage"]["completion_tokens"] + + @property + def finish_reason(self): + return self.choice["finish_reason"] diff --git a/src/lm_executors/chat_executor.py b/src/lm_executors/chat_executor.py index 07928f7..4011315 100644 --- a/src/lm_executors/chat_executor.py +++ b/src/lm_executors/chat_executor.py @@ -1,7 +1,7 @@ import jinja2 import yaml -from ..api_client import ApiClient +from ..api_client import APIClient from ..chat_completion import ChatCompletion from ..resolve_vars import resolve_vars @@ -13,7 +13,7 @@ def __init__(self, context): self.context = context async def execute(self): - client = ApiClient(self.context.api_provider) + client = APIClient.create(self.context.api_provider) params = {"max_tokens": 1000} if self.context.model is not None: