Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [Ai] Adapt the existing AI capabilities to switch between multi… #1056

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/base/ai/abstractllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,17 @@ AbstractLLM::AbstractLLM(QObject *parent)
{
qRegisterMetaType<ResponseState>("ResponseState");
}

void AbstractLLM::setModelState(LLMState st)
{
if (state.loadAcquire() == st)
return;

state.storeRelease(st);
Q_EMIT modelStateChanged();
}

AbstractLLM::LLMState AbstractLLM::modelState() const
{
return static_cast<LLMState>(state.loadAcquire());
}
29 changes: 24 additions & 5 deletions src/base/ai/abstractllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class AbstractLLM : public QObject
{
Q_OBJECT
public:
enum LLMState {
Idle = 0,
Busy
};

enum ResponseState {
Receiving,
Success,
Expand All @@ -22,6 +27,13 @@ class AbstractLLM : public QObject
Canceled
};

enum Locale {
Zh,
En
};

using ResponseHandler = std::function<void(const QString &data, ResponseState state)>;

explicit AbstractLLM(QObject *parent = nullptr);
virtual ~AbstractLLM() {}

Expand All @@ -30,18 +42,25 @@ class AbstractLLM : public QObject
virtual bool checkValid(QString *errStr) = 0;
virtual QJsonObject create(const Conversation &conversation) = 0;
virtual void request(const QJsonObject &data) = 0;
virtual void request(const QString &prompt) = 0;
virtual void generate(const QString &prompt, const QString &suffix) = 0;
virtual void request(const QString &prompt, ResponseHandler handler = nullptr) = 0;
virtual void generate(const QString &prefix, const QString &suffix) = 0; // use to auto compelte
virtual void setTemperature(double temperature) = 0;
virtual void setStream(bool isStream) = 0;
virtual void processResponse(QNetworkReply *reply) = 0;
virtual void setLocale(Locale lc) = 0;
virtual void cancel() = 0;
virtual void setMaxTokens(int maxToken) = 0;
virtual Conversation *getCurrentConversation() = 0;
virtual bool isIdle() = 0;

void setModelState(LLMState st);
LLMState modelState() const;

signals:
void dataReceived(const QString &data, ResponseState statu);
void customDataReceived(const QString &key, const QJsonObject &customData);
void dataReceived(const QString &data, ResponseState state);
void modelStateChanged();

private:
QAtomicInt state { Idle };
};

#endif
14 changes: 14 additions & 0 deletions src/base/ai/conversation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,21 @@
return false;
}

bool Conversation::addResponse(const QString &data)
{
if (!data.isEmpty()) {
const QJsonDocument &document = QJsonDocument::fromJson(data.toUtf8());
if (document.isArray()) {
conversation = document.array();
} else {
conversation.push_back(QJsonObject({ { "role", "assistant" }, {"content", data} }));
}
return true;
}
return false;
}

QString Conversation::getLastResponse() const

Check warning on line 87 in src/base/ai/conversation.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

The function 'getLastResponse' is never used.

Check warning on line 87 in src/base/ai/conversation.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

