Skip to content

Commit

Permalink
add s3 connection example
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Jan 21, 2025
1 parent df52f74 commit e02e360
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 4 deletions.
96 changes: 96 additions & 0 deletions examples/db_examples/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import boto3
import pandas as pd

import lotus
from lotus.data_connectors import DataConnector
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)

# Service configurations
service_configs = {
"aws": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"region": "us-west-2",
"bucket": "your-aws-bucket",
"file_path": "data/test.csv",
"protocol": "s3",
"endpoint_url": None,
},
"minio": {
"aws_access_key": "accesskey",
"aws_secret_key": "secretkey",
"region": None,
"bucket": "test-bucket",
"file_path": "data/test.csv",
"protocol": "http",
"endpoint_url": "http://localhost:9000",
},
"cloudfare_R2": {
"aws_access_key": "your_r2_access_key",
"aws_secret_key": "your_r2_secret_key",
"region": None,
"bucket": "your-r2-bucket",
"file_path": "data/test.csv",
"protocol": "https",
"endpoint_url": "https://<account_id>.r2.cloudflarestorage.com",
},
}

# Get configuration for selected service
service = "minio"
service_config = service_configs[service]

# Create Test Data
test_data = pd.DataFrame(
{
"title": ["The Matrix", "The Godfather", "Inception", "Parasite", "Interstellar", "Titanic"],
"description": [
"A hacker discovers the reality is simulated.",
"The rise and fall of a powerful mafia family.",
"A thief enters dreams to steal secrets.",
"A poor family schemes to infiltrate a rich household.",
"A team travels through a wormhole to save humanity.",
"A love story set during the Titanic tragedy.",
],
}
)
csv_data = test_data.to_csv(index=False)

# Connect to s3 Service and load data
try:
session = boto3.Session(
aws_access_key_id=service_config["aws_access_key"],
aws_secret_access_key=service_config["aws_secret_key"],
region_name=service_config["region"],
)
s3 = session.resource("s3", endpoint_url=service_config["endpoint_url"])

try:
s3.create_bucket(Bucket=service_config["bucket"])
print(f"Bucket '{service_config['bucket']}' created successfully.")
except s3.meta.client.exceptions.BucketAlreadyOwnedByYou:
print(f"Bucket '{service_config['bucket']}' already exists.")
except Exception as e:
print(f"Error creating bucket: {e}")

s3.Bucket(service_config["bucket"]).put_object(Key=service_config["file_path"], Body=csv_data)
except Exception as e:
print(f"Error connecting to s3 service: {e}")


# loading data from s3
df = DataConnector.load_from_s3(
aws_access_key=str(service_config["aws_access_key"]),
aws_secret_key=str(service_config["aws_secret_key"]),
region=str(service_config["region"]),
bucket=str(service_config["bucket"]),
file_path=str(service_config["file_path"]),
endpoint_url=str(service_config["endpoint_url"]),
protocol=str(service_config["protocol"]),
)
user_instruction = "{title} is scienece fiction movie"
df = df.sem_filter(user_instruction)
print(df)
19 changes: 15 additions & 4 deletions lotus/data_connectors/connectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from io import BytesIO, StringIO
from typing import Optional

import boto3
import pandas as pd
Expand Down Expand Up @@ -28,17 +29,25 @@ def load_from_db(connection_url: str, query: str) -> pd.DataFrame:

@staticmethod
def load_from_s3(
aws_access_key: str, aws_secret_key: str, region: str, bucket: str, file_path: str
aws_access_key: str,
aws_secret_key: str,
region: str,
bucket: str,
file_path: str,
endpoint_url: Optional[str] = None,
protocol: str = "s3",
) -> pd.DataFrame:
"""
Loads a pandas DataFrame from an S3 object.
Loads a pandas DataFrame from an S3-compatible service.
Args:
aws_access_key (str): The AWS access key
aws_secret_key (str): The AWS secret key
region (str): The AWS region
bucket (str): The S3 bucket
file_path (str): The path to the file in S3
endpoint_url (str): The Minio endpoint URL. Default is None for AWS s3
prtocol (str): The protocol to use (http for Minio and https for R2). Default is "s3"
Returns:
pd.DataFrame: The loaded DataFrame
Expand All @@ -48,12 +57,12 @@ def load_from_s3(
session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
region_name=region,
region_name=region if protocol == "s3" and endpoint_url is None else None,
)
except Exception as e:
raise ValueError(f"Error creating boto3 session: {e}")

s3 = session.resource("s3")
s3 = session.resource("s3", endpoint_url=endpoint_url)
s3_obj = s3.Bucket(bucket).Object(file_path)
data = s3_obj.get()["Body"].read()

Expand All @@ -70,3 +79,5 @@ def load_from_s3(
return file_mapping[file_type](data)
except KeyError:
raise ValueError(f"Unsupported file type: {file_type}")
except Exception as e:
raise ValueError(f"Error loading from S3-compatible service: {e}")

0 comments on commit e02e360

Please sign in to comment.