Skip to content

Commit

Permalink
Add cachable artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Aug 16, 2024
1 parent 9302044 commit 3d05d59
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 22 deletions.
16 changes: 10 additions & 6 deletions src/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ class AnthropicAPIClient(APIClient):
ENV_KEY = "ANTHROPIC_API_KEY"

def get_headers(self):
return {"x-api-key": self.api_key, "anthropic-version": "2023-06-01"}
return {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
}

def prepare_body(self, messages, parameters):
other_messages, system = self._transform_messages(messages)
Expand All @@ -75,9 +79,9 @@ def create_completion(self, response, pricing):

def _transform_messages(self, original_messages):
messages = [msg for msg in original_messages if msg["role"] != "system"]
system = [
{"type": "text", "text": msg["content"]}
for msg in original_messages
if msg["role"] == "system"
]
system = []
for msg in original_messages:
if msg["role"] == "system":
system.extend(msg["content"])

return messages, system
16 changes: 16 additions & 0 deletions src/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def cost(self):
def error_message(self):
return self.response.get("error", {}).get("message", "")

@property
def cache_creation_input_tokens(self):
return 0

@property
def cache_read_input_tokens(self):
return 0


class AnthropicChatCompletion(ChatCompletion):
@property
Expand All @@ -42,6 +50,14 @@ def prompt_tokens(self):
def completion_tokens(self):
return self.response["usage"]["output_tokens"]

@property
def cache_creation_input_tokens(self):
return self.response["usage"]["cache_creation_input_tokens"]

@property
def cache_read_input_tokens(self):
return self.response["usage"]["cache_read_input_tokens"]

@property
def finish_reason(self):
return self.response["stop_reason"]
Expand Down
23 changes: 18 additions & 5 deletions src/lm_executors/chat_executor_template.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
---
{% for artifact in artifacts %}
- role: system
content: |-
{{ chat_prompt | indent(4) }}
content:
- type: text
text: |-
{{ artifact | indent(8) }}
cache_control:
type: ephemeral
{% endfor %}
- role: system
content:
- type: text
text: |-
{{ chat_prompt | indent(8) }}
{% for message in messages %}
{% if message.image_url %}
- role: {{ message.role }}
Expand All @@ -20,7 +31,9 @@
{% endif %}
{% endfor %}
{% if reinforcement_chat_prompt %}
- role: {{ 'system' if not model.startswith('anthropic/') else 'assistant' }}
content: |-
{{ reinforcement_chat_prompt | indent(4) }}
- role: {{ 'system' if not 'claude' in model else 'assistant' }}
content:
- type: text
text: |-
{{ reinforcement_chat_prompt | indent(8) }}
{% endif %}
30 changes: 19 additions & 11 deletions src/telegram/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,26 @@ async def undo_command_handler(self, ctx):

@message_handler
async def stats_command_handler(self, ctx):
lines = []
lines.append("*Conversation*")
lines.append(f"`Cost: ${self.sim.get_conversation_cost():.2f}`")
lines.append("\n*Last Message*")
last_completion = self.sim.last_completion
if last_completion:
lines.append(f"`Cost: ${last_completion.cost:.2f}`")
lines.append(f"`Prompt tokens: {last_completion.prompt_tokens}`")
lines.append(f"`Completion tokens: {last_completion.completion_tokens}`")
conversation_cost = (
f"*Conversation*\n`Cost: ${self.sim.get_conversation_cost():.2f}`"
)

last_message_stats = "*Last Message*\n"
if self.sim.last_completion:
lc = self.sim.last_completion
last_message_stats += "\n".join(
[
f"`Cost: ${lc.cost:.2f}`",
f"`Prompt tokens: {lc.prompt_tokens}`",
f"`Completion tokens: {lc.completion_tokens}`",
f"`Cache creation tokens: {lc.cache_creation_input_tokens}`",
f"`Cache read tokens: {lc.cache_read_input_tokens}`",
]
)
else:
lines.append("`Not available`")
await ctx.send_message("\n".join(lines))
last_message_stats += "`Not available`"

await ctx.send_message(f"{conversation_cost}\n\n{last_message_stats}")

@message_handler
async def clear_command_handler(self, ctx):
Expand Down

0 comments on commit 3d05d59

Please sign in to comment.