Skip to content

Commit

Permalink
WIP: nearly done with waveform, rest api causing problems
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Dec 12, 2024
1 parent d9fd87c commit 2fc7114
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 22 deletions.
14 changes: 12 additions & 2 deletions src/fastcs/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from functools import cached_property
from typing import Generic, TypeVar

import numpy as np

T = TypeVar("T", int, float, bool, str, enum.IntEnum, np.ndarray)
Expand Down Expand Up @@ -162,21 +163,30 @@ def initial_value(self) -> enum.IntEnum:
@dataclass(frozen=True)
class WaveForm(DataType[np.ndarray]):
array_dtype: np.typing.DTypeLike
array_shape: tuple[int, ...] = (2000,)
shape: tuple[int, ...] = (2000,)

@property
def dtype(self) -> type[np.ndarray]:
return np.ndarray

@property
def initial_value(self) -> np.ndarray:
return np.ndarray(self.array_shape, dtype=self.array_dtype)
return np.zeros(self.shape, dtype=self.array_dtype)

def validate(self, value: np.ndarray) -> np.ndarray:
print("VALIDATING", self)
super().validate(value)
if self.array_dtype != value.dtype:
raise ValueError(
f"Value dtype {value.dtype} is not the same as the array dtype "
f"{self.array_dtype}"
)
if len(self.shape) != len(value.shape) or any(
shape1 > shape2
for shape1, shape2 in zip(value.shape, self.shape, strict=True)
):
raise ValueError(
f"Value shape {value.shape} exceeeds the shape maximum shape "
f"{self.shape}"
)
return value
1 change: 0 additions & 1 deletion src/fastcs/transport/epics/gui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum

from pvi._format.dls import DLSFormatter
from pvi.device import (
Expand Down
4 changes: 1 addition & 3 deletions src/fastcs/transport/epics/ioc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from collections.abc import Callable
from types import MethodType
from typing import Any, Literal
Expand All @@ -7,10 +6,9 @@
from softioc.asyncio_dispatcher import AsyncioDispatcher
from softioc.pythonSoftIoc import RecordWrapper

from fastcs import attributes
from fastcs.attributes import AttrR, AttrRW, AttrW
from fastcs.controller import BaseController, Controller
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, WaveForm, T
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm
from fastcs.exceptions import FastCSException
from fastcs.transport.epics.util import (
MBB_MAX_CHOICES,
Expand Down
4 changes: 2 additions & 2 deletions src/fastcs/transport/epics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import asdict

from fastcs.attributes import Attribute
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, WaveForm, T
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm

_MBB_FIELD_PREFIXES = (
"ZR",
Expand Down Expand Up @@ -39,7 +39,7 @@
"max_alarm": "HOPR",
"znam": "ZNAM",
"onam": "ONAM",
"array_shape": "length",
"shape": "length",
}


Expand Down
21 changes: 16 additions & 5 deletions src/fastcs/transport/rest/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from fastapi import FastAPI
from pydantic import create_model

from .util import (
convert_datatype,
get_cast_method_to_rest_type,
get_cast_method_from_rest_type,
)
from fastcs.attributes import AttrR, AttrRW, AttrW, T
from fastcs.controller import BaseController, Controller

Expand Down Expand Up @@ -38,19 +43,22 @@ def _put_request_body(attribute: AttrW[T]):
Creates a pydantic model for each datatype which defines the schema
of the PUT request body
"""
converted_datatype = convert_datatype(attribute.datatype)
type_name = str(attribute.datatype.dtype.__name__).title()
# key=(type, ...) to declare a field without default value
return create_model(
f"Put{type_name}Value",
value=(attribute.datatype.dtype, ...),
value=(converted_datatype, ...),
)


def _wrap_attr_put(
attribute: AttrW[T],
) -> Callable[[T], Coroutine[Any, Any, None]]:
cast_method = get_cast_method_from_rest_type(attribute.datatype)

async def attr_set(request):
await attribute.process(request.value)
await attribute.process(cast_method(request.value))

# Fast api uses type annotations for validation, schema, conversions
attr_set.__annotations__["request"] = _put_request_body(attribute)
Expand All @@ -63,20 +71,23 @@ def _get_response_body(attribute: AttrR[T]):
Creates a pydantic model for each datatype which defines the schema
of the GET request body
"""
type_name = str(attribute.datatype.dtype.__name__).title()
converted_datatype = convert_datatype(attribute.datatype)
type_name = str(converted_datatype.__name__).title()
# key=(type, ...) to declare a field without default value
return create_model(
f"Get{type_name}Value",
value=(attribute.datatype.dtype, ...),
value=(converted_datatype, ...),
)


def _wrap_attr_get(
attribute: AttrR[T],
) -> Callable[[], Coroutine[Any, Any, Any]]:
cast_method = get_cast_method_to_rest_type(attribute.datatype)

async def attr_get() -> Any: # Must be any as response_model is set
value = attribute.get() # type: ignore
return {"value": value}
return {"value": cast_method(value)}

return attr_get

Expand Down
45 changes: 45 additions & 0 deletions src/fastcs/transport/rest/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
from typing import Callable
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm

REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String)


def convert_datatype(datatype: DataType[T]) -> type:
match datatype:
case WaveForm():
return list
case _:
return datatype.dtype


def get_cast_method_to_rest_type(datatype: DataType[T]) -> Callable[[T], object]:
match datatype:
case WaveForm():

def cast_to_rest_type(value) -> list:
return value.tolist()
case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES):

