Skip to content

Commit

Permalink
Work around issue with contextvars in fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Nov 7, 2023
1 parent 6b23112 commit 31fa19d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 42 deletions.
59 changes: 41 additions & 18 deletions clean_python/fastapi/fastapi_access_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from typing import Callable
from typing import Optional
from uuid import UUID
from uuid import uuid4

import inject
from starlette.background import BackgroundTasks
from starlette.requests import Request
from starlette.responses import Response

from clean_python import ctx
from clean_python import Gateway
from clean_python.fluentbit import FluentbitGateway

__all__ = ["FastAPIAccessLogger"]
__all__ = ["FastAPIAccessLogger", "get_correlation_id"]


CORRELATION_ID_HEADER = b"x-correlation-id"


def get_view_name(request: Request) -> Optional[str]:
Expand All @@ -32,6 +35,24 @@ def is_health_check(request: Request) -> bool:
return get_view_name(request) == "health_check"


def get_correlation_id(request: Request) -> Optional[UUID]:
headers = dict(request.scope["headers"])
try:
return UUID(headers[CORRELATION_ID_HEADER].decode())
except (KeyError, ValueError, UnicodeDecodeError):
return None


def ensure_correlation_id(request: Request) -> None:
correlation_id = get_correlation_id(request)
if correlation_id is None:
# generate an id and update the request inplace
correlation_id = uuid4()
headers = dict(request.scope["headers"])
headers[CORRELATION_ID_HEADER] = str(correlation_id).encode()
request.scope["headers"] = list(headers.items())


