Skip to content

Commit

Permalink
Backend: Decouple datasets service from HTTP presentation layer (#719)
Browse files Browse the repository at this point in the history
* Remove HTTP from datasets_service

* Add comments to improve clarity

* Exception added to the root ServiceError, log errors with an exception, improve docs.

* Move error handling from some routes to service (removing HTTP component)
  • Loading branch information
peteski22 authored Jan 23, 2025
1 parent 151fceb commit 118cbee
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 49 deletions.
40 changes: 22 additions & 18 deletions lumigator/python/mzai/backend/backend/api/routes/datasets.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
from http import HTTPStatus
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Form, HTTPException, Query, UploadFile, status
from loguru import logger
from fastapi import APIRouter, Form, Query, UploadFile, status
from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetFormat, DatasetResponse
from lumigator_schemas.extras import ListingResponse
from starlette.requests import Request
from starlette.responses import Response

from backend.api.deps import DatasetServiceDep
from backend.api.http_headers import HttpHeaders
from backend.services.exceptions.base_exceptions import ServiceError
from backend.services.exceptions.dataset_exceptions import (
DatasetInvalidError,
DatasetMissingFieldsError,
DatasetNotFoundError,
DatasetSizeError,
DatasetUpstreamError,
)
from backend.settings import settings

router = APIRouter()


def dataset_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
return {
DatasetNotFoundError: status.HTTP_404_NOT_FOUND,
DatasetMissingFieldsError: status.HTTP_403_FORBIDDEN,
DatasetUpstreamError: status.HTTP_500_INTERNAL_SERVER_ERROR,
DatasetSizeError: status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
DatasetInvalidError: status.HTTP_422_UNPROCESSABLE_ENTITY,
}


@router.post(
"/",
status_code=status.HTTP_201_CREATED,
Expand Down Expand Up @@ -60,26 +78,12 @@ def upload_dataset(

@router.get("/{dataset_id}")
def get_dataset(service: DatasetServiceDep, dataset_id: UUID) -> DatasetResponse:
dataset = service.get_dataset(dataset_id)
if not dataset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Dataset '{dataset_id}' not found.",
)

return dataset
return service.get_dataset(dataset_id)


@router.delete("/{dataset_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_dataset(service: DatasetServiceDep, dataset_id: UUID) -> None:
try:
service.delete_dataset(dataset_id)
except Exception as e:
logger.error(f"Unexpected error deleting dataset ID from DB and S3: {dataset_id}. {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unexpected error deleting dataset for ID: {dataset_id}",
) from e
service.delete_dataset(dataset_id)


@router.get("/")
Expand Down
34 changes: 34 additions & 0 deletions lumigator/python/mzai/backend/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os
import sys
from collections.abc import Callable
from http import HTTPStatus
from pathlib import Path

from alembic import command
from alembic.config import Config
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from backend.api.router import api_router
from backend.api.routes.datasets import dataset_exception_mappings
from backend.api.tags import TAGS_METADATA
from backend.services.exceptions.base_exceptions import ServiceError
from backend.settings import settings

LUMIGATOR_APP_TAGS = {
Expand Down Expand Up @@ -45,6 +51,21 @@ def _configure_logger():
)


def create_error_handler(status_code: HTTPStatus) -> Callable[[Request, ServiceError], Response]:
"""Creates an error handler function for service errors, using the given status code"""

def handler(_: Request, exc: ServiceError) -> Response:
# Log any inner exceptions as part of handling the service error.
logger.opt(exception=exc).error("Service error")

return JSONResponse(
status_code=status_code,
content={"detail": exc.message},
)

return handler


def create_app() -> FastAPI:
_configure_logger()

Expand All @@ -67,6 +88,19 @@ def create_app() -> FastAPI:

app.include_router(api_router)

# Group mappings of service error types to HTTP status code, for routes.
exception_mappings = [
# TODO: Add completions
dataset_exception_mappings() # Datasets
# TODO: Add experiments
# TODO: Add jobs
]

# Add a handler for each error -> status mapping.
for mapping in exception_mappings:
for key, value in mapping.items():
app.add_exception_handler(key, create_error_handler(value))

@app.get("/")
def get_root():
return {"Hello": "Lumigator!🐊"}
Expand Down
98 changes: 69 additions & 29 deletions lumigator/python/mzai/backend/backend/services/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import csv
import traceback
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import BinaryIO
from uuid import UUID

from datasets import load_dataset
from fastapi import HTTPException, UploadFile, status
from fastapi import UploadFile
from loguru import logger
from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetFormat, DatasetResponse
from lumigator_schemas.extras import ListingResponse
Expand All @@ -16,6 +15,13 @@

from backend.records.datasets import DatasetRecord
from backend.repositories.datasets import DatasetRepository
from backend.services.exceptions.dataset_exceptions import (
DatasetInvalidError,
DatasetMissingFieldsError,
DatasetNotFoundError,
DatasetSizeError,
DatasetUpstreamError,
)
from backend.settings import settings

GT_FIELD: str = "ground_truth"
Expand All @@ -32,19 +38,28 @@ def validate_file_size(input: BinaryIO, output: BinaryIO, max_size: ByteSize) ->
We can then process the file as a whole once its been written to a buffer on the server.
Reference: https://github.com/tiangolo/fastapi/issues/362#issuecomment-584104025
:param input: the input buffer
:param output: the output buffer (which is updated to contain the parsed dataset)
:param max_size: the maximum allowed size of the output buffer
:raises DatasetSizeError: if the size of the output buffer (file) is too large
"""
actual_size = 0
for chunk in input:
actual_size += output.write(chunk)
if actual_size > max_size:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File upload exceeds the {max_size.human_readable(decimal=True)} limit.",
)
raise DatasetSizeError(max_size.human_readable(decimal=True)) from None
return actual_size


def validate_dataset_format(filename: str, format: DatasetFormat):
"""Validates the dataset identified by the filename, based on the format.
:param filename: the filename of the dataset to validate
:param format: the dataset format (e.g. 'job')
:raises DatasetInvalidError: if there is a problem with the dataset file format
:raises DatasetMissingFieldsError: if the dataset is missing any required fields
"""
try:
match format:
case DatasetFormat.JOB:
Expand All @@ -54,24 +69,22 @@ def validate_dataset_format(filename: str, format: DatasetFormat):
raise ValueError(f"Unknown dataset format: {format}")
except UnicodeError as e:
logger.opt(exception=e).info("Error processing dataset upload.")
http_exception = HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Dataset is not a valid CSV file.",
)
raise http_exception from e
raise DatasetInvalidError("not a CSV file") from e


def validate_experiment_dataset(filename: str):
"""Validates the dataset (CSV) file to ensure all required fields are present.
:param filename: the filename of the dataset to validate
:raises DatasetMissingFieldsError: if the dataset is missing any of the required fields
"""
with Path(filename).open() as f:
reader = csv.DictReader(f)
fields = set(reader.fieldnames or [])

missing_fields = REQUIRED_EXPERIMENT_FIELDS.difference(fields)
if missing_fields:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Experiment dataset is missing the required fields: {missing_fields}.",
)
raise DatasetMissingFieldsError(missing_fields) from None


def dataset_has_gt(filename: str) -> bool:
Expand All @@ -95,13 +108,6 @@ def __init__(
self.s3_client = s3_client
self.s3_filesystem = s3_filesystem

def _raise_not_found(self, dataset_id: UUID) -> None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dataset '{dataset_id}' not found.")

def _raise_unhandled_exception(self, e: Exception) -> None:
traceback.print_exc()
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, str(e)) from e

def _get_dataset_record(self, dataset_id: UUID) -> DatasetRecord | None:
return self.dataset_repo.get(dataset_id)

Expand All @@ -122,6 +128,10 @@ def _get_s3_key(self, dataset_id: UUID, filename: str) -> str:
def _save_dataset_to_s3(self, temp_fname, record):
"""Converts the specified file to a set of HuggingFace dataset formatted files,
along with a newly recreated CSV file. The files are stored in an S3 bucket.
:param temp_fname: temporary file name to read the dataset from
:param record: the dataset record (DatasetRecord)
:raises DatasetUpstreamError: if there is an exception interacting with S3
"""
# Temp file to be used to contain the recreated CSV file.
temp = NamedTemporaryFile(delete=False)
Expand All @@ -144,7 +154,7 @@ def _save_dataset_to_s3(self, temp_fname, record):
if record:
self.dataset_repo.delete(record.id)

self._raise_unhandled_exception(e)
raise DatasetUpstreamError("s3", "error attempting to save dataset to S3", e) from e
finally:
# Clean up temp file
Path(temp.name).unlink()
Expand All @@ -157,6 +167,14 @@ def upload_dataset(
generated: bool = False,
generated_by: str | None = None,
) -> DatasetResponse:
"""Attempts to upload and convert the specified dataset (CSV) to HF format which is then
stored in S3.
:raises DatasetSizeError: if the dataset is too large
:raises DatasetInvalidError: if the dataset is invalid
:raises DatasetMissingFieldsError: if the dataset is missing any of the required fields
:raises DatasetUpstreamError: if there is an exception interacting with S3
"""
temp = NamedTemporaryFile(delete=False)
try:
# Write to tempfile and validate size
Expand Down Expand Up @@ -191,9 +209,14 @@ def upload_dataset(
return DatasetResponse.model_validate(record)

def get_dataset(self, dataset_id: UUID) -> DatasetResponse | None:
"""Gets the dataset record by its ID.
:param dataset_id: dataset ID
:raises DatasetNotFoundError: if there is no dataset record with that ID
"""
record = self._get_dataset_record(dataset_id)
if record is None:
return None
raise DatasetNotFoundError(dataset_id) from None

return DatasetResponse.model_validate(record)

Expand Down Expand Up @@ -221,11 +244,15 @@ def delete_dataset(self, dataset_id: UUID) -> None:
This operation is idempotent, calling it with a record that never existed, or that has
already been deleted, will not raise an error.
:param dataset_id: dataset ID to delete
:raises DatasetNotFoundError: if there is no dataset record with that ID
:raises DatasetUpstreamError: if there is an exception deleting the dataset from S3
"""
record = self._get_dataset_record(dataset_id)
# Early return if the record does not exist (for idempotency).
if record is None:
return None
raise DatasetNotFoundError(dataset_id) from None

try:
# S3 delete is called first, if this fails for any other reason that the file not being
Expand All @@ -239,6 +266,10 @@ def delete_dataset(self, dataset_id: UUID) -> None:
f"Dataset ID: {dataset_id} was present in the DB but not found on S3... "
f"Cleaning up DB by removing ID. {e}"
)
except Exception as e:
raise DatasetUpstreamError(
"s3", f"error attempting to delete dataset {dataset_id} from S3", e
) from e

# Getting this far means we are OK to remove the record from the DB.
self.dataset_repo.delete(record.id)
Expand All @@ -249,11 +280,19 @@ def get_dataset_download(
"""Generate pre-signed download URLs for dataset files.
When supplied, only URLs for files that match the specified extension are returned.
:param dataset_id: ID of the dataset to generate pre-signed download URLs for
:param extension: File extension used to determine which files to generate URLs for
:raises DatasetNotFoundError: if the dataset cannot be found in S3
:raises DatasetUpstreamError: if there is an exception interacting with S3
"""
# Sanitize the input for a file extension.
extension = extension.strip().lower() if extension and extension.strip() else None

record = self._get_dataset_record(dataset_id)
if record is None:
raise DatasetNotFoundError(dataset_id, "error getting dataset download") from None

dataset_key = self._get_s3_key(dataset_id, record.filename)

try:
Expand All @@ -263,9 +302,9 @@ def get_dataset_download(
)

if s3_response.get("KeyCount") == 0:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"No files found with prefix '{dataset_key}'."
)
raise DatasetNotFoundError(
dataset_id, f"No S3 files found with prefix '{dataset_key}'"
) from None

download_urls = []
for s3_object in s3_response["Contents"]:
Expand All @@ -284,7 +323,8 @@ def get_dataset_download(
download_urls.append(download_url)

except Exception as e:
self._raise_unhandled_exception(e)
msg = f"Error generating pre-signed download URLs for dataset {dataset_id}"
raise DatasetUpstreamError("s3", msg, e) from e

return DatasetDownloadResponse(id=dataset_id, download_urls=download_urls)

Expand Down
Empty file.
Loading

0 comments on commit 118cbee

Please sign in to comment.