Skip to content

Commit

Permalink
property fields in msgspec and dataclass DTOs
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Feb 1, 2025
1 parent b0322d5 commit bebc307
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 28 deletions.
42 changes: 30 additions & 12 deletions litestar/dto/base_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import typing
from abc import abstractmethod
from inspect import getmodule
Expand All @@ -17,6 +18,7 @@
from litestar.types.builtin_types import NoneType
from litestar.types.composite_types import TypeEncodersMap
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature

if TYPE_CHECKING:
from typing import Any, ClassVar, Generator
Expand Down Expand Up @@ -278,18 +280,7 @@ def resolve_generic_wrapper_type(
return None

@staticmethod
def get_model_type_hints(
model_type: type[Any], namespace: dict[str, Any] | None = None
) -> dict[str, FieldDefinition]:
"""Retrieve type annotations for ``model_type``.
Args:
model_type: Any type-annotated class.
namespace: Optional namespace to use for resolving type hints.
Returns:
Parsed type hints for ``model_type`` resolved within the scope of its module.
"""
def get_model_namespace(model_type: type[Any], namespace: dict[str, Any] | None = None) -> dict[str, Any]:
namespace = namespace or {}
namespace.update(vars(typing))
namespace.update(
Expand All @@ -303,6 +294,33 @@ def get_model_type_hints(

if model_module := getmodule(model_type):
namespace.update(vars(model_module))
return namespace

@classmethod
def get_property_fields(cls, model_type: type[Any]) -> dict[str, FieldDefinition]:
return {
name: dataclasses.replace(
ParsedSignature.from_fn(attr.fget, cls.get_model_namespace(model_type)).return_type,
name=name,
)
for name, attr in vars(model_type).items()
if isinstance(attr, property) and attr.fget is not None
}

@staticmethod
def get_model_type_hints(
model_type: type[Any], namespace: dict[str, Any] | None = None
) -> dict[str, FieldDefinition]:
"""Retrieve type annotations for ``model_type``.
Args:
model_type: Any type-annotated class.
namespace: Optional namespace to use for resolving type hints.
Returns:
Parsed type hints for ``model_type`` resolved within the scope of its module.
"""
namespace = AbstractDTO.get_model_namespace(model_type, namespace)

return {
k: FieldDefinition.from_kwarg(annotation=v, name=k)
Expand Down
15 changes: 14 additions & 1 deletion litestar/dto/dataclass_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import extract_dto_field
from litestar.dto.field import DTOField, extract_dto_field
from litestar.params import DependencyKwarg, KwargDefinition
from litestar.types.empty import Empty

Expand All @@ -30,6 +30,8 @@ def generate_field_definitions(
cls, model_type: type[DataclassProtocol]
) -> Generator[DTOFieldDefinition, None, None]:
dc_fields = {f.name: f for f in fields(model_type)}
properties = cls.get_property_fields(model_type)

for key, field_definition in cls.get_model_type_hints(model_type).items():
if not (dc_field := dc_fields.get(key)):
continue
Expand All @@ -53,6 +55,17 @@ def generate_field_definitions(
else field_defintion
)

for key, property_field in properties.items():
if key.startswith("_"):
continue

yield DTOFieldDefinition.from_field_definition(
property_field,
model_name=model_type.__name__,
default_factory=None,
dto_field=DTOField(mark="read-only"),
)

@classmethod
def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:
return hasattr(field_definition.annotation, "__dataclass_fields__")
34 changes: 24 additions & 10 deletions litestar/dto/msgspec_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField, extract_dto_field
from litestar.plugins.core._msgspec import kwarg_definition_from_field
from litestar.types.empty import Empty

Expand All @@ -24,25 +24,28 @@
T = TypeVar("T", bound="Struct | Collection[Struct]")


def _default_or_empty(value: Any) -> Any:
return Empty if value is NODEFAULT else value


def _default_or_none(value: Any) -> Any:
return None if value is NODEFAULT else value


class MsgspecDTO(AbstractDTO[T], Generic[T]):
"""Support for domain modelling with Msgspec."""

@classmethod
def generate_field_definitions(cls, model_type: type[Struct]) -> Generator[DTOFieldDefinition, None, None]:
msgspec_fields = {f.name: f for f in structs.fields(model_type)}

# TODO: Move out of here
def default_or_empty(value: Any) -> Any:
return Empty if value is NODEFAULT else value

def default_or_none(value: Any) -> Any:
return None if value is NODEFAULT else value

inspect_fields: dict[str, msgspec.inspect.Field] = {
field.name: field
for field in msgspec.inspect.type_info(model_type).fields # type: ignore[attr-defined]
}

property_fields = cls.get_property_fields(model_type)

for key, field_definition in cls.get_model_type_hints(model_type).items():
kwarg_definition, extra = kwarg_definition_from_field(inspect_fields[key])
field_definition = dataclasses.replace(field_definition, kwarg_definition=kwarg_definition)
Expand All @@ -56,12 +59,23 @@ def default_or_none(value: Any) -> Any:
field_definition=field_definition,
dto_field=dto_field,
model_name=model_type.__name__,
default_factory=default_or_none(msgspec_field.default_factory),
default_factory=_default_or_none(msgspec_field.default_factory),
),
default=default_or_empty(msgspec_field.default),
default=_default_or_empty(msgspec_field.default),
name=key,
)

for key, property_field in property_fields.items():
if key.startswith("_"):
continue

yield DTOFieldDefinition.from_field_definition(
property_field,
model_name=model_type.__name__,
default_factory=None,
dto_field=DTOField(mark="read-only"),
)

@classmethod
def detect_nested_field(cls, field_definition: FieldDefinition) -> bool:
return field_definition.is_subclass_of(Struct)
22 changes: 21 additions & 1 deletion tests/unit/test_contrib/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from dataclasses import replace
from typing import TYPE_CHECKING
from unittest.mock import ANY
Expand Down Expand Up @@ -90,6 +91,21 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
raw=ANY,
kwarg_definition=ANY,
),
replace(
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_kwarg(
annotation=str,
name="computed",
),
model_name=ANY,
default_factory=None,
dto_field=DTOField(mark="read-only"),
),
metadata=ANY,
type_wrappers=ANY,
raw=ANY,
kwarg_definition=ANY,
),
]


