From b3ae9c22b594b8f93e60d9d34d6a3bcb27a686ba Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 12 Jan 2024 13:49:17 +0100 Subject: [PATCH] improve and document responses --- aggrec/aggregates.py | 71 +++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/aggrec/aggregates.py b/aggrec/aggregates.py index 60e3a35..5b7a869 100644 --- a/aggrec/aggregates.py +++ b/aggrec/aggregates.py @@ -13,7 +13,7 @@ import bson import pendulum from bson.objectid import ObjectId -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field @@ -34,10 +34,6 @@ "Signature-Input", ] -ALLOWED_AGGREGATE_TYPES = ["histogram", "vector"] - -ALLOWED_CONTENT_TYPES = ["application/vnd.apache.parquet", "application/binary"] - router = APIRouter() @@ -46,6 +42,11 @@ class AggregateType(str, Enum): vector = "vector" +class AggregateContentType(str, Enum): + parquet = "application/vnd.apache.parquet" + binary = "application/binary" + + class AggregateMetadataResponse(BaseModel): aggregate_id: str = Field(title="Aggregate identifier") aggregate_type: AggregateType = Field(title="Aggregate type") @@ -54,9 +55,9 @@ class AggregateMetadataResponse(BaseModel): headers: dict = Field(title="Dictionary of relevant HTTP headers") content_type: str = Field(title="Content MIME type") content_length: int = Field(title="Content length") - content_location: str = Field(title="Content local (URL)") - s3_bucket: str = Field(title="S3 Bucket Name") - s3_object_key: str = Field(title="S3 Object Key") + content_location: str = Field(title="Content location (URL)") + s3_bucket: str = Field(title="S3 bucket name") + s3_object_key: str = Field(title="S3 object key") aggregate_interval_start: datetime | None = Field( default=None, title="Aggregate interval start" ) @@ -210,23 +211,38 @@ def get_s3_object_metadata(metadata: AggregateMetadata) -> dict: } -@router.post("/api/v1/aggregate/{aggregate_type}") +@router.post( + "/api/v1/aggregate/{aggregate_type}", + status_code=201, + responses={ + 201: { + "description": "Aggregate created", + "content": None, + "headers": { + "location": { + "description": "Aggregate URL", + "schema": {"type": "string", "format": "uri"}, + }, + }, + } + }, +) async def create_aggregate( aggregate_type: AggregateType, + content_type: Annotated[AggregateContentType, Header()], + aggregate_interval: Annotated[ + str | None, + Header( + description="Aggregate window as an ISO 8601 time interval (start and duration)", + example="1984-01-01T12:00:00Z/PT1M", + ), + ], request: Request, settings: Annotated[Settings, Depends(get_settings)], s3_client: Annotated[aiobotocore.client.AioBaseClient, Depends(s3_client)], mqtt_client: Annotated[aiomqtt.Client, Depends(mqtt_client)], http_request_verifier: Annotated[RequestVerifier, Depends(http_request_verifier)], ): - if aggregate_type not in ALLOWED_AGGREGATE_TYPES: - raise HTTPException(status.HTTP_400_BAD_REQUEST, "Aggregate type not supported") - - content_type = request.headers.get("content-type", None) - - if content_type not in ALLOWED_CONTENT_TYPES: - raise HTTPException(status.HTTP_400_BAD_REQUEST, "Content-Type not supported") - res = await http_request_verifier.verify(request) creator = res.parameters.get("keyid") @@ -239,11 +255,11 @@ async def create_aggregate( s3_bucket = settings.s3_bucket - if aggregate_interval := request.headers.get("Aggregate-Interval"): + if aggregate_interval: period = pendulum.parse(aggregate_interval) if not isinstance(period, pendulum.Interval): raise HTTPException( - status.HTTP_400_BAD_REQUEST, "Invalid Aggregate-Interval" + status.HTTP_422_UNPROCESSABLE_ENTITY, "Invalid Aggregate-Interval" ) aggregate_interval_start = pendulum_as_datetime(period.start) aggregate_interval_duration = period.start.diff(period.end).in_seconds() @@ -296,7 +312,13 @@ async def create_aggregate( return Response(status_code=status.HTTP_201_CREATED, headers={"Location": location}) -@router.get("/api/v1/aggregates/{aggregate_id}") +@router.get( + "/api/v1/aggregates/{aggregate_id}", + responses={ + 200: {"model": AggregateMetadataResponse}, + 404: {}, + }, +) def get_aggregate_metadata( aggregate_id: str, settings: Annotated[Settings, Depends(get_settings)] ) -> AggregateMetadataResponse: @@ -316,11 +338,18 @@ def get_aggregate_metadata( responses={ 200: { "description": "Aggregate payload", + "headers": { + "link": { + "description": 'Linked resources (RFC 8288), rel="about" for metadata URL', + "schema": {"type": "string", "format": "uri"}, + }, + }, "content": { "application/vnd.apache.parquet": {}, "application/binary": {}, }, - } + }, + 404: {}, }, ) async def get_aggregate_payload(