Skip to content

Commit

Permalink
Merge branch 'message-caching'
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed Nov 11, 2024
2 parents ebafbfa + 8299a18 commit 73f043f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
31 changes: 21 additions & 10 deletions src/simulacrum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import textwrap

from .context import Context
from .lm_executors import ChatExecutor
Expand All @@ -9,11 +10,12 @@ def __init__(self, context_file):
self.context = Context(context_file)
self.last_completion = None
self.cost_warning_sent = False
self.instruction_text = None

async def chat(self, user_input, user_name, image_url):
self.context.load()
if user_input:
user_input = self._apply_attribution(user_input, user_name)
user_input = self._inject_instruction(user_input)
self.context.add_message("user", user_input, image_url)
self.context.save()
completion = await ChatExecutor(self.context).execute()
Expand All @@ -37,11 +39,14 @@ def reset_conversation(self):
self.context.save()
self.cost_warning_sent = False

def add_conversation_fact(self, fact):
def add_conversation_fact(self, fact_text):
self.context.load()
self.context.add_conversation_fact(fact)
self.context.add_conversation_fact(fact_text)
self.context.save()

def apply_instruction(self, instruction_text):
self.instruction_text = instruction_text

def undo_last_messages_by_role(self, role):
self.context.load()
num_messages = len(self.context.conversation_messages)
Expand Down Expand Up @@ -69,10 +74,16 @@ def _strip_tag(self, content, tag):
content = re.sub(r"\n{3,}", "\n\n", content)
return content.strip()

def _apply_attribution(self, input, name):
attribute_messages = self.context.vars.get("attribute_messages")
if attribute_messages and name:
if isinstance(attribute_messages, dict) and name in attribute_messages:
name = attribute_messages[name]
input = f"{name}: {input}"
return input
def _inject_instruction(self, text):
if self.instruction_text:
text = textwrap.dedent(
f"""
{text}
<instruct>
{self.instruction_text}
</instruct>
"""
)
self.instruction_text = None
return text
32 changes: 22 additions & 10 deletions src/telegram/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def __init__(self, context_filepath, telegram_token, authorized_users):
command_handlers = [
(["new", "n"], self.new_conversation_command_handler),
(["retry", "r"], self.retry_command_handler),
(["reply", "rep"], self.reply_command_handler),
(["continue", "co"], self.continue_command_handler),
(["undo", "u"], self.undo_command_handler),
(["fact", "f"], self.add_fact_command_handler),
(["instruct", "i"], self.apply_instruction_command_handler),
(["stats", "s"], self.stats_command_handler),
(["clear", "c"], self.clear_command_handler),
(["clear", "cl"], self.clear_command_handler),
(["help", "h"], self.help_command_handler),
(["start"], self.do_nothing),
]
Expand Down Expand Up @@ -82,7 +83,7 @@ async def retry_command_handler(self, ctx):
await self._chat(ctx, user_message=None)

@message_handler
async def reply_command_handler(self, ctx):
async def continue_command_handler(self, ctx):
await self._chat(ctx, user_message=None)

@message_handler
Expand Down Expand Up @@ -119,12 +120,15 @@ async def clear_command_handler(self, ctx):

@message_handler
async def add_fact_command_handler(self, ctx):
fact_text = re.search(r"/fact (.*)", ctx.message.text)
if fact_text:
self.sim.add_conversation_fact(fact_text.group(1))
await ctx.send_message("`✅ Fact added to conversation`")
else:
await ctx.send_message("`❌ No text provided`")
await self._process_text_after_command(
ctx, self.sim.add_conversation_fact, "`✅ Fact added to conversation`"
)

@message_handler
async def apply_instruction_command_handler(self, ctx):
await self._process_text_after_command(
ctx, self.sim.apply_instruction, "`✅ Instruction applied to next response`"
)

@message_handler
async def help_command_handler(self, ctx):
Expand All @@ -134,10 +138,10 @@ async def help_command_handler(self, ctx):
*Actions*
/new - Start a new conversation
/retry - Retry the last response
/reply - Reply immediately
/undo - Undo the last exchange
/clear - Clear the conversation
/fact - Add a fact to the conversation
/continue - Request another response
*Information*
/stats - Show conversation statistics
Expand Down Expand Up @@ -181,3 +185,11 @@ async def _warn_cost(self, ctx, threshold_high=0.15, threshold_elevated=0.10):
await ctx.send_message(
"🟡 Cost is elevated. Start a new conversation when ready."
)

async def _process_text_after_command(self, ctx, action_method, success_message):
command_text = re.search(r"/\w+\s+(.*)", ctx.message.text)
if command_text:
action_method(command_text.group(1))
await ctx.send_message(success_message)
else:
await ctx.send_message("`❌ No text provided`")

0 comments on commit 73f043f

Please sign in to comment.