Expand All @@ -103,9 +119,13 @@ class TestStruct(Struct):
d: int = field(default=1)
e: int = field(default_factory=int_factory)

@property
def computed(self) -> str:
return "i am computed"

field_defs = list(MsgspecDTO.generate_field_definitions(TestStruct))
assert field_defs[0].model_name == "TestStruct"
for field_def, exp in zip(field_defs, expected_field_defs):
for field_def, exp in itertools.zip_longest(expected_field_defs, field_defs, fillvalue=None):
assert field_def == exp


Expand Down
58 changes: 54 additions & 4 deletions tests/unit/test_dto/test_factory/test_dataclass_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import sys
from dataclasses import dataclass, field, replace
from typing import ClassVar, List
Expand All @@ -20,6 +21,10 @@ class Model:
c: List[int] = field(default_factory=list) # noqa: UP006
d: ClassVar[float] = 1.0

@property
def computed(self) -> str:
return "i am a property"


@pytest.fixture(name="dto_type")
def fx_dto_type() -> type[DataclassDTO[Model]]:
Expand Down Expand Up @@ -68,8 +73,22 @@ def test_dataclass_field_definitions_38(dto_type: type[DataclassDTO[Model]]) ->
type_wrappers=ANY,
raw=ANY,
),
replace(
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_kwarg(
name="computed",
annotation=str,
),
default_factory=None,
model_name=Model.__name__,
dto_field=DTOField(),
),
metadata=ANY,
type_wrappers=ANY,
raw=ANY,
),
]
for field_def, exp in zip(dto_type.generate_field_definitions(Model), expected):
for field_def, exp in itertools.zip_longest(expected, dto_type.generate_field_definitions(Model), fillvalue=None):
assert field_def == exp


Expand Down Expand Up @@ -114,8 +133,22 @@ def test_dataclass_field_definitions(dto_type: type[DataclassDTO[Model]]) -> Non
type_wrappers=ANY,
raw=ANY,
),
replace(
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_kwarg(
name="computed",
annotation=str,
),
default_factory=None,
model_name=Model.__name__,
dto_field=DTOField(mark="read-only"),
),
metadata=ANY,
type_wrappers=ANY,
raw=ANY,
),
]
for field_def, exp in zip(dto_type.generate_field_definitions(Model), expected):
for field_def, exp in itertools.zip_longest(expected, dto_type.generate_field_definitions(Model), fillvalue=None):
assert field_def == exp


Expand All @@ -128,8 +161,6 @@ def test_dataclass_detect_nested(dto_type: type[DataclassDTO[Model]]) -> None:


def test_dataclass_dto_annotated_dto_field() -> None:
Annotated[int, DTOField("read-only")]

@dataclass
class Model:
a: Annotated[int, DTOField("read-only")]
Expand All @@ -139,3 +170,22 @@ class Model:
fields = list(dto_type.generate_field_definitions(Model))
assert fields[0].dto_field == DTOField("read-only")
assert fields[1].dto_field == DTOField("read-only")


def test_property_underscore_exclude() -> None:
@dataclass
class Model:
one: str

@property
def _computed(self) -> int:
return 1

@property
def __also_computed(self) -> int:
return 1

dto_type = DataclassDTO[Model]
fields = list(dto_type.generate_field_definitions(Model))
assert fields[0].name == "one"
assert len(fields) == 1

0 comments on commit bebc307

Please sign in to comment.