Skip to content

Commit

Permalink
add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Jan 21, 2025
1 parent e02e360 commit aa9acf1
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 15 deletions.
61 changes: 61 additions & 0 deletions .github/tests/db_tests.py → .github/tests/connector_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sqlite3

import boto3
import pandas as pd
import pytest

Expand Down Expand Up @@ -62,6 +63,45 @@ def setup_sqlite_db():
conn.close()


@pytest.fixture(scope="session")
def setup_minio():
minio_config = {
"aws_access_key": "accesskey",
"aws_secret_key": "secretkey",
"region": None,
"bucket": "test-bucket",
"file_path": "data/test.csv",
"protocol": "http",
"endpoint_url": "http://localhost:9000",
}

session = boto3.Session(
aws_access_key_id=minio_config["aws_access_key"],
aws_secret_access_key=minio_config["aws_secret_key"],
)

s3 = session.resource("s3", endpoint_url=minio_config["endpoint_url"])

try:
s3.create_bucket(Bucket=minio_config["bucket"])
except s3.meta.client.exceptions.BucketAlreadyOwnedByYou:
pass

# Upload test file
test_data = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"score": [85, 90, 88],
}
)
csv_data = test_data.to_csv(index=False)

s3.Bucket(minio_config["bucket"]).put_object(Key="test_data.csv", Body=csv_data)

return minio_config


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_models):
yield # this runs the test
Expand Down Expand Up @@ -89,3 +129,24 @@ def test_SQL_db(setup_models, model):

filtered_df = df.sem_filter("{name} is an adult")
assert isinstance(filtered_df, pd.DataFrame)


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_minio(setup_models, setup_minio, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)
minio_config = setup_minio

df = DataConnector.load_from_s3(
aws_access_key=minio_config["aws_access_key"],
aws_secret_key=minio_config["aws_secret_key"],
region=minio_config["region"],
bucket=minio_config["bucket"],
file_path="test_data.csv",
endpoint_url=minio_config["endpoint_url"],
protocol="http",
)

assert not df.empty
assert df.shape[0] == 3
assert set(df.columns) == {"id", "name", "score"}
10 changes: 5 additions & 5 deletions examples/db_examples/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"aws": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"region": "us-west-2",
"region": "us-east-1",
"bucket": "your-aws-bucket",
"file_path": "data/test.csv",
"protocol": "s3",
Expand Down Expand Up @@ -83,14 +83,14 @@

# loading data from s3
df = DataConnector.load_from_s3(
aws_access_key=str(service_config["aws_access_key"]),
aws_secret_key=str(service_config["aws_secret_key"]),
aws_access_key=(service_config["aws_access_key"]),
aws_secret_key=(service_config["aws_secret_key"]),
region=str(service_config["region"]),
bucket=str(service_config["bucket"]),
file_path=str(service_config["file_path"]),
endpoint_url=str(service_config["endpoint_url"]),
endpoint_url=(service_config["endpoint_url"]),
protocol=str(service_config["protocol"]),
)
user_instruction = "{title} is scienece fiction movie"
user_instruction = "{title} is science fiction movie"
df = df.sem_filter(user_instruction)
print(df)
23 changes: 13 additions & 10 deletions lotus/data_connectors/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def load_from_db(connection_url: str, query: str) -> pd.DataFrame:

@staticmethod
def load_from_s3(
aws_access_key: str,
aws_secret_key: str,
aws_access_key: Optional[str],
aws_secret_key: Optional[str],
region: str,
bucket: str,
file_path: str,
Expand All @@ -41,24 +41,27 @@ def load_from_s3(
Loads a pandas DataFrame from an S3-compatible service.
Args:
aws_access_key (str): The AWS access key
aws_secret_key (str): The AWS secret key
aws_access_key (str): The AWS access key (None for Public Access)
aws_secret_key (str): The AWS secret key (None for Public Access)
region (str): The AWS region
bucket (str): The S3 bucket
file_path (str): The path to the file in S3
endpoint_url (str): The Minio endpoint URL. Default is None for AWS s3
prtocol (str): The protocol to use (http for Minio and https for R2). Default is "s3"
protocol (str): The protocol to use (http for Minio and https for R2). Default is "s3"
Returns:
pd.DataFrame: The loaded DataFrame
"""
try:
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
region_name=region if protocol == "s3" and endpoint_url is None else None,
)
if aws_access_key is None and aws_secret_key is None:
session = boto3.Session(region_name=region)
else:
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
region_name=region if protocol == "s3" and endpoint_url is None else None,
)
except Exception as e:
raise ValueError(f"Error creating boto3 session: {e}")

Expand Down

0 comments on commit aa9acf1

Please sign in to comment.