class FastAPIAccessLogger:
def __init__(self, hostname: str, gateway_override: Optional[Gateway] = None):
self.origin = f"{hostname}-{os.getpid()}"
Expand All @@ -44,24 +65,27 @@ def gateway(self) -> Gateway:
async def __call__(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
if request.scope["type"] != "http" or is_health_check(request):
return await call_next(request)

ensure_correlation_id(request)

time_received = time.time()
response = await call_next(request)
request_time = time.time() - time_received

if not is_health_check(request):
# Instead of logging directly, set it as background task so that it is
# executed after the response. See https://www.starlette.io/background/.
if response.background is None:
response.background = BackgroundTasks()
response.background.add_task(
log_access,
self.gateway,
request,
response,
time_received,
request_time,
ctx.correlation_id,
)
# Instead of logging directly, set it as background task so that it is
# executed after the response. See https://www.starlette.io/background/.
if response.background is None:
response.background = BackgroundTasks()
response.background.add_task(
log_access,
self.gateway,
request,
response,
time_received,
request_time,
)
return response


Expand All @@ -71,7 +95,6 @@ async def log_access(
response: Response,
time_received: float,
request_time: float,
correlation_id: Optional[UUID] = None,
) -> None:
"""
Create a dictionary with logging data.
Expand All @@ -96,6 +119,6 @@ async def log_access(
"content_length": content_length,
"time": time_received,
"request_time": request_time,
"correlation_id": str(correlation_id) if correlation_id else None,
"correlation_id": str(get_correlation_id(request)),
}
await gateway.add(item)
7 changes: 2 additions & 5 deletions clean_python/fastapi/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
from typing import List
from typing import Optional
from typing import Set
from uuid import UUID
from uuid import uuid4

from fastapi import Depends
from fastapi import FastAPI
from fastapi import Header
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from starlette.types import ASGIApp
Expand All @@ -36,6 +33,7 @@
from .error_responses import validation_error_handler
from .error_responses import ValidationErrorResponse
from .fastapi_access_logger import FastAPIAccessLogger
from .fastapi_access_logger import get_correlation_id
from .resource import APIVersion
from .resource import clean_resources
from .resource import Resource
Expand Down Expand Up @@ -68,12 +66,11 @@ def get_auth_kwargs(auth_client: Optional[OAuth2SPAClientSettings]) -> Dict[str,
async def set_context(
request: Request,
token: Token = Depends(get_token),
x_correlation_id: UUID = Header(default_factory=uuid4),
) -> None:
ctx.path = request.url
ctx.user = token.user
ctx.tenant = token.tenant
ctx.correlation_id = x_correlation_id
ctx.correlation_id = get_correlation_id(request)


async def health_check():
Expand Down
5 changes: 5 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ services:
MINIO_ROOT_PASSWORD: cleanpython
ports:
- "9000:9000"

fluentbit:
image: fluent/fluent-bit:1.9
ports:
- "24224:24224"
62 changes: 43 additions & 19 deletions tests/fastapi/test_fastapi_access_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from starlette.responses import JSONResponse
from starlette.responses import StreamingResponse

from clean_python import ctx
from clean_python import InMemoryGateway
from clean_python.fastapi import FastAPIAccessLogger
from clean_python.fastapi import get_correlation_id

SOME_UUID = uuid4()


@pytest.fixture
Expand Down Expand Up @@ -38,6 +40,7 @@ def req():
(b"accept-encoding", b"gzip, deflate, br"),
(b"accept-language", b"en-US,en;q=0.9"),
(b"cookie", b"..."),
(b"x-correlation-id", str(SOME_UUID).encode()),
],
"state": {},
"method": "GET",
Expand All @@ -64,23 +67,14 @@ def response():
@pytest.fixture
def call_next(response):
async def func(request):
assert get_correlation_id(request) == SOME_UUID
return response

return func


@pytest.fixture
def correlation_id():
uid = uuid4()
ctx.correlation_id = uid
yield uid
ctx.correlation_id = None


@mock.patch("time.time", return_value=0.0)
async def test_logging(
time, fastapi_access_logger, req, response, call_next, correlation_id
):
async def test_logging(time, fastapi_access_logger, req, response, call_next):
await fastapi_access_logger(req, call_next)
assert len(fastapi_access_logger.gateway.data) == 0
await response.background()
Expand All @@ -101,7 +95,7 @@ async def test_logging(
"content_length": 13,
"time": 0.0,
"request_time": 0.0,
"correlation_id": str(correlation_id),
"correlation_id": str(SOME_UUID),
}


Expand All @@ -116,7 +110,7 @@ def req_minimal():
"scheme": "http",
"path": "/",
"query_string": "",
"headers": [],
"headers": [(b"abc", b"def")],
}
return Request(scope)

Expand All @@ -135,16 +129,27 @@ async def numbers(minimum, maximum):
@pytest.fixture
def call_next_streaming(streaming_response):
async def func(request):
assert get_correlation_id(request) == SOME_UUID
return streaming_response

return func


@mock.patch("time.time", return_value=0.0)
@mock.patch("clean_python.fastapi.fastapi_access_logger.uuid4", return_value=SOME_UUID)
async def test_logging_minimal(
time, fastapi_access_logger, req_minimal, streaming_response, call_next_streaming
time,
uuid4,
fastapi_access_logger,
req_minimal,
streaming_response,
call_next_streaming,
):
await fastapi_access_logger(req_minimal, call_next_streaming)
assert req_minimal["headers"] == [
(b"abc", b"def"),
(b"x-correlation-id", str(SOME_UUID).encode()),
]
assert len(fastapi_access_logger.gateway.data) == 0
await streaming_response.background()
(actual,) = fastapi_access_logger.gateway.data.values()
Expand All @@ -164,15 +169,21 @@ async def test_logging_minimal(
"content_length": None,
"time": 0.0,
"request_time": 0.0,
"correlation_id": None,
"correlation_id": str(SOME_UUID),
}


@pytest.fixture
def req_health():
# a copy-paste from a local session, with some values removed / shortened
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"scheme": "http",
"path": "/",
"query_string": "",
"headers": [],
"route": APIRoute(
endpoint=lambda x: x,
path="/health",
Expand All @@ -183,9 +194,22 @@ def req_health():
return Request(scope)


@pytest.fixture
def call_next_no_correlation_id(response):
async def func(request):
assert get_correlation_id(request) is None
return response

return func


@mock.patch("time.time", return_value=0.0)
async def test_logging_health_check_skipped(
time, fastapi_access_logger, req_health, streaming_response, call_next_streaming
time,
fastapi_access_logger,
req_health,
streaming_response,
call_next_no_correlation_id,
):
await fastapi_access_logger(req_health, call_next_streaming)
await fastapi_access_logger(req_health, call_next_no_correlation_id)
assert streaming_response.background is None

0 comments on commit 31fa19d

Please sign in to comment.