diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 39f2a57a..92fe5b24 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -1,17 +1,12 @@ from __future__ import annotations -import dataclasses from collections.abc import Callable, Iterable, Iterator, Sequence -from functools import partial -from inspect import isclass +from functools import lru_cache from typing import ( Any, Generic, Literal, TypeVar, - Union, - get_origin, - get_type_hints, ) import numpy as np @@ -19,10 +14,9 @@ ConfigDict, Field, GetCoreSchemaHandler, - TypeAdapter, ) from pydantic.dataclasses import rebuild_dataclass -from pydantic.fields import FieldInfo +from pydantic_core.core_schema import tagged_union_schema __all__ = [ "if_instance_do", @@ -43,13 +37,13 @@ def discriminated_union_of_subclasses( - cls, + cls: type, discriminator: str = "type", -): +) -> type: """Add all subclasses of super_cls to a discriminated union. For all subclasses of super_cls, add a discriminator field to identify - the type. Raw JSON should look like {"type": , params for + the type. Raw JSON should look like {: , params for ...}. Example:: @@ -107,131 +101,69 @@ def calculate(self) -> int: super_cls: The superclass of the union, Expression in the above example discriminator: The discriminator that will be inserted into the serialized documents for type determination. Defaults to "type". - config: A pydantic config class to be inserted into all - subclasses. Defaults to None. Returns: - Type | Callable[[Type], Type]: A decorator that adds the necessary + Type: A decorator that adds the necessary functionality to a class. """ tagged_union = _TaggedUnion(cls, discriminator) - _tagged_unions[cls] = tagged_union - cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator)) - cls.__get_pydantic_core_schema__ = classmethod( - partial(__get_pydantic_core_schema__, tagged_union=tagged_union) - ) - return cls + def add_subclass_to_union(subclass): + # Add a discriminator field to a subclass so it can + # be identified when deserializing + subclass.__annotations__ = { + **subclass.__annotations__, + discriminator: Literal[subclass.__name__], # type: ignore + } + setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore -T = TypeVar("T", type, Callable) + def default_handler(subclass, source_type: Any, handler: GetCoreSchemaHandler): + tagged_union.add_member(subclass) + return handler(subclass) + subclass.__get_pydantic_core_schema__ = classmethod(default_handler) -def deserialize_as(cls, obj): - return _tagged_unions[cls].type_adapter.validate_python(obj) + def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler): + # Rebuild any dataclass (including this one) that references this union + # Note that this has to be done after the creation of the dataclass so that + # previously created classes can refer to this newly created class + return tagged_union.schema(handler) + cls.__init_subclass__ = classmethod(add_subclass_to_union) + cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) + return cls -def uses_tagged_union(cls_or_func: T) -> T: - """ - Decorator that processes the type hints of a class or function to detect and - register any tagged unions. If a tagged union is detected in the type hints, - it registers the class or function as a referrer to that tagged union. - Args: - cls_or_func (T): The class or function to be processed for tagged unions. - Returns: - T: The original class or function, unmodified. - """ - for k, v in get_type_hints(cls_or_func).items(): - tagged_union = _tagged_unions.get(get_origin(v) or v, None) - if tagged_union: - tagged_union.add_referrer(cls_or_func, k) - return cls_or_func + +T = TypeVar("T", type, Callable) class _TaggedUnion: def __init__(self, base_class: type, discriminator: str): self._base_class = base_class - # The members of the tagged union, i.e. subclasses of the baseclasses - self._members: list[type] = [] # Classes and their field names that refer to this tagged union - self._referrers: dict[type | Callable, set[str]] = {} - self.type_adapter: TypeAdapter = TypeAdapter(None) self._discriminator = discriminator - - def _make_union(self): - if len(self._members) > 0: - return Union[tuple(self._members)] # type: ignore # noqa - - def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any): - # Set the field to use the `type` discriminator on deserialize - # https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators - if isclass(cls): - assert isinstance( - field, FieldInfo - ), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501 - field.discriminator = self._discriminator + # The members of the tagged union, i.e. subclasses of the baseclass + self._members: list[type] = [] def add_member(self, cls: type): if cls in self._members: - # A side effect of hooking to __get_pydantic_core_schema__ is that it is - # called muliple times for the same member, do no process if it wouldn't - # change the member list return - self._members.append(cls) - union = self._make_union() - if union: - # There are more than 1 subclasses in the union, so set all the referrers - # to use this union - for referrer, fields in self._referrers.items(): - if isclass(referrer): - for field in dataclasses.fields(referrer): - if field.name in fields: - field.type = union - self._set_discriminator(referrer, field.name, field.default) - rebuild_dataclass(referrer, force=True) - # Make a type adapter for use in deserialization - self.type_adapter = TypeAdapter(union) - - def add_referrer(self, cls: type | Callable, attr_name: str): - self._referrers.setdefault(cls, set()).add(attr_name) - union = self._make_union() - if union: - # There are more than 1 subclasses in the union, so set the referrer - # (which is currently being constructed) to use it - # note that we use annotations as the class has not been turned into - # a dataclass yet - cls.__annotations__[attr_name] = union - self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None)) - - -_tagged_unions: dict[type, _TaggedUnion] = {} - - -def __init_subclass__(discriminator: str, cls: type): - # Add a discriminator field to the class so it can - # be identified when deserailizing, and make sure it is last in the list - cls.__annotations__ = { - **cls.__annotations__, - discriminator: Literal[cls.__name__], # type: ignore - } - cls.type = Field(cls.__name__, repr=False) # type: ignore - # Replace any bare annotation with a discriminated union of subclasses - # and register this class as one that refers to that union so it can be updated - for k, v in get_type_hints(cls).items(): - # This works for Expression[T] or Expression - tagged_union = _tagged_unions.get(get_origin(v) or v, None) - if tagged_union: - tagged_union.add_referrer(cls, k) - - -def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion -): - # Rebuild any dataclass (including this one) that references this union - # Note that this has to be done after the creation of the dataclass so that - # previously created classes can refer to this newly created class - tagged_union.add_member(cls) - return handler(source_type) + for member in self._members: + if member != cls: + rebuild_dataclass(member, force=True) + + def schema(self, handler): + return tagged_union_schema( + make_schema(tuple(self._members), handler), + discriminator=self._discriminator, + ref=self._base_class.__name__, + ) + + +@lru_cache(1) +def make_schema(members: tuple[type, ...], handler): + return {member.__name__: handler(member) for member in members} def if_instance_do(x: Any, cls: type, func: Callable): diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index 0c9b0b1f..acf1237a 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -5,14 +5,13 @@ from typing import Any, Generic import numpy as np -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from pydantic.dataclasses import dataclass from .core import ( AxesPoints, Axis, StrictConfig, - deserialize_as, discriminated_union_of_subclasses, if_instance_do, ) @@ -71,9 +70,8 @@ def serialize(self) -> Mapping[str, Any]: return asdict(self) # type: ignore @staticmethod - def deserialize(obj): - """Deserialize the Region from a dictionary.""" - return deserialize_as(Region, obj) + def deserialize(obj: Mapping[str, Any]) -> Region: + return TypeAdapter(Region).validate_python(obj) def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray: diff --git a/src/scanspec/service.py b/src/scanspec/service.py index e52ef5d9..ac8a147f 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -11,7 +11,7 @@ from pydantic import Field from pydantic.dataclasses import dataclass -from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union +from scanspec.core import AxesPoints, Frames, Path from .specs import Line, Spec @@ -27,7 +27,6 @@ @dataclass -@uses_tagged_union class ValidResponse: """Response model for spec validation.""" @@ -44,7 +43,6 @@ class PointsFormat(str, Enum): @dataclass -@uses_tagged_union class PointsRequest: """A request for generated scan points.""" @@ -125,7 +123,6 @@ class SmallestStepResponse: @app.post("/valid", response_model=ValidResponse) -@uses_tagged_union def valid( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), ) -> ValidResponse | JSONResponse: @@ -198,7 +195,6 @@ def bounds( @app.post("/gap", response_model=GapResponse) -@uses_tagged_union def gap( spec: Spec = Body( ..., @@ -224,7 +220,6 @@ def gap( @app.post("/smalleststep", response_model=SmallestStepResponse) -@uses_tagged_union def smallest_step( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), ) -> SmallestStepResponse: diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index d51e5a6f..74c4d262 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -8,7 +8,7 @@ ) import numpy as np -from pydantic import Field, validate_call +from pydantic import Field, TypeAdapter, validate_call from pydantic.dataclasses import dataclass from .core import ( @@ -18,7 +18,6 @@ Path, SnakedFrames, StrictConfig, - deserialize_as, discriminated_union_of_subclasses, gap_between_frames, if_instance_do, @@ -107,13 +106,12 @@ def concat(self, other: Spec) -> Concat[Axis]: return Concat(self, other) def serialize(self) -> Mapping[str, Any]: - """Serialize the spec to a dictionary.""" + """Serialize the Spec to a dictionary.""" return asdict(self) # type: ignore @staticmethod - def deserialize(obj): - """Deserialize the spec from a dictionary.""" - return deserialize_as(Spec, obj) + def deserialize(obj: Mapping[str, Any]) -> Spec: + return TypeAdapter(Spec).validate_python(obj) @dataclass(config=StrictConfig) diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py new file mode 100644 index 00000000..3f42e798 --- /dev/null +++ b/tests/test_basemodel.py @@ -0,0 +1,40 @@ +import pytest +from pydantic import BaseModel, TypeAdapter + +from scanspec.specs import Line, Spec + + +class Foo(BaseModel): + spec: Spec + + +simple_foo = Foo(spec=Line("x", 1, 2, 5)) +nested_foo = Foo(spec=(Line("x", 1, 2, 5) + Line("x", 1, 2, 5)) * Line("y", 1, 2, 5)) + + +@pytest.mark.parametrize("model", [simple_foo, nested_foo]) +def test_model_validation(model: Foo): + # To/from Python dict + as_dict = model.model_dump() + deserialized = Foo.model_validate(as_dict) + assert deserialized == model + + # To/from Json dict + as_json = model.model_dump_json() + deserialized = Foo.model_validate_json(as_json) + assert deserialized == model + + +@pytest.mark.parametrize("model", [simple_foo, nested_foo]) +def test_type_adapter(model: Foo): + type_adapter = TypeAdapter(Foo) + + # To/from Python dict + as_dict = model.model_dump() + deserialized = type_adapter.validate_python(as_dict) + assert deserialized == model + + # To/from Json dict + as_json = model.model_dump_json() + deserialized = type_adapter.validate_json(as_json) + assert deserialized == model