The function 'getLastResponse' is never used.
{
if (!conversation.isEmpty() && conversation.last()["role"].toString() == "assistant") {
return conversation.last()["content"].toString();
Expand Down
1 change: 1 addition & 0 deletions src/base/ai/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Conversation
bool addUserData(const QString &data);
bool popUserData();

bool addResponse(const QString &data);
QString getLastResponse() const;
QByteArray getLastByteResponse() const;
bool popLastResponse();
Expand Down
67 changes: 67 additions & 0 deletions src/plugins/aimanager/aimanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
// SPDX-License-Identifier: GPL-3.0-or-later

#include "aimanager.h"
#include "services/ai/aiservice.h"

Check warning on line 6 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Include file: "services/ai/aiservice.h" not found.

Check warning on line 6 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Include file: "services/ai/aiservice.h" not found.
#include "openai/openaicompatiblellm.h"
#include "openai/openaicompletionprovider.h"
#include "services/option/optionmanager.h"

Check warning on line 9 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Include file: "services/option/optionmanager.h" not found.

Check warning on line 9 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Include file: "services/option/optionmanager.h" not found.
#include "services/editor/editorservice.h"

Check warning on line 10 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Include file: "services/editor/editorservice.h" not found.

Check warning on line 10 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Include file: "services/editor/editorservice.h" not found.
#include "option/detailwidget.h"
#include "common/util/eventdefinitions.h"

Check warning on line 12 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Include file: "common/util/eventdefinitions.h" not found.

Check warning on line 12 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Include file: "common/util/eventdefinitions.h" not found.
#include "codegeex/codegeexllm.h"
#include "codegeex/codegeexcompletionprovider.h"

#include <QMap>

Check warning on line 16 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Include file: <QMap> not found. Please note: Cppcheck does not need standard library headers to get proper results.

Check warning on line 16 in src/plugins/aimanager/aimanager.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Include file: <QMap> not found. Please note: Cppcheck does not need standard library headers to get proper results.

static const char *kCodeGeeXModelPath = "https://codegeex.cn/prod/code/chatCodeSseV3/chat";

using namespace dpfservice;

class AiManagerPrivate
{
public:
QList<LLMInfo> models;
CodeGeeXCompletionProvider *cgcProvider = nullptr;
OpenAiCompletionProvider *oacProvider = nullptr;
};

AiManager *AiManager::instance()
Expand All @@ -29,6 +37,7 @@
: QObject(parent)
, d(new AiManagerPrivate)
{
initCompleteProvider();
readLLMFromOption();
}

Expand All @@ -52,6 +61,11 @@
if (!info.apikey.isEmpty())
llm->setApiKey(info.apikey);
return llm;
} else if (info.type == LLMType::ZHIPU_CODEGEEX) {
auto llm = new CodeGeeXLLM(this);
llm->setModelName(info.modelName);
llm->setModelPath(info.modelPath);
return llm;
}
}

Expand Down Expand Up @@ -100,6 +114,59 @@
appendModel(info);
}

for (LLMInfo defaultLLM : getDefaultLLM()) {
if (!d->models.contains(defaultLLM))
appendModel(defaultLLM);
}

if (changed)
ai.LLMChanged();

if (!map.value(kCATEGORY_AUTO_COMPLETE).isValid()) {
d->cgcProvider->setInlineCompletionEnabled(false);
d->oacProvider->setInlineCompletionEnabled(false);
return;
}

auto currentCompleteLLMInfo = LLMInfo::fromVariantMap(map.value(kCATEGORY_AUTO_COMPLETE).toMap());
if (currentCompleteLLMInfo.type == LLMType::OPENAI) {
d->oacProvider->setInlineCompletionEnabled(true);
d->oacProvider->setLLM(getLLM(currentCompleteLLMInfo));
d->cgcProvider->setInlineCompletionEnabled(false); // codegeex completion provider use default url
} else if (currentCompleteLLMInfo.type == LLMType::ZHIPU_CODEGEEX) {
d->cgcProvider->setInlineCompletionEnabled(true);
d->oacProvider->setInlineCompletionEnabled(false);
}
}

void AiManager::initCompleteProvider()
{
d->cgcProvider = new CodeGeeXCompletionProvider(this);
d->oacProvider = new OpenAiCompletionProvider(this);
d->cgcProvider->setInlineCompletionEnabled(false);
d->oacProvider->setInlineCompletionEnabled(false);

connect(&dpf::Listener::instance(), &dpf::Listener::pluginsStarted, this, [=] {
auto editorService = dpfGetService(EditorService);
d->cgcProvider->setInlineCompletionEnabled(true);
editorService->registerInlineCompletionProvider(d->cgcProvider);
editorService->registerInlineCompletionProvider(d->oacProvider);
}, Qt::DirectConnection);
}

