Skip to content

Commit

Permalink
add sqlalchemy support and mongodb support (#321)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzhihong <[email protected]>
  • Loading branch information
SuperEver and wzh1994 authored Feb 13, 2025
1 parent 380ed25 commit 7aaeaad
Show file tree
Hide file tree
Showing 16 changed files with 1,159 additions and 561 deletions.
332 changes: 222 additions & 110 deletions lazyllm/docs/tools.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion lazyllm/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class NodeArgs(object):
port=NodeArgs(str, None),
db_name=NodeArgs(str, None),
options_str=NodeArgs(str, ""),
tables_info_dict=NodeArgs(list, None),
tables_info_dict=NodeArgs(dict, None),
),
)

Expand Down
2 changes: 1 addition & 1 deletion lazyllm/thirdparty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ def __getattribute__(self, __name):
modules = ['redis', 'huggingface_hub', 'jieba', 'modelscope', 'pandas', 'jwt', 'rank_bm25', 'redisvl', 'datasets',
'deepspeed', 'fire', 'numpy', 'peft', 'torch', 'transformers', 'collie', 'faiss', 'flash_attn', 'google',
'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy', 'pymilvus',
'sentence_transformers', 'gradio', 'chromadb', 'nltk', 'PIL', 'httpx', 'bm25s', 'kubernetes']
'sentence_transformers', 'gradio', 'chromadb', 'nltk', 'PIL', 'httpx', 'bm25s', 'kubernetes', 'pymongo']
for m in modules:
vars()[m] = PackageWrapper(m)
7 changes: 5 additions & 2 deletions lazyllm/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
ReWOOAgent,
)
from .classifier import IntentClassifier
from .sql import SQLiteManger, SqlManager, SqlCall
from .sql import SqlManager, MongoDBManager, DBResult, DBStatus
from .sql_call import SqlCall
from .tools.http_tool import HttpTool

__all__ = [
Expand All @@ -28,8 +29,10 @@
"ReWOOAgent",
"IntentClassifier",
"SentenceSplitter",
"SQLiteManger",
"SqlManager",
"MongoDBManager",
"DBResult",
"DBStatus",
"SqlCall",
"HttpTool",
]
6 changes: 4 additions & 2 deletions lazyllm/tools/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .sql_tool import SQLiteManger, SqlCall, SqlManager
from .sql_manager import SqlManager
from .mongodb_manager import MongoDBManager
from .db_manager import DBManager, DBResult, DBStatus

__all__ = ["SqlCall", "SQLiteManger", "SqlManager"]
__all__ = ["DBManager", "SqlManager", "MongoDBManager", "DBResult", "DBStatus"]
54 changes: 54 additions & 0 deletions lazyllm/tools/sql/db_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from enum import Enum, unique
from typing import List, Union
from pydantic import BaseModel
from abc import ABC, abstractmethod
from lazyllm.module import ModuleBase


@unique
class DBStatus(Enum):
SUCCESS = 0
FAIL = 1


class DBResult(BaseModel):
status: DBStatus = DBStatus.SUCCESS
detail: str = "Success"
result: Union[List, None] = None

class CommonMeta(type(ABC), type(ModuleBase)):
pass

class DBManager(ABC, ModuleBase, metaclass=CommonMeta):

def __init__(self, db_type: str):
ModuleBase.__init__(self)
self._db_type = db_type
self._desc = None

@abstractmethod
def execute_query(self, statement) -> str:
pass

def forward(self, statement: str) -> str:
return self.execute_query(statement)

@property
def db_type(self) -> str:
return self._db_type

@property
@abstractmethod
def desc(self) -> str: pass

@staticmethod
def _is_dict_all_str(d):
if not isinstance(d, dict):
return False
return all(isinstance(key, str) and (isinstance(value, str) or DBManager._is_dict_all_str(value))
for key, value in d.items())

@staticmethod
def _serialize_uncommon_type(obj):
if not isinstance(obj, int, str, float, bool, tuple, list, dict):
return str(obj)
104 changes: 104 additions & 0 deletions lazyllm/tools/sql/mongodb_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from contextlib import contextmanager
from urllib.parse import quote_plus
import pydantic

from lazyllm.thirdparty import pymongo

from .db_manager import DBManager, DBResult, DBStatus


class CollectionDesc(pydantic.BaseModel):
summary: str = ""
schema_type: dict
schema_desc: dict


class MongoDBManager(DBManager):
MAX_TIMEOUT_MS = 5000

def __init__(self, user: str, password: str, host: str, port: int, db_name: str, collection_name: str, **kwargs):
super().__init__(db_type="mongodb")
self._user = user
self._password = password
self._host = host
self._port = port
self._db_name = db_name
self._collection_name = collection_name
self._collection = None
self._options_str = kwargs.get("options_str")
self._conn_url = self._gen_conn_url()
self._collection_desc_dict = kwargs.get("collection_desc_dict")

@property
def db_name(self):
return self._db_name

@property
def collection_name(self):
return self._collection_name

def _gen_conn_url(self) -> str:
password = quote_plus(self._password)
conn_url = (f"{self._db_type}://{self._user}:{password}@{self._host}:{self._port}/"
f"{('?' + self._options_str) if self._options_str else ''}")
return conn_url

@contextmanager
def get_client(self):
client = pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS)
try:
yield client
finally:
client.close()

@property
def desc(self):
if self._desc is None:
self.set_desc(schema_desc_dict=self._collection_desc_dict)
return self._desc

def set_desc(self, schema_desc_dict: dict):
self._collection_desc_dict = schema_desc_dict
if schema_desc_dict is None:
with self.get_client() as client:
egs_one = client[self._db_name][self._collection_name].find_one()
if egs_one is not None:
self._desc = "Collection Example:\n"
self._desc += json.dumps(egs_one, ensure_ascii=False, indent=4)
else:
self._desc = ""
try:
collection_desc = CollectionDesc.model_validate(schema_desc_dict)
except pydantic.ValidationError as e:
raise ValueError(f"Validate input schema_desc_dict failed: {str(e)}")
if not self._is_dict_all_str(collection_desc.schema_type):
raise ValueError("schema_type shouble be str or nested str dict")
if not self._is_dict_all_str(collection_desc.schema_desc):
raise ValueError("schema_desc shouble be str or nested str dict")
if collection_desc.summary:
self._desc += f"Collection summary: {collection_desc.summary}\n"
self._desc += "Collection schema:\n"
self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4)
self._desc += "Collection schema description:\n"
self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4)

def check_connection(self) -> DBResult:
try:
with pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) as client:
_ = client.server_info()
return DBResult()
except Exception as e:
return DBResult(status=DBStatus.FAIL, detail=str(e))

def execute_query(self, statement) -> str:
str_result = ""
try:
pipeline_list = json.loads(statement)
with self.get_client() as client:
collection = client[self._db_name][self._collection_name]
result = list(collection.aggregate(pipeline_list))
str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type)
except Exception as e:
str_result = f"MongoDB ERROR: {str(e)}"
return str_result
Loading

0 comments on commit 7aaeaad

Please sign in to comment.