Skip to content

Commit

Permalink
Merge pull request #62 from rob1nzon/main
Browse files Browse the repository at this point in the history
feat: Add mention handling & format prompt threads
  • Loading branch information
ruecat authored Aug 9, 2024
2 parents 13f530f + 1085a9a commit 0fb16aa
Showing 1 changed file with 45 additions and 13 deletions.
58 changes: 45 additions & 13 deletions bot/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aiogram import Bot, Dispatcher
from aiogram import Bot, Dispatcher, types
from aiogram.enums import ParseMode
from aiogram.filters.command import Command, CommandStart
from aiogram.types import Message
Expand Down Expand Up @@ -34,17 +34,13 @@
CHAT_TYPE_SUPERGROUP = "supergroup"


def is_mentioned_in_group_or_supergroup(message):
return message.chat.type in [CHAT_TYPE_GROUP, CHAT_TYPE_SUPERGROUP] and (
(message.text is not None and message.text.startswith(mention))
or (message.caption is not None and message.caption.startswith(mention))
)
async def get_bot_info():
global mention
if mention is None:
get = await bot.get_me()
mention = f"@{get.username}"
return mention

@dp.message(CommandStart())
async def command_start_handler(message: Message) -> None:
start_message = f"Welcome, <b>{message.from_user.full_name}</b>!"
Expand Down Expand Up @@ -141,17 +137,53 @@ async def about_callback_handler(query: types.CallbackQuery):
@perms_allowed
async def handle_message(message: types.Message):
await get_bot_info()

if message.chat.type == "private":
await ollama_request(message)
if is_mentioned_in_group_or_supergroup(message):
if message.text is not None:
text_without_mention = message.text.replace(mention, "").strip()
prompt = text_without_mention
else:
text_without_mention = message.caption.replace(mention, "").strip()
prompt = text_without_mention
return

if await is_mentioned_in_group_or_supergroup(message):
thread = await collect_message_thread(message)
prompt = format_thread_for_prompt(thread)

await ollama_request(message, prompt)

async def is_mentioned_in_group_or_supergroup(message: types.Message):
if message.chat.type not in ["group", "supergroup"]:
return False

is_mentioned = (
(message.text and message.text.startswith(mention)) or
(message.caption and message.caption.startswith(mention))
)

is_reply_to_bot = (
message.reply_to_message and
message.reply_to_message.from_user.id == bot.id
)

return is_mentioned or is_reply_to_bot

async def collect_message_thread(message: types.Message, thread=None):
if thread is None:
thread = []

thread.insert(0, message)

if message.reply_to_message:
await collect_message_thread(message.reply_to_message, thread)

return thread

def format_thread_for_prompt(thread):
prompt = "Conversation thread:\n\n"
for msg in thread:
sender = "User" if msg.from_user.id != bot.id else "Bot"
content = msg.text or msg.caption or "[No text content]"
prompt += f"{sender}: {content}\n\n"

prompt += "History:"
return prompt

async def process_image(message):
image_base64 = ""
Expand Down

0 comments on commit 0fb16aa

Please sign in to comment.