Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastAPI #2

Merged
merged 15 commits into from
Nov 14, 2023
Merged
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ push-container:
docker push $(CONTAINER)

server: clients clients/test.pem
poetry run flask --app 'aggrec.app:create_app("../example.toml")' run --debug
poetry run aggrec_server --config example.toml --host 127.0.0.1 --port 8080 --debug

client: test-private.pem
python3 tools/client.py
Expand All @@ -20,6 +20,13 @@ keys: test.pem
test-private.pem:
openssl ecparam -genkey -name prime256v1 -noout -out $@

test-client:
openssl rand 1024 > random.bin
poetry run aggrec_client \
--http-key-id test \
--http-key-file test-private.pem \
random.bin

clients:
mkdir clients

Expand Down
271 changes: 171 additions & 100 deletions aggrec/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import json
import logging
from typing import Dict
from enum import Enum
from functools import lru_cache
from typing import Annotated, Dict
from urllib.parse import urljoin

import aiobotocore.session
import aiomqtt
import boto3
import bson
import paho.mqtt.client as mqtt
from bson.objectid import ObjectId
from flask import Blueprint, Response, current_app, g, jsonify, request, send_file
from werkzeug.exceptions import BadRequest, NotFound
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from .db_models import AggregateMetadata
from .helpers import RequestVerifier, rfc_3339_datetime_now
from .openapi import OPENAPI_DICT
from .settings import Settings

logger = logging.getLogger(__name__)

bp = Blueprint("aggregates", __name__)


