Skip to content

Commit

Permalink
Merge pull request #18 from rob1nzon/Add-img-support
Browse files Browse the repository at this point in the history
  • Loading branch information
ruecat authored Jan 17, 2024
2 parents 88e2074 + e9feab5 commit 470ce16
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions bot/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import traceback

import io
import base64
from aiogram import Bot, Dispatcher, types
from aiogram.enums import ParseMode
from aiogram.filters.command import Command, CommandStart
Expand Down Expand Up @@ -90,10 +92,14 @@ async def modelmanager_callback_handler(query: types.CallbackQuery):
modelmanager_builder = InlineKeyboardBuilder()
for model in models:
modelname = model["name"]
modelfamilies = ""
if model["details"]["families"]:
modelicon = {"llama": "🦙","clip":"📷"}
modelfamilies = "".join([modelicon[family] for family in model['details']['families']])
# Add a button for each model
modelmanager_builder.row(
types.InlineKeyboardButton(
text=modelname, callback_data=f"model_{modelname}"
text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}"
)
)
await query.message.edit_text(
Expand All @@ -106,6 +112,7 @@ async def modelmanager_callback_handler(query: types.CallbackQuery):
@dp.callback_query(lambda query: query.data.startswith("model_"))
async def model_callback_handler(query: types.CallbackQuery):
global modelname
global modelfamily
modelname = query.data.split("model_")[1]
await query.answer(f"Chosen model: {modelname}")

Expand Down Expand Up @@ -137,14 +144,24 @@ async def handle_message(message: types.Message):
)
if (
is_allowed_user
and message.text
and (message.text or message.caption)
and (is_private_chat or (is_supergroup and bot_mentioned))
):
if is_supergroup and bot_mentioned:
cutmention = len(botinfo.username) + 2
prompt = message.text[cutmention:] # + ""
prompt = message.text[cutmention:] or message.caption[cutmention:] # + ""
else:
prompt = message.text
prompt = message.text or message.caption

image_base64=''
if message.content_type=='photo':
image_buffer = io.BytesIO()
await bot.download(
message.photo[-1],
destination=image_buffer
)
image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8')

await bot.send_chat_action(message.chat.id, "typing")
full_response = ""
sent_message = None
Expand All @@ -155,12 +172,12 @@ async def handle_message(message: types.Message):
if ACTIVE_CHATS.get(message.from_user.id) is None:
ACTIVE_CHATS[message.from_user.id] = {
"model": modelname,
"messages": [{"role": "user", "content": prompt}],
"messages": [{"role": "user", "content": prompt, "images": [image_base64]}],
"stream": True,
}
else:
ACTIVE_CHATS[message.from_user.id]["messages"].append(
{"role": "user", "content": prompt}
{"role": "user", "content": prompt, "images": [image_base64]}
)
logging.info(
f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
Expand Down

0 comments on commit 470ce16

Please sign in to comment.