diff --git a/src/api_client.py b/src/api_client.py index 78aa869..b6236df 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -64,7 +64,11 @@ class AnthropicAPIClient(APIClient): ENV_KEY = "ANTHROPIC_API_KEY" def get_headers(self): - return {"x-api-key": self.api_key, "anthropic-version": "2023-06-01"} + return { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + } def prepare_body(self, messages, parameters): other_messages, system = self._transform_messages(messages) @@ -75,9 +79,9 @@ def create_completion(self, response, pricing): def _transform_messages(self, original_messages): messages = [msg for msg in original_messages if msg["role"] != "system"] - system = [ - {"type": "text", "text": msg["content"]} - for msg in original_messages - if msg["role"] == "system" - ] + system = [] + for msg in original_messages: + if msg["role"] == "system": + system.extend(msg["content"]) + return messages, system diff --git a/src/chat_completion.py b/src/chat_completion.py index 68f19ea..33fb93d 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -24,6 +24,14 @@ def cost(self): def error_message(self): return self.response.get("error", {}).get("message", "") + @property + def cache_creation_input_tokens(self): + return 0 + + @property + def cache_read_input_tokens(self): + return 0 + class AnthropicChatCompletion(ChatCompletion): @property @@ -42,6 +50,14 @@ def prompt_tokens(self): def completion_tokens(self): return self.response["usage"]["output_tokens"] + @property + def cache_creation_input_tokens(self): + return self.response["usage"]["cache_creation_input_tokens"] + + @property + def cache_read_input_tokens(self): + return self.response["usage"]["cache_read_input_tokens"] + @property def finish_reason(self): return self.response["stop_reason"] diff --git a/src/lm_executors/chat_executor_template.yml b/src/lm_executors/chat_executor_template.yml index 85c7a51..f441779 100644 --- a/src/lm_executors/chat_executor_template.yml +++ b/src/lm_executors/chat_executor_template.yml @@ -1,7 +1,18 @@ --- +{% for artifact in artifacts %} - role: system - content: |- - {{ chat_prompt | indent(4) }} + content: + - type: text + text: |- + {{ artifact | indent(8) }} + cache_control: + type: ephemeral +{% endfor %} +- role: system + content: + - type: text + text: |- + {{ chat_prompt | indent(8) }} {% for message in messages %} {% if message.image_url %} - role: {{ message.role }} @@ -20,7 +31,9 @@ {% endif %} {% endfor %} {% if reinforcement_chat_prompt %} -- role: {{ 'system' if not model.startswith('anthropic/') else 'assistant' }} - content: |- - {{ reinforcement_chat_prompt | indent(4) }} +- role: {{ 'system' if not 'claude' in model else 'assistant' }} + content: + - type: text + text: |- + {{ reinforcement_chat_prompt | indent(8) }} {% endif %} diff --git a/src/telegram/telegram_bot.py b/src/telegram/telegram_bot.py index fc6cf3d..03b92fd 100644 --- a/src/telegram/telegram_bot.py +++ b/src/telegram/telegram_bot.py @@ -88,18 +88,26 @@ async def undo_command_handler(self, ctx): @message_handler async def stats_command_handler(self, ctx): - lines = [] - lines.append("*Conversation*") - lines.append(f"`Cost: ${self.sim.get_conversation_cost():.2f}`") - lines.append("\n*Last Message*") - last_completion = self.sim.last_completion - if last_completion: - lines.append(f"`Cost: ${last_completion.cost:.2f}`") - lines.append(f"`Prompt tokens: {last_completion.prompt_tokens}`") - lines.append(f"`Completion tokens: {last_completion.completion_tokens}`") + conversation_cost = ( + f"*Conversation*\n`Cost: ${self.sim.get_conversation_cost():.2f}`" + ) + + last_message_stats = "*Last Message*\n" + if self.sim.last_completion: + lc = self.sim.last_completion + last_message_stats += "\n".join( + [ + f"`Cost: ${lc.cost:.2f}`", + f"`Prompt tokens: {lc.prompt_tokens}`", + f"`Completion tokens: {lc.completion_tokens}`", + f"`Cache creation tokens: {lc.cache_creation_input_tokens}`", + f"`Cache read tokens: {lc.cache_read_input_tokens}`", + ] + ) else: - lines.append("`Not available`") - await ctx.send_message("\n".join(lines)) + last_message_stats += "`Not available`" + + await ctx.send_message(f"{conversation_cost}\n\n{last_message_stats}") @message_handler async def clear_command_handler(self, ctx):