METADATA_HTTP_HEADERS = [
"Content-Length",
Expand All @@ -32,40 +34,75 @@

ALLOWED_CONTENT_TYPES = ["application/vnd.apache.parquet", "application/binary"]


def get_http_request_verifier() -> RequestVerifier:
if "http_request_verifier" not in g:
g.http_request_verifier = RequestVerifier(
client_database=current_app.config["CLIENTS_DATABASE"],
)
logging.info("HTTP request verifier created")
return g.http_request_verifier


def get_s3_client():
if "s3_client" not in g:
g.s3_client = boto3.client(
"s3",
endpoint_url=current_app.config["S3_ENDPOINT_URL"],
aws_access_key_id=current_app.config["S3_ACCESS_KEY_ID"],
aws_secret_access_key=current_app.config["S3_SECRET_ACCESS_KEY"],
aws_session_token=None,
config=boto3.session.Config(signature_version="s3v4"),
router = APIRouter()


class AggregateType(str, Enum):
histogram = "histogram"
vector = "vector"


class AggregateMetadataResponse(BaseModel):
aggregate_id: str
aggregate_type: AggregateType
created: str
creator: str
headers: dict
content_type: str
content_length: int
content_location: str
s3_bucket: str
s3_object_key: str

@classmethod
def from_db_model(cls, metadata: AggregateMetadata, settings: Settings):
aggregate_id = str(metadata.id)
return cls(
aggregate_id=aggregate_id,
aggregate_type=metadata.aggregate_type.value,
created=metadata.id.generation_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
creator=str(metadata.creator),
headers=metadata.http_headers,
content_type=metadata.content_type,
content_length=metadata.content_length,
content_location=urljoin(
settings.metadata_base_url,
f"/api/v1/aggregates/{aggregate_id}/payload",
),
s3_bucket=metadata.s3_bucket,
s3_object_key=metadata.s3_object_key,
)
logging.info("S3 client created")
return g.s3_client


def get_mqtt_client():
if "mqtt_client" not in g:
client = mqtt.Client()
client.connect(current_app.config["MQTT_BROKER"])
g.mqtt_client = client
logging.info("MQTT client created")
return g.mqtt_client
@lru_cache
def get_settings():
return Settings()


def http_request_verifier(settings: Annotated[Settings, Depends(get_settings)]):
return RequestVerifier(client_database=settings.clients_database)


async def s3_client(settings: Annotated[Settings, Depends(get_settings)]):
logger.debug("Returning settings")
session = aiobotocore.session.AioSession()
async with session.create_client(
"s3",
endpoint_url=settings.s3_endpoint_url,
aws_access_key_id=settings.s3_access_key_id,
aws_secret_access_key=settings.s3_secret_access_key,
aws_session_token=None,
config=boto3.session.Config(signature_version="s3v4"),
) as client:
yield client


def get_http_headers() -> Dict[str, str]:
async def mqtt_client(settings: Annotated[Settings, Depends(get_settings)]):
async with aiomqtt.Client(settings.mqtt_broker) as client:
yield client


def get_http_headers(request: Request) -> Dict[str, str]:
"""Get dictionary of relevant metadata HTTP headers"""
res = {}
for header in METADATA_HTTP_HEADERS:
Expand All @@ -74,7 +111,9 @@ def get_http_headers() -> Dict[str, str]:
return res


def get_new_aggregate_event_message(metadata: AggregateMetadata) -> dict:
def get_new_aggregate_event_message(
metadata: AggregateMetadata, settings: Settings
) -> dict:
"""Get new aggregate event message"""
return {
"version": 1,
Expand All @@ -85,114 +124,146 @@ def get_new_aggregate_event_message(metadata: AggregateMetadata) -> dict:
"created": metadata.id.generation_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"creator": str(metadata.creator),
"metadata_location": urljoin(
current_app.config["METADATA_BASE_URL"],
settings.metadata_base_url,
f"/api/v1/aggregates/{metadata.id}",
),
"content_location": urljoin(
current_app.config["METADATA_BASE_URL"],
settings.metadata_base_url,
f"/api/v1/aggregates/{metadata.id}/payload",
),
"s3_bucket": metadata.s3_bucket,
"s3_object_key": metadata.s3_object_key,
}


@bp.route("/api/v1/openapi", methods=["GET"])
def get_openapi():
return jsonify(OPENAPI_DICT)
def get_s3_object_key(metadata: AggregateMetadata) -> str:
"""Get S3 object key from metadata"""
dt = metadata.id.generation_time
return "/".join(
[
f"type={metadata.aggregate_type.name.lower()}",
f"year={dt.year}",
f"month={dt.month}",
f"day={dt.day}",
f"creator={metadata.creator}",
f"id={metadata.id}",
]
)


@bp.route("/api/v1/aggregate/<aggregate_type>", methods=["POST"])
def create_aggregate(aggregate_type: str):
@router.post("/api/v1/aggregate/{aggregate_type}")
async def create_aggregate(
aggregate_type: AggregateType,
request: Request,
settings: Annotated[Settings, Depends(get_settings)],
s3_client: Annotated[aiobotocore.client.AioBaseClient, Depends(s3_client)],
mqtt_client: Annotated[aiomqtt.Client, Depends(mqtt_client)],
http_request_verifier: Annotated[RequestVerifier, Depends(http_request_verifier)],
):
if aggregate_type not in ALLOWED_AGGREGATE_TYPES:
raise BadRequest(description="Aggregate type not supported")
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Aggregate type not supported")

if request.content_type not in ALLOWED_CONTENT_TYPES:
raise BadRequest(description="Content-Type not supported")
content_type = request.headers.get("content-type", None)

res = get_http_request_verifier().verify(request)
creator = res.get("keyid")
if content_type not in ALLOWED_CONTENT_TYPES:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Content-Type not supported")

logger.info("Create aggregate request by keyid=%s", creator)
res = await http_request_verifier.verify(request)

mqtt_client = get_mqtt_client()
creator = res.get("keyid")
logger.info("Create aggregate request by keyid=%s", creator)

aggregate_id = ObjectId()
location = f"/api/v1/aggregates/{aggregate_id}"

s3_bucket = current_app.config["S3_BUCKET"]
s3_object_key = f"type={aggregate_type}/creator={creator}/{aggregate_id}"

s3 = get_s3_client()
if current_app.config.get("S3_BUCKET_CREATE", False):
try:
s3.create_bucket(Bucket=s3_bucket)
except Exception:
pass
s3.put_object(Bucket=s3_bucket, Key=s3_object_key, Body=request.data)
s3_bucket = settings.s3_bucket

metadata = AggregateMetadata(
id=aggregate_id,
aggregate_type=aggregate_type,
creator=creator,
http_headers=get_http_headers(),
content_type=request.content_type,
content_length=request.content_length,
http_headers=get_http_headers(request),
content_type=content_type,
s3_bucket=s3_bucket,
s3_object_key=s3_object_key,
)

content = await request.body()

metadata.content_length = len(content)
metadata.s3_object_key = get_s3_object_key(metadata)

if settings.s3_bucket_create:
try:
await s3_client.create_bucket(Bucket=s3_bucket)
except Exception:
pass

await s3_client.put_object(
Bucket=s3_bucket, Key=metadata.s3_object_key, Body=content
)
logger.info("Object created: %s", metadata.s3_object_key)

metadata.save()
logger.info("Metadata saved: %s", metadata.id)

mqtt_client.publish(
current_app.config["MQTT_TOPIC"],
json.dumps(get_new_aggregate_event_message(metadata)),
await mqtt_client.publish(
settings.mqtt_topic,
json.dumps(get_new_aggregate_event_message(metadata, settings)),
)

return Response(status=201, headers={"Location": location})
return Response(status_code=status.HTTP_201_CREATED, headers={"Location": location})


@bp.route("/api/v1/aggregates/<aggregate_id>", methods=["GET"])
def get_aggregate_metadata(aggregate_id: str):
@router.get("/api/v1/aggregates/{aggregate_id}")
def get_aggregate_metadata(
aggregate_id: str, settings: Annotated[Settings, Depends(get_settings)]
) -> AggregateMetadataResponse:
try:
aggregate_object_id = ObjectId(aggregate_id)
except bson.errors.InvalidId:
raise NotFound
raise HTTPException(status.HTTP_404_NOT_FOUND)

if metadata := AggregateMetadata.objects(id=aggregate_object_id).first():
return {
"aggregate_id": str(metadata.id),
"aggregate_type": metadata.aggregate_type.value,
"created": metadata.id.generation_time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"creator": str(metadata.creator),
"headers": metadata.http_headers,
"content_type": metadata.content_type,
"content_length": metadata.content_length,
"content_location": urljoin(
current_app.config["METADATA_BASE_URL"],
f"/api/v1/aggregates/{aggregate_id}/payload",
),
"s3_bucket": metadata.s3_bucket,
"s3_object_key": metadata.s3_object_key,
}
return AggregateMetadataResponse.from_db_model(metadata, settings)

raise NotFound
raise HTTPException(status.HTTP_404_NOT_FOUND)


@bp.route("/api/v1/aggregates/<aggregate_id>/payload", methods=["GET"])
def get_aggregate_payload(aggregate_id: str):
@router.get(
"/api/v1/aggregates/{aggregate_id}/payload",
responses={
200: {
"description": "Aggregate payload",
"content": {
"application/vnd.apache.parquet": {},
"application/binary": {},
},
}
},
)
async def get_aggregate_payload(
aggregate_id: str,
settings: Annotated[Settings, Depends(get_settings)],
s3_client: Annotated[aiobotocore.client.AioBaseClient, Depends(s3_client)],
) -> bytes:
try:
aggregate_object_id = ObjectId(aggregate_id)
except bson.errors.InvalidId:
raise NotFound
raise HTTPException(status.HTTP_404_NOT_FOUND)

if metadata := AggregateMetadata.objects(id=aggregate_object_id).first():
s3 = get_s3_client()
s3_obj = s3.get_object(Bucket=metadata.s3_bucket, Key=metadata.s3_object_key)
s3_obj = await s3_client.get_object(
Bucket=metadata.s3_bucket, Key=metadata.s3_object_key
)
metadata_location = f"/api/v1/aggregates/{aggregate_id}"
response = send_file(s3_obj["Body"], mimetype=metadata.content_type)
response.headers.update(
{

return StreamingResponse(
content=s3_obj["Body"],
media_type=metadata.content_type,
headers={
"Link": f'{metadata_location}; rel="about"',
"Content-Length": metadata.content_length,
}
# "Content-Length": str(metadata.content_length),
},
)
return response
raise NotFound

raise HTTPException(status.HTTP_404_NOT_FOUND)
Loading