Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GCP Storage/Bucket support #91

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ config/authorized_keys
config/rclone
tpdocs/
.env
.venv

72 changes: 37 additions & 35 deletions build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requestmodels/models.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,60 @@
from typing import List, Union, Dict, Annotated
from typing import Dict
from pydantic import BaseModel, Field
import os
import json

if os.environ.get("GCP_CREDENTIALS"):
with open(os.environ["GCP_CREDENTIALS"]) as f:
_GCP_CREDENTIALS = json.load(f)
else:
_GCP_CREDENTIALS = {}

class S3Config(BaseModel):
access_key_id: str = Field(default="")
secret_access_key: str = Field(default="")
endpoint_url: str = Field(default="")
bucket_name: str = Field(default="")
access_key_id: str = Field(default=os.environ.get("S3_ACCESS_KEY_ID", ""))
secret_access_key: str = Field(
default=os.environ.get("S3_SECRET_ACCESS_KEY", ""))
endpoint_url: str = Field(default=os.environ.get("S3_ENDPOINT_URL", ""))
bucket_name: str = Field(default=os.environ.get("S3_BUCKET_NAME", ""))
connect_timeout: int = Field(default=5)
connect_attempts: int = Field(default=1)

@staticmethod
def get_defaults():
return {
"access_key_id": "",
"secret_access_key": "",
"endpoint_url": "",
"bucket_name": "",
"connect_timeout": "5",
"connect_attempts": "1"
}
def get_config(self):
config = {"access_key_id": self.access_key_id,
"secret_access_key": self.secret_access_key,
"endpoint_url": self.endpoint_url,
"bucket_name": self.bucket_name,
"connect_timeout": self.connect_timeout,
"connect_attempts": self.connect_attempts}
set_values = sum(1 for v in config.values() if v)
return config if set_values > 2 else {}

class GcpConfig(BaseModel):
credentials: Dict = Field(default_factory=_GCP_CREDENTIALS.copy)
project_id: str = Field(default=os.environ.get("GCP_PROJECT_ID", ""))
bucket_name: str = Field(default=os.environ.get("GCP_BUCKET_NAME", ""))

def get_config(self):
return {
"access_key_id": getattr(self, "access_key_id", os.environ.get("S3_ACCESS_KEY_ID", "")),
"secret_access_key": getattr(self, "secret_access_key", os.environ.get("S3_SECRET_ACCESS_KEY", "")),
"endpoint_url": getattr(self, "endpoint_url", os.environ.get("S3_ENDPOINT_URL", "")),
"bucket_name": getattr(self, "bucket_name", os.environ.get("S3_BUCKET_NAME", "")),
"connect_timeout": "5",
"connect_attempts": "1"
}
config = {"credentials": self.credentials,
"project_id": self.project_id,
"bucket_name": self.bucket_name}
set_values = sum(1 for v in config.values() if v)
return config if set_values > 0 else {}

class WebHook(BaseModel):
url: str = Field(default="")
extra_params: Dict = Field(default={})

@staticmethod
def get_defaults():
return {
"url": "",
"extra_params": {}
}
extra_params: Dict = Field(default_factory=dict)

def has_valid_url(self):
return network.is_url(self.url)

class Input(BaseModel):
request_id: str = Field(default="")
modifier: str = Field(default="")
modifications: Dict = Field(default={})
workflow_json: Dict = Field(default={})
s3: S3Config = Field(default=S3Config.get_defaults())
webhook: WebHook = Field(default=WebHook.get_defaults())
modifications: Dict = Field(default_factory=dict)
workflow_json: Dict = Field(default_factory=dict)
s3: S3Config = Field(default_factory=S3Config)
gcp: GcpConfig = Field(default_factory=GcpConfig)
webhook: WebHook = Field(default_factory=WebHook)

class Payload(BaseModel):
input: Input
Expand Down
3 changes: 3 additions & 0 deletions build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ aiocache
pydantic>=2
aiobotocore
aiofiles
aiogoogle
aiohttp
fastapi==0.103
google-auth
google-cloud-storage
pathlib
python-magic
uvicorn==0.23
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import datetime
import aiogoogle.auth.creds
import aiogoogle.client
import asyncio
import itertools
import aiobotocore.session
import aiofiles
import aiofiles.os
from google.oauth2 import service_account
from google.cloud.storage import _signing as signing
from config import config
from pathlib import Path

Expand Down Expand Up @@ -33,7 +39,31 @@ async def work(self):
result = await self.response_store.get(request_id)