def cast_to_rest_type(value):
return datatype.validate(value)
case _:
raise ValueError(f"Unsupported datatype {datatype}")

return cast_to_rest_type


def get_cast_method_from_rest_type(datatype: DataType[T]) -> Callable[[object], T]:
match datatype:
case WaveForm():

def cast_from_rest_type(value) -> T:
return datatype.validate(np.array(value, dtype=datatype.array_dtype))
case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES):

def cast_from_rest_type(value) -> T:
return datatype.validate(value)
case _:
raise ValueError(f"Unsupported datatype {datatype}")

return cast_from_rest_type
3 changes: 1 addition & 2 deletions src/fastcs/transport/tango/dsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from tango import AttrWriteType, Database, DbDevInfo, DevState, server
from tango.server import Device

from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW
from fastcs.attributes import AttrR, AttrRW, AttrW
from fastcs.controller import BaseController
from fastcs.datatypes import Float

from .options import TangoDSROptions
from .util import (
Expand Down
22 changes: 16 additions & 6 deletions src/fastcs/transport/tango/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections.abc import Callable
from dataclasses import asdict
from typing import Any

from fastcs.attributes import Attribute
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, WaveForm, T
from typing import Any
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm
from tango import AttrDataFormat

TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm)

DATATYPE_FIELD_TO_SERVER_FIELD = {
"prec": "format",
"units": "unit",
"min": "min_value",
"max": "max_value",
Expand Down Expand Up @@ -36,7 +36,19 @@ def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]:

match datatype:
case WaveForm():
dtype = (datatype.array_dtype for _ in datatype.array_shape)
dtype = datatype.array_dtype
match len(datatype.shape):
case 1:
arguments["dformat"] = AttrDataFormat.SPECTRUM
arguments["max_dim_x"] = datatype.shape[0]
case 2:
arguments["dformat"] = AttrDataFormat.IMAGE
arguments["max_dim_x"], arguments["max_dim_y"] = datatype.shape
case _:
raise ValueError(
"Waveform has to be 1D or 2D in tango, received shape of "
f"{datatype.shape}"
)
case Float():
arguments["format"] = f"%.{datatype.prec}"

Expand All @@ -45,8 +57,6 @@ def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]:
if value is None:
arguments[argument] = ""

raise RuntimeError(arguments)

