Skip to content

Commit

Permalink
improve and document responses
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Jan 12, 2024
1 parent f5b0770 commit b3ae9c2
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions aggrec/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import bson
import pendulum
from bson.objectid import ObjectId
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

Expand All @@ -34,10 +34,6 @@
"Signature-Input",
]

ALLOWED_AGGREGATE_TYPES = ["histogram", "vector"]

ALLOWED_CONTENT_TYPES = ["application/vnd.apache.parquet", "application/binary"]

router = APIRouter()


Expand All @@ -46,6 +42,11 @@ class AggregateType(str, Enum):
vector = "vector"


class AggregateContentType(str, Enum):
parquet = "application/vnd.apache.parquet"
binary = "application/binary"


class AggregateMetadataResponse(BaseModel):
aggregate_id: str = Field(title="Aggregate identifier")
aggregate_type: AggregateType = Field(title="Aggregate type")
Expand All @@ -54,9 +55,9 @@ class AggregateMetadataResponse(BaseModel):
headers: dict = Field(title="Dictionary of relevant HTTP headers")
content_type: str = Field(title="Content MIME type")
content_length: int = Field(title="Content length")
content_location: str = Field(title="Content local (URL)")
s3_bucket: str = Field(title="S3 Bucket Name")
s3_object_key: str = Field(title="S3 Object Key")
content_location: str = Field(title="Content location (URL)")
s3_bucket: str = Field(title="S3 bucket name")
s3_object_key: str = Field(title="S3 object key")
aggregate_interval_start: datetime | None = Field(
default=None, title="Aggregate interval start"
)
Expand Down Expand Up @@ -210,23 +211,38 @@ def get_s3_object_metadata(metadata: AggregateMetadata) -> dict:
}


@router.post("/api/v1/aggregate/{aggregate_type}")
@router.post(
"/api/v1/aggregate/{aggregate_type}",
status_code=201,
responses={
201: {
"description": "Aggregate created",
"content": None,
"headers": {
"location": {
"description": "Aggregate URL",
"schema": {"type": "string", "format": "uri"},
},
},
}
},
)
async def create_aggregate(
aggregate_type: AggregateType,
content_type: Annotated[AggregateContentType, Header()],
aggregate_interval: Annotated[
str | None,
Header(
description="Aggregate window as an ISO 8601 time interval (start and duration)",
example="1984-01-01T12:00:00Z/PT1M",
),
],
request: Request,
settings: Annotated[Settings, Depends(get_settings)],
s3_client: Annotated[aiobotocore.client.AioBaseClient, Depends(s3_client)],
mqtt_client: Annotated[aiomqtt.Client, Depends(mqtt_client)],
http_request_verifier: Annotated[RequestVerifier, Depends(http_request_verifier)],
):
if aggregate_type not in ALLOWED_AGGREGATE_TYPES:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Aggregate type not supported")

content_type = request.headers.get("content-type", None)

if content_type not in ALLOWED_CONTENT_TYPES:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Content-Type not supported")

res = await http_request_verifier.verify(request)

creator = res.parameters.get("keyid")
Expand All @@ -239,11 +255,11 @@ async def create_aggregate(

s3_bucket = settings.s3_bucket

if aggregate_interval := request.headers.get("Aggregate-Interval"):
if aggregate_interval:
period = pendulum.parse(aggregate_interval)
if not isinstance(period, pendulum.Interval):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Invalid Aggregate-Interval"
status.HTTP_422_UNPROCESSABLE_ENTITY, "Invalid Aggregate-Interval"
)
aggregate_interval_start = pendulum_as_datetime(period.start)
aggregate_interval_duration = period.start.diff(period.end).in_seconds()
Expand Down Expand Up @@ -296,7 +312,13 @@ async def create_aggregate(
return Response(status_code=status.HTTP_201_CREATED, headers={"Location": location})


@router.get("/api/v1/aggregates/{aggregate_id}")
@router.get(
"/api/v1/aggregates/{aggregate_id}",
responses={
200: {"model": AggregateMetadataResponse},
404: {},
},
)
def get_aggregate_metadata(
aggregate_id: str, settings: Annotated[Settings, Depends(get_settings)]
) -> AggregateMetadataResponse:
Expand All @@ -316,11 +338,18 @@ def get_aggregate_metadata(
responses={
200: {
"description": "Aggregate payload",
"headers": {
"link": {
"description": 'Linked resources (RFC 8288), rel="about" for metadata URL',
"schema": {"type": "string", "format": "uri"},
},
},
"content": {
"application/vnd.apache.parquet": {},
"application/binary": {},
},
}
},
404: {},
},
)
async def get_aggregate_payload(
Expand Down

0 comments on commit b3ae9c2

Please sign in to comment.