await self.move_assets(request_id, result)
await self.upload_assets(request_id, request.input.s3.get_config(), result)

named_upload_tasks = []
if (s3_config := request.input.s3.get_config()):
async def upload_s3_assets():
return ("s3", await self.upload_s3_assets(request_id, s3_config, result))
named_upload_tasks.append(
asyncio.create_task(upload_s3_assets()))
if (gcp_config := request.input.gcp.get_config()):
async def upload_gcp_assets():
return ("gcp", await self.upload_gcp_assets(request_id, gcp_config, result))
named_upload_tasks.append(
asyncio.create_task(upload_gcp_assets()))
if named_upload_tasks:
named_presigned_urls = dict(await asyncio.gather(*named_upload_tasks))
presigned_urls = itertools.zip_longest(
named_presigned_urls.get("s3", []),
named_presigned_urls.get("gcp", []),
fillvalue=None)
for obj, (s3_url, gcp_url) in zip(result.output, presigned_urls):
if s3_url:
# Keeping for backward compatibility
obj["url"] = s3_url
obj["s3_url"] = s3_url
if gcp_url:
obj["gcp_url"] = gcp_url

result.status = "success"
result.message = "Process complete."
Expand Down Expand Up @@ -77,7 +107,7 @@ async def move_assets(self, request_id, result):
"local_path": new_path
})

async def upload_assets(self, request_id, s3_config, result):
async def upload_s3_assets(self, request_id, s3_config, result):
session = aiobotocore.session.get_session()
async with session.create_client(
's3',
Expand All @@ -96,16 +126,12 @@ async def upload_assets(self, request_id, s3_config, result):
tasks.append(task)

# Run all tasks concurrently
presigned_urls = await asyncio.gather(*tasks)

# Append the presigned URLs to the respective objects
for obj, url in zip(result.output, presigned_urls):
obj["url"] = url
return await asyncio.gather(*tasks)

async def upload_file_and_get_url(self, requst_id, s3_client, bucket_name, local_path):
async def upload_file_and_get_url(self, request_id, s3_client, bucket_name, local_path):
# Get the file name from the local path
file_name = f"{requst_id}/{Path(local_path).name}"
print (f"uploading {file_name}")
file_name = f"{request_id}/{Path(local_path).name}"
print(f"uploading to s3 {file_name}")

try:
# Upload the file
Expand All @@ -116,9 +142,51 @@ async def upload_file_and_get_url(self, requst_id, s3_client, bucket_name, local
presigned_url = await s3_client.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': file_name},
ExpiresIn=604800 # URL expiration time in seconds
ExpiresIn=int(datetime.timedelta(days=7).total_seconds()),
)
return presigned_url
except Exception as e:
print(f"Error uploading {local_path}: {e}")
return None
print(f"Error uploading to s3 {local_path}: {e}")
return None

async def upload_gcp_assets(self, request_id, gcp_config, result):
creds = aiogoogle.auth.creds.ServiceAccountCreds(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
**gcp_config["credentials"],
)
google_credentials = service_account.Credentials.from_service_account_info(
gcp_config["credentials"])
aiog_client = aiogoogle.client.Aiogoogle(service_account_creds=creds)
async with aiog_client:
# Not needed as we are using provided service account creds. Uncomment if using discovery.
# await aiog_client.service_account_manager.detect_default_creds_source()
storage = await aiog_client.discover("storage", "v1")
tasks = []
for obj in result.output:
local_path = obj["local_path"]
task = asyncio.create_task(self.upload_file_to_gcp_and_get_url(
request_id, aiog_client, storage, gcp_config["bucket_name"], local_path, google_credentials))
tasks.append(task)

# Run all tasks concurrently
return await asyncio.gather(*tasks)

async def upload_file_to_gcp_and_get_url(self, request_id, aiog_client, storage, bucket_name, local_path, google_credentials):
destination_path = f"{request_id}/{Path(local_path).name}"
print(f"uploading to gcp {destination_path}")

try:
await aiog_client.as_service_account(storage.objects.insert(
bucket=bucket_name,
name=destination_path,
upload_file=local_path,
), full_res=True)
return signing.generate_signed_url_v4(
google_credentials,
f"/{bucket_name}/{destination_path}",
expiration=datetime.timedelta(days=7),
method="GET",
)
except Exception as e:
print(f"Error uploading to gcp {local_path}: {e}")
return None