Skip to content

Commit

Permalink
Remove AnthropicAPIClient
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Oct 26, 2024
1 parent 21b8af0 commit 9ef54ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 66 deletions.
74 changes: 10 additions & 64 deletions src/api_client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
import os
from abc import ABC, abstractmethod

import httpx

from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion
from .chat_completion import OpenRouterChatCompletion
from .logger import Logger


class APIClient(ABC):
class OpenRouterAPIClient:
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"
TIMEOUT = 30

def __init__(self):
self.api_key = os.environ.get(self.ENV_KEY)
self.logger = Logger("log.yml")

@staticmethod
def create(client_type):
clients = {"openrouter": OpenRouterAPIClient, "anthropic": AnthropicAPIClient}
if client_type not in clients:
raise ValueError(f"Unsupported client type: {client_type}")
return clients[client_type]()
def get_headers(self):
return {"Authorization": f"Bearer {self.api_key}"}

def prepare_body(self, messages, parameters):
return {"messages": messages, **parameters}

async def request_completion(self, messages, parameters, pricing):
body = self.prepare_body(messages, parameters)
try:
completion_data = await self.get_completion_data(body)
completion = self.create_completion(completion_data, pricing)
completion = OpenRouterChatCompletion(completion_data, pricing)
self.logger.log(parameters, messages, completion.content)
return completion
except httpx.ReadTimeout:
Expand Down Expand Up @@ -53,57 +53,3 @@ async def _poll_details(self, client, generation_id, max_attempts=10):
return details_response.json()

raise TimeoutError("Details not available after maximum attempts")

@abstractmethod
def get_headers(self):
pass

@abstractmethod
def prepare_body(self, messages, parameters):
pass

@abstractmethod
def create_completion(self, response, pricing):
pass


class OpenRouterAPIClient(APIClient):
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"

def get_headers(self):
return {"Authorization": f"Bearer {self.api_key}"}

def prepare_body(self, messages, parameters):
return {"messages": messages, **parameters}

def create_completion(self, response, pricing):
return OpenRouterChatCompletion(response, pricing)


class AnthropicAPIClient(APIClient):
API_URL = "https://api.anthropic.com/v1/messages"
ENV_KEY = "ANTHROPIC_API_KEY"

def get_headers(self):
return {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
}

def prepare_body(self, messages, parameters):
other_messages, system = self._transform_messages(messages)
return {"messages": other_messages, "system": system, **parameters}

def create_completion(self, response, pricing):
return AnthropicChatCompletion(response, pricing)

def _transform_messages(self, original_messages):
messages = [msg for msg in original_messages if msg["role"] != "system"]
system = []
for msg in original_messages:
if msg["role"] == "system":
system.extend(msg["content"])

return messages, system
4 changes: 2 additions & 2 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jinja2
import yaml

from ..api_client import APIClient
from ..api_client import OpenRouterAPIClient
from ..resolve_vars import resolve_vars


Expand All @@ -12,7 +12,7 @@ def __init__(self, context):
self.context = context

async def execute(self):
client = APIClient.create(self.context.api_provider)
client = OpenRouterAPIClient()

params = {"max_tokens": 1000}
if self.context.model is not None:
Expand Down

0 comments on commit 9ef54ae

Please sign in to comment.