From e9feab5e8401e9032ccbda20fd8507a3a4e3c437 Mon Sep 17 00:00:00 2001 From: Denis Date: Wed, 17 Jan 2024 00:15:36 +0300 Subject: [PATCH] Add multimodal model support --- bot/run.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/bot/run.py b/bot/run.py index 18b77ad..101169d 100644 --- a/bot/run.py +++ b/bot/run.py @@ -1,6 +1,8 @@ import asyncio import traceback +import io +import base64 from aiogram import Bot, Dispatcher, types from aiogram.enums import ParseMode from aiogram.filters.command import Command, CommandStart @@ -90,10 +92,14 @@ async def modelmanager_callback_handler(query: types.CallbackQuery): modelmanager_builder = InlineKeyboardBuilder() for model in models: modelname = model["name"] + modelfamilies = "" + if model["details"]["families"]: + modelicon = {"llama": "🦙","clip":"📷"} + modelfamilies = "".join([modelicon[family] for family in model['details']['families']]) # Add a button for each model modelmanager_builder.row( types.InlineKeyboardButton( - text=modelname, callback_data=f"model_{modelname}" + text=f"{modelname} {modelfamilies}", callback_data=f"model_{modelname}" ) ) await query.message.edit_text( @@ -106,6 +112,7 @@ async def modelmanager_callback_handler(query: types.CallbackQuery): @dp.callback_query(lambda query: query.data.startswith("model_")) async def model_callback_handler(query: types.CallbackQuery): global modelname + global modelfamily modelname = query.data.split("model_")[1] await query.answer(f"Chosen model: {modelname}") @@ -137,14 +144,24 @@ async def handle_message(message: types.Message): ) if ( is_allowed_user - and message.text + and (message.text or message.caption) and (is_private_chat or (is_supergroup and bot_mentioned)) ): if is_supergroup and bot_mentioned: cutmention = len(botinfo.username) + 2 - prompt = message.text[cutmention:] # + "" + prompt = message.text[cutmention:] or message.caption[cutmention:] # + "" else: - prompt = message.text + prompt = message.text or message.caption + + image_base64='' + if message.content_type=='photo': + image_buffer = io.BytesIO() + await bot.download( + message.photo[-1], + destination=image_buffer + ) + image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8') + await bot.send_chat_action(message.chat.id, "typing") full_response = "" sent_message = None @@ -155,12 +172,12 @@ async def handle_message(message: types.Message): if ACTIVE_CHATS.get(message.from_user.id) is None: ACTIVE_CHATS[message.from_user.id] = { "model": modelname, - "messages": [{"role": "user", "content": prompt}], + "messages": [{"role": "user", "content": prompt, "images": [image_base64]}], "stream": True, } else: ACTIVE_CHATS[message.from_user.id]["messages"].append( - {"role": "user", "content": prompt} + {"role": "user", "content": prompt, "images": [image_base64]} ) logging.info( f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"