From aa9acf1352eb436c5dd45f889a29c177a78ef653 Mon Sep 17 00:00:00 2001 From: StanChan03 Date: Tue, 21 Jan 2025 11:16:28 -0800 Subject: [PATCH] add testing --- .../tests/{db_tests.py => connector_tests.py} | 61 +++++++++++++++++++ examples/db_examples/s3.py | 10 +-- lotus/data_connectors/connectors.py | 23 ++++--- 3 files changed, 79 insertions(+), 15 deletions(-) rename .github/tests/{db_tests.py => connector_tests.py} (59%) diff --git a/.github/tests/db_tests.py b/.github/tests/connector_tests.py similarity index 59% rename from .github/tests/db_tests.py rename to .github/tests/connector_tests.py index e28d12a..85d7a3f 100644 --- a/.github/tests/db_tests.py +++ b/.github/tests/connector_tests.py @@ -1,6 +1,7 @@ import os import sqlite3 +import boto3 import pandas as pd import pytest @@ -62,6 +63,45 @@ def setup_sqlite_db(): conn.close() +@pytest.fixture(scope="session") +def setup_minio(): + minio_config = { + "aws_access_key": "accesskey", + "aws_secret_key": "secretkey", + "region": None, + "bucket": "test-bucket", + "file_path": "data/test.csv", + "protocol": "http", + "endpoint_url": "http://localhost:9000", + } + + session = boto3.Session( + aws_access_key_id=minio_config["aws_access_key"], + aws_secret_access_key=minio_config["aws_secret_key"], + ) + + s3 = session.resource("s3", endpoint_url=minio_config["endpoint_url"]) + + try: + s3.create_bucket(Bucket=minio_config["bucket"]) + except s3.meta.client.exceptions.BucketAlreadyOwnedByYou: + pass + + # Upload test file + test_data = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "score": [85, 90, 88], + } + ) + csv_data = test_data.to_csv(index=False) + + s3.Bucket(minio_config["bucket"]).put_object(Key="test_data.csv", Body=csv_data) + + return minio_config + + @pytest.fixture(autouse=True) def print_usage_after_each_test(setup_models): yield # this runs the test @@ -89,3 +129,24 @@ def test_SQL_db(setup_models, model): filtered_df = df.sem_filter("{name} is an adult") assert isinstance(filtered_df, pd.DataFrame) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_minio(setup_models, setup_minio, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + minio_config = setup_minio + + df = DataConnector.load_from_s3( + aws_access_key=minio_config["aws_access_key"], + aws_secret_key=minio_config["aws_secret_key"], + region=minio_config["region"], + bucket=minio_config["bucket"], + file_path="test_data.csv", + endpoint_url=minio_config["endpoint_url"], + protocol="http", + ) + + assert not df.empty + assert df.shape[0] == 3 + assert set(df.columns) == {"id", "name", "score"} diff --git a/examples/db_examples/s3.py b/examples/db_examples/s3.py index b3a121d..531c90a 100644 --- a/examples/db_examples/s3.py +++ b/examples/db_examples/s3.py @@ -13,7 +13,7 @@ "aws": { "aws_access_key": "your_aws_access_key", "aws_secret_key": "your_aws_secret_key", - "region": "us-west-2", + "region": "us-east-1", "bucket": "your-aws-bucket", "file_path": "data/test.csv", "protocol": "s3", @@ -83,14 +83,14 @@ # loading data from s3 df = DataConnector.load_from_s3( - aws_access_key=str(service_config["aws_access_key"]), - aws_secret_key=str(service_config["aws_secret_key"]), + aws_access_key=(service_config["aws_access_key"]), + aws_secret_key=(service_config["aws_secret_key"]), region=str(service_config["region"]), bucket=str(service_config["bucket"]), file_path=str(service_config["file_path"]), - endpoint_url=str(service_config["endpoint_url"]), + endpoint_url=(service_config["endpoint_url"]), protocol=str(service_config["protocol"]), ) -user_instruction = "{title} is scienece fiction movie" +user_instruction = "{title} is science fiction movie" df = df.sem_filter(user_instruction) print(df) diff --git a/lotus/data_connectors/connectors.py b/lotus/data_connectors/connectors.py index 4c30aba..181eeb5 100644 --- a/lotus/data_connectors/connectors.py +++ b/lotus/data_connectors/connectors.py @@ -29,8 +29,8 @@ def load_from_db(connection_url: str, query: str) -> pd.DataFrame: @staticmethod def load_from_s3( - aws_access_key: str, - aws_secret_key: str, + aws_access_key: Optional[str], + aws_secret_key: Optional[str], region: str, bucket: str, file_path: str, @@ -41,24 +41,27 @@ def load_from_s3( Loads a pandas DataFrame from an S3-compatible service. Args: - aws_access_key (str): The AWS access key - aws_secret_key (str): The AWS secret key + aws_access_key (str): The AWS access key (None for Public Access) + aws_secret_key (str): The AWS secret key (None for Public Access) region (str): The AWS region bucket (str): The S3 bucket file_path (str): The path to the file in S3 endpoint_url (str): The Minio endpoint URL. Default is None for AWS s3 - prtocol (str): The protocol to use (http for Minio and https for R2). Default is "s3" + protocol (str): The protocol to use (http for Minio and https for R2). Default is "s3" Returns: pd.DataFrame: The loaded DataFrame """ try: - session = boto3.Session( - aws_access_key_id=aws_access_key, - aws_secret_access_key=aws_secret_key, - region_name=region if protocol == "s3" and endpoint_url is None else None, - ) + if aws_access_key is None and aws_secret_key is None: + session = boto3.Session(region_name=region) + else: + session = boto3.Session( + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + region_name=region if protocol == "s3" and endpoint_url is None else None, + ) except Exception as e: raise ValueError(f"Error creating boto3 session: {e}")