From ab1a8e697ec63f9295a3facf95c327824f067532 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Tue, 14 Jan 2025 20:07:12 -0800 Subject: [PATCH 1/3] initial sql integration --- examples/db_examples/sql_db.py | 55 ++++++++++++++++++++++++++++++++++ lotus/__init__.py | 2 ++ lotus/databases/__init__.py | 4 +++ lotus/databases/connectors.py | 31 +++++++++++++++++++ lotus/databases/lotus_db.py | 33 ++++++++++++++++++++ requirements.txt | 5 +++- 6 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 examples/db_examples/sql_db.py create mode 100644 lotus/databases/__init__.py create mode 100644 lotus/databases/connectors.py create mode 100644 lotus/databases/lotus_db.py diff --git a/examples/db_examples/sql_db.py b/examples/db_examples/sql_db.py new file mode 100644 index 0000000..5711365 --- /dev/null +++ b/examples/db_examples/sql_db.py @@ -0,0 +1,55 @@ +import sqlite3 + +import lotus +from lotus.databases import DatabaseConnector, LotusDB +from lotus.models import LM + +conn = sqlite3.connect("example_movies.db") +cursor = conn.cursor() + +# Create the table +cursor.execute(""" +CREATE TABLE IF NOT EXISTS movies ( + id INTEGER PRIMARY KEY, + title TEXT, + director TEXT, + rating REAL, + release_year INTEGER, + description TEXT +) +""") + +cursor.execute("DELETE FROM movies") + +# Insert sample data +cursor.executemany( + """ +INSERT INTO movies (title, director, rating, release_year, description) +VALUES (?, ?, ?, ?, ?) +""", + [ + ("The Matrix", "Wachowskis", 8.7, 1999, "A hacker discovers the reality is simulated."), + ("The Godfather", "Francis Coppola", 9.2, 1972, "The rise and fall of a powerful mafia family."), + ("Inception", "Christopher Nolan", 8.8, 2010, "A thief enters dreams to steal secrets."), + ("Parasite", "Bong Joon-ho", 8.6, 2019, "A poor family schemes to infiltrate a rich household."), + ("Interstellar", "Christopher Nolan", 8.6, 2014, "A team travels through a wormhole to save humanity."), + ("Titanic", "James Cameron", 7.8, 1997, "A love story set during the Titanic tragedy."), + ], +) + +conn.commit() +conn.close() + + +lm = LM(model="gpt-4o-mini") +lotus.settings.configure(lm=lm) + +connector = DatabaseConnector() +connector.connect_sql("sqlite:///example_movies.db") +lotus_db = LotusDB(connector) + +df = lotus_db.query("SELECT * FROM movies") + +user_instruction = "{title} that are science fiction" +df = df.sem_filter(user_instruction) +print(df) diff --git a/lotus/__init__.py b/lotus/__init__.py index d20f710..d1c8bf8 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -2,6 +2,7 @@ import lotus.dtype_extensions import lotus.models import lotus.vector_store +import lotus.databases import lotus.nl_expression import lotus.templates import lotus.utils @@ -48,4 +49,5 @@ "vector_store", "utils", "dtype_extensions", + "databases", ] diff --git a/lotus/databases/__init__.py b/lotus/databases/__init__.py new file mode 100644 index 0000000..8281a94 --- /dev/null +++ b/lotus/databases/__init__.py @@ -0,0 +1,4 @@ +from lotus.databases.connectors import DatabaseConnector +from lotus.databases.lotus_db import LotusDB + +__all__ = ["DatabaseConnector", "LotusDB"] diff --git a/lotus/databases/connectors.py b/lotus/databases/connectors.py new file mode 100644 index 0000000..7466f88 --- /dev/null +++ b/lotus/databases/connectors.py @@ -0,0 +1,31 @@ +from pymongo import MongoClient +from sqlalchemy import create_engine + + +class DatabaseConnector: + def __init__(self): + self.sql_engine = None + self.nosql_client = None + + def connect_sql(self, connection_url: str): + """Connect to SQL database""" + try: + self.sql_engine = create_engine(connection_url) + return self + except Exception as e: + raise ConnectionError(f"Error connecting to SQL database: {e}") + + def connect_nosql(self, connection_url: str): + """Connect to MongoDB NoSQL database""" + try: + self.nosql_client = MongoClient(connection_url) + return self + except Exception as e: + raise ConnectionError(f"Error connecting to NoSQL database: {e}") + + def close_connections(self): + """Close SQL and NoSQL connections""" + if self.sql_engine: + self.sql_engine.dispose() + if self.nosql_client: + self.nosql_client.close() diff --git a/lotus/databases/lotus_db.py b/lotus/databases/lotus_db.py new file mode 100644 index 0000000..a59e89a --- /dev/null +++ b/lotus/databases/lotus_db.py @@ -0,0 +1,33 @@ +from typing import Optional + +import pandas as pd + + +class LotusDB: + def __init__(self, connector): + self.connector = connector + + def query( + self, query: str, db_type: str = "sql", db_name: Optional[str] = None, collection_name: Optional[str] = None + ) -> pd.DataFrame: + """ + Executes query and returns a pandas dataframe + + Args: + query (str): The query to execute + db_type (str, optional): The type of database to use. Defaults to 'sql'. + + Returns: + pd.DataFrame: The result of the query + + """ + if db_type == "sql": + return pd.read_sql(query, self.connector.sql_engine) + elif db_type == "nosql": + if not collection_name or not db_name: + raise ValueError("Collection name and database is required for NoSQL database") + collection = self.connector.nosql_client[db_name][collection_name] + results = collection.find(query) + return pd.DataFrame(list(results)) + else: + raise ValueError("Invalid database type") diff --git a/requirements.txt b/requirements.txt index e645c71..04018c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,7 @@ tqdm==4.66.4 weaviate-client==4.10.2 pinecone==5.4.2 chromadb==0.6.2 -qdrant-client==1.12.2 \ No newline at end of file +qdrant-client==1.12.2 +psycopg2-binary==2.9.10 +SQLAlchemy==2.0.37 +pymongo==4.10.1 \ No newline at end of file From 668c90e5bcce972deba0e02538f101392b981cde Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 16 Jan 2025 12:08:57 -0800 Subject: [PATCH 2/3] add sql db tests --- .github/tests/db_tests.py | 94 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 .github/tests/db_tests.py diff --git a/.github/tests/db_tests.py b/.github/tests/db_tests.py new file mode 100644 index 0000000..dcd9fe8 --- /dev/null +++ b/.github/tests/db_tests.py @@ -0,0 +1,94 @@ +import os +import sqlite3 + +import pandas as pd +import pytest + +import lotus +from lotus.databases import DatabaseConnector, LotusDB +from lotus.models import LM + +################################################################################ +# Setup +################################################################################ +# Set logger level to DEBUG +lotus.logger.setLevel("DEBUG") + +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true" + +MODEL_NAME_TO_ENABLED = { + "gpt-4o-mini": ENABLE_OPENAI_TESTS, + "gpt-4o": ENABLE_OPENAI_TESTS, + "ollama/llama3.1": ENABLE_OLLAMA_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] + + +@pytest.fixture(scope="session") +def setup_models(): + models = {} + + for model_path in ENABLED_MODEL_NAMES: + models[model_path] = LM(model=model_path) + + return models + + +@pytest.fixture(scope="session") +def setup_sqlite_db(): + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER + ) + """) + + cursor.executemany( + """INSERT INTO test_table (name, age) VALUES (?, ?)""", + [("Alice", 8), ("Bob", 14), ("Charlie", 35), ("Dave", 42)], + ) + + conn.commit() + conn.close() + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield # this runs the test + models = setup_models + for model_name, model in models.items(): + print(f"\nUsage stats for {model_name} after test:") + model.print_total_usage() + model.reset_stats() + model.reset_cache() + + +################################################################################# +# Standard Tests +################################################################################# + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_SQL_db(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + connector = DatabaseConnector() + connector.connect_sql("sqlite:///:memory:") + lotus_db = LotusDB(connector) + + df = lotus_db.query("SELECT * FROM test_table") + assert len(df) > 0 + + filtered_df = df.sem_filter("{name} is an adult") + assert isinstance(filtered_df, pd.DataFrame) From cc4ddcfb352cb4143474cfac42ef82d77e21aa6f Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Thu, 16 Jan 2025 16:44:12 -0800 Subject: [PATCH 3/3] improve logging --- lotus/databases/lotus_db.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/lotus/databases/lotus_db.py b/lotus/databases/lotus_db.py index a59e89a..4550676 100644 --- a/lotus/databases/lotus_db.py +++ b/lotus/databases/lotus_db.py @@ -2,6 +2,8 @@ import pandas as pd +import lotus + class LotusDB: def __init__(self, connector): @@ -21,13 +23,21 @@ def query( pd.DataFrame: The result of the query """ - if db_type == "sql": - return pd.read_sql(query, self.connector.sql_engine) - elif db_type == "nosql": - if not collection_name or not db_name: - raise ValueError("Collection name and database is required for NoSQL database") - collection = self.connector.nosql_client[db_name][collection_name] - results = collection.find(query) - return pd.DataFrame(list(results)) - else: - raise ValueError("Invalid database type") + try: + if db_type == "sql": + if not isinstance(query, str): + raise ValueError("Query must be a string") + lotus.logger.debug("Executing SQL Query") + return pd.read_sql(query, self.connector.sql_engine) + elif db_type == "nosql": + if not collection_name or not db_name: + raise ValueError("Collection name and database is required for NoSQL database") + collection = self.connector.nosql_client[db_name][collection_name] + results = collection.find(query) + return pd.DataFrame(list(results)) + else: + raise ValueError("Invalid database type") + + except Exception as e: + lotus.logger.error(f"Error executing query: {e}") + raise