From 31fa19da3f0a9e4fa719ac944aa661cebf2c09de Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Tue, 7 Nov 2023 14:05:19 +0100 Subject: [PATCH] Work around issue with contextvars in fastapi --- clean_python/fastapi/fastapi_access_logger.py | 59 ++++++++++++------ clean_python/fastapi/service.py | 7 +-- docker-compose.yaml | 5 ++ tests/fastapi/test_fastapi_access_logger.py | 62 +++++++++++++------ 4 files changed, 91 insertions(+), 42 deletions(-) diff --git a/clean_python/fastapi/fastapi_access_logger.py b/clean_python/fastapi/fastapi_access_logger.py index 83fcde7..e60f6d1 100644 --- a/clean_python/fastapi/fastapi_access_logger.py +++ b/clean_python/fastapi/fastapi_access_logger.py @@ -6,17 +6,20 @@ from typing import Callable from typing import Optional from uuid import UUID +from uuid import uuid4 import inject from starlette.background import BackgroundTasks from starlette.requests import Request from starlette.responses import Response -from clean_python import ctx from clean_python import Gateway from clean_python.fluentbit import FluentbitGateway -__all__ = ["FastAPIAccessLogger"] +__all__ = ["FastAPIAccessLogger", "get_correlation_id"] + + +CORRELATION_ID_HEADER = b"x-correlation-id" def get_view_name(request: Request) -> Optional[str]: @@ -32,6 +35,24 @@ def is_health_check(request: Request) -> bool: return get_view_name(request) == "health_check" +def get_correlation_id(request: Request) -> Optional[UUID]: + headers = dict(request.scope["headers"]) + try: + return UUID(headers[CORRELATION_ID_HEADER].decode()) + except (KeyError, ValueError, UnicodeDecodeError): + return None + + +def ensure_correlation_id(request: Request) -> None: + correlation_id = get_correlation_id(request) + if correlation_id is None: + # generate an id and update the request inplace + correlation_id = uuid4() + headers = dict(request.scope["headers"]) + headers[CORRELATION_ID_HEADER] = str(correlation_id).encode() + request.scope["headers"] = list(headers.items()) + + class FastAPIAccessLogger: def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None): self.origin = f"{hostname}-{os.getpid()}" @@ -44,24 +65,27 @@ def gateway(self) -> Gateway: async def __call__( self, request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: + if request.scope["type"] != "http" or is_health_check(request): + return await call_next(request) + + ensure_correlation_id(request) + time_received = time.time() response = await call_next(request) request_time = time.time() - time_received - if not is_health_check(request): - # Instead of logging directly, set it as background task so that it is - # executed after the response. See https://www.starlette.io/background/. - if response.background is None: - response.background = BackgroundTasks() - response.background.add_task( - log_access, - self.gateway, - request, - response, - time_received, - request_time, - ctx.correlation_id, - ) + # Instead of logging directly, set it as background task so that it is + # executed after the response. See https://www.starlette.io/background/. + if response.background is None: + response.background = BackgroundTasks() + response.background.add_task( + log_access, + self.gateway, + request, + response, + time_received, + request_time, + ) return response @@ -71,7 +95,6 @@ async def log_access( response: Response, time_received: float, request_time: float, - correlation_id: Optional[UUID] = None, ) -> None: """ Create a dictionary with logging data. @@ -96,6 +119,6 @@ async def log_access( "content_length": content_length, "time": time_received, "request_time": request_time, - "correlation_id": str(correlation_id) if correlation_id else None, + "correlation_id": str(get_correlation_id(request)), } await gateway.add(item) diff --git a/clean_python/fastapi/service.py b/clean_python/fastapi/service.py index 0cc6ffe..34361ca 100644 --- a/clean_python/fastapi/service.py +++ b/clean_python/fastapi/service.py @@ -6,12 +6,9 @@ from typing import List from typing import Optional from typing import Set -from uuid import UUID -from uuid import uuid4 from fastapi import Depends from fastapi import FastAPI -from fastapi import Header from fastapi import Request from fastapi.exceptions import RequestValidationError from starlette.types import ASGIApp @@ -36,6 +33,7 @@ from .error_responses import validation_error_handler from .error_responses import ValidationErrorResponse from .fastapi_access_logger import FastAPIAccessLogger +from .fastapi_access_logger import get_correlation_id from .resource import APIVersion from .resource import clean_resources from .resource import Resource @@ -68,12 +66,11 @@ def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str, async def set_context( request: Request, token: Token = Depends(get_token), - x_correlation_id: UUID = Header(default_factory=uuid4), ) -> None: ctx.path = request.url ctx.user = token.user ctx.tenant = token.tenant - ctx.correlation_id = x_correlation_id + ctx.correlation_id = get_correlation_id(request) async def health_check(): diff --git a/docker-compose.yaml b/docker-compose.yaml index 4058ebd..9f3cea6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -18,3 +18,8 @@ services: MINIO_ROOT_PASSWORD: cleanpython ports: - "9000:9000" + + fluentbit: + image: fluent/fluent-bit:1.9 + ports: + - "24224:24224" diff --git a/tests/fastapi/test_fastapi_access_logger.py b/tests/fastapi/test_fastapi_access_logger.py index 2b243fd..69fac34 100644 --- a/tests/fastapi/test_fastapi_access_logger.py +++ b/tests/fastapi/test_fastapi_access_logger.py @@ -7,9 +7,11 @@ from starlette.responses import JSONResponse from starlette.responses import StreamingResponse -from clean_python import ctx from clean_python import InMemoryGateway from clean_python.fastapi import FastAPIAccessLogger +from clean_python.fastapi import get_correlation_id + +SOME_UUID = uuid4() @pytest.fixture @@ -38,6 +40,7 @@ def req(): (b"accept-encoding", b"gzip, deflate, br"), (b"accept-language", b"en-US,en;q=0.9"), (b"cookie", b"..."), + (b"x-correlation-id", str(SOME_UUID).encode()), ], "state": {}, "method": "GET", @@ -64,23 +67,14 @@ def response(): @pytest.fixture def call_next(response): async def func(request): + assert get_correlation_id(request) == SOME_UUID return response return func -@pytest.fixture -def correlation_id(): - uid = uuid4() - ctx.correlation_id = uid - yield uid - ctx.correlation_id = None - - @mock.patch("time.time", return_value=0.0) -async def test_logging( - time, fastapi_access_logger, req, response, call_next, correlation_id -): +async def test_logging(time, fastapi_access_logger, req, response, call_next): await fastapi_access_logger(req, call_next) assert len(fastapi_access_logger.gateway.data) == 0 await response.background() @@ -101,7 +95,7 @@ async def test_logging( "content_length": 13, "time": 0.0, "request_time": 0.0, - "correlation_id": str(correlation_id), + "correlation_id": str(SOME_UUID), } @@ -116,7 +110,7 @@ def req_minimal(): "scheme": "http", "path": "/", "query_string": "", - "headers": [], + "headers": [(b"abc", b"def")], } return Request(scope) @@ -135,16 +129,27 @@ async def numbers(minimum, maximum): @pytest.fixture def call_next_streaming(streaming_response): async def func(request): + assert get_correlation_id(request) == SOME_UUID return streaming_response return func @mock.patch("time.time", return_value=0.0) +@mock.patch("clean_python.fastapi.fastapi_access_logger.uuid4", return_value=SOME_UUID) async def test_logging_minimal( - time, fastapi_access_logger, req_minimal, streaming_response, call_next_streaming + time, + uuid4, + fastapi_access_logger, + req_minimal, + streaming_response, + call_next_streaming, ): await fastapi_access_logger(req_minimal, call_next_streaming) + assert req_minimal["headers"] == [ + (b"abc", b"def"), + (b"x-correlation-id", str(SOME_UUID).encode()), + ] assert len(fastapi_access_logger.gateway.data) == 0 await streaming_response.background() (actual,) = fastapi_access_logger.gateway.data.values() @@ -164,15 +169,21 @@ async def test_logging_minimal( "content_length": None, "time": 0.0, "request_time": 0.0, - "correlation_id": None, + "correlation_id": str(SOME_UUID), } @pytest.fixture def req_health(): - # a copy-paste from a local session, with some values removed / shortened scope = { "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": "/", + "query_string": "", + "headers": [], "route": APIRoute( endpoint=lambda x: x, path="/health", @@ -183,9 +194,22 @@ def req_health(): return Request(scope) +@pytest.fixture +def call_next_no_correlation_id(response): + async def func(request): + assert get_correlation_id(request) is None + return response + + return func + + @mock.patch("time.time", return_value=0.0) async def test_logging_health_check_skipped( - time, fastapi_access_logger, req_health, streaming_response, call_next_streaming + time, + fastapi_access_logger, + req_health, + streaming_response, + call_next_no_correlation_id, ): - await fastapi_access_logger(req_health, call_next_streaming) + await fastapi_access_logger(req_health, call_next_no_correlation_id) assert streaming_response.background is None