Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Aug 7, 2024
1 parent a7c804b commit ec6be11
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ def add_member(self, cls: type):
# called muliple times for the same member, do no process if it wouldn't
# change the member list
return
if cls is self._base_class:
return

self._members.append(cls)
union = self._make_union()
if union:
Expand Down
8 changes: 6 additions & 2 deletions src/scanspec/regions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Iterator
from dataclasses import is_dataclass
from typing import Generic
from dataclasses import asdict, is_dataclass
from typing import Any, Generic, Mapping

import numpy as np
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -66,6 +66,10 @@ def __sub__(self, other) -> DifferenceOf[Axis]:
def __xor__(self, other) -> SymmetricDifferenceOf[Axis]:
return if_instance_do(other, Region, lambda o: SymmetricDifferenceOf(self, o))

def serialize(self) -> Mapping[str, Any]:
"""Serialize the Region to a dictionary."""
return asdict(self) # type: ignore

@staticmethod
def deserialize(obj):
"""Deserialize the Region from a dictionary."""
Expand Down
16 changes: 15 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from pydantic import ValidationError

from scanspec.regions import Circle, Rectangle, UnionOf
from scanspec.regions import Circle, Rectangle, Region, UnionOf
from scanspec.specs import Line, Mask, Spec, Spiral


Expand All @@ -15,6 +15,20 @@ def test_line_serializes() -> None:
assert Spec.deserialize(serialized) == ob


def test_circle_serializes() -> None:
ob = Circle("x", "y", x_middle=0, y_middle=1, radius=4)
serialized = {
"x_axis": "x",
"y_axis": "y",
"x_middle": 0.0,
"y_middle": 1.0,
"radius": 4.0,
"type": "Circle",
}
assert ob.serialize() == serialized
assert Region.deserialize(serialized) == ob


def test_masked_circle_serializes() -> None:
ob = Mask(Line("x", 0, 1, 4), Circle("x", "y", x_middle=0, y_middle=1, radius=4))
serialized = {
Expand Down

0 comments on commit ec6be11

Please sign in to comment.