From ec6be11d4d80cd8bc73fa2acb358fd8e2ce5a032 Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Wed, 7 Aug 2024 10:05:49 +0000 Subject: [PATCH] added more tests --- src/scanspec/core.py | 3 +-- src/scanspec/regions.py | 8 ++++++-- tests/test_serialization.py | 16 +++++++++++++++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 75565f20..d180d547 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -198,8 +198,7 @@ def add_member(self, cls: type): # called muliple times for the same member, do no process if it wouldn't # change the member list return - if cls is self._base_class: - return + self._members.append(cls) union = self._make_union() if union: diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index c1ef8e22..2a345c06 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -1,8 +1,8 @@ from __future__ import annotations from collections.abc import Iterator -from dataclasses import is_dataclass -from typing import Generic +from dataclasses import asdict, is_dataclass +from typing import Any, Generic, Mapping import numpy as np from pydantic import BaseModel, Field @@ -66,6 +66,10 @@ def __sub__(self, other) -> DifferenceOf[Axis]: def __xor__(self, other) -> SymmetricDifferenceOf[Axis]: return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o)) + def serialize(self) -> Mapping[str, Any]: + """Serialize the Region to a dictionary.""" + return asdict(self) # type: ignore + @staticmethod def deserialize(obj): """Deserialize the Region from a dictionary.""" diff --git a/tests/test_serialization.py b/tests/test_serialization.py index adb5729d..89c2e7d7 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,7 +4,7 @@ import pytest from pydantic import ValidationError -from scanspec.regions import Circle, Rectangle, UnionOf +from scanspec.regions import Circle, Rectangle, Region, UnionOf from scanspec.specs import Line, Mask, Spec, Spiral @@ -15,6 +15,20 @@ def test_line_serializes() -> None: assert Spec.deserialize(serialized) == ob +def test_circle_serializes() -> None: + ob = Circle("x", "y", x_middle=0, y_middle=1, radius=4) + serialized = { + "x_axis": "x", + "y_axis": "y", + "x_middle": 0.0, + "y_middle": 1.0, + "radius": 4.0, + "type": "Circle", + } + assert ob.serialize() == serialized + assert Region.deserialize(serialized) == ob + + def test_masked_circle_serializes() -> None: ob = Mask(Line("x", 0, 1, 4), Circle("x", "y", x_middle=0, y_middle=1, radius=4)) serialized = {