Skip to content

Commit

Permalink
Add support for Anthropic API
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Aug 15, 2024
1 parent 13e9fac commit 254d02b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 43 deletions.
87 changes: 60 additions & 27 deletions src/api_client.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,73 @@
import os
from abc import ABC, abstractmethod

import httpx

from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion

class ApiClient:
TIMEOUT = 120
CONFIG = {
"openai": {
"api_url": "https://api.openai.com",
"api_key_env": "OPENAI_API_KEY",
},
"openrouter": {
"api_url": "https://openrouter.ai/api",
"api_key_env": "OPENROUTER_API_KEY",
},
}

def __init__(self, provider):
self.provider = provider
class APIClient(ABC):
TIMEOUT = 60

async def call_api(self, messages, parameters):
body = {"messages": messages, **parameters}

headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
def __init__(self):
self.api_key = os.environ.get(self.ENV_KEY)

async def call_api(self, body):
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
url = f"{self.api_url}/v1/chat/completions"
response = await client.post(url, headers=headers, json=body)
response = await client.post(
self.API_URL, headers=self.get_headers(), json=body
)
response.raise_for_status()
return response.json()

@property
def api_url(self):
return self.CONFIG[self.provider]["api_url"]
@abstractmethod
def get_headers(self):
pass

@staticmethod
def create(client_type):
if client_type == "openrouter":
return OpenRouterAPIClient()
elif client_type == "anthropic":
return AnthropicAPIClient()
else:
raise ValueError(f"Unsupported client type: {client_type}")


class OpenRouterAPIClient(APIClient):
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"

async def call_api(self, messages, parameters, pricing):
body = {"messages": messages, **parameters}
response = await super().call_api(body)
return OpenRouterChatCompletion(response, pricing)

def get_headers(self):
return {"Authorization": f"Bearer {self.api_key}"}


class AnthropicAPIClient(APIClient):
API_URL = "https://api.anthropic.com/v1/messages"
ENV_KEY = "ANTHROPIC_API_KEY"

async def call_api(self, messages, parameters, pricing):
messages, system = self._transform_messages(messages)
body = {"messages": messages, "system": system, **parameters}
response = await super().call_api(body)
return AnthropicChatCompletion(response, pricing)

def get_headers(self):
return {"x-api-key": self.api_key, "anthropic-version": "2023-06-01"}

def _transform_messages(self, original_messages):
messages = []
system = []

for message in original_messages:
if message["role"] == "system":
system.append({"type": "text", "text": message["content"]})
else:
messages.append(message)

@property
def api_key(self):
var = self.CONFIG[self.provider]["api_key_env"]
return os.environ.get(var)
return messages, system
51 changes: 37 additions & 14 deletions src/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def __init__(self, response, pricing):

@classmethod
async def generate(cls, client, content, parameters, pricing=None):
response = await client.call_api(content, parameters)
completion = cls(response, pricing)
completion = await client.call_api(content, parameters, pricing)
completion.validate()
completion.logger.log(parameters, content, completion.content)
return completion
Expand All @@ -23,34 +22,58 @@ def validate(self):
if self.content == "":
raise Exception("Response was empty")

@property
def cost(self):
if self.pricing:
return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + (
self.completion_tokens / 1_000_000 * self.pricing[1]
)
return 0

@property
def error_message(self):
return self.response.get("error", {}).get("message", "")


class AnthropicChatCompletion(ChatCompletion):
@property
def choice(self):
return self.response["choices"][0]
return self.response["content"][0]

@property
def content(self):
return self.choice["message"]["content"]
return self.choice["text"]

@property
def prompt_tokens(self):
return self.response["usage"]["prompt_tokens"]

@property
def completion_tokens(self):
return self.response["usage"]["completion_tokens"]
return self.response["usage"]["output_tokens"]

@property
def finish_reason(self):
return self.choice["finish_reason"]
return self.response["stop_reason"]


class OpenRouterChatCompletion(ChatCompletion):
@property
def cost(self):
if self.pricing:
return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + (
self.completion_tokens / 1_000_000 * self.pricing[1]
)
return 0
def choice(self):
return self.response["choices"][0]

@property
def error_message(self):
return self.response.get("error", {}).get("message", "")
def content(self):
return self.choice["message"]["content"]

@property
def prompt_tokens(self):
return self.response["usage"]["prompt_tokens"]

@property
def completion_tokens(self):
return self.response["usage"]["completion_tokens"]

@property
def finish_reason(self):
return self.choice["finish_reason"]
4 changes: 2 additions & 2 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jinja2
import yaml

from ..api_client import ApiClient
from ..api_client import APIClient
from ..chat_completion import ChatCompletion
from ..resolve_vars import resolve_vars

Expand All @@ -13,7 +13,7 @@ def __init__(self, context):
self.context = context

async def execute(self):
client = ApiClient(self.context.api_provider)
client = APIClient.create(self.context.api_provider)

params = {"max_tokens": 1000}
if self.context.model is not None:
Expand Down

0 comments on commit 254d02b

Please sign in to comment.