return arguments


Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self) -> None:
write_bool: AttrW = AttrW(Bool(), handler=TestSender())
read_string: AttrRW = AttrRW(String())
enum: AttrRW = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2})))
ond_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10,)))
one_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10,)))
two_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10, 10)))
big_enum: AttrR = AttrR(
Int(
Expand Down
24 changes: 24 additions & 0 deletions tests/transport/rest/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from fastapi import responses
import pytest
import numpy as np
from fastapi.testclient import TestClient

from fastcs.transport.rest.adapter import RestTransport
Expand Down Expand Up @@ -75,6 +77,28 @@ def test_big_enum(self, assertable_controller, client):
assert response.status_code == 200
assert response.json()["value"] == expect

def test_1d_waveform(self, assertable_controller, client):
expect = np.zeros((10,), dtype=np.int32)
with assertable_controller.assert_read_here(["one_d_waveform"]):
response = client.get("one-d-waveform")
assert np.array_equal(response.json()["value"], expect)
new = np.array([1, 2, 3], dtype=np.int32)
with assertable_controller.assert_write_here(["one_d_waveform"]):
client.put("/one-d-waveform", new)
assert np.array_equal(client.get("/one-d-waveform").json()["value"], new)

def test_2d_waveform(self, assertable_controller, client):
expect = np.zeros((10, 10), dtype=np.int32)
with assertable_controller.assert_read_here(["two_d_waveform"]):
result = client.get("/two-d-waveform")
assert np.array_equal(result.json()["value"], expect)
new = np.array([[1, 2, 3]], dtype=np.int32)
# with assertable_controller.assert_write_here(["two_d_waveform"]):
client.put("/two_d_waveform", json={"value": new.tolist()})
print(new)
print(client.get("/two-d-waveform").json()["value"])
assert np.array_equal(client.get("/two-d-waveform").json()["value"], new)

def test_go(self, assertable_controller, client):
with assertable_controller.assert_execute_here(["go"]):
response = client.put("/go")
Expand Down
23 changes: 23 additions & 0 deletions tests/transport/tango/test_dsr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy as np
from tango import DevState
from tango.test_context import DeviceTestContext

Expand All @@ -17,11 +18,13 @@ def test_list_attributes(self, tango_context):
assert list(tango_context.get_attribute_list()) == [
"BigEnum",
"Enum",
"OneDWaveform",
"ReadBool",
"ReadInt",
"ReadString",
"ReadWriteFloat",
"ReadWriteInt",
"TwoDWaveform",
"WriteBool",
"SubController01_ReadInt",
"SubController02_ReadInt",
Expand Down Expand Up @@ -102,6 +105,26 @@ def test_big_enum(self, assertable_controller, tango_context):
result = tango_context.read_attribute("BigEnum").value
assert result == expect

def test_1d_waveform(self, assertable_controller, tango_context):
expect = np.zeros((10,), dtype=np.int32)
with assertable_controller.assert_read_here(["one_d_waveform"]):
result = tango_context.read_attribute("OneDWaveform").value
assert np.array_equal(result, expect)
new = np.array([1, 2, 3], dtype=np.int32)
with assertable_controller.assert_write_here(["one_d_waveform"]):
tango_context.write_attribute("OneDWaveform", new)
assert np.array_equal(tango_context.read_attribute("OneDWaveform").value, new)

def test_2d_waveform(self, assertable_controller, tango_context):
expect = np.zeros((10, 10), dtype=np.int32)
with assertable_controller.assert_read_here(["two_d_waveform"]):
result = tango_context.read_attribute("TwoDWaveform").value
assert np.array_equal(result, expect)
new = np.array([[1, 2, 3]], dtype=np.int32)
with assertable_controller.assert_write_here(["two_d_waveform"]):
tango_context.write_attribute("TwoDWaveform", new)
assert np.array_equal(tango_context.read_attribute("TwoDWaveform").value, new)

def test_go(self, assertable_controller, tango_context):
with assertable_controller.assert_execute_here(["go"]):
tango_context.command_inout("Go")
Expand Down

0 comments on commit 2fc7114

Please sign in to comment.