From 73e0b892baa935cdd8099e2c0ce4d9ac060536b3 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 10 Dec 2024 11:51:02 +0000 Subject: [PATCH] Introduced enum to tango and rest --- src/fastcs/datatypes.py | 28 ++++++++++------------ src/fastcs/transport/tango/dsr.py | 14 ++++++----- src/fastcs/transport/tango/util.py | 37 ++++++++++++++++++++++++++++++ tests/conftest.py | 10 ++++---- tests/transport/epics/test_ioc.py | 28 +++++++++++----------- tests/transport/rest/test_rest.py | 22 +++++++++++------- tests/transport/tango/test_dsr.py | 27 ++++++++++++++-------- 7 files changed, 106 insertions(+), 60 deletions(-) create mode 100644 src/fastcs/transport/tango/util.py diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index a81062e4..8f119f38 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -3,11 +3,11 @@ import enum from abc import abstractmethod from collections.abc import Awaitable, Callable -from dataclasses import MISSING, dataclass, field +from dataclasses import dataclass, field from functools import cached_property from typing import Generic, TypeVar -T = TypeVar("T", int, float, bool, str, enum.Enum) +T = TypeVar("T", int, float, bool, str, enum.IntEnum) ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore @@ -115,35 +115,31 @@ def dtype(self) -> type[str]: return str -T_Enum = TypeVar("T_Enum", bound=enum.Enum) +T_Enum = TypeVar("T_Enum", bound=enum.IntEnum) @dataclass(frozen=True) -class Enum(DataType[enum.Enum]): - enum_cls: type[enum.Enum] +class Enum(DataType[enum.IntEnum]): + enum_cls: type[enum.IntEnum] @cached_property def is_string_enum(self) -> bool: return all(isinstance(member.value, str) for member in self.members) - @cached_property - def is_int_enum(self) -> bool: - return all(isinstance(member.value, int) for member in self.members) - def __post_init__(self): - if not issubclass(self.enum_cls, enum.Enum): - raise ValueError("Enum class has to take an enum.") - if not (self.is_string_enum or self.is_int_enum): - raise ValueError("All enum values must be of type str or int.") + if not issubclass(self.enum_cls, enum.IntEnum): + raise ValueError("Enum class has to take an IntEnum.") + if {member.value for member in self.members} != set(range(len(self.members))): + raise ValueError("Enum values must be contiguous.") @cached_property - def members(self) -> list[enum.Enum]: + def members(self) -> list[enum.IntEnum]: return list(self.enum_cls) @property - def dtype(self) -> type[enum.Enum]: + def dtype(self) -> type[enum.IntEnum]: return self.enum_cls @property - def initial_value(self) -> enum.Enum: + def initial_value(self) -> enum.IntEnum: return self.members[0] diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index 9b7f6f71..493fbb5c 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -10,14 +10,17 @@ from fastcs.datatypes import Float from .options import TangoDSROptions +from .util import get_cast_method_from_tango_type, get_cast_method_to_tango_type def _wrap_updater_fget( attr_name: str, attribute: AttrR, controller: BaseController ) -> Callable[[Any], Any]: + cast_method = get_cast_method_to_tango_type(attribute.datatype) + async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") - return attribute.get() + return cast_method(attribute.get()) return fget @@ -33,9 +36,11 @@ def _tango_display_format(attribute: Attribute) -> str: def _wrap_updater_fset( attr_name: str, attribute: AttrW, controller: BaseController ) -> Callable[[Any, Any], Any]: + cast_method = get_cast_method_from_tango_type(attribute.datatype) + async def fset(tango_device: Device, val): tango_device.info_stream(f"called fset method: {attr_name}") - await attribute.process(val) + await attribute.process(cast_method(val)) return fset @@ -179,10 +184,7 @@ def run(self, options: TangoDSROptions | None = None) -> None: def register_dev(dev_name: str, dev_class: str, dsr_instance: str) -> None: dsr_name = f"{dev_class}/{dsr_instance}" - dev_info = DbDevInfo() - dev_info.name = dev_name - dev_info._class = dev_class # noqa - dev_info.server = dsr_name + dev_info = DbDevInfo(dev_name, dev_class, dsr_name) db = Database() db.delete_device(dev_name) # Remove existing device entry diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py new file mode 100644 index 00000000..d7166d89 --- /dev/null +++ b/src/fastcs/transport/tango/util.py @@ -0,0 +1,37 @@ +from collections.abc import Callable + +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T + +TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) + + +def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object]: + match datatype: + case Enum(): + + def cast_to_tango_type(value) -> int: + return datatype.validate(value).value + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + + def cast_to_tango_type(value) -> object: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_to_tango_type + + +def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: + match datatype: + case Enum(enum_cls): + + def cast_from_tango_type(value: object) -> T: + return datatype.validate(enum_cls(value)) + + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + + def cast_from_tango_type(value) -> T: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + return cast_from_tango_type diff --git a/tests/conftest.py b/tests/conftest.py index 5d1c18cc..b8c86560 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,10 @@ import copy +import enum import os import random import string import subprocess import time -import enum from contextlib import contextmanager from pathlib import Path from typing import Any, Literal @@ -15,7 +15,7 @@ from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater from fastcs.controller import Controller, SubController -from fastcs.datatypes import Bool, Float, Int, String, Enum +from fastcs.datatypes import Bool, Enum, Float, Int, String from fastcs.wrappers import command, scan DATA_PATH = Path(__file__).parent / "data" @@ -80,12 +80,10 @@ def __init__(self) -> None: read_bool: AttrR = AttrR(Bool()) write_bool: AttrW = AttrW(Bool(), handler=TestSender()) read_string: AttrRW = AttrRW(String()) - enum: AttrRW = AttrRW( - Enum(enum.Enum("Enum", {"RED": "red", "GREEN": "green", "BLUE": "blue"})) - ) + enum: AttrRW = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) big_enum: AttrR = AttrR( Int( - allowed_values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + allowed_values=list(range(17)), ), ) diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 04f9b4be..bf1f789e 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -22,8 +22,8 @@ ) from fastcs.transport.epics.util import ( MBB_STATE_FIELDS, - get_record_metadata_from_datatype, get_record_metadata_from_attribute, + get_record_metadata_from_datatype, ) DEVICE = "DEVICE" @@ -31,14 +31,14 @@ SEVENTEEN_VALUES = [str(i) for i in range(1, 18)] -class OnOffStates(enum.Enum): - DISABLED = "disabled" - ENABLED = "enabled" +class OnOffStates(enum.IntEnum): + DISABLED = 0 + ENABLED = 1 -def record_input_from_enum(enum_cls: type[enum.Enum]) -> dict[str, str]: +def record_input_from_enum(enum_cls: type[enum.IntEnum]) -> dict[str, str]: return dict( - zip(MBB_STATE_FIELDS, [member.value for member in enum_cls], strict=False) + zip(MBB_STATE_FIELDS, [member.name for member in enum_cls], strict=False) ) @@ -64,10 +64,10 @@ async def test_create_and_link_read_pv(mocker: MockerFixture): record.set.assert_called_once_with(1) -class StringEnum(enum.Enum): - RED = "RED" - GREEN = "GREEN" - BLUE = "BLUE" +class ColourEnum(enum.IntEnum): + RED = 0 + GREEN = 1 + BLUE = 2 @pytest.mark.parametrize( @@ -75,17 +75,17 @@ class StringEnum(enum.Enum): ( (AttrR(String()), "longStringIn", {}), ( - AttrR(String(allowed_values=[member.value for member in StringEnum])), + AttrR(String(allowed_values=[member.name for member in list(ColourEnum)])), "longStringIn", {}, ), ( - AttrR(Enum(StringEnum)), + AttrR(Enum(ColourEnum)), "mbbIn", {"ZRST": "RED", "ONST": "GREEN", "TWST": "BLUE"}, ), ( - AttrR(Enum(enum.Enum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), + AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbIn", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), @@ -151,7 +151,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): "attribute,record_type,kwargs", ( ( - AttrR(Enum(enum.Enum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), + AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbOut", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index cfbb7dc2..de75fc4b 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -51,16 +51,22 @@ def test_write_bool(self, assertable_controller, client): with assertable_controller.assert_write_here(["write_bool"]): client.put("/write-bool", json={"value": True}) - def test_string_enum(self, assertable_controller, client): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - response = client.get("/string-enum") + def test_enum(self, assertable_controller, client): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) + expect = 0 + with assertable_controller.assert_read_here(["enum"]): + response = client.get("/enum") assert response.status_code == 200 assert response.json()["value"] == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - response = client.put("/string-enum", json={"value": new}) - assert client.get("/string-enum").json()["value"] == new + new = 2 + with assertable_controller.assert_write_here(["enum"]): + response = client.put("/enum", json={"value": new}) + assert client.get("/enum").json()["value"] == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(2) def test_big_enum(self, assertable_controller, client): expect = 0 diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 3fbe8931..bd8eead7 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -5,7 +5,7 @@ from fastcs.transport.tango.adapter import TangoTransport -class TestTangoDevice: +class TestTangoContext: @pytest.fixture(scope="class") def tango_context(self, assertable_controller): # https://tango-controls.readthedocs.io/projects/pytango/en/v9.5.1/testing/test_context.html @@ -16,11 +16,12 @@ def tango_context(self, assertable_controller): def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ "BigEnum", + "Enum", "ReadBool", "ReadInt", + "ReadString", "ReadWriteFloat", "ReadWriteInt", - "StringEnum", "WriteBool", "SubController01_ReadInt", "SubController02_ReadInt", @@ -79,15 +80,21 @@ def test_write_bool(self, assertable_controller, tango_context): with assertable_controller.assert_write_here(["write_bool"]): tango_context.write_attribute("WriteBool", True) - def test_string_enum(self, assertable_controller, tango_context): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - result = tango_context.read_attribute("StringEnum").value + def test_enum(self, assertable_controller, tango_context): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) + expect = 0 + with assertable_controller.assert_read_here(["enum"]): + result = tango_context.read_attribute("Enum").value assert result == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - tango_context.write_attribute("StringEnum", new) - assert tango_context.read_attribute("StringEnum").value == new + new = 1 + with assertable_controller.assert_write_here(["enum"]): + tango_context.write_attribute("Enum", new) + assert tango_context.read_attribute("Enum").value == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(1) def test_big_enum(self, assertable_controller, tango_context): expect = 0