-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sqlalchemy support and mongodb support (#321)
Co-authored-by: wangzhihong <[email protected]>
- Loading branch information
Showing
16 changed files
with
1,159 additions
and
561 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from .sql_tool import SQLiteManger, SqlCall, SqlManager | ||
from .sql_manager import SqlManager | ||
from .mongodb_manager import MongoDBManager | ||
from .db_manager import DBManager, DBResult, DBStatus | ||
|
||
__all__ = ["SqlCall", "SQLiteManger", "SqlManager"] | ||
__all__ = ["DBManager", "SqlManager", "MongoDBManager", "DBResult", "DBStatus"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from enum import Enum, unique | ||
from typing import List, Union | ||
from pydantic import BaseModel | ||
from abc import ABC, abstractmethod | ||
from lazyllm.module import ModuleBase | ||
|
||
|
||
@unique | ||
class DBStatus(Enum): | ||
SUCCESS = 0 | ||
FAIL = 1 | ||
|
||
|
||
class DBResult(BaseModel): | ||
status: DBStatus = DBStatus.SUCCESS | ||
detail: str = "Success" | ||
result: Union[List, None] = None | ||
|
||
class CommonMeta(type(ABC), type(ModuleBase)): | ||
pass | ||
|
||
class DBManager(ABC, ModuleBase, metaclass=CommonMeta): | ||
|
||
def __init__(self, db_type: str): | ||
ModuleBase.__init__(self) | ||
self._db_type = db_type | ||
self._desc = None | ||
|
||
@abstractmethod | ||
def execute_query(self, statement) -> str: | ||
pass | ||
|
||
def forward(self, statement: str) -> str: | ||
return self.execute_query(statement) | ||
|
||
@property | ||
def db_type(self) -> str: | ||
return self._db_type | ||
|
||
@property | ||
@abstractmethod | ||
def desc(self) -> str: pass | ||
|
||
@staticmethod | ||
def _is_dict_all_str(d): | ||
if not isinstance(d, dict): | ||
return False | ||
return all(isinstance(key, str) and (isinstance(value, str) or DBManager._is_dict_all_str(value)) | ||
for key, value in d.items()) | ||
|
||
@staticmethod | ||
def _serialize_uncommon_type(obj): | ||
if not isinstance(obj, int, str, float, bool, tuple, list, dict): | ||
return str(obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import json | ||
from contextlib import contextmanager | ||
from urllib.parse import quote_plus | ||
import pydantic | ||
|
||
from lazyllm.thirdparty import pymongo | ||
|
||
from .db_manager import DBManager, DBResult, DBStatus | ||
|
||
|
||
class CollectionDesc(pydantic.BaseModel): | ||
summary: str = "" | ||
schema_type: dict | ||
schema_desc: dict | ||
|
||
|
||
class MongoDBManager(DBManager): | ||
MAX_TIMEOUT_MS = 5000 | ||
|
||
def __init__(self, user: str, password: str, host: str, port: int, db_name: str, collection_name: str, **kwargs): | ||
super().__init__(db_type="mongodb") | ||
self._user = user | ||
self._password = password | ||
self._host = host | ||
self._port = port | ||
self._db_name = db_name | ||
self._collection_name = collection_name | ||
self._collection = None | ||
self._options_str = kwargs.get("options_str") | ||
self._conn_url = self._gen_conn_url() | ||
self._collection_desc_dict = kwargs.get("collection_desc_dict") | ||
|
||
@property | ||
def db_name(self): | ||
return self._db_name | ||
|
||
@property | ||
def collection_name(self): | ||
return self._collection_name | ||
|
||
def _gen_conn_url(self) -> str: | ||
password = quote_plus(self._password) | ||
conn_url = (f"{self._db_type}://{self._user}:{password}@{self._host}:{self._port}/" | ||
f"{('?' + self._options_str) if self._options_str else ''}") | ||
return conn_url | ||
|
||
@contextmanager | ||
def get_client(self): | ||
client = pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) | ||
try: | ||
yield client | ||
finally: | ||
client.close() | ||
|
||
@property | ||
def desc(self): | ||
if self._desc is None: | ||
self.set_desc(schema_desc_dict=self._collection_desc_dict) | ||
return self._desc | ||
|
||
def set_desc(self, schema_desc_dict: dict): | ||
self._collection_desc_dict = schema_desc_dict | ||
if schema_desc_dict is None: | ||
with self.get_client() as client: | ||
egs_one = client[self._db_name][self._collection_name].find_one() | ||
if egs_one is not None: | ||
self._desc = "Collection Example:\n" | ||
self._desc += json.dumps(egs_one, ensure_ascii=False, indent=4) | ||
else: | ||
self._desc = "" | ||
try: | ||
collection_desc = CollectionDesc.model_validate(schema_desc_dict) | ||
except pydantic.ValidationError as e: | ||
raise ValueError(f"Validate input schema_desc_dict failed: {str(e)}") | ||
if not self._is_dict_all_str(collection_desc.schema_type): | ||
raise ValueError("schema_type shouble be str or nested str dict") | ||
if not self._is_dict_all_str(collection_desc.schema_desc): | ||
raise ValueError("schema_desc shouble be str or nested str dict") | ||
if collection_desc.summary: | ||
self._desc += f"Collection summary: {collection_desc.summary}\n" | ||
self._desc += "Collection schema:\n" | ||
self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) | ||
self._desc += "Collection schema description:\n" | ||
self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4) | ||
|
||
def check_connection(self) -> DBResult: | ||
try: | ||
with pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) as client: | ||
_ = client.server_info() | ||
return DBResult() | ||
except Exception as e: | ||
return DBResult(status=DBStatus.FAIL, detail=str(e)) | ||
|
||
def execute_query(self, statement) -> str: | ||
str_result = "" | ||
try: | ||
pipeline_list = json.loads(statement) | ||
with self.get_client() as client: | ||
collection = client[self._db_name][self._collection_name] | ||
result = list(collection.aggregate(pipeline_list)) | ||
str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type) | ||
except Exception as e: | ||
str_result = f"MongoDB ERROR: {str(e)}" | ||
return str_result |
Oops, something went wrong.