From e77d1a5bca4eab1d12237deafcd5948d96ffa636 Mon Sep 17 00:00:00 2001 From: David Chiu Date: Thu, 9 Nov 2023 20:24:33 +0800 Subject: [PATCH] feat: sync `mongo` to `SimpleMongoReader` of `llama-index` (#624) --- llama_hub/mongo/base.py | 84 ++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/llama_hub/mongo/base.py b/llama_hub/mongo/base.py index 575cf57b35..1579e4b37a 100644 --- a/llama_hub/mongo/base.py +++ b/llama_hub/mongo/base.py @@ -1,6 +1,6 @@ """Mongo client.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from llama_index.readers.base import BaseReader from llama_index.readers.schema.base import Document @@ -14,8 +14,6 @@ class SimpleMongoReader(BaseReader): Args: host (str): Mongo host. port (int): Mongo port. - max_docs (int): Maximum number of documents to load. - """ def __init__( @@ -23,50 +21,84 @@ def __init__( host: Optional[str] = None, port: Optional[int] = None, uri: Optional[str] = None, - max_docs: int = 1000, ) -> None: """Initialize with parameters.""" try: - import pymongo # noqa: F401 - from pymongo import MongoClient # noqa: F401 - except ImportError: + from pymongo import MongoClient + except ImportError as err: raise ImportError( "`pymongo` package not found, please run `pip install pymongo`" - ) + ) from err + + client: MongoClient if uri: - if uri is None: - raise ValueError("Either `host` and `port` or `uri` must be provided.") - self.client: MongoClient = MongoClient(uri) + client = MongoClient(uri) + elif host and port: + client = MongoClient(host, port) else: - if host is None or port is None: - raise ValueError("Either `host` and `port` or `uri` must be provided.") - self.client = MongoClient(host, port) - self.max_docs = max_docs + raise ValueError("Either `host` and `port` or `uri` must be provided.") + + self.client = client + + def _flatten(self, texts: List[Union[str, List[str]]]) -> List[str]: + result = [] + for text in texts: + result += text if isinstance(text, list) else [text] + return result def load_data( - self, db_name: str, collection_name: str, query_dict: Optional[Dict] = None + self, + db_name: str, + collection_name: str, + field_names: List[str] = ["text"], + separator: str = "", + query_dict: Optional[Dict] = None, + max_docs: int = 0, + metadata_names: Optional[List[str]] = None, ) -> List[Document]: """Load data from the input directory. Args: db_name (str): name of the database. collection_name (str): name of the collection. - query_dict (Optional[Dict]): query to filter documents. + field_names(List[str]): names of the fields to be concatenated. + Defaults to ["text"] + separator (str): separator to be used between fields. + Defaults to "" + query_dict (Optional[Dict]): query to filter documents. Read more + at [official docs](https://www.mongodb.com/docs/manual/reference/method/db.collection.find/#std-label-method-find-query) Defaults to None + max_docs (int): maximum number of documents to load. + Defaults to 0 (no limit) + metadata_names (Optional[List[str]]): names of the fields to be added + to the metadata attribute of the Document. Defaults to None Returns: List[Document]: A list of documents. - """ - documents = [] db = self.client[db_name] - if query_dict is None: - cursor = db[collection_name].find() - else: - cursor = db[collection_name].find(query_dict) + cursor = db[collection_name].find(filter=query_dict or {}, limit=max_docs) + documents = [] for item in cursor: - if "text" not in item: - raise ValueError("`text` field not found in Mongo document.") - documents.append(Document(text=item["text"])) + try: + texts = [item[name] for name in field_names] + except KeyError as err: + raise ValueError( + f"{err.args[0]} field not found in Mongo document." + ) from err + + texts = self._flatten(texts) + text = separator.join(texts) + + if metadata_names is None: + documents.append(Document(text=text)) + else: + try: + metadata = {name: item[name] for name in metadata_names} + except KeyError as err: + raise ValueError( + f"{err.args[0]} field not found in Mongo document." + ) from err + documents.append(Document(text=text, metadata=metadata)) return documents