diff --git a/.gitignore b/.gitignore index 30dd45d..5655d21 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg +*.cfg *.egg MANIFEST *.csv @@ -39,3 +40,17 @@ src/vdf_io/notebooks/chroma/* *.pem src/vdf_io/notebooks/**.jpg +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Testing folders +testing/ +tests/ +test_results/ +test_reports/ \ No newline at end of file diff --git a/docs/mongodb_README.md b/docs/mongodb_README.md new file mode 100644 index 0000000..ed3f336 --- /dev/null +++ b/docs/mongodb_README.md @@ -0,0 +1,72 @@ +# MongoDB Import/Export Utility + +This guide provides a comprehensive overview of how to effectively import and export VDF formatted to and from MongoDB collections. + +## Prerequisites + +Ensure you have reviewed the root [README](../README.md) of this repository before proceeding. + +## Command-Line Usage + +### Shared Arguments + +- ``: Your MongoDB Atlas connection string. +- ``: The name of your MongoDB database. +- ``: The name of your MongoDB collection. +- ``: The dimension of the vector columns to be imported/exported. If not specified, the script will auto-detect the dimension. + +### 1. Exporting Data from MongoDB + +To export data from a MongoDB collection to a VDF (Vector Data Format) dataset: + +```bash +export_vdf mongodb --connection_string --database --collection --vector_dim +``` + +### 2. Importing Data to MongoDB + +To import data from a VDF dataset into a MongoDB collection: + +```bash +import_vdf -d mongodb --connection_string --database --collection --vector_dim +``` + +**Additional Argument** for Import: + +- ``: Path to the VDF dataset directory on your system. + +### Example Usage + +#### Export Example + +To export data from a MongoDB collection called `my_collection` in the database `my_database`, where vectors are of dimension 128: + +```bash +export_vdf mongodb --connection_string "mongodb+srv://:@.mongodb.net/?retryWrites=true&w=majority" --database "my_database" --collection "my_collection" --vector_dim 128 +``` + +#### Import Example + +To import data from a VDF dataset located in `/path/to/vdf/dataset` into the MongoDB collection `sample_collection`: + +```bash +import_vdf -d /path/to/vdf/dataset mongodb --connection_string "mongodb+srv://:@.mongodb.net/?retryWrites=true&w=majority" --database "sample_database" --collection "sample_collection" --vector_dim 128 +``` + +## Key Features + +- **Batch Processing**: Both import and export operations support batching for improved efficiency. +- **Data Type Conversion**: Automatically converts data types to corresponding MongoDB-compatible formats. +- **Auto-detection**: If the `vector_dim` parameter is not specified, the utility will automatically detect the dimension of the vectors. +- **Interactive Mode**: The utility will prompt for any missing arguments if they are not provided via the command line. + +## Additional Notes + +- Always verify that your `` contains the correct username, password, cluster name, and database details. +- Ensure the VDF dataset is properly formatted to match MongoDB's expected data types and structure. + +## Troubleshooting + +- Ensure that your IP address is configured in the **Network Access** section of your MongoDB Atlas dashboard to allow connections to your MongoDB instance. If you encounter difficulties with the connection string format, consult [MongoDB's official documentation](https://www.mongodb.com/docs/atlas/connect-to-cluster/) for detailed guidance. + +- For any issues related to vector dimension mismatches, verify that the vector dimension in the VDF dataset matches the `vector_dim` parameter you provide during import or export operations. diff --git a/requirements.txt b/requirements.txt index 260025b..fc08161 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,5 @@ mlx_embedding_models azure-search-documents azure-identity turbopuffer[fast] -psycopg2 \ No newline at end of file +psycopg2 +pymongo \ No newline at end of file diff --git a/src/vdf_io/export_vdf/mongodb_export.py b/src/vdf_io/export_vdf/mongodb_export.py new file mode 100644 index 0000000..c7cbc57 --- /dev/null +++ b/src/vdf_io/export_vdf/mongodb_export.py @@ -0,0 +1,236 @@ +import json +import os +from typing import Dict, List +import pymongo +import pandas as pd +from tqdm import tqdm +from vdf_io.meta_types import NamespaceMeta +from vdf_io.names import DBNames +from vdf_io.util import set_arg_from_input +from vdf_io.export_vdf.vdb_export_cls import ExportVDB +from bson import ObjectId, Binary, Regex, Timestamp, Decimal128, Code +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ExportMongoDB(ExportVDB): + DB_NAME_SLUG = DBNames.MONGODB + + @classmethod + def make_parser(cls, subparsers): + parser_mongodb = subparsers.add_parser( + cls.DB_NAME_SLUG, help="Export data from MongoDB" + ) + parser_mongodb.add_argument( + "--connection_string", type=str, help="MongoDB Atlas Connection string" + ) + parser_mongodb.add_argument( + "--vector_dim", type=int, help="Expected dimension of vector columns" + ) + parser_mongodb.add_argument( + "--database", type=str, help="MongoDB Atlas Database name" + ) + parser_mongodb.add_argument( + "--collection", type=str, help="MongoDB Atlas collection to export" + ) + parser_mongodb.add_argument( + "--batch_size", + type=int, + help="Batch size for exporting data", + default=10_000, + ) + + @classmethod + def export_vdb(cls, args): + set_arg_from_input( + args, + "connection_string", + "Enter the MongoDB Atlas connection string: ", + str, + ) + set_arg_from_input( + args, + "database", + "Enter the MongoDB Atlas database name: ", + str, + ) + set_arg_from_input( + args, + "collection", + "Enter the name of collection to export: ", + str, + ) + set_arg_from_input( + args, + "vector_dim", + "Enter the expected dimension of vector columns: ", + int, + ) + mongodb_atlas_export = ExportMongoDB(args) + mongodb_atlas_export.all_collections = mongodb_atlas_export.get_index_names() + mongodb_atlas_export.get_data() + return mongodb_atlas_export + + def __init__(self, args): + super().__init__(args) + try: + self.client = pymongo.MongoClient( + args["connection_string"], serverSelectionTimeoutMS=5000 + ) + self.client.server_info() + logger.info("Successfully connected to MongoDB") + except pymongo.errors.ServerSelectionTimeoutError as err: + logger.error(f"Failed to connect to MongoDB: {err}") + raise + + try: + self.db = self.client[args["database"]] + except Exception as err: + logger.error(f"Failed to select MongoDB database: {err}") + raise + + try: + self.collection = self.db[args["collection"]] + except Exception as err: + logger.error(f"Failed to select MongoDB collection: {err}") + raise + + def get_index_names(self): + collection_name = self.args.get("collection", None) + if collection_name is not None: + if collection_name not in self.db.list_collection_names(): + logger.error( + f"Collection '{collection_name}' does not exist in the database." + ) + raise ValueError( + f"Collection '{collection_name}' does not exist in the database." + ) + return [collection_name] + else: + return self.get_all_index_names() + + def get_all_index_names(self): + return self.db.list_collection_names() + + def flatten_dict(self, d, parent_key="", sep="#SEP#"): + items = [] + type_conversions = { + ObjectId: lambda v: f"BSON_ObjectId_{str(v)}", + Binary: lambda v: f"BSON_Binary_{v.decode('utf-8', errors='ignore')}", + Regex: lambda v: f"BSON_Regex_{json.dumps({'pattern': v.pattern, 'options': v.options})}", + Timestamp: lambda v: f"BSON_Timestamp_{v.as_datetime().isoformat()}", + Decimal128: lambda v: f"BSON_Decimal128_{float(v.to_decimal())}", + Code: lambda v: f"BSON_Code_{str(v.code)}", + } + + for key, value in d.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + conversion = type_conversions.get(type(value)) + + if conversion: + items.append((new_key, conversion(value))) + elif isinstance(value, dict): + items.extend(self.flatten_dict(value, new_key, sep=sep).items()) + elif isinstance(value, list): + if all(isinstance(v, dict) and "$numberDouble" in v for v in value): + float_list = [float(v["$numberDouble"]) for v in value] + items.append((new_key, float_list)) + else: + items.append((new_key, value)) + else: + items.append((new_key, value)) + + return dict(items) + + def get_data(self): + object_columns_list = [] + vector_columns = [] + expected_dim = self.args.get("vector_dim") + collection_name = self.args["collection"] + batch_size = self.args["batch_size"] + + vectors_directory = self.create_vec_dir(collection_name) + + total_documents = self.collection.count_documents({}) + total_batches = (total_documents + batch_size - 1) // batch_size + total = 0 + index_metas: Dict[str, List[NamespaceMeta]] = {} + + if expected_dim is None: + logger.info("Vector dimension not provided. Detecting from data...") + sample_doc = self.collection.find_one() + if sample_doc: + flat_doc = self.flatten_dict(sample_doc) + for key, value in flat_doc.items(): + if isinstance(value, list) and all( + isinstance(x, (int, float)) for x in value + ): + expected_dim = len(value) + logger.info( + f"Detected vector dimension: {expected_dim} from column: {key}" + ) + break + + if expected_dim is None: + expected_dim = 0 + logger.warning("No vector columns detected in the data") + + for i in tqdm(range(total_batches), desc=f"Exporting {collection_name}"): + cursor = self.collection.find().skip(i * batch_size).limit(batch_size) + batch_data = list(cursor) + if not batch_data: + break + + flattened_data = [] + for document in batch_data: + flat_doc = self.flatten_dict(document) + + for key in flat_doc: + if isinstance(flat_doc[key], dict): + flat_doc[key] = json.dumps(flat_doc[key]) + elif flat_doc[key] == "": + flat_doc[key] = None + + flattened_data.append(flat_doc) + + df = pd.DataFrame(flattened_data) + df = df.dropna(axis=1, how="all") + + for column in df.columns: + if ( + isinstance(df[column].iloc[0], list) + and len(df[column].iloc[0]) == expected_dim + ): + vector_columns.append(column) + else: + object_columns_list.append(column) + df[column] = df[column].astype(str) + + parquet_file = os.path.join(vectors_directory, f"{i}.parquet") + df.to_parquet(parquet_file) + total += len(df) + + namespace_metas = [ + self.get_namespace_meta( + collection_name, + vectors_directory, + total=total, + num_vectors_exported=total, + dim=expected_dim, + vector_columns=vector_columns, + distance="cosine", + ) + ] + index_metas[collection_name] = namespace_metas + + self.file_structure.append(os.path.join(self.vdf_directory, "VDF_META.json")) + internal_metadata = self.get_basic_vdf_meta(index_metas) + meta_text = json.dumps(internal_metadata.model_dump(), indent=4) + tqdm.write(meta_text) + with open(os.path.join(self.vdf_directory, "VDF_META.json"), "w") as json_file: + json_file.write(meta_text) + + logger.info(f"Export completed. Total documents exported: {total}") + return True diff --git a/src/vdf_io/import_vdf/mongodb_import.py b/src/vdf_io/import_vdf/mongodb_import.py new file mode 100644 index 0000000..ddfe299 --- /dev/null +++ b/src/vdf_io/import_vdf/mongodb_import.py @@ -0,0 +1,249 @@ +from dotenv import load_dotenv +from tqdm import tqdm +import pymongo +import logging +import re +import ast +import numpy as np +from bson import ObjectId, Binary, Regex, Timestamp, Decimal128, Code +import json +from datetime import datetime +from vdf_io.constants import DEFAULT_BATCH_SIZE, INT_MAX +from vdf_io.names import DBNames +from vdf_io.util import ( + cleanup_df, + divide_into_batches, + set_arg_from_input, +) +from vdf_io.import_vdf.vdf_import_cls import ImportVDB + +load_dotenv() +logger = logging.getLogger(__name__) + + +class ImportMongoDB(ImportVDB): + DB_NAME_SLUG = DBNames.MONGODB + + @classmethod + def make_parser(cls, subparsers): + parser_mongodb = subparsers.add_parser( + cls.DB_NAME_SLUG, help="Import data to MongoDB" + ) + parser_mongodb.add_argument( + "--connection_string", type=str, help="MongoDB Atlas Connection string" + ) + parser_mongodb.add_argument( + "--database", type=str, help="MongoDB Atlas Database name" + ) + parser_mongodb.add_argument( + "--collection", type=str, help="MongoDB Atlas collection to export" + ) + parser_mongodb.add_argument( + "--vector_dim", type=int, help="Expected dimension of vector columns" + ) + + @classmethod + def import_vdb(cls, args): + """ + Import data to MongoDB + """ + set_arg_from_input( + args, + "connection_string", + "Enter the MongoDB connection string: ", + str, + ) + set_arg_from_input( + args, + "database", + "Enter the MongoDB database name: ", + str, + ) + set_arg_from_input( + args, + "collection", + "Enter the name of collection: ", + str, + ) + set_arg_from_input( + args, "vector_dim", "Enter the expected dimension of vector columns: ", int + ) + mongodb_import = ImportMongoDB(args) + mongodb_import.upsert_data() + return mongodb_import + + def __init__(self, args): + super().__init__(args) + + try: + self.client = pymongo.MongoClient( + args["connection_string"], serverSelectionTimeoutMS=5000 + ) + self.client.server_info() + logger.info("Successfully connected to MongoDB") + except pymongo.errors.ServerSelectionTimeoutError as err: + logger.error(f"Failed to connect to MongoDB: {err}") + raise + + try: + self.db = self.client[args["database"]] + except Exception as err: + logger.error(f"Failed to select MongoDB database: {err}") + raise + + try: + self.collection = self.db[args["collection"]] + except Exception as err: + logger.error(f"Failed to select MongoDB collection: {err}") + raise + + def convert_types(self, documents): + return [self.convert_document(doc) for doc in documents] + + def convert_document(self, doc): + converted_doc = {} + for key, value in doc.items(): + parts = key.split("#SEP#") + value = self.convert_value(value) + self.nested_set(converted_doc, parts, value) + return converted_doc + + def nested_set(self, dic, keys, value): + for key in keys[:-1]: + # If the key already exists and is not a dictionary, we need to handle it + if key in dic and not isinstance(dic[key], dict): + dic[key] = {} # Overwrite with an empty dictionary + + dic = dic.setdefault(key, {}) + + dic[keys[-1]] = value # Set the final key to the value + + def convert_value(self, value): + if isinstance(value, np.ndarray): + return value.tolist() # Convert numpy array to list : MongoDB can't handle the numpy array directly + + # Check if the value is a string + if isinstance(value, str): + # Check if the string is a date in "YYYY-MM-DD" format or extended ISO format + date_pattern = r"^\d{4}-\d{2}-\d{2}$" # Regex for "YYYY-MM-DD" + iso_pattern = ( + r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{3}Z$" # Extended ISO + ) + + if re.match(date_pattern, value): # If it matches "YYYY-MM-DD" + try: + return datetime.strptime(value, "%Y-%m-%d") # Convert to datetime + except ValueError: + pass + + if re.match(iso_pattern, value): + try: + return datetime.fromisoformat(value) + except ValueError: + pass + + try: + return int(value) + except ValueError: + try: + return float(value) + except ValueError: + pass + + try: + # Try to evaluate if the string is a list (e.g., for arrays like genres) + value = ast.literal_eval(value) + if isinstance(value, list): + return value # Return as a list if it is a list + except (ValueError, SyntaxError): + # If it's not an array or number, leave it as a string + pass + + # Handle special BSON formats, as before + if value.startswith("BSON_ObjectId_"): + return ObjectId(value[14:]) + elif value.startswith("BSON_Binary_"): + return Binary(value[12:].encode("utf-8")) + elif value.startswith("BSON_Regex_"): + regex_dict = json.loads(value[11:]) + return Regex(regex_dict["pattern"], regex_dict["options"]) + elif value.startswith("BSON_Timestamp_"): + return Timestamp(datetime.fromisoformat(value[16:])) + elif value.startswith("BSON_Decimal128_"): + return Decimal128(value[16:]) + elif value.startswith("BSON_Code_"): + return Code(value[10:]) + + elif isinstance(value, list): + return [self.convert_value(item) for item in value] + + return value + + def upsert_data(self): + max_hit = False + self.total_imported_count = 0 + indexes_content = self.vdf_meta["indexes"] + index_names = list(indexes_content.keys()) + if len(index_names) == 0: + raise ValueError("No indexes found in VDF_META.json") + + for index_name, index_meta in tqdm( + indexes_content.items(), desc="Importing indexes" + ): + for namespace_meta in tqdm(index_meta, desc="Importing namespaces"): + self.set_dims(namespace_meta, index_name) + data_path = namespace_meta["data_path"] + final_data_path = self.get_final_data_path(data_path) + parquet_files = self.get_parquet_files(final_data_path) + + for file in tqdm(parquet_files, desc="Iterating parquet files"): + file_path = self.get_file_path(final_data_path, file) + try: + df = self.read_parquet_progress( + file_path, + max_num_rows=(self.args.get("max_num_rows") or INT_MAX), + ) + except Exception as e: + logger.error( + f"Error reading Parquet file {file_path}: {str(e)}" + ) + continue + df = cleanup_df(df) + + BATCH_SIZE = self.args.get("batch_size") or DEFAULT_BATCH_SIZE + for batch in tqdm( + divide_into_batches(df, BATCH_SIZE), + desc="Importing batches", + total=len(df) // BATCH_SIZE, + ): + if self.total_imported_count + len(batch) >= ( + self.args.get("max_num_rows") or INT_MAX + ): + batch = batch[ + : (self.args.get("max_num_rows") or INT_MAX) + - self.total_imported_count + ] + max_hit = True + + documents = batch.to_dict("records") + + try: + documents = self.convert_types(documents) + self.collection.insert_many(documents) + self.total_imported_count += len(batch) + except pymongo.errors.BulkWriteError as e: + logger.error(f"Error during bulk insert: {str(e.details)}") + + if max_hit: + break + + tqdm.write(f"Imported {self.total_imported_count} rows") + tqdm.write( + f"New collection size: {self.collection.count_documents({})}" + ) + if max_hit: + break + + logger.info( + f"Data import completed. Total rows imported: {self.total_imported_count}" + ) diff --git a/src/vdf_io/names.py b/src/vdf_io/names.py index 6697e05..95baeae 100644 --- a/src/vdf_io/names.py +++ b/src/vdf_io/names.py @@ -12,3 +12,4 @@ class DBNames: ASTRADB = "astradb" AZUREAI = "azureai" TURBOPUFFER = "turbopuffer" + MONGODB = "mongodb" diff --git a/src/vdf_io/util.py b/src/vdf_io/util.py index 18554a8..41d93f6 100644 --- a/src/vdf_io/util.py +++ b/src/vdf_io/util.py @@ -215,6 +215,10 @@ def expand_shorthand_path(shorthand_path): "euclidean_distance": Distance.EUCLID, "dot_product": Distance.DOT, }, + DBNames.MONGODB: { + "cosine": Distance.COSINE, + "euclidean": Distance.EUCLID, + }, }