From c4eca55a9f904a2461d1d9e8174111b663fbfadf Mon Sep 17 00:00:00 2001 From: Zhang TingAn Date: Wed, 4 Dec 2024 10:24:12 +0800 Subject: [PATCH] feat: [AI] support adding multi LLM optimizing existing features Log:as title --- src/base/CMakeLists.txt | 2 + src/base/ai/abstractllm.cpp | 11 + src/base/ai/abstractllm.h | 45 +++ src/base/ai/conversation.cpp | 149 ++++++++ src/base/ai/conversation.h | 50 +++ src/plugins/CMakeLists.txt | 1 + src/plugins/aimanager/CMakeLists.txt | 54 +++ src/plugins/aimanager/aimanager.cpp | 92 +++++ src/plugins/aimanager/aimanager.h | 34 ++ src/plugins/aimanager/aimanager.json | 15 + src/plugins/aimanager/aimanager.qrc | 4 + src/plugins/aimanager/aiplugin.cpp | 41 +++ src/plugins/aimanager/aiplugin.h | 25 ++ src/plugins/aimanager/eventreceiver.cpp | 26 ++ src/plugins/aimanager/eventreceiver.h | 28 ++ .../openai/openaicompatibleconversation.cpp | 106 ++++++ .../openai/openaicompatibleconversation.h | 22 ++ .../aimanager/openai/openaicompatiblellm.cpp | 339 ++++++++++++++++++ .../aimanager/openai/openaicompatiblellm.h | 42 +++ .../aimanager/option/addmodeldialog.cpp | 232 ++++++++++++ src/plugins/aimanager/option/addmodeldialog.h | 23 ++ .../option/custommodelsoptionwidget.cpp | 69 ++++ .../option/custommodelsoptionwidget.h | 25 ++ src/plugins/aimanager/option/detailwidget.cpp | 207 +++++++++++ src/plugins/aimanager/option/detailwidget.h | 87 +++++ .../option/optioncustommodelsgenerator.cpp | 23 ++ .../option/optioncustommodelsgenerator.h | 18 + src/plugins/codegeex/codegeex.cpp | 3 +- src/plugins/codegeex/codegeex/askapi.cpp | 93 +++-- src/plugins/codegeex/codegeex/askapi.h | 5 + src/plugins/codegeex/codegeexmanager.cpp | 54 ++- src/plugins/codegeex/codegeexmanager.h | 3 + src/services/ai/aiservice.h | 60 +++- src/services/option/optiondatastruct.h | 1 + 34 files changed, 1951 insertions(+), 38 deletions(-) create mode 100644 src/base/ai/abstractllm.cpp create mode 100644 src/base/ai/abstractllm.h create mode 100644 src/base/ai/conversation.cpp create mode 100644 src/base/ai/conversation.h create mode 100644 src/plugins/aimanager/CMakeLists.txt create mode 100644 src/plugins/aimanager/aimanager.cpp create mode 100644 src/plugins/aimanager/aimanager.h create mode 100644 src/plugins/aimanager/aimanager.json create mode 100644 src/plugins/aimanager/aimanager.qrc create mode 100644 src/plugins/aimanager/aiplugin.cpp create mode 100644 src/plugins/aimanager/aiplugin.h create mode 100644 src/plugins/aimanager/eventreceiver.cpp create mode 100644 src/plugins/aimanager/eventreceiver.h create mode 100644 src/plugins/aimanager/openai/openaicompatibleconversation.cpp create mode 100644 src/plugins/aimanager/openai/openaicompatibleconversation.h create mode 100644 src/plugins/aimanager/openai/openaicompatiblellm.cpp create mode 100644 src/plugins/aimanager/openai/openaicompatiblellm.h create mode 100644 src/plugins/aimanager/option/addmodeldialog.cpp create mode 100644 src/plugins/aimanager/option/addmodeldialog.h create mode 100644 src/plugins/aimanager/option/custommodelsoptionwidget.cpp create mode 100644 src/plugins/aimanager/option/custommodelsoptionwidget.h create mode 100644 src/plugins/aimanager/option/detailwidget.cpp create mode 100644 src/plugins/aimanager/option/detailwidget.h create mode 100644 src/plugins/aimanager/option/optioncustommodelsgenerator.cpp create mode 100644 src/plugins/aimanager/option/optioncustommodelsgenerator.h diff --git a/src/base/CMakeLists.txt b/src/base/CMakeLists.txt index aaf727df1..95fb8e1a8 100644 --- a/src/base/CMakeLists.txt +++ b/src/base/CMakeLists.txt @@ -3,6 +3,8 @@ project(duc-base) FILE(GLOB_RECURSE PROJECT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.h" "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/*/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/*/*.cpp" ) add_library( diff --git a/src/base/ai/abstractllm.cpp b/src/base/ai/abstractllm.cpp new file mode 100644 index 000000000..a26e1dfbc --- /dev/null +++ b/src/base/ai/abstractllm.cpp @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "abstractllm.h" + +AbstractLLM::AbstractLLM(QObject *parent) + : QObject(parent) +{ + qRegisterMetaType("ResponseState"); +} diff --git a/src/base/ai/abstractllm.h b/src/base/ai/abstractllm.h new file mode 100644 index 000000000..2117146ad --- /dev/null +++ b/src/base/ai/abstractllm.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef ABSTRACTLLM_H +#define ABSTRACTLLM_H + +#include "conversation.h" + +#include +#include + +class AbstractLLM : public QObject +{ + Q_OBJECT +public: + enum ResponseState { + Receiving, + Success, + CutByLength, + Failed, + Canceled + }; + + explicit AbstractLLM(QObject *parent = nullptr); + virtual ~AbstractLLM() {} + + virtual QString modelPath() const = 0; + virtual bool checkValid(QString *errStr) = 0; + virtual QJsonObject create(const Conversation &conversation) = 0; + virtual void request(const QJsonObject &data) = 0; + virtual void request(const QString &prompt) = 0; + virtual void generate(const QString &prompt, const QString &suffix) = 0; + virtual void setTemperature(double temperature) = 0; + virtual void setStream(bool isStream) = 0; + virtual void processResponse(QNetworkReply *reply) = 0; + virtual void cancel() = 0; + virtual void setMaxTokens(int maxToken) = 0; + virtual Conversation *getCurrentConversation() = 0; + +signals: + void dataReceived(const QString &data, ResponseState statu); +}; + +#endif diff --git a/src/base/ai/conversation.cpp b/src/base/ai/conversation.cpp new file mode 100644 index 000000000..322c1d78f --- /dev/null +++ b/src/base/ai/conversation.cpp @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "conversation.h" + +#include +#include + +Conversation::Conversation() +{ +} + +Conversation::~Conversation() +{ +} + +QString Conversation::conversationLastUserData(const QString &conversation) +{ + const QJsonArray &array = QJsonDocument::fromJson(conversation.toUtf8()).array(); + if (!array.isEmpty() && array.last()["role"] == "user") { + return array.last()["content"].toString(); + } + + return conversation; +} + +bool Conversation::setSystemData(const QString &data) +{ + if (!data.isEmpty()) { + for (auto iter = conversation.begin(); iter != conversation.end(); iter++) { + if (iter->toObject().value("role").toString() == "system") + return false; + } + conversation.insert(0, QJsonObject({ { "role", "system" }, {"content", data} })); + return true; + } + return false; +} + +bool Conversation::popSystemData() +{ + if (!conversation.isEmpty() && conversation.at(0)["role"].toString() == "system") { + conversation.removeFirst(); + return true; + } + return false; +} + +bool Conversation::addUserData(const QString &data) +{ + if (!data.isEmpty()) { + const QJsonDocument &document = QJsonDocument::fromJson(data.toUtf8()); + if (document.isArray()) { + conversation = document.array(); + } else { + conversation.push_back(QJsonObject({ { "role", "user" }, {"content", data} })); + } + return true; + } + return false; +} + +bool Conversation::popUserData() +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "user") { + conversation.removeLast(); + return true; + } + return false; +} + +QString Conversation::getLastResponse() const +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "assistant") { + return conversation.last()["content"].toString(); + } + return QString(); +} + +QByteArray Conversation::getLastByteResponse() const +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "assistant") { + return conversation.last()["content"].toVariant().toByteArray(); + } + return QByteArray(); +} + +bool Conversation::popLastResponse() +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "assistant") { + conversation.removeLast(); + return true; + } + return false; +} + +QJsonObject Conversation::getLastTools() const +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "tools") { + return conversation.last()["content"].toObject(); + } + + return QJsonObject(); +} + +bool Conversation::popLastTools() +{ + if (!conversation.isEmpty() && conversation.last()["role"].toString() == "tools") { + conversation.removeLast(); + return true; + } + return false; +} + +bool Conversation::setFunctions(const QJsonArray &functions) +{ + this->functions = functions; + return true; +} + +QJsonArray Conversation::getConversions() const +{ + return conversation; +} + +QJsonArray Conversation::getFunctions() const +{ + return functions; +} + +QJsonArray Conversation::getFunctionTools() const +{ + QJsonArray tools; + for (const QJsonValue &fun : functions) { + QJsonObject tool; + tool["type"] = "function"; + tool["function"] = fun; + tools << tool; + } + + return tools; +} + +void Conversation::clear() +{ + conversation = QJsonArray(); + functions = QJsonArray(); +} diff --git a/src/base/ai/conversation.h b/src/base/ai/conversation.h new file mode 100644 index 000000000..f368591f2 --- /dev/null +++ b/src/base/ai/conversation.h @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef CONVERSATION_H +#define CONVERSATION_H + +#include +#include + +class Conversation +{ +public: + Conversation(); + virtual ~Conversation(); + + /** + * @brief conversationLastUserData + * @param conversation + * @return + */ + static QString conversationLastUserData(const QString &conversation); + +public: + bool setSystemData(const QString &data); + bool popSystemData(); + + bool addUserData(const QString &data); + bool popUserData(); + + QString getLastResponse() const; + QByteArray getLastByteResponse() const; + bool popLastResponse(); + + QJsonObject getLastTools() const; + bool popLastTools(); + + bool setFunctions(const QJsonArray &functions); + QJsonArray getFunctions() const; + QJsonArray getFunctionTools() const; + + QJsonArray getConversions() const; + + void clear(); +protected: + QJsonArray conversation; + QJsonArray functions; +}; + +#endif // CONVERSATION_H diff --git a/src/plugins/CMakeLists.txt b/src/plugins/CMakeLists.txt index 81359d82a..3d1ff7b82 100644 --- a/src/plugins/CMakeLists.txt +++ b/src/plugins/CMakeLists.txt @@ -45,3 +45,4 @@ add_subdirectory(commandproxy) add_subdirectory(codegeex) add_subdirectory(git) add_subdirectory(linglong) +add_subdirectory(aimanager) diff --git a/src/plugins/aimanager/CMakeLists.txt b/src/plugins/aimanager/CMakeLists.txt new file mode 100644 index 000000000..f499d49ba --- /dev/null +++ b/src/plugins/aimanager/CMakeLists.txt @@ -0,0 +1,54 @@ +cmake_minimum_required(VERSION 3.0.2) + +set(CMAKE_CXX_STANDARD 17) + +project(aimanager) + +find_package(duc-base REQUIRED) +find_package(duc-common REQUIRED) +find_package(duc-framework REQUIRED) +find_package(duc-services REQUIRED) +find_package(Dtk COMPONENTS Widget REQUIRED) + +set(QtFindModules Core Gui Widgets Concurrent) +foreach(QtModule ${QtFindModules}) + find_package(Qt5 COMPONENTS ${QtModule} REQUIRED) + # include qt module private include directors + include_directories(${Qt5${QtModule}_PRIVATE_INCLUDE_DIRS}) + # can use target_link_libraries(xxx ${QtUseModules}) + list(APPEND QtUseModules "Qt5::${QtModule}") + message("QtModule found ${QtModule} OK!") +endforeach() + +FILE(GLOB CODEGEEX_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/*/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/*/*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/*.json" +) + +add_library(${PROJECT_NAME} + SHARED + ${CODEGEEX_FILES} + aimanager.qrc + ) + +target_link_libraries(${PROJECT_NAME} + duc-framework + duc-base + duc-services + duc-common + ${QtUseModules} + ${PkgUserModules} + ${DtkWidget_LIBRARIES} + ) + + +if(NOT PLUGIN_INSTALL_PATH) + set(PLUGIN_INSTALL_PATH "/usr/lib/${CMAKE_LIBRARY_ARCHITECTURE}/deepin-unioncode/plugins") +endif() + +install(TARGETS ${PROJECT_NAME} LIBRARY DESTINATION ${PLUGIN_INSTALL_PATH}) + + diff --git a/src/plugins/aimanager/aimanager.cpp b/src/plugins/aimanager/aimanager.cpp new file mode 100644 index 000000000..ea3800ebc --- /dev/null +++ b/src/plugins/aimanager/aimanager.cpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "aimanager.h" +#include "services/ai/aiservice.h" +#include "openai/openaicompatiblellm.h" +#include "services/option/optionmanager.h" +#include "option/detailwidget.h" + +#include + +using namespace dpfservice; + +class AiManagerPrivate +{ +public: + QList models; +}; + +AiManager *AiManager::instance() +{ + static AiManager ins; + return &ins; +} + +AiManager::AiManager(QObject *parent) +: QObject(parent) +, d(new AiManagerPrivate) +{ + readLLMFromOption(); +} + +AiManager::~AiManager() +{ + delete d; +} + +QList AiManager::getAllModel() +{ + return d->models; +} + +AbstractLLM *AiManager::getLLM(const LLMInfo &info) +{ + if (d->models.contains(info)) { + if (info.type == LLMType::OPENAI) { + auto llm = new OpenAiCompatibleLLM(this); + llm->setModelName(info.modelName); + llm->setModelPath(info.modelPath); + if (!info.apikey.isEmpty()) + llm->setApiKey(info.apikey); + return llm; + } + } + + return nullptr; +} + +bool AiManager::checkModelValid(const LLMInfo &info, QString *errStr) +{ + if (info.type == LLMType::OPENAI) { + OpenAiCompatibleLLM llm; + llm.setModelName(info.modelName); + llm.setModelPath(info.modelPath); + llm.setApiKey(info.apikey); + bool valid = llm.checkValid(errStr); + return valid; + } +} + +void AiManager::appendModel(const LLMInfo &info) +{ + if (!d->models.contains(info)) + d->models.append(info); +} + +void AiManager::removeModel(const LLMInfo &info) +{ + if (d->models.contains(info)) + d->models.removeOne(info); +} + +void AiManager::readLLMFromOption() +{ + d->models.clear(); + QMap map = OptionManager::getInstance()->getValue(kCATEGORY_CUSTOMMODELS, kCATEGORY_OPTIONKEY).toMap(); + auto LLMs = map.value(kCATEGORY_CUSTOMMODELS); + for (auto llmInfo : LLMs.toList()) { + appendModel(LLMInfo::fromVariantMap(llmInfo.toMap())); + } +} diff --git a/src/plugins/aimanager/aimanager.h b/src/plugins/aimanager/aimanager.h new file mode 100644 index 000000000..8a6a9744b --- /dev/null +++ b/src/plugins/aimanager/aimanager.h @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "base/ai/abstractllm.h" +#include "services/ai/aiservice.h" + +#include + +// option manager -> custom models +static const char *kCATEGORY_CUSTOMMODELS = "CustomModels"; +static const char *kCATEGORY_OPTIONKEY = "OptionKey"; + +class AiManagerPrivate; +class AiManager : public QObject +{ + Q_OBJECT +public: + static AiManager *instance(); + ~AiManager(); + + QList getAllModel(); + AbstractLLM *getLLM(const LLMInfo &info); + + void appendModel(const LLMInfo &LLMInfo); + void removeModel(const LLMInfo &LLMInfo); + bool checkModelValid(const LLMInfo &info, QString *errStr); + + void readLLMFromOption(); + +private: + AiManager(QObject *parent = nullptr); + AiManagerPrivate *d; +}; diff --git a/src/plugins/aimanager/aimanager.json b/src/plugins/aimanager/aimanager.json new file mode 100644 index 000000000..7ea59d15d --- /dev/null +++ b/src/plugins/aimanager/aimanager.json @@ -0,0 +1,15 @@ +{ + "Name" : "Ai Manager", + "Version" : "4.8.2", + "CompatVersion" : "4.8.0", + "Vendor" : "The Uniontech Software Technology Co., Ltd.", + "Copyright" : "Copyright (C) 2020 ~ 2024 Uniontech Software Technology Co., Ltd.", + "License" : [ + "GPL-3.0-or-later" + ], + "Category" : "Core Plugins", + "Description" : "A Ai Manager Plugin for the unioncode.", + "UrlLink" : "https://uosdn.uniontech.com/#document2?dirid=656d40a9bd766615b0b02e5e", + "Depends" : [ + ] +} diff --git a/src/plugins/aimanager/aimanager.qrc b/src/plugins/aimanager/aimanager.qrc new file mode 100644 index 000000000..81446e4bc --- /dev/null +++ b/src/plugins/aimanager/aimanager.qrc @@ -0,0 +1,4 @@ + + + + diff --git a/src/plugins/aimanager/aiplugin.cpp b/src/plugins/aimanager/aiplugin.cpp new file mode 100644 index 000000000..7e7e6fa94 --- /dev/null +++ b/src/plugins/aimanager/aiplugin.cpp @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "aiplugin.h" +#include "aimanager.h" +#include "option/optioncustommodelsgenerator.h" +#include "openai/openaicompatiblellm.h" +#include "openai/openaicompatibleconversation.h" +#include "services/ai/aiservice.h" +#include "services/option/optionservice.h" +#include "services/option/optiondatastruct.h" + +#include +#include + +void AiPlugin::initialize() +{ +} + +bool AiPlugin::start() +{ + using namespace dpfservice; + auto aiService = dpfGetService(AiService); + auto impl = AiManager::instance(); + using namespace std::placeholders; + aiService->getAllModel = std::bind(&AiManager::getAllModel, impl); + aiService->getLLM = std::bind(&AiManager::getLLM, impl, _1); + + auto optionService = dpfGetService(dpfservice::OptionService); + if (optionService) { + optionService->implGenerator(option::GROUP_AI, OptionCustomModelsGenerator::kitName()); + } + + return true; +} + +dpf::Plugin::ShutdownFlag AiPlugin::stop() +{ + return Sync; +} diff --git a/src/plugins/aimanager/aiplugin.h b/src/plugins/aimanager/aiplugin.h new file mode 100644 index 000000000..4ba7b2054 --- /dev/null +++ b/src/plugins/aimanager/aiplugin.h @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef AIMANAGER_H +#define AIMANAGER_H + +#include +#include + +class AiPlugin: public dpf::Plugin +{ + Q_OBJECT + Q_PLUGIN_METADATA(IID "org.deepin.plugin.unioncode" FILE "aimanager.json") +public: + // Initialization function, executed in an asynchronous thread + virtual void initialize() override; + // Start the function and execute it in the main thread. If + // there is an interface operation, please execute it within this function + virtual bool start() override; + // stop function + virtual dpf::Plugin::ShutdownFlag stop() override; +}; + +#endif // AIMANAGER_H diff --git a/src/plugins/aimanager/eventreceiver.cpp b/src/plugins/aimanager/eventreceiver.cpp new file mode 100644 index 000000000..98efad83c --- /dev/null +++ b/src/plugins/aimanager/eventreceiver.cpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "eventreceiver.h" +#include "common/common.h" + +AiManagerReceiver::AiManagerReceiver(QObject *parent) + : dpf::EventHandler(parent), dpf::AutoEventHandlerRegister() +{ +} + +dpf::EventHandler::Type AiManagerReceiver::type() +{ + return dpf::EventHandler::Type::Async; +} + +QStringList AiManagerReceiver::topics() +{ + return {}; +} + +void AiManagerReceiver::eventProcess(const dpf::Event &event) +{ + QString data = event.data().toString(); +} diff --git a/src/plugins/aimanager/eventreceiver.h b/src/plugins/aimanager/eventreceiver.h new file mode 100644 index 000000000..ddc2f71a4 --- /dev/null +++ b/src/plugins/aimanager/eventreceiver.h @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef EVENTRECEIVER_H +#define EVENTRECEIVER_H + +#include +#include + +/** + * @brief The EventReceiverDemo class + * It will auto registered, and receive subscribed topics events. + */ +class AiManagerReceiver : public dpf::EventHandler, dpf::AutoEventHandlerRegister +{ + friend class dpf::AutoEventHandlerRegister; + +public: + explicit AiManagerReceiver(QObject *parent = nullptr); + static Type type(); + static QStringList topics(); + +private: + virtual void eventProcess(const dpf::Event &event) override; +}; + +#endif // EVENTRECEIVER_H diff --git a/src/plugins/aimanager/openai/openaicompatibleconversation.cpp b/src/plugins/aimanager/openai/openaicompatibleconversation.cpp new file mode 100644 index 000000000..147552a1a --- /dev/null +++ b/src/plugins/aimanager/openai/openaicompatibleconversation.cpp @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "openaicompatibleconversation.h" + +#include +#include + +OpenAiCompatibleConversation::OpenAiCompatibleConversation() +{ +} + +QJsonObject OpenAiCompatibleConversation::parseContentString(const QString &content) +{ + QString deltacontent; + + QRegularExpression regex(R"(data:\s*\{(.*)\})"); + QRegularExpressionMatchIterator iter = regex.globalMatch(content); + + QString finishReason = ""; + while (iter.hasNext()) { + QRegularExpressionMatch match = iter.next(); + QString matchString = match.captured(0); + + int startIndex = matchString.indexOf('{'); + int endIndex = matchString.lastIndexOf('}'); + + if (startIndex >= 0 && endIndex > startIndex) { + QString content = matchString.mid(startIndex, endIndex - startIndex + 1); + + QJsonObject j = QJsonDocument::fromJson(content.toUtf8()).object(); + if (j.contains("choices")) { + const QJsonArray &choices = j["choices"].toArray(); + for (auto choice = choices.begin(); choice != choices.end(); choice++) { + const QJsonObject &cj = choice->toObject(); + if (cj.contains("finish_reason")) + finishReason = cj["finish_reason"].toString(); + if (cj.contains("delta")) { + const QJsonObject &delta = cj["delta"].toObject(); + if (delta.contains("content")) { + const QString &deltaData = delta["content"].toString(); + deltacontent += deltaData; + } + } else if (cj.contains("text")) { + deltacontent += cj["text"].toString(); + } + } + } + } + } + + QJsonObject response; + if (!deltacontent.isEmpty()) { + response["content"] = deltacontent; + } + + if (!finishReason.isEmpty()) + response["finish_reason"] = finishReason; + + return response; +} + +void OpenAiCompatibleConversation::update(const QByteArray &response) +{ + if (response.isEmpty()) + return; + + if (response.startsWith("data:")) { + const QJsonObject &delateData = parseContentString(response); + if (delateData.contains("content")) { + conversation.push_back(QJsonObject({ + { "role", "assistant" }, + { "content", delateData.value("content") } + })); + } + } else { + const QJsonObject &j = QJsonDocument::fromJson(response).object(); + if (j.contains("choices")) { + const QJsonArray &choices = j["choices"].toArray(); + for (auto choice = choices.begin(); choice != choices.end(); choice++) { + const QJsonObject &cj = choice->toObject(); + if (cj.contains("message")) { + if (!cj["message"]["role"].isNull() && !cj["message"]["content"].isNull()) { + conversation.push_back(QJsonObject({ + { "role", cj["message"]["role"] }, + { "content", cj["message"]["content"] } + })); + } + } + } + } else if (j.contains("message")) { + if (!j["message"]["role"].isNull() && !j["message"]["content"].isNull()) { + conversation.push_back(QJsonObject({ + { "role", j["message"]["role"] }, + { "content", j["message"]["content"] } + })); + } + } else if (j.contains("role") && j.contains("content")) { + conversation.push_back(QJsonObject({ + { "role", j["role"] }, + { "content", j["content"] } + })); + } + } +} diff --git a/src/plugins/aimanager/openai/openaicompatibleconversation.h b/src/plugins/aimanager/openai/openaicompatibleconversation.h new file mode 100644 index 000000000..3d48e5bb2 --- /dev/null +++ b/src/plugins/aimanager/openai/openaicompatibleconversation.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef COMPATIBLECONVERSATION_H +#define COMPATIBLECONVERSATION_H + +#include + +class OpenAiCompatibleConversation : public Conversation +{ +public: + explicit OpenAiCompatibleConversation(); + +public: + void update(const QByteArray &response); + +public: + static QJsonObject parseContentString(const QString &content); +}; + +#endif // COMPATIBLECONVERSATION_H \ No newline at end of file diff --git a/src/plugins/aimanager/openai/openaicompatiblellm.cpp b/src/plugins/aimanager/openai/openaicompatiblellm.cpp new file mode 100644 index 000000000..f3eb51d84 --- /dev/null +++ b/src/plugins/aimanager/openai/openaicompatiblellm.cpp @@ -0,0 +1,339 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "openaicompatiblellm.h" +#include "openaicompatibleconversation.h" + +#include +#include +#include + +QJsonObject parseNonStreamContent(const QByteArray &data) +{ + QJsonDocument jsonDoc = QJsonDocument::fromJson(data); + if (jsonDoc.isNull()) { + qDebug() << "Failed to parse JSON response"; + return QJsonObject(); + } + + QJsonObject root = jsonDoc.object(); + QJsonObject parseResult; + if (root.contains("choices") && root["choices"].isArray()) { + QJsonArray choices = root["choices"].toArray(); + for (const QJsonValue &choice : choices) { + if (!choice.isObject()) + continue; + if (choice.toObject().contains("message")) { + QJsonObject messageObj = choice.toObject()["message"].toObject(); + parseResult["content"] = messageObj["content"].toString(); + } else if (choice.toObject().contains("text")) { + QString text = choice.toObject()["text"].toString(); + parseResult["content"] = text; + } + if (choice.toObject().contains("finish_reason")) + parseResult["finish_reason"] = root["finish_reason"].toString(); + } + } else if (root.contains("response")) { + QString response = root["response"].toString(); + parseResult["content"] = response; + } + return parseResult; +} + +class OpenAiCompatibleLLMPrivate +{ +public: + OpenAiCompatibleLLMPrivate(OpenAiCompatibleLLM *qq); + ~OpenAiCompatibleLLMPrivate(); + + QNetworkReply *postMessage(const QString &url, const QString &apiKey, const QByteArray &body); + QNetworkReply *getMessage(const QString &url, const QString &apiKey); + + QString modelName { "" }; + QString modelPath { "" }; + QString apiKey { "" }; + double temprature { 1.0 }; + int maxTokens = 0; // default not set + bool stream { true }; + + QByteArray httpResult {}; + bool waitingResponse { false }; + + OpenAiCompatibleConversation *currentConversation = nullptr; + QNetworkAccessManager *manager = nullptr; + OpenAiCompatibleLLM *q = nullptr; +}; + +OpenAiCompatibleLLMPrivate::OpenAiCompatibleLLMPrivate(OpenAiCompatibleLLM *qq) + : q(qq) +{ + manager = new QNetworkAccessManager(qq); + currentConversation = new OpenAiCompatibleConversation(); +} + +OpenAiCompatibleLLMPrivate::~OpenAiCompatibleLLMPrivate() +{ + if (currentConversation) + delete currentConversation; +} + +QNetworkReply *OpenAiCompatibleLLMPrivate::postMessage(const QString &url, const QString &apiKey, const QByteArray &body) +{ + QNetworkRequest request; + request.setUrl(QUrl(url)); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", "Bearer " + apiKey.toUtf8()); + + if (QThread::currentThread() != qApp->thread()) { + QNetworkAccessManager* threadManager(new QNetworkAccessManager); + OpenAiCompatibleLLM::connect(QThread::currentThread(), &QThread::finished, threadManager, &QNetworkAccessManager::deleteLater); + return threadManager->post(request, body); + } + return manager->post(request, body); +} + +QNetworkReply *OpenAiCompatibleLLMPrivate::getMessage(const QString &url, const QString &apiKey) +{ + QNetworkRequest request; + request.setUrl(QUrl(url)); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", "Bearer " + apiKey.toUtf8()); + + if (QThread::currentThread() != qApp->thread()) { + QNetworkAccessManager* threadManager(new QNetworkAccessManager); + OpenAiCompatibleLLM::connect(QThread::currentThread(), &QThread::finished, threadManager, &QNetworkAccessManager::deleteLater); + return threadManager->get(request); + } + return manager->get(request); +} + +OpenAiCompatibleLLM::OpenAiCompatibleLLM(QObject *parent) + : AbstractLLM(parent), d(new OpenAiCompatibleLLMPrivate(this)) +{ +} + +OpenAiCompatibleLLM::~OpenAiCompatibleLLM() +{ + if (d) + delete d; +} + +QString OpenAiCompatibleLLM::modelPath() const +{ + return d->modelPath; +} + +Conversation *OpenAiCompatibleLLM::getCurrentConversation() +{ + return d->currentConversation; +} + +void OpenAiCompatibleLLM::setModelName(const QString &name) +{ + d->modelName = name; +} + +void OpenAiCompatibleLLM::setModelPath(const QString &path) +{ + d->modelPath = path; +} + +void OpenAiCompatibleLLM::setApiKey(const QString &key) +{ + d->apiKey = key; +} + +bool OpenAiCompatibleLLM::checkValid(QString *errStr) +{ + // Check if the model is valid + if (d->modelPath.isEmpty()) { + *errStr = "Model path is empty"; + qWarning() << *errStr; + return false; + } + + OpenAiCompatibleConversation c; + c.setSystemData("You are a test assistant"); + c.addUserData("Testing. Just say hi and nothing else"); + + auto obj = create(c); + request(obj); + QEventLoop loop; + bool valid = false; + QString errstr; + + connect(this, &AbstractLLM::dataReceived, &loop, [&, this](const QString & data, ResponseState state){ + if (state == ResponseState::Receiving) + return; + + if (state == ResponseState::Success) { + valid = true; + } else { + *errStr = data; + } + loop.quit(); + }); + + loop.exec(); + return valid; +} + +QJsonObject OpenAiCompatibleLLM::create(const Conversation &conversation) +{ + QJsonObject dataObject; + dataObject.insert("model", d->modelName); + dataObject.insert("messages", conversation.getConversions()); + dataObject.insert("temperature", qBound(0.01, d->temprature, 0.99)); + dataObject.insert("stream", d->stream); + if (d->maxTokens != 0) + dataObject.insert("max_tokens", d->maxTokens); + + return dataObject; +} + +void OpenAiCompatibleLLM::request(const QJsonObject &data) +{ + QByteArray body = QJsonDocument(data).toJson(); + d->httpResult.clear(); + d->waitingResponse = true; + d->currentConversation->update(body); + + QNetworkReply *reply = d->postMessage(modelPath() + "/v1/chat/completions", d->apiKey, body); + connect(this, &OpenAiCompatibleLLM::requstCancel, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, [=](){ + d->waitingResponse = false; + if (!d->httpResult.isEmpty()) + d->currentConversation->update(d->httpResult); + if (reply->error()) { + qWarning() << "NetWork Error: " << reply->errorString(); + emit dataReceived(reply->errorString(), AbstractLLM::ResponseState::Failed); + return; + } + emit dataReceived("", AbstractLLM::ResponseState::Success); + }); + + processResponse(reply); +} + +void OpenAiCompatibleLLM::request(const QString &prompt) +{ + if (d->waitingResponse) + return; + + d->waitingResponse = true; + + QJsonObject dataObject; + dataObject.insert("model", d->modelName); + dataObject.insert("prompt", prompt); + dataObject.insert("temperature", qBound(0.01, d->temprature, 0.99)); + dataObject.insert("stream", d->stream); + if (d->maxTokens != 0) + dataObject.insert("max_tokens", d->maxTokens); + + QNetworkReply *reply = d->postMessage(modelPath() + "/v1/completions", d->apiKey, QJsonDocument(dataObject).toJson()); + connect(this, &OpenAiCompatibleLLM::requstCancel, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, [=](){ + d->waitingResponse = false; + if (reply->error()) { + qWarning() << "NetWork Error: " << reply->errorString(); + emit dataReceived(reply->errorString(), AbstractLLM::ResponseState::Failed); + return; + } + emit dataReceived("", AbstractLLM::ResponseState::Success); + }); + + processResponse(reply); +} + +void OpenAiCompatibleLLM::generate(const QString &prompt, const QString &suffix) +{ + if (d->waitingResponse) + return; + + d->waitingResponse = true; + + QJsonObject dataObject; + dataObject.insert("model", d->modelName); + dataObject.insert("suffix", suffix); + dataObject.insert("prompt", prompt); + dataObject.insert("temperature", 0.01); + dataObject.insert("stream", d->stream); + if (d->maxTokens != 0) + dataObject.insert("max_tokens", d->maxTokens); + + QNetworkReply *reply = d->postMessage(modelPath() + "/api/generate", d->apiKey, QJsonDocument(dataObject).toJson()); + connect(this, &OpenAiCompatibleLLM::requstCancel, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, [=](){ + d->waitingResponse = false; + if (reply->error()) { + qWarning() << "NetWork Error: " << reply->errorString(); + emit dataReceived(reply->errorString(), AbstractLLM::ResponseState::Failed); + return; + } + emit dataReceived("", AbstractLLM::ResponseState::Success); + }); + + processResponse(reply); +} + +void OpenAiCompatibleLLM::setTemperature(double temperature) +{ + d->temprature = temperature; +} + +void OpenAiCompatibleLLM::setStream(bool isStream) +{ + d->stream = isStream; +} + +void OpenAiCompatibleLLM::processResponse(QNetworkReply *reply) +{ + connect(reply, &QNetworkReply::readyRead, [=]() { + if (reply->error()) { + qCritical() << "Error:" << reply->errorString(); + emit dataReceived(reply->errorString(), AbstractLLM::ResponseState::Failed); + } else { + auto data = reply->readAll(); + + // process {"code":,"msg":,"success":false} + QJsonDocument jsonDoc = QJsonDocument::fromJson(data); + if (!jsonDoc.isNull()) { + QJsonObject jsonObj = jsonDoc.object(); + if (jsonObj.contains("success") && !jsonObj.value("success").toBool()) { + emit dataReceived(jsonObj.value("msg").toString(), AbstractLLM::ResponseState::Failed); + return; + } + } + + d->httpResult.append(data); + QString content; + QJsonObject retJson; + if (d->stream) { + retJson = OpenAiCompatibleConversation::parseContentString(QString(data)); + if (retJson.contains("content")) + content = retJson.value("content").toString(); + } else { + retJson = parseNonStreamContent(data); + } + + if (retJson["finish_reason"].toString() == "length") + emit dataReceived(content, AbstractLLM::ResponseState::CutByLength); + else + emit dataReceived(retJson["content"].toString(), AbstractLLM::ResponseState::Receiving); + } + }); +} + +void OpenAiCompatibleLLM::cancel() +{ + d->waitingResponse = false; + d->httpResult.clear(); + emit requstCancel(); + emit dataReceived("", AbstractLLM::ResponseState::Canceled); +} + +void OpenAiCompatibleLLM::setMaxTokens(int maxTokens) +{ + d->maxTokens = maxTokens; +} diff --git a/src/plugins/aimanager/openai/openaicompatiblellm.h b/src/plugins/aimanager/openai/openaicompatiblellm.h new file mode 100644 index 000000000..fd35b0685 --- /dev/null +++ b/src/plugins/aimanager/openai/openaicompatiblellm.h @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include + +#ifndef OPENAICOMPATIBLELLM_H +#define OPENAICOMPATIBLELLM_H + +class OpenAiCompatibleLLMPrivate; +class OpenAiCompatibleLLM : public AbstractLLM +{ + Q_OBJECT +public: + explicit OpenAiCompatibleLLM(QObject *parent = nullptr); + ~OpenAiCompatibleLLM() override; + + Conversation* getCurrentConversation() override; + void setModelName(const QString &modelName); + void setModelPath(const QString &path); + void setApiKey(const QString &apiKey); + + QString modelPath() const override; + bool checkValid(QString *errStr) override; + QJsonObject create(const Conversation &conversation) override; + void request(const QJsonObject &data) override; // v1/chat/compltions + void request(const QString &prompt) override; // v1/completions + void generate(const QString &prompt, const QString &suffix) override; // api/generate + void setTemperature(double temperature) override; + void setStream(bool isStream) override; + void processResponse(QNetworkReply *reply) override; + void cancel() override; + void setMaxTokens(int maxTokens) override; + +signals: + void requstCancel(); + +private: + OpenAiCompatibleLLMPrivate *d; +}; + +#endif // OPENAICOMPATIBLELLM_H diff --git a/src/plugins/aimanager/option/addmodeldialog.cpp b/src/plugins/aimanager/option/addmodeldialog.cpp new file mode 100644 index 000000000..62e4bd280 --- /dev/null +++ b/src/plugins/aimanager/option/addmodeldialog.cpp @@ -0,0 +1,232 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "addmodeldialog.h" +#include "services/ai/aiservice.h" +#include "aimanager.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +DWIDGET_USE_NAMESPACE + +class AddModelDialogPrivate +{ +public: + friend class AddModelDialog; + AddModelDialogPrivate(AddModelDialog *qq) : q(qq) {} + void initUi(); + void initConnection(); + +private slots: + void slotAddModel(); + +private: + DStackedWidget *stackWidget { nullptr }; + DWidget *mainWidget { nullptr }; + DWidget *checkingWidget { nullptr }; + + DLineEdit *leLLMName { nullptr }; + DComboBox *cbLLMType { nullptr }; + DLineEdit *leApiUrl { nullptr }; + DLineEdit *leApiKey { nullptr }; + DPushButton *passwordBtn { nullptr }; + + DSuggestButton *okButton { nullptr }; + DPushButton *cancelButton { nullptr }; + DSpinner *spinner { nullptr }; + + void showWaitingState(bool waiting); + void showErrorInfoDialog(const QString &error); + AddModelDialog *q; + LLMInfo LLMAppended; +}; + +void AddModelDialogPrivate::initUi() +{ + q->setFixedWidth(543); + QLabel *lbModelName = new QLabel(AddModelDialog::tr("Model Name")); + lbModelName->setAlignment(Qt::AlignRight | Qt::AlignVCenter); + leLLMName = new DLineEdit(q); + leLLMName->setPlaceholderText(AddModelDialog::tr("Required, please enter.")); + + QLabel *lbLLMType = new QLabel(AddModelDialog::tr("Model Type")); + lbLLMType->setAlignment(Qt::AlignRight | Qt::AlignVCenter); + cbLLMType = new DComboBox(q); + cbLLMType->addItem(AddModelDialog::tr("OpenAi(Compatible)"), LLMType::OPENAI); + + QLabel *lbApiUrl = new QLabel(AddModelDialog::tr("Api Path")); + lbApiUrl->setAlignment(Qt::AlignRight | Qt::AlignVCenter); + leApiUrl = new DLineEdit(q); + leApiUrl->setPlaceholderText(AddModelDialog::tr("Required, please enter.")); + + QLabel *lbApiKey = new QLabel(AddModelDialog::tr("Api Key")); + lbApiKey->setAlignment(Qt::AlignRight | Qt::AlignVCenter); + leApiKey = new DLineEdit(q); + leApiKey->setEchoMode(QLineEdit::Password); + leApiKey->setPlaceholderText(AddModelDialog::tr("Optional, please enter.")); + passwordBtn = new DPushButton(q); + passwordBtn->setIcon(DStyle::standardIcon(q->style(), DStyle::SP_HidePassword)); + passwordBtn->setCheckable(true); + passwordBtn->setChecked(true); + + q->setWindowTitle(AddModelDialog::tr("Add Model")); + + mainWidget = new QWidget(q); + QVBoxLayout *mainLayout = new QVBoxLayout(mainWidget); + QGridLayout *gridLayout = new QGridLayout; + + gridLayout->setContentsMargins(0, 0, 0, 0); + gridLayout->setSpacing(10); + + gridLayout->addWidget(lbModelName, 0, 0); + gridLayout->addWidget(leLLMName, 0, 1, 1, 2); + + gridLayout->addWidget(lbLLMType, 1, 0); + gridLayout->addWidget(cbLLMType, 1, 1, 1, 2); + + gridLayout->addWidget(lbApiUrl, 2, 0); + gridLayout->addWidget(leApiUrl, 2, 1, 1, 2); + + gridLayout->addWidget(lbApiKey, 3, 0); + gridLayout->addWidget(leApiKey, 3, 1); + gridLayout->addWidget(passwordBtn, 3, 2); + + okButton = new DSuggestButton(q); + okButton->setText(AddModelDialog::tr("Confirm")); + cancelButton = new DPushButton(q); + cancelButton->setText(AddModelDialog::tr("Cancel")); + QHBoxLayout *buttonLayout = new QHBoxLayout; + buttonLayout->addWidget(cancelButton); + buttonLayout->addWidget(new DVerticalLine(q)); + buttonLayout->addWidget(okButton); + + mainLayout->addLayout(gridLayout); + mainLayout->addSpacing(10); + auto label = new DLabel(AddModelDialog::tr("To test if the model is available, the system will send a small amount of information, which will consume a small amount of tokens."), q); + label->setWordWrap(true); + label->setForegroundRole(DPalette::PlaceholderText); + mainLayout->addWidget(label); + mainLayout->addSpacing(10); + mainLayout->addLayout(buttonLayout); + + // check llm valid + checkingWidget = new QWidget; + auto checkingLayout = new QVBoxLayout(checkingWidget); + checkingLayout->setSpacing(20); + spinner = new DSpinner(checkingWidget); + spinner->setFixedSize(32, 32); + QLabel *lbCheck = new QLabel(AddModelDialog::tr("Checking... please wait."), checkingWidget); + lbCheck->setAlignment(Qt::AlignCenter); + checkingLayout->addStretch(1); + checkingLayout->addWidget(spinner, 0, Qt::AlignCenter); + checkingLayout->addWidget(lbCheck, 0, Qt::AlignCenter); + checkingLayout->addStretch(1); + + stackWidget = new DStackedWidget(q); + stackWidget->addWidget(mainWidget); + stackWidget->addWidget(checkingWidget); + q->addContent(stackWidget); +} + +void AddModelDialogPrivate::initConnection() +{ + AddModelDialog::connect(okButton, &DSuggestButton::clicked, q, [=](){ + slotAddModel(); + }); + AddModelDialog::connect(cancelButton, &DSuggestButton::clicked, q, &AddModelDialog::reject); + AddModelDialog::connect(passwordBtn, &DPushButton::clicked, q, [=](){ + if (passwordBtn->isChecked()) { + passwordBtn->setIcon(DStyle::standardIcon(q->style(), DStyle::SP_ShowPassword)); + leApiKey->setEchoMode(QLineEdit::Password); + } else { + passwordBtn->setIcon(DStyle::standardIcon(q->style(), DStyle::SP_HidePassword)); + leApiKey->setEchoMode(QLineEdit::Normal); + } + }); +} + +void AddModelDialogPrivate::slotAddModel() +{ + LLMInfo newLLMInfo; + newLLMInfo.modelName = leLLMName->text(); + if (newLLMInfo.modelName.isEmpty()) { + leLLMName->showAlertMessage(AddModelDialog::tr("This field cannot be empty.")); + return; + } + newLLMInfo.type = cbLLMType->currentData().value(); + newLLMInfo.modelPath = leApiUrl->text(); + if (newLLMInfo.modelPath.isEmpty()) { + leApiUrl->showAlertMessage(AddModelDialog::tr("This field cannot be empty.")); + return; + } + newLLMInfo.apikey = leApiKey->text(); + + showWaitingState(true); + QString errStr; + auto valid = AiManager::instance()->checkModelValid(newLLMInfo, &errStr); + if (valid) { + LLMAppended = newLLMInfo; + q->accept(); + } else if (!errStr.isEmpty() && q->isVisible()) { // if dialog closed, do not show error info + showErrorInfoDialog(errStr); + } + showWaitingState(false); +} + +void AddModelDialogPrivate::showWaitingState(bool waiting) +{ + if (waiting) { + spinner->start(); + stackWidget->setCurrentWidget(checkingWidget); + } else { + spinner->stop(); + stackWidget->setCurrentWidget(mainWidget); + } +} + +void AddModelDialogPrivate::showErrorInfoDialog(const QString &error) +{ + DDialog dialog; + dialog.setIcon(QIcon::fromTheme("dialog-warning")); + dialog.setWindowTitle(AddModelDialog::tr("Error Information")); + auto warningText = new DLabel(error, q); + warningText->setTextInteractionFlags(Qt::TextSelectableByMouse); + warningText->setWordWrap(true); + + DFrame *mainWidget = new DFrame(&dialog); + mainWidget->setFixedSize(360, 140); + QVBoxLayout *mainLayout = new QVBoxLayout(mainWidget); + mainLayout->addWidget(warningText, 0, Qt::AlignTop | Qt::AlignLeft); + dialog.addContent(mainWidget); + + dialog.addButton(AddModelDialog::tr("Confirm")); + dialog.exec(); +} + +AddModelDialog::AddModelDialog(QWidget *parent) +: DDialog(parent), d(new AddModelDialogPrivate(this)) +{ + d->initUi(); + d->initConnection(); +} + +AddModelDialog::~AddModelDialog() +{ + delete d; +} + +LLMInfo AddModelDialog::getNewLLmInfo() +{ + if (!d->LLMAppended.modelName.isEmpty()) + return d->LLMAppended; + else + return LLMInfo(); +} diff --git a/src/plugins/aimanager/option/addmodeldialog.h b/src/plugins/aimanager/option/addmodeldialog.h new file mode 100644 index 000000000..7f667b178 --- /dev/null +++ b/src/plugins/aimanager/option/addmodeldialog.h @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include +#include + +#include +#include + +class LLMInfo; +class AddModelDialogPrivate; +class AddModelDialog : public Dtk::Widget::DDialog +{ + Q_OBJECT +public: + AddModelDialog(QWidget *parent = nullptr); + ~AddModelDialog(); + LLMInfo getNewLLmInfo(); + +private: + AddModelDialogPrivate *d; +}; diff --git a/src/plugins/aimanager/option/custommodelsoptionwidget.cpp b/src/plugins/aimanager/option/custommodelsoptionwidget.cpp new file mode 100644 index 000000000..7b1df10c0 --- /dev/null +++ b/src/plugins/aimanager/option/custommodelsoptionwidget.cpp @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "custommodelsoptionwidget.h" +#include "detailwidget.h" +#include "aimanager.h" + +#include "services/option/optiondatastruct.h" +#include "services/option/optionmanager.h" + +#include + +#include + +DWIDGET_USE_NAMESPACE + +class CustomModelsOptionWidgetPrivate +{ + DTabWidget *tabWidget = nullptr; + friend class CustomModelsOptionWidget; +}; + +CustomModelsOptionWidget::CustomModelsOptionWidget(QWidget *parent) + : PageWidget(parent), d(new CustomModelsOptionWidgetPrivate()) +{ + QHBoxLayout *layout = new QHBoxLayout(this); + d->tabWidget = new DTabWidget(this); + d->tabWidget->tabBar()->setAutoHide(true); + d->tabWidget->setDocumentMode(true); + layout->addWidget(d->tabWidget); + + d->tabWidget->addTab(new DetailWidget(d->tabWidget), kCATEGORY_OPTIONKEY); + QObject::connect(d->tabWidget, &DTabWidget::currentChanged, [this]() { + readConfig(); + }); +} + +CustomModelsOptionWidget::~CustomModelsOptionWidget() +{ + if (d) + delete d; +} + +void CustomModelsOptionWidget::saveConfig() +{ + for (int index = 0; index < d->tabWidget->count(); index++) { + PageWidget *pageWidget = qobject_cast(d->tabWidget->widget(index)); + if (pageWidget) { + QString itemNode = d->tabWidget->tabText(d->tabWidget->currentIndex()); + QMap map; + pageWidget->getUserConfig(map); + OptionManager::getInstance()->setValue(kCATEGORY_CUSTOMMODELS, itemNode, map); + } + } + AiManager::instance()->readLLMFromOption(); +} + +void CustomModelsOptionWidget::readConfig() +{ + for (int index = 0; index < d->tabWidget->count(); index++) { + PageWidget *pageWidget = qobject_cast(d->tabWidget->widget(index)); + if (pageWidget) { + QString itemNode = d->tabWidget->tabText(d->tabWidget->currentIndex()); + QMap map = OptionManager::getInstance()->getValue(kCATEGORY_CUSTOMMODELS, itemNode).toMap(); + pageWidget->setUserConfig(map); + } + } +} diff --git a/src/plugins/aimanager/option/custommodelsoptionwidget.h b/src/plugins/aimanager/option/custommodelsoptionwidget.h new file mode 100644 index 000000000..4f5237187 --- /dev/null +++ b/src/plugins/aimanager/option/custommodelsoptionwidget.h @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef CUSTOMMODELSOPTIONWIDGET +#define CUSTOMMODELSOPTIONWIDGET + +#include "common/common.h" + +class CustomModelsOptionWidgetPrivate; +class CustomModelsOptionWidget : public PageWidget +{ + Q_OBJECT +public: + explicit CustomModelsOptionWidget(QWidget *parent = nullptr); + ~CustomModelsOptionWidget() override; + + void saveConfig() override; + void readConfig() override; + +private: + CustomModelsOptionWidgetPrivate *const d; +}; + +#endif // CUSTOMMODELSOPTIONWIDGET diff --git a/src/plugins/aimanager/option/detailwidget.cpp b/src/plugins/aimanager/option/detailwidget.cpp new file mode 100644 index 000000000..a6e35d809 --- /dev/null +++ b/src/plugins/aimanager/option/detailwidget.cpp @@ -0,0 +1,207 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "detailwidget.h" +#include "addmodeldialog.h" +#include "aimanager.h" + +#include +#include +#include +#include +#ifdef DTKWIDGET_CLASS_DPaletteHelper +#include +#endif + +#include +#include +#include +#include +#include + +DWIDGET_USE_NAMESPACE + +LLMModels::LLMModels(QObject *parent) + : QAbstractTableModel(parent) +{ +} + +LLMModels::~LLMModels() +{ +} + +int LLMModels::rowCount(const QModelIndex &parent) const +{ + Q_UNUSED(parent) + + return LLMs.size(); +} + +int LLMModels::columnCount(const QModelIndex &parent) const +{ + Q_UNUSED(parent) + return 1; +} + +QVariant LLMModels::data(const QModelIndex &index, int role) const { + if (!index.isValid()) + return QVariant(); + + if (index.row() >= LLMs.size()) + return QVariant(); + + const LLMInfo &info = LLMs.at(index.row()); + + switch (role) { + case Qt::UserRole + 1: + return info.modelName; + case Qt::UserRole + 2: + return LLMTypeToString(info.type); + case Qt::ToolTipRole: + return QString("Path:%1").arg(info.modelPath); + default: + return QVariant(); + } +} + +QVariant LLMModels::headerData(int section, Qt::Orientation orientation, int role) const +{ + return QVariant(); +} + +void LLMModels::appendLLM(const LLMInfo &info) +{ + beginResetModel(); + if (!LLMs.contains(info)) + LLMs.append(info); + endResetModel(); +} + +void LLMModels::removeLLM(const LLMInfo &info) +{ + beginResetModel(); + LLMs.removeOne(info); + endResetModel(); +} + +QList LLMModels::allLLMs() +{ + return LLMs; +} + +class DetailWidgetPrivate +{ + friend class DetailWidget; + DListView *modelsView = nullptr; + LLMModels *LLMModel = nullptr; + AddModelDialog *addModelDialog = nullptr; +}; + +DetailWidget::DetailWidget(QWidget *parent) + : PageWidget(parent), d(new DetailWidgetPrivate()) +{ + setupUi(); +} + +DetailWidget::~DetailWidget() +{ + if (d) { + delete d; + } +} + +void DetailWidget::setupUi() +{ + setFixedHeight(300); + QVBoxLayout *vLayout = new QVBoxLayout(this); + vLayout->setContentsMargins(0, 0, 0, 0); + + auto listframe = new DFrame(this); + auto listlayout = new QVBoxLayout(listframe); + listlayout->setContentsMargins(5, 5, 5, 5); + listframe->setLayout(listlayout); + d->modelsView = new DListView(listframe); + d->modelsView->setFrameShape(DFrame::NoFrame); + d->modelsView->setSelectionMode(QAbstractItemView::SingleSelection); + d->modelsView->setSelectionBehavior(QAbstractItemView::SelectRows); + d->modelsView->setEditTriggers(QAbstractItemView::NoEditTriggers); + d->modelsView->setAlternatingRowColors(true); + d->modelsView->setSizePolicy(QSizePolicy::Preferred, QSizePolicy::Preferred); + d->modelsView->setItemDelegate(new TwoLineDelegate(d->modelsView)); + listlayout->addWidget(d->modelsView); + + d->LLMModel = new LLMModels(this); + d->modelsView->setModel(d->LLMModel); + + vLayout->addWidget(listframe); + + QHBoxLayout *buttonLayout = new QHBoxLayout; + buttonLayout->setAlignment(Qt::AlignRight); + DToolButton *buttonAdd = new DToolButton(this); + buttonAdd->setText(tr("Add")); + DToolButton *buttonRemove = new DToolButton(this); + buttonRemove->setText(tr("Remove")); +#ifdef DTKWIDGET_CLASS_DPaletteHelper + DPalette pl(DPaletteHelper::instance()->palette(this)); +#else + const DPalette &pl = DGuiApplicationHelper::instance()->applicationPalette(); +#endif + + QColor warningColor = pl.color(DPalette::ColorGroup::Active, DPalette::ColorType::TextWarning); + QColor highlightColor = pl.color(DPalette::ColorGroup::Active, DPalette::ColorType::LightLively); + auto pa = buttonRemove->palette(); + pa.setColor(QPalette::ButtonText, warningColor); + buttonRemove->setPalette(pa); + pa.setColor(QPalette::ButtonText, highlightColor); + buttonAdd->setPalette(pa); + + buttonLayout->addWidget(buttonRemove); + buttonLayout->addWidget(buttonAdd); + vLayout->addLayout(buttonLayout); + + connect(buttonAdd, &DToolButton::clicked, this, [=](){ + auto dialog = new AddModelDialog(this); + auto code = dialog->exec(); + if (code == QDialog::Accepted) { + d->LLMModel->appendLLM(dialog->getNewLLmInfo()); + } + dialog->deleteLater(); + }); + connect(buttonRemove, &DToolButton::clicked, this, [=](){ + if (!d->modelsView->selectionModel()->hasSelection()) + return; + QModelIndex index = d->modelsView->selectionModel()->selectedIndexes().at(0); + if (!index.isValid()) + return; + auto llmInfo = d->LLMModel->allLLMs().at(index.row()); + d->LLMModel->removeLLM(llmInfo); + }); +} + +bool DetailWidget::getControlValue(QMap &map) +{ + QVariantList LLMs; + for (auto llmInfo : d->LLMModel->allLLMs()) { + LLMs.append(llmInfo.toVariant()); + } + map.insert(kCATEGORY_CUSTOMMODELS, LLMs); + return true; +} + +void DetailWidget::setControlValue(const QMap &map) +{ + for (auto mapData : map.value(kCATEGORY_CUSTOMMODELS).toList()) { + d->LLMModel->appendLLM(LLMInfo::fromVariantMap(mapData.toMap())); + } +} + +void DetailWidget::setUserConfig(const QMap &map) +{ + setControlValue(map); +} + +void DetailWidget::getUserConfig(QMap &map) +{ + getControlValue(map); +} diff --git a/src/plugins/aimanager/option/detailwidget.h b/src/plugins/aimanager/option/detailwidget.h new file mode 100644 index 000000000..ef28e2693 --- /dev/null +++ b/src/plugins/aimanager/option/detailwidget.h @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef DETAILWIDGET_H +#define DETAILWIDGET_H + +#include "common/widget/pagewidget.h" +#include "base/baseitemdelegate.h" +#include "services/ai/aiservice.h" + +#include + +#include +#include + +class TwoLineDelegate : public DTK_WIDGET_NAMESPACE::DStyledItemDelegate { +public: + TwoLineDelegate(QAbstractItemView *parent = nullptr) : DStyledItemDelegate(parent) {} + + void paint(QPainter *painter, const QStyleOptionViewItem &option, const QModelIndex &index) const override { + QStyleOptionViewItemV4 opt = option; + initStyleOption(&opt, index); + + DStyledItemDelegate::paint(painter, opt, index); + + painter->save(); + + QString leftText = index.data(Qt::UserRole + 1).toString(); + QString rightText = index.data(Qt::UserRole + 2).toString(); + + int textHeight = painter->fontMetrics().height(); + int verticalCenter = option.rect.top() + (option.rect.height() - textHeight) / 2; + + QRect leftTextRect(option.rect.left() + 10, verticalCenter, option.rect.width() / 2 - 15, textHeight); + painter->drawText(leftTextRect, Qt::AlignLeft | Qt::AlignVCenter, leftText); + + QRect rightTextRect(option.rect.left() + option.rect.width() / 2 + 5, verticalCenter, option.rect.width() / 2 - 15, textHeight); + painter->drawText(rightTextRect, Qt::AlignRight | Qt::AlignVCenter, rightText); + + painter->restore(); + } + + QSize sizeHint(const QStyleOptionViewItem &option, const QModelIndex &index) const override { return QSize(option.rect.width(), 36);} +}; + +class LLMModels : public QAbstractTableModel +{ + Q_OBJECT +public: + explicit LLMModels(QObject *parent = nullptr); + ~LLMModels(); + + int rowCount(const QModelIndex &parent = QModelIndex()) const override; + int columnCount(const QModelIndex &parent = QModelIndex()) const override; + QVariant data(const QModelIndex &index, int role) const override; + QVariant headerData(int section, Qt::Orientation orientation, int role) const override; + + void appendLLM(const LLMInfo &info); + void removeLLM(const LLMInfo &info); + QList allLLMs(); + +private: + QList LLMs; +}; + +class DetailWidgetPrivate; +class DetailWidget : public PageWidget +{ + Q_OBJECT +public: + explicit DetailWidget(QWidget *parent = nullptr); + ~DetailWidget() override; + + void setUserConfig(const QMap &map) override; + void getUserConfig(QMap &map) override; + +private: + void setupUi(); + + bool getControlValue(QMap &map); + void setControlValue(const QMap &map); + + DetailWidgetPrivate *const d; +}; + +#endif // DETAILWIDGET_H diff --git a/src/plugins/aimanager/option/optioncustommodelsgenerator.cpp b/src/plugins/aimanager/option/optioncustommodelsgenerator.cpp new file mode 100644 index 000000000..5f55737da --- /dev/null +++ b/src/plugins/aimanager/option/optioncustommodelsgenerator.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "optioncustommodelsgenerator.h" +#include "custommodelsoptionwidget.h" + +#include "common/widget/pagewidget.h" + +#include +#include + +#include + +OptionCustomModelsGenerator::OptionCustomModelsGenerator() +{ + +} + +QWidget *OptionCustomModelsGenerator::optionWidget() +{ + return new CustomModelsOptionWidget(); +} diff --git a/src/plugins/aimanager/option/optioncustommodelsgenerator.h b/src/plugins/aimanager/option/optioncustommodelsgenerator.h new file mode 100644 index 000000000..3f90f88f8 --- /dev/null +++ b/src/plugins/aimanager/option/optioncustommodelsgenerator.h @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd. +// +// SPDX-License-Identifier: GPL-3.0-or-later + +#ifndef OPTIONCODEGEEXGENERATOR_H +#define OPTIONCODEGEEXGENERATOR_H + +#include "services/option/optiongenerator.h" + +class OptionCustomModelsGenerator : public dpfservice::OptionGenerator +{ +public: + OptionCustomModelsGenerator(); + inline static QString kitName() { return "Models"; } + virtual QWidget *optionWidget() override; +}; + +#endif // OPTIONCODEGEEXGENERATOR_H diff --git a/src/plugins/codegeex/codegeex.cpp b/src/plugins/codegeex/codegeex.cpp index 7018483b4..c07662b4e 100644 --- a/src/plugins/codegeex/codegeex.cpp +++ b/src/plugins/codegeex/codegeex.cpp @@ -11,6 +11,7 @@ #include "common/common.h" #include "services/window/windowservice.h" #include "services/option/optionservice.h" +#include "services/option/optiondatastruct.h" #include "services/ai/aiservice.h" #include "copilot.h" @@ -44,7 +45,7 @@ bool CodeGeex::start() auto optionService = dpfGetService(dpfservice::OptionService); if (optionService) { - optionService->implGenerator(QObject::tr("AI"), OptionCodeGeeXGenerator::kitName()); + optionService->implGenerator(option::GROUP_AI, OptionCodeGeeXGenerator::kitName()); } Copilot::instance(); diff --git a/src/plugins/codegeex/codegeex/askapi.cpp b/src/plugins/codegeex/codegeex/askapi.cpp index 13b7e70c8..08373c8f8 100644 --- a/src/plugins/codegeex/codegeex/askapi.cpp +++ b/src/plugins/codegeex/codegeex/askapi.cpp @@ -8,6 +8,7 @@ #include "services/window/windowservice.h" #include "services/editor/editorservice.h" #include "common/type/constants.h" +#include #include #include @@ -20,6 +21,7 @@ #include namespace CodeGeeX { +static const QString PrePrompt = "你是一个智能编程助手,可以回答用户任何的问题。问题中可能会带有相关的context,这些context来自工程相关的文件,你要结合这些上下文回答用户问题。 请注意:\n1.用户问题中以@符号开始的标记代表着context内容。/n2.按正确的语言回答,不要被上下文中的字符影响."; static int kCode_Success = 200; @@ -63,7 +65,7 @@ AskApiPrivate::AskApiPrivate(AskApi *qq) : q(qq), manager(new QNetworkAccessManager(qq)) { - connect(q, &AskApi::stopReceive, this, [=](){ terminated = true; }); + connect(q, &AskApi::stopReceive, this, [=]() { terminated = true; }); } QNetworkReply *AskApiPrivate::postMessage(const QString &url, const QString &token, const QByteArray &body) @@ -73,6 +75,12 @@ QNetworkReply *AskApiPrivate::postMessage(const QString &url, const QString &tok QNetworkRequest request(url); request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); request.setRawHeader("code-token", token.toUtf8()); + + if (QThread::currentThread() != qApp->thread()) { + QNetworkAccessManager* threadManager(new QNetworkAccessManager); + AskApi::connect(QThread::currentThread(), &QThread::finished, threadManager, &QNetworkAccessManager::deleteLater); + return threadManager->post(request, body); + } return manager->post(request, body); } @@ -82,6 +90,11 @@ QNetworkReply *AskApiPrivate::getMessage(const QString &url, const QString &toke request.setHeader(QNetworkRequest::ContentTypeHeader, "application/x-www-form-urlencoded"); request.setRawHeader("code-token", token.toUtf8()); + if (QThread::currentThread() != qApp->thread()) { + QNetworkAccessManager* threadManager(new QNetworkAccessManager); + AskApi::connect(QThread::currentThread(), &QThread::finished, threadManager, &QNetworkAccessManager::deleteLater); + return threadManager->get(request); + } return manager->get(request); } @@ -189,34 +202,6 @@ QByteArray AskApiPrivate::assembleSSEChatBody(const QString &prompt, const QStri jsonObject.insert("locale", locale); jsonObject.insert("model", model); - using dpfservice::ProjectService; - ProjectService *prjService = dpfGetService(ProjectService); - auto currentProjectPath = prjService->getActiveProjectInfo().workspaceFolder(); - - if (codebaseEnabled && currentProjectPath != "") { - QString queryText = prompt; - QJsonObject result = CodeGeeXManager::instance()->query(currentProjectPath, queryText.remove("@CodeBase"), 50); - QJsonArray chunks = result["Chunks"].toArray(); - if (!chunks.isEmpty()) { - if (result["Completed"].toBool() == false) - emit q->notify(0, CodeGeeXManager::tr("The indexing of project %1 has not been completed, which may cause the results to be inaccurate.").arg(currentProjectPath)); - jsonObject["history"] = QJsonArray(); - QString context; - context += prompt; - context += "\n参考下面这些代码片段,回答上面的问题, @CodeBase表示针对这些代码所属的代码工程进行提问,其不属于真正的问题内容,回答时忽略@CodeBase。不要参考其他的代码和上下文,数据不够充分的情况下提示用户.\n"; - for (auto chunk : chunks) { - context += chunk.toObject()["fileName"].toString(); - context += '\n'; - context += chunk.toObject()["content"].toString(); - context += "\n\n"; - } - jsonObject["prompt"] = context; - } else if (CodeGeeXManager::instance()->condaHasInstalled()) { - emit q->noChunksFounded(); - return {}; - } - } - if (!referenceFiles.isEmpty()) { auto fileDatas = parseFile(referenceFiles); jsonObject.insert("command", "file_augment"); @@ -344,7 +329,7 @@ void AskApi::sendQueryRequest(const QString &codeToken) QString url = "https://codegeex.cn/prod/code/oauth/getUserInfo"; QNetworkReply *reply = d->getMessage(url, codeToken); - connect(reply, &QNetworkReply::finished, [=]() { + connect(reply, &QNetworkReply::finished, this, [=]() { if (reply->error()) { qCritical() << "Error:" << reply->errorString(); return; @@ -381,7 +366,7 @@ void AskApi::slotSendMessage(const QString url, const QString &token, const QByt { QNetworkReply *reply = d->postMessage(url, token, body); connect(this, &AskApi::stopReceive, reply, [reply]() { - reply->close(); + reply->abort(); }); d->processResponse(reply); } @@ -408,11 +393,49 @@ void AskApi::postSSEChat(const QString &url, } #endif - QtConcurrent::run([prompt, machineId, jsonArray, talkId, url, token, this]() { - QByteArray body = d->assembleSSEChatBody(prompt, machineId, jsonArray, talkId); - if (!body.isEmpty()) - emit syncSendMessage(url, token, body); + QByteArray body = d->assembleSSEChatBody(prompt, machineId, jsonArray, talkId); + if (!body.isEmpty()) + emit syncSendMessage(url, token, body); +} + +QString AskApi::syncQuickAsk(const QString &url, + const QString &token, + const QString &prompt, + const QString &talkId) +{ + d->terminated = false; + + QJsonObject jsonObject; + jsonObject.insert("ide", qApp->applicationName()); + jsonObject.insert("ide_version", version()); + jsonObject.insert("prompt", prompt); + jsonObject.insert("history", {}); + jsonObject.insert("locale", d->locale); + jsonObject.insert("model", chatModelLite); + jsonObject.insert("talkId", talkId); + jsonObject.insert("stream", false); + + QByteArray body = d->jsonToByteArray(jsonObject); + QNetworkReply *reply = d->postMessage(url, token, body); + QEventLoop loop; + connect(reply, &QNetworkReply::finished, reply, [&]() { + loop.exit(); + }); + connect(this, &AskApi::stopReceive, reply, [reply, &loop]() { + reply->abort(); + loop.exit(); }); + loop.exec(); + + QJsonParseError error; + QJsonDocument jsonDocument = QJsonDocument::fromJson(reply->readAll(), &error); + QJsonObject obj = jsonDocument.object(); + if (error.error != QJsonParseError::NoError) { + qCritical() << "JSON parse error: " << error.errorString(); + return ""; + } + + return obj["text"].toString(); } void AskApi::postNewSession(const QString &url, diff --git a/src/plugins/codegeex/codegeex/askapi.h b/src/plugins/codegeex/codegeex/askapi.h index ebef99db4..2c04d059b 100644 --- a/src/plugins/codegeex/codegeex/askapi.h +++ b/src/plugins/codegeex/codegeex/askapi.h @@ -68,6 +68,11 @@ class AskApi : public QObject const QMultiMap &history, const QString &talkId); + QString syncQuickAsk(const QString &url, + const QString &token, + const QString &prompt, + const QString &talkId); + void postNewSession(const QString &url, const QString &token, const QString &prompt, diff --git a/src/plugins/codegeex/codegeexmanager.cpp b/src/plugins/codegeex/codegeexmanager.cpp index 507c20376..5c3f5c50f 100644 --- a/src/plugins/codegeex/codegeexmanager.cpp +++ b/src/plugins/codegeex/codegeexmanager.cpp @@ -226,6 +226,54 @@ void CodeGeeXManager::setMessage(const QString &prompt) Q_EMIT setTextToSend(prompt); } +QString CodeGeeXManager::getChunks(const QString &queryText) +{ + using dpfservice::ProjectService; + ProjectService *prjService = dpfGetService(ProjectService); + auto currentProjectPath = prjService->getActiveProjectInfo().workspaceFolder(); + + if (currentProjectPath != "") { + QJsonObject result = CodeGeeXManager::instance()->query(currentProjectPath, queryText, 20); + QJsonArray chunks = result["Chunks"].toArray(); + if (!chunks.isEmpty()) { + if (result["Completed"].toBool() == false) + emit askApi.notify(0, CodeGeeXManager::tr("The indexing of project %1 has not been completed, which may cause the results to be inaccurate.").arg(currentProjectPath)); + QString context; + context += "\n\n"; + for (auto chunk : chunks) { + context += chunk.toObject()["fileName"].toString(); + context += '\n'; + context += chunk.toObject()["content"].toString(); + context += "\n\n"; + } + context += "\n"; + return context; + } else if (CodeGeeXManager::instance()->condaHasInstalled()) { + emit askApi.noChunksFounded(); + return ""; + } + } + + return ""; +} + +QString CodeGeeXManager::promptPreProcessing(const QString &originText) +{ + QString processedText = originText; + + QString message = originText; + if (askApi.codebaseEnabled()) { + QString prompt = QString("Translate this passage into English :\"%1\", with the requirements: Do not provide responses other than translation.").arg(message.remove("@CodeBase")); + auto englishPrompt = askApi.syncQuickAsk(kUrlSSEChat, sessionId, prompt, currentTalkID); + QString chunksContext = getChunks(englishPrompt); + if (!chunksContext.isEmpty()) + message.append(chunksContext); + processedText = originText + chunksContext; + } + + return processedText; +} + void CodeGeeXManager::sendMessage(const QString &prompt) { QString askId = "User" + QString::number(QDateTime::currentMSecsSinceEpoch()); @@ -241,7 +289,11 @@ void CodeGeeXManager::sendMessage(const QString &prompt) } QString machineId = QSysInfo::machineUniqueId(); QString talkId = currentTalkID; - askApi.postSSEChat(kUrlSSEChat, sessionId, prompt, machineId, history, talkId); + + QtConcurrent::run([=, this](){ + auto processedText = promptPreProcessing(prompt); + askApi.postSSEChat(kUrlSSEChat, sessionId, processedText, machineId, history, talkId); + }); startReceiving(); } diff --git a/src/plugins/codegeex/codegeexmanager.h b/src/plugins/codegeex/codegeexmanager.h index 4d7e4286a..ef11ab432 100644 --- a/src/plugins/codegeex/codegeexmanager.h +++ b/src/plugins/codegeex/codegeexmanager.h @@ -106,6 +106,9 @@ class CodeGeeXManager : public QObject */ QJsonObject query(const QString &projectPath, const QString &query, int topItems); + QString promptPreProcessing(const QString &originText); + QString getChunks(const QString &queryText); + Q_SIGNALS: void loginSuccessed(); void logoutSuccessed(); diff --git a/src/services/ai/aiservice.h b/src/services/ai/aiservice.h index 6a565bf78..ad77be1ce 100644 --- a/src/services/ai/aiservice.h +++ b/src/services/ai/aiservice.h @@ -6,10 +6,63 @@ #define AISERVICE_H #include "services/services_global.h" +#include "base/ai/abstractllm.h" #include #include +enum LLMType { + OPENAI, + ZHIPU_CODEGEEX +}; + +static QString LLMTypeToString(LLMType type) +{ + switch (type) { + case OPENAI: + return "OpenAi"; + case ZHIPU_CODEGEEX: + return QObject::tr("ZhiPu"); + default: + return "UnKnown"; + } +} + +struct LLMInfo +{ + QString modelName = ""; + QString modelPath = ""; + QString apikey = ""; + LLMType type; + bool operator==(const LLMInfo &info) const + { + return (this->modelName == info.modelName + && this->modelPath == info.modelPath + && this->apikey == info.apikey + && this->type == info.type); + }; + QVariant toVariant() const + { + QVariantMap map; + map["modelName"] = modelName; + map["modelPath"] = modelPath; + map["apikey"] = apikey; + map["type"] = static_cast(type); + return QVariant(map); + } + static LLMInfo fromVariantMap(const QVariantMap &map) + { + LLMInfo info; + info.modelName = map["modelName"].toString(); + info.modelPath = map["modelPath"].toString(); + info.apikey = map["apikey"].toString(); + info.type = static_cast(map["type"].toInt()); + return info; + } +}; + +Q_DECLARE_METATYPE(LLMType); + namespace dpfservice { // service interface class SERVICE_EXPORT AiService final : public dpf::PluginService, dpf::AutoServiceRegister @@ -28,11 +81,16 @@ class SERVICE_EXPORT AiService final : public dpf::PluginService, dpf::AutoServi } + // codegeex DPF_INTERFACE(bool, available); DPF_INTERFACE(void, askQuestion, const QString &prompt, QIODevice *pipe); // pipe recevice data from ai DPF_INTERFACE(void, askQuestionWithHistory, const QString &prompt, const QMultiMap history, QIODevice *pipe); + + // custom model + DPF_INTERFACE(AbstractLLM *, getLLM, const LLMInfo &info); + DPF_INTERFACE(QList, getAllModel); }; -} // namespace dpfservice +} // namespace dpfservice #endif // AISERVICE_H diff --git a/src/services/option/optiondatastruct.h b/src/services/option/optiondatastruct.h index 98055706a..77d3fad6d 100644 --- a/src/services/option/optiondatastruct.h +++ b/src/services/option/optiondatastruct.h @@ -20,5 +20,6 @@ static const QString CATEGORY_NINJA{"Ninja"}; static const QString GROUP_GENERAL{QObject::tr("General")}; static const QString GROUP_LANGUAGE{QObject::tr("Language")}; +static const QString GROUP_AI{QObject::tr("AI")}; } #endif // OPTIONDATASTRUCT_H