Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial data connectors #82

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions .github/tests/db_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sqlite3

import pandas as pd
import pytest

import lotus
from lotus.databases import DatabaseConnector, LotusDB
from lotus.models import LM

################################################################################
# Setup
################################################################################
# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")

# Environment flags to enable/disable tests
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true"
ENABLE_OLLAMA_TESTS = os.getenv("ENABLE_OLLAMA_TESTS", "false").lower() == "true"

MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o": ENABLE_OPENAI_TESTS,
"ollama/llama3.1": ENABLE_OLLAMA_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])


def get_enabled(*candidate_models: str) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]


@pytest.fixture(scope="session")
def setup_models():
models = {}

for model_path in ENABLED_MODEL_NAMES:
models[model_path] = LM(model=model_path)

return models


@pytest.fixture(scope="session")
def setup_sqlite_db():
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()

cursor.execute("""
CREATE TABLE test_table (
id INTEGER PRIMARY KEY,
name TEXT,
age INTEGER
)
""")

cursor.executemany(
"""INSERT INTO test_table (name, age) VALUES (?, ?)""",
[("Alice", 8), ("Bob", 14), ("Charlie", 35), ("Dave", 42)],
)

conn.commit()
conn.close()


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_models):
yield # this runs the test
models = setup_models
for model_name, model in models.items():
print(f"\nUsage stats for {model_name} after test:")
model.print_total_usage()
model.reset_stats()
model.reset_cache()


#################################################################################
# Standard Tests
#################################################################################


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_SQL_db(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

connector = DatabaseConnector()
connector.connect_sql("sqlite:///:memory:")
lotus_db = LotusDB(connector)

df = lotus_db.query("SELECT * FROM test_table")
assert len(df) > 0

filtered_df = df.sem_filter("{name} is an adult")
assert isinstance(filtered_df, pd.DataFrame)
55 changes: 55 additions & 0 deletions examples/db_examples/sql_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import sqlite3

import lotus
from lotus.databases import DatabaseConnector, LotusDB
from lotus.models import LM

conn = sqlite3.connect("example_movies.db")
cursor = conn.cursor()

# Create the table
cursor.execute("""
CREATE TABLE IF NOT EXISTS movies (
id INTEGER PRIMARY KEY,
title TEXT,
director TEXT,
rating REAL,
release_year INTEGER,
description TEXT
)
""")

cursor.execute("DELETE FROM movies")

# Insert sample data
cursor.executemany(
"""
INSERT INTO movies (title, director, rating, release_year, description)
VALUES (?, ?, ?, ?, ?)
""",
[
("The Matrix", "Wachowskis", 8.7, 1999, "A hacker discovers the reality is simulated."),
("The Godfather", "Francis Coppola", 9.2, 1972, "The rise and fall of a powerful mafia family."),
("Inception", "Christopher Nolan", 8.8, 2010, "A thief enters dreams to steal secrets."),
("Parasite", "Bong Joon-ho", 8.6, 2019, "A poor family schemes to infiltrate a rich household."),
("Interstellar", "Christopher Nolan", 8.6, 2014, "A team travels through a wormhole to save humanity."),
("Titanic", "James Cameron", 7.8, 1997, "A love story set during the Titanic tragedy."),
],
)

conn.commit()
conn.close()


lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

connector = DatabaseConnector()
connector.connect_sql("sqlite:///example_movies.db")
lotus_db = LotusDB(connector)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need LotusDB? Can't we just have a connect.load_as_pandas(table_name: str) directly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe DatabaseConnector just provides a staticmethod

@staticmethod
def load_from_db(connection_url: str, table_name: str) -> pd.DataFrame

And then all the code here becomes

df = DatabaseConnector.load_from_db("sqlite:///example_movies.db", "movies")

Copy link
Collaborator Author

@StanChan03 StanChan03 Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, what if someone has a specific query thats not just select * from table? Maybe have something like

def load_from_db(connection_url: str, query: str) -> pd.DataFrame

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if someone has a very large dataset should we consider processing it in batches, and then concat them together?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def load_from_db(connection_url: str, query: str) -> pd.DataFrame

This works

Also if someone has a very large dataset should we consider processing it in batches, and then concat them together?

If someone has a very large dataset then I think lotus just will not be able to work, since the requirement is that it can be loaded into a pandas dataframe in memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense


df = lotus_db.query("SELECT * FROM movies")

user_instruction = "{title} that are science fiction"
df = df.sem_filter(user_instruction)
print(df)
2 changes: 2 additions & 0 deletions lotus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import lotus.dtype_extensions
import lotus.models
import lotus.vector_store
import lotus.databases
import lotus.nl_expression
import lotus.templates
import lotus.utils
Expand Down Expand Up @@ -48,4 +49,5 @@
"vector_store",
"utils",
"dtype_extensions",
"databases",
]
4 changes: 4 additions & 0 deletions lotus/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from lotus.databases.connectors import DatabaseConnector
from lotus.databases.lotus_db import LotusDB

__all__ = ["DatabaseConnector", "LotusDB"]
31 changes: 31 additions & 0 deletions lotus/databases/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pymongo import MongoClient
from sqlalchemy import create_engine


class DatabaseConnector:
def __init__(self):
self.sql_engine = None
self.nosql_client = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove all the nosql stuff since there is not one interface for all the nosql offerings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree


def connect_sql(self, connection_url: str):
"""Connect to SQL database"""
try:
self.sql_engine = create_engine(connection_url)
return self
except Exception as e:
raise ConnectionError(f"Error connecting to SQL database: {e}")

def connect_nosql(self, connection_url: str):
"""Connect to MongoDB NoSQL database"""
try:
self.nosql_client = MongoClient(connection_url)
return self
except Exception as e:
raise ConnectionError(f"Error connecting to NoSQL database: {e}")

def close_connections(self):
"""Close SQL and NoSQL connections"""
if self.sql_engine:
self.sql_engine.dispose()
if self.nosql_client:
self.nosql_client.close()
43 changes: 43 additions & 0 deletions lotus/databases/lotus_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional

import pandas as pd

import lotus


class LotusDB:
def __init__(self, connector):
self.connector = connector

def query(
self, query: str, db_type: str = "sql", db_name: Optional[str] = None, collection_name: Optional[str] = None
) -> pd.DataFrame:
"""
Executes query and returns a pandas dataframe

Args:
query (str): The query to execute
db_type (str, optional): The type of database to use. Defaults to 'sql'.

Returns:
pd.DataFrame: The result of the query

"""
try:
if db_type == "sql":
if not isinstance(query, str):
raise ValueError("Query must be a string")
lotus.logger.debug("Executing SQL Query")
return pd.read_sql(query, self.connector.sql_engine)
elif db_type == "nosql":
if not collection_name or not db_name:
raise ValueError("Collection name and database is required for NoSQL database")
collection = self.connector.nosql_client[db_name][collection_name]
results = collection.find(query)
return pd.DataFrame(list(results))
else:
raise ValueError("Invalid database type")

except Exception as e:
lotus.logger.error(f"Error executing query: {e}")
raise
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ tqdm==4.66.4
weaviate-client==4.10.2
pinecone==5.4.2
chromadb==0.6.2
qdrant-client==1.12.2
qdrant-client==1.12.2
psycopg2-binary==2.9.10
SQLAlchemy==2.0.37
pymongo==4.10.1
Loading