Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial sql integration #82

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions .github/tests/db_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sqlite3

import pandas as pd
import pytest

import lotus
from lotus.databases import DatabaseConnector, LotusDB
from lotus.models import LM

################################################################################
# Setup
################################################################################
# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")

# Environment flags to enable/disable tests
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true"
ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true"

MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o": ENABLE_OPENAI_TESTS,
"ollama/llama3.1": ENABLE_OLLAMA_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])


def get_enabled(*candidate_models: str) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]


@pytest.fixture(scope="session")
def setup_models():
models = {}

for model_path in ENABLED_MODEL_NAMES:
models[model_path] = LM(model=model_path)

return models


@pytest.fixture(scope="session")
def setup_sqlite_db():
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()

cursor.execute("""
CREATE TABLE test_table (
id INTEGER PRIMARY KEY,
name TEXT,
age INTEGER
)
""")

cursor.executemany(
"""INSERT INTO test_table (name, age) VALUES (?, ?)""",
[("Alice", 8), ("Bob", 14), ("Charlie", 35), ("Dave", 42)],
)

conn.commit()
conn.close()


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_models):
yield # this runs the test
models = setup_models
for model_name, model in models.items():
print(f"\nUsage stats for {model_name} after test:")
model.print_total_usage()
model.reset_stats()
model.reset_cache()


#################################################################################
# Standard Tests
#################################################################################


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_SQL_db(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

connector = DatabaseConnector()
connector.connect_sql("sqlite:///:memory:")
lotus_db = LotusDB(connector)

df = lotus_db.query("SELECT * FROM test_table")
assert len(df) > 0

filtered_df = df.sem_filter("{name} is an adult")
assert isinstance(filtered_df, pd.DataFrame)
55 changes: 55 additions & 0 deletions examples/db_examples/sql_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import sqlite3

import lotus
from lotus.databases import DatabaseConnector, LotusDB
from lotus.models import LM

conn = sqlite3.connect("example_movies.db")
cursor = conn.cursor()

# Create the table
cursor.execute("""
CREATE TABLE IF NOT EXISTS movies (
id INTEGER PRIMARY KEY,
title TEXT,
director TEXT,
rating REAL,
release_year INTEGER,
description TEXT
)
""")

cursor.execute("DELETE FROM movies")

# Insert sample data
cursor.executemany(
"""
INSERT INTO movies (title, director, rating, release_year, description)
VALUES (?, ?, ?, ?, ?)
""",
[
("The Matrix", "Wachowskis", 8.7, 1999, "A hacker discovers the reality is simulated."),
("The Godfather", "Francis Coppola", 9.2, 1972, "The rise and fall of a powerful mafia family."),
("Inception", "Christopher Nolan", 8.8, 2010, "A thief enters dreams to steal secrets."),
("Parasite", "Bong Joon-ho", 8.6, 2019, "A poor family schemes to infiltrate a rich household."),
("Interstellar", "Christopher Nolan", 8.6, 2014, "A team travels through a wormhole to save humanity."),
("Titanic", "James Cameron", 7.8, 1997, "A love story set during the Titanic tragedy."),
],
)

conn.commit()
conn.close()


lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

connector = DatabaseConnector()
connector.connect_sql("sqlite:///example_movies.db")
lotus_db = LotusDB(connector)

df = lotus_db.query("SELECT * FROM movies")

user_instruction = "{title} that are science fiction"
df = df.sem_filter(user_instruction)
print(df)
2 changes: 2 additions & 0 deletions lotus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import lotus.dtype_extensions
import lotus.models
import lotus.vector_store
import lotus.databases
import lotus.nl_expression
import lotus.templates
import lotus.utils
Expand Down Expand Up @@ -48,4 +49,5 @@
"vector_store",
"utils",
"dtype_extensions",
"databases",
]
4 changes: 4 additions & 0 deletions lotus/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from lotus.databases.connectors import DatabaseConnector
from lotus.databases.lotus_db import LotusDB

__all__ = ["DatabaseConnector", "LotusDB"]
31 changes: 31 additions & 0 deletions lotus/databases/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pymongo import MongoClient
from sqlalchemy import create_engine


class DatabaseConnector:
def __init__(self):
self.sql_engine = None
self.nosql_client = None

def connect_sql(self, connection_url: str):
"""Connect to SQL database"""
try:
self.sql_engine = create_engine(connection_url)
return self
except Exception as e:
raise ConnectionError(f"Error connecting to SQL database: {e}")

def connect_nosql(self, connection_url: str):
"""Connect to MongoDB NoSQL database"""
try:
self.nosql_client = MongoClient(connection_url)
return self
except Exception as e:
raise ConnectionError(f"Error connecting to NoSQL database: {e}")

def close_connections(self):
"""Close SQL and NoSQL connections"""
if self.sql_engine:
self.sql_engine.dispose()
if self.nosql_client:
self.nosql_client.close()
43 changes: 43 additions & 0 deletions lotus/databases/lotus_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional

import pandas as pd

import lotus


class LotusDB:
def __init__(self, connector):
self.connector = connector

def query(
self, query: str, db_type: str = "sql", db_name: Optional[str] = None, collection_name: Optional[str] = None
) -> pd.DataFrame:
"""
Executes query and returns a pandas dataframe

Args:
query (str): The query to execute
db_type (str, optional): The type of database to use. Defaults to 'sql'.

Returns:
pd.DataFrame: The result of the query

"""
try:
if db_type == "sql":
if not isinstance(query, str):
raise ValueError("Query must be a string")
lotus.logger.debug("Executing SQL Query")
return pd.read_sql(query, self.connector.sql_engine)
elif db_type == "nosql":
if not collection_name or not db_name:
raise ValueError("Collection name and database is required for NoSQL database")
collection = self.connector.nosql_client[db_name][collection_name]
results = collection.find(query)
return pd.DataFrame(list(results))
else:
raise ValueError("Invalid database type")

except Exception as e:
lotus.logger.error(f"Error executing query: {e}")
raise
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ tqdm==4.66.4
weaviate-client==4.10.2
pinecone==5.4.2
chromadb==0.6.2
qdrant-client==1.12.2
qdrant-client==1.12.2
psycopg2-binary==2.9.10
SQLAlchemy==2.0.37
pymongo==4.10.1
Loading