QList<LLMInfo> AiManager::getDefaultLLM()
{
LLMInfo liteInfo;
liteInfo.icon = QIcon::fromTheme("codegeex_model_lite");
liteInfo.modelName = CodeGeeXChatModelLite;
liteInfo.modelPath = kCodeGeeXModelPath;
liteInfo.type = LLMType::ZHIPU_CODEGEEX;

LLMInfo proInfo;
proInfo.icon = QIcon::fromTheme("codegeex_model_pro");
proInfo.modelName = CodeGeeXChatModelPro;
proInfo.modelPath = kCodeGeeXModelPath;
proInfo.type = LLMType::ZHIPU_CODEGEEX;

return { liteInfo, proInfo };
}
5 changes: 5 additions & 0 deletions src/plugins/aimanager/aimanager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

// option manager -> custom models
static const char *kCATEGORY_CUSTOMMODELS = "CustomModels";
static const char *kCATEGORY_AUTO_COMPLETE = "AutoComplete";
static const char *kCATEGORY_OPTIONKEY = "OptionKey";
static const char *CodeGeeXChatModelLite = "codegeex-4";
static const char *CodeGeeXChatModelPro = "codegeex-chat-pro";

class AiManagerPrivate;
class AiManager : public QObject
Expand All @@ -25,8 +28,10 @@ class AiManager : public QObject
void appendModel(const LLMInfo &LLMInfo);
void removeModel(const LLMInfo &LLMInfo);
bool checkModelValid(const LLMInfo &info, QString *errStr);
QList<LLMInfo> getDefaultLLM();

void readLLMFromOption();
void initCompleteProvider();

private:
AiManager(QObject *parent = nullptr);
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/aimanager/aimanager.qrc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
<RCC>
<qresource prefix="/aimanager">
<qresource prefix="/icons/deepin">
<file>builtin/icons/codegeex_model_lite_16px.svg</file>
<file>builtin/icons/codegeex_model_pro_16px.svg</file>
</qresource>
</RCC>
17 changes: 15 additions & 2 deletions src/plugins/aimanager/aiplugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,24 @@
using namespace std::placeholders;
aiService->getAllModel = std::bind(&AiManager::getAllModel, impl);
aiService->getLLM = std::bind(&AiManager::getLLM, impl, _1);
aiService->getCodeGeeXLLMLite = []() -> LLMInfo {
auto defaultLLMs = AiManager::instance()->getDefaultLLM();
for (auto llm : defaultLLMs) {
if (llm.modelName == CodeGeeXChatModelLite)

Check warning on line 32 in src/plugins/aimanager/aiplugin.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Consider using std::find_if algorithm instead of a raw loop.

Check warning on line 32 in src/plugins/aimanager/aiplugin.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Consider using std::find_if algorithm instead of a raw loop.
return llm;
}
};
aiService->getCodeGeeXLLMPro = []() -> LLMInfo {
auto defaultLLMs = AiManager::instance()->getDefaultLLM();
for (auto llm : defaultLLMs) {
if (llm.modelName == CodeGeeXChatModelPro)

Check warning on line 39 in src/plugins/aimanager/aiplugin.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

Consider using std::find_if algorithm instead of a raw loop.

Check warning on line 39 in src/plugins/aimanager/aiplugin.cpp

View workflow job for this annotation

GitHub Actions / static-check / static-check

Consider using std::find_if algorithm instead of a raw loop.
return llm;
}
};

auto optionService = dpfGetService(dpfservice::OptionService);
if (optionService) {
// TODO:uncomment the code when everything is ok
// optionService->implGenerator<OptionCustomModelsGenerator>(option::GROUP_AI, OptionCustomModelsGenerator::kitName());
optionService->implGenerator<OptionCustomModelsGenerator>(option::GROUP_AI, OptionCustomModelsGenerator::kitName());
}

return true;
Expand Down
Loading
Loading