diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 75565f20..d180d547 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -198,8 +198,7 @@ def add_member(self, cls: type): # called muliple times for the same member, do no process if it wouldn't # change the member list return - if cls is self._base_class: - return + self._members.append(cls) union = self._make_union() if union: diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index c1ef8e22..2a345c06 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -1,8 +1,8 @@ from __future__ import annotations from collections.abc import Iterator -from dataclasses import is_dataclass -from typing import Generic +from dataclasses import asdict, is_dataclass +from typing import Any, Generic, Mapping import numpy as np from pydantic import BaseModel, Field @@ -66,6 +66,10 @@ def __sub__(self, other) -> DifferenceOf[Axis]: def __xor__(self, other) -> SymmetricDifferenceOf[Axis]: return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o)) + def serialize(self) -> Mapping[str, Any]: + """Serialize the Region to a dictionary.""" + return asdict(self) # type: ignore + @staticmethod def deserialize(obj): """Deserialize the Region from a dictionary.""" diff --git a/tests/test_serialization.py b/tests/test_serialization.py index adb5729d..89c2e7d7 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,7 +4,7 @@ import pytest from pydantic import ValidationError -from scanspec.regions import Circle, Rectangle, UnionOf +from scanspec.regions import Circle, Rectangle, Region, UnionOf from scanspec.specs import Line, Mask, Spec, Spiral @@ -15,6 +15,20 @@ def test_line_serializes() -> None: assert Spec.deserialize(serialized) == ob +def test_circle_serializes() -> None: + ob = Circle("x", "y", x_middle=0, y_middle=1, radius=4) + serialized = { + "x_axis": "x", + "y_axis": "y", + "x_middle": 0.0, + "y_middle": 1.0, + "radius": 4.0, + "type": "Circle", + } + assert ob.serialize() == serialized + assert Region.deserialize(serialized) == ob + + def test_masked_circle_serializes() -> None: ob = Mask(Line("x", 0, 1, 4), Circle("x", "y", x_middle=0, y_middle=1, radius=4)) serialized = {