Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic 2 model with type parameter produces invalid OpenAPI schema #1365

Open
sastraxi opened this issue Jan 13, 2025 · 0 comments
Open

Pydantic 2 model with type parameter produces invalid OpenAPI schema #1365

sastraxi opened this issue Jan 13, 2025 · 0 comments

Comments

@sastraxi
Copy link

sastraxi commented Jan 13, 2025

Describe the bug
A pydantic schema (extending BaseModel) is automatically picked up by PydanticExtension, which serializes it with a name that contains square brackets, e.g. WrappedResponse[MessageResponse]. This name is invalid for OpenAPI schemas.

To Reproduce
Here's the base code to play around with.

from typing import Generic
from pydantic import BaseModel

T = TypeVar("T", bound=BaseModel)
U = TypeVar("U", bound=BaseModel)

class MessageResponse(BaseModel):
    msg: str

class WrappedResponse(BaseModel, Generic[U]):
    data: U

    @staticmethod
    def with(data: T) -> WrappedResponse[T]:
        return WrappedResponse(data=data)
from django.shortcuts import get_object_or_404
from drf_spectacular.utils import extend_schema
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response

from .schemas import WrappedResponse, MessageResponse

class MyViewset(viewsets.GenericViewSet):
    @extend_schema(
        responses={200: WrappedResponse[MyResponse]},
    )
    @action(detail=False, methods=["get"], url_path="say-hello")
    def say_hello(self, request: Request):
        data = WrappedResponse.with(MessageResponse(msg="Hello, world!"))
        return Response(data.model_dump())

Expected behavior
The PydanticExtension class should be able to generate names that are safe for OpenAPI schemas, e.g. by replacing [ and ] with _.

Alternatively, I should be able to provide my own OpenApiSerializerExtension that picks up the generic type and has a higher priority than the built-in extension. Unfortunately, because class names are directly compared in the registry (including type parameters), we cannot register one extension that picks up all types.

Another approach we could take would be to override _get_serializer_name in our own custom AutoSchema (get_serializer_name won't work because it's never called for Pydantic models). Because this is a private method, it doesn't feel correct to override it.

The approach that we ended up going with, until this is fixed, is to override the serializer name inside of e.g. WrappedResponse:

class WrappedResponse(OutputSchema, Generic[U]):
    data: U

    @classmethod
    def model_parametrized_name(cls, params: tuple[type[Any], ...]) -> str:
        """
        Ensures that types like WrappedResponse[MySchema] get names
        that are valid OpenAPI schema names (e.g. WrappedResponse_MySchema).
        """
        param_names = '_'.join([param.__name__ for param in params])
        return f"{cls.__name__}_{param_names}"

    @staticmethod
    def with(data: T) -> WrappedResponse[T]:
        return WrappedResponse(data=data)

I'm not sure if model_parametrized_name is used elsewhere, so I'm not 100% comfortable with this solution. What do you think is the best way to solve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant