Skip to content

Commit

Permalink
removed allowed_values and improved Enum code
Browse files Browse the repository at this point in the history
In both `EPICS` and `Tango`, now the index of the `Enum` member will be used instead of the `value`. The labels will always be the enum member `name`.
  • Loading branch information
evalott100 committed Jan 17, 2025
1 parent 0e7a5f2 commit 00f6389
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 351 deletions.
7 changes: 4 additions & 3 deletions src/fastcs/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def __init__(
handler: Any = None,
description: str | None = None,
) -> None:
assert issubclass(
datatype.dtype, ATTRIBUTE_TYPES
), f"Attr type must be one of {ATTRIBUTE_TYPES}, received type {datatype.dtype}"
assert issubclass(datatype.dtype, ATTRIBUTE_TYPES), (
f"Attr type must be one of {ATTRIBUTE_TYPES}, "
"received type {datatype.dtype}"
)
self._datatype: DataType[T] = datatype
self._access_mode: AttrMode = access_mode
self._group = group
Expand Down
6 changes: 3 additions & 3 deletions src/fastcs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def _link_attribute_sender_class(single_mapping: SingleMapping) -> None:
for attr_name, attribute in single_mapping.attributes.items():
match attribute:
case AttrW(sender=Sender()):
assert (
not attribute.has_process_callback()
), f"Cannot assign both put method and Sender object to {attr_name}"
assert not attribute.has_process_callback(), (
f"Cannot assign both put method and Sender object to {attr_name}"
)

callback = _create_sender_callback(attribute, single_mapping.controller)
attribute.set_process_callback(callback)
Expand Down
52 changes: 14 additions & 38 deletions src/fastcs/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import enum
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, TypeVar

import numpy as np

T = TypeVar("T", int, float, bool, str, enum.IntEnum, np.ndarray)
T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray)

ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore

Expand All @@ -21,10 +21,6 @@
class DataType(Generic[T]):
"""Generic datatype mapping to a python type, with additional metadata."""

# We move this to each datatype so that we can have positional
# args in subclasses.
allowed_values: list[T] | None = field(init=False, default=None)

@property
@abstractmethod
def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars
Expand All @@ -34,15 +30,7 @@ def validate(self, value: T) -> T:
"""Validate a value against fields in the datatype."""
if not isinstance(value, self.dtype):
raise ValueError(f"Value {value} is not of type {self.dtype}")
if (
hasattr(self, "allowed_values")
and self.allowed_values is not None
and value not in self.allowed_values
):
raise ValueError(
f"Value {value} is not in the allowed values for this "
f"datatype {self.allowed_values}."
)

return value

@property
Expand Down Expand Up @@ -79,8 +67,6 @@ def initial_value(self) -> T_Numerical:
class Int(_Numerical[int]):
"""`DataType` mapping to builtin ``int``."""

allowed_values: list[int] | None = None

@property
def dtype(self) -> type[int]:
return int
Expand All @@ -91,7 +77,6 @@ class Float(_Numerical[float]):
"""`DataType` mapping to builtin ``float``."""

prec: int = 2
allowed_values: list[float] | None = None

@property
def dtype(self) -> type[float]:
Expand All @@ -102,10 +87,6 @@ def dtype(self) -> type[float]:
class Bool(DataType[bool]):
"""`DataType` mapping to builtin ``bool``."""

znam: str = "OFF"
onam: str = "ON"
allowed_values: list[bool] | None = None

@property
def dtype(self) -> type[bool]:
return bool
Expand All @@ -119,8 +100,6 @@ def initial_value(self) -> bool:
class String(DataType[str]):
"""`DataType` mapping to builtin ``str``."""

allowed_values: list[str] | None = None

@property
def dtype(self) -> type[str]:
return str
Expand All @@ -130,33 +109,30 @@ def initial_value(self) -> str:
return ""


T_Enum = TypeVar("T_Enum", bound=enum.IntEnum)
T_Enum = TypeVar("T_Enum", bound=enum.Enum)


@dataclass(frozen=True)
class Enum(DataType[enum.IntEnum]):
enum_cls: type[enum.IntEnum]

@cached_property
def is_string_enum(self) -> bool:
return all(isinstance(member.value, str) for member in self.members)
class Enum(Generic[T_Enum], DataType[T_Enum]):
enum_cls: type[T_Enum]

def __post_init__(self):
if not issubclass(self.enum_cls, enum.IntEnum):
raise ValueError("Enum class has to take an IntEnum.")
if {member.value for member in self.members} != set(range(len(self.members))):
raise ValueError("Enum values must be contiguous.")
if not issubclass(self.enum_cls, enum.Enum):
raise ValueError("Enum class has to take an Enum.")

def index_of(self, value: T_Enum) -> int:
return self.members.index(value)

@cached_property
def members(self) -> list[enum.IntEnum]:
def members(self) -> list[T_Enum]:
return list(self.enum_cls)

@property
def dtype(self) -> type[enum.IntEnum]:
def dtype(self) -> type[T_Enum]:
return self.enum_cls

@property
def initial_value(self) -> enum.IntEnum:
def initial_value(self) -> T_Enum:
return self.members[0]


Expand Down
182 changes: 29 additions & 153 deletions src/fastcs/transport/epics/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

from fastcs.attributes import AttrR, AttrRW, AttrW
from fastcs.controller import BaseController, Controller
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm
from fastcs.exceptions import FastCSException
from fastcs.datatypes import DataType, T
from fastcs.transport.epics.util import (
MBB_MAX_CHOICES,
MBB_STATE_FIELDS,
get_cast_method_from_epics_type,
get_cast_method_to_epics_type,
get_record_metadata_from_attribute,
get_record_metadata_from_datatype,
builder_callable_from_attribute,
get_callable_from_epics_type,
get_callable_to_epics_type,
record_metadata_from_attribute,
record_metadata_from_datatype,
)

from .options import EpicsIOCOptions
Expand Down Expand Up @@ -160,76 +158,34 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No
def _create_and_link_read_pv(
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T]
) -> None:
cast_method = get_cast_method_to_epics_type(attribute.datatype)
cast_to_epics_type = get_callable_to_epics_type(attribute.datatype)

async def async_record_set(value: T):
record.set(cast_method(value))
record.set(cast_to_epics_type(value))

record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
record = _make_record(f"{pv_prefix}:{pv_name}", attribute)
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")

attribute.set_update_callback(async_record_set)


def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
match attribute.datatype:
case Bool():
record = builder.boolIn(
pv,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Int():
record = builder.longIn(
pv,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Float():
record = builder.aIn(
pv,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case String():
record = builder.longStringIn(
pv,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Enum():
if len(attribute.datatype.members) > MBB_MAX_CHOICES:
raise RuntimeError(
f"Received an `Enum` datatype on attribute {attribute} "
f"with more elements than the epics limit `{MBB_MAX_CHOICES}` "
f"for `mbbIn`. Use an `Int or `String with `allowed_values`."
)
state_keys = dict(
zip(
MBB_STATE_FIELDS,
[member.name for member in attribute.datatype.members],
strict=False,
)
)
record = builder.mbbIn(
pv,
**state_keys,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case WaveForm():
record = builder.WaveformIn(
pv,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case _:
raise FastCSException(
f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}"
)
def _make_record(
pv: str,
attribute: AttrR | AttrW | AttrRW,
on_update: Callable | None = None,
) -> RecordWrapper:
builder_callable = builder_callable_from_attribute(attribute, on_update is None)
datatype_record_metadata = record_metadata_from_datatype(attribute.datatype)
attribute_record_metadata = record_metadata_from_attribute(attribute)

update = {"always_update": True, "on_update": on_update} if on_update else {}

record = builder_callable(
pv, **update, **datatype_record_metadata, **attribute_record_metadata
)

def datatype_updater(datatype: DataType):
for name, value in get_record_metadata_from_datatype(datatype).items():
for name, value in record_metadata_from_datatype(datatype).items():
record.set_field(name, value)

attribute.add_update_datatype_callback(datatype_updater)
Expand All @@ -239,102 +195,22 @@ def datatype_updater(datatype: DataType):
def _create_and_link_write_pv(
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T]
) -> None:
cast_method = get_cast_method_from_epics_type(attribute.datatype)
cast_from_epics_type = get_callable_from_epics_type(attribute.datatype)
cast_to_epics_type = get_callable_to_epics_type(attribute.datatype)

async def on_update(value):
await attribute.process_without_display_update(cast_method(value))
await attribute.process_without_display_update(cast_from_epics_type(value))

async def async_write_display(value: T):
record.set(cast_method(value), process=False)
record.set(cast_to_epics_type(value), process=False)

record = _get_output_record(
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update
)
record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update)

_add_attr_pvi_info(record, pv_prefix, attr_name, "w")

attribute.set_write_display_callback(async_write_display)


def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
match attribute.datatype:
case Bool():
record = builder.boolOut(
pv,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Int():
record = builder.longOut(
pv,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Float():
record = builder.aOut(
pv,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case String():
record = builder.longStringOut(
pv,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case Enum():
if len(attribute.datatype.members) > MBB_MAX_CHOICES:
raise RuntimeError(
f"Received an `Enum` datatype on attribute {attribute} "
f"with more elements than the epics limit `{MBB_MAX_CHOICES}` "
f"for `mbbOut`. Use an `Int or `String with `allowed_values`."
)

state_keys = dict(
zip(
MBB_STATE_FIELDS,
[member.name for member in attribute.datatype.members],
strict=False,
)
)
record = builder.mbbOut(
pv,
**state_keys,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)
case WaveForm():
record = builder.WaveformOut(
pv,
always_update=True,
on_update=on_update,
**get_record_metadata_from_datatype(attribute.datatype),
**get_record_metadata_from_attribute(attribute),
)

case _:
raise FastCSException(
f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}"
)

def datatype_updater(datatype: DataType):
for name, value in get_record_metadata_from_datatype(datatype).items():
record.set_field(name, value)

attribute.add_update_datatype_callback(datatype_updater)
return record


def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None:
for single_mapping in controller.get_controller_mappings():
path = single_mapping.controller.path
Expand Down
Loading

0 comments on commit 00f6389

Please sign in to comment.