diff --git a/pyproject.toml b/pyproject.toml index 2b7722ca..c1663382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "pvi~=0.10.0", "pytango", "softioc>=4.5.0", + "strawberry-graphql", ] dynamic = ["version"] license.file = "LICENSE" @@ -63,7 +64,7 @@ version_file = "src/fastcs/_version.py" [tool.pyright] typeCheckingMode = "standard" -reportMissingImports = false # Ignore missing stubs in imported modules +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index 04828578..d1eb12dd 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -14,11 +14,12 @@ from .exceptions import LaunchError from .transport.adapter import TransportAdapter from .transport.epics.options import EpicsOptions +from .transport.graphQL.options import GraphQLOptions from .transport.rest.options import RestOptions from .transport.tango.options import TangoOptions # Define a type alias for transport options -TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions +TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions | GraphQLOptions class FastCS: @@ -38,6 +39,13 @@ def __init__( self._backend.dispatcher, transport_options, ) + case GraphQLOptions(): + from .transport.graphQL.adapter import GraphQLTransport + + self._transport = GraphQLTransport( + controller, + transport_options, + ) case TangoOptions(): from .transport.tango.adapter import TangoTransport diff --git a/src/fastcs/transport/__init__.py b/src/fastcs/transport/__init__.py index 0ca90d43..36f3470e 100644 --- a/src/fastcs/transport/__init__.py +++ b/src/fastcs/transport/__init__.py @@ -2,6 +2,8 @@ from .epics.options import EpicsGUIOptions as EpicsGUIOptions from .epics.options import EpicsIOCOptions as EpicsIOCOptions from .epics.options import EpicsOptions as EpicsOptions +from .graphQL.options import GraphQLOptions as GraphQLOptions +from .graphQL.options import GraphQLServerOptions as GraphQLServerOptions from .rest.options import RestOptions as RestOptions from .rest.options import RestServerOptions as RestServerOptions from .tango.options import TangoDSROptions as TangoDSROptions diff --git a/src/fastcs/transport/graphQL/__init__.py b/src/fastcs/transport/graphQL/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/fastcs/transport/graphQL/adapter.py b/src/fastcs/transport/graphQL/adapter.py new file mode 100644 index 00000000..5b573c02 --- /dev/null +++ b/src/fastcs/transport/graphQL/adapter.py @@ -0,0 +1,24 @@ +from fastcs.controller import Controller +from fastcs.transport.adapter import TransportAdapter + +from .graphQL import GraphQLServer +from .options import GraphQLOptions + + +class GraphQLTransport(TransportAdapter): + def __init__( + self, + controller: Controller, + options: GraphQLOptions | None = None, + ): + self.options = options or GraphQLOptions() + self._server = GraphQLServer(controller) + + def create_docs(self) -> None: + raise NotImplementedError + + def create_gui(self) -> None: + raise NotImplementedError + + def run(self) -> None: + self._server.run(self.options.gql) diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py new file mode 100644 index 00000000..85bde07b --- /dev/null +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -0,0 +1,172 @@ +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any + +import strawberry +import uvicorn +from strawberry.asgi import GraphQL +from strawberry.tools import create_type +from strawberry.types.field import StrawberryField + +from fastcs.attributes import AttrR, AttrRW, AttrW, T +from fastcs.controller import ( + BaseController, + Controller, + SingleMapping, + _get_single_mapping, +) +from fastcs.exceptions import FastCSException + +from .options import GraphQLServerOptions + + +class GraphQLServer: + def __init__(self, controller: Controller): + self._controller = controller + self._app = self._create_app() + + def _create_app(self) -> GraphQL: + api = GraphQLAPI(self._controller) + schema = api.create_schema() + app = GraphQL(schema) + + return app + + def run(self, options: GraphQLServerOptions | None = None) -> None: + if options is None: + options = GraphQLServerOptions() + + uvicorn.run( + self._app, + host=options.host, + port=options.port, + log_level=options.log_level, + ) + + +class GraphQLAPI: + """A Strawberry API built dynamically from a Controller""" + + def __init__(self, controller: BaseController): + self.queries: list[StrawberryField] = [] + self.mutations: list[StrawberryField] = [] + + api = _get_single_mapping(controller) + + self._process_attributes(api) + self._process_commands(api) + self._process_sub_controllers(api) + + def _process_attributes(self, api: SingleMapping): + """Create queries and mutations from api attributes.""" + for attr_name, attribute in api.attributes.items(): + match attribute: + # mutation for server changes https://graphql.org/learn/queries/ + case AttrRW(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + case AttrR(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + case AttrW(): + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + + def _process_commands(self, api: SingleMapping): + """Create mutations from api commands""" + for cmd_name, method in api.command_methods.items(): + self.mutations.append( + strawberry.mutation(_wrap_command(cmd_name, method.fn, api.controller)) + ) + + def _process_sub_controllers(self, api: SingleMapping): + """Recursively add fields from the queries and mutations of sub controllers""" + for sub_controller in api.controller.get_sub_controllers().values(): + name = "".join(sub_controller.path) + child_tree = GraphQLAPI(sub_controller) + if child_tree.queries: + self.queries.append( + _wrap_as_field( + name, create_type(f"{name}Query", child_tree.queries) + ) + ) + if child_tree.mutations: + self.mutations.append( + _wrap_as_field( + name, create_type(f"{name}Mutation", child_tree.mutations) + ) + ) + + def create_schema(self) -> strawberry.Schema: + """Create a Strawberry Schema to load into a GraphQL application.""" + if not self.queries: + raise FastCSException( + "Can't create GraphQL transport from Controller with no read attributes" + ) + + query = create_type("Query", self.queries) + mutation = create_type("Mutation", self.mutations) if self.mutations else None + + return strawberry.Schema(query=query, mutation=mutation) + + +def _wrap_attr_set( + attr_name: str, attribute: AttrW[T] +) -> Callable[[T], Coroutine[Any, Any, None]]: + """Wrap an attribute in a function with annotations for strawberry""" + + async def _dynamic_f(value): + await attribute.process(value) + return value + + # Add type annotations for validation, schema, conversions + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["value"] = attribute.datatype.dtype + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_attr_get( + attr_name: str, attribute: AttrR[T] +) -> Callable[[], Coroutine[Any, Any, Any]]: + """Wrap an attribute in a function with annotations for strawberry""" + + async def _dynamic_f() -> Any: + return attribute.get() + + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_as_field(field_name: str, operation: type) -> StrawberryField: + """Wrap a strawberry type as a field of a parent type""" + + def _dynamic_field(): + return operation() + + _dynamic_field.__name__ = field_name + _dynamic_field.__annotations__["return"] = operation + + return strawberry.field(_dynamic_field) + + +def _wrap_command( + method_name: str, method: Callable, controller: BaseController +) -> Callable[..., Awaitable[bool]]: + """Wrap a command in a function with annotations for strawberry""" + + async def _dynamic_f() -> bool: + await getattr(controller, method.__name__)() + return True + + _dynamic_f.__name__ = method_name + + return _dynamic_f diff --git a/src/fastcs/transport/graphQL/options.py b/src/fastcs/transport/graphQL/options.py new file mode 100644 index 00000000..b1ce2e83 --- /dev/null +++ b/src/fastcs/transport/graphQL/options.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass, field + + +@dataclass +class GraphQLServerOptions: + host: str = "localhost" + port: int = 8080 + log_level: str = "info" + + +@dataclass +class GraphQLOptions: + gql: GraphQLServerOptions = field(default_factory=GraphQLServerOptions) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py new file mode 100644 index 00000000..05ba00e0 --- /dev/null +++ b/tests/transport/graphQL/test_graphQL.py @@ -0,0 +1,158 @@ +import copy +import json +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from fastcs.transport.graphQL.adapter import GraphQLTransport + + +def nest_query(path: list[str]) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return field + + +def nest_mutation(path: list[str], value: Any) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return f"{field}(value: {json.dumps(value)})" + + +def nest_responce(path: list[str], value: Any) -> dict: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_responce(queue, value) + return {field: nesting} + else: + return {field: value} + + +class TestGraphQLServer: + @pytest.fixture(scope="class") + def client(self, assertable_controller): + app = GraphQLTransport(assertable_controller)._server._app + return TestClient(app) + + def test_read_int(self, assertable_controller, client): + expect = 0 + path = ["readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_write_int(self, assertable_controller, client): + expect = 0 + path = ["readWriteInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 9 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_int"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_write_float(self, assertable_controller, client): + expect = 0 + path = ["readWriteFloat"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_float"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 0.5 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_float"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_bool(self, assertable_controller, client): + expect = False + path = ["readBool"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_bool"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_write_bool(self, assertable_controller, client): + value = True + path = ["writeBool"] + mutation = f"mutation {{ {nest_mutation(path, value)} }}" + with assertable_controller.assert_write_here(["write_bool"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, value) + + def test_string_enum(self, assertable_controller, client): + expect = "" + path = ["stringEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["string_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = "new" + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["string_enum"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_big_enum(self, assertable_controller, client): + expect = 0 + path = ["bigEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["big_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_go(self, assertable_controller, client): + path = ["go"] + mutation = f"mutation {{ {nest_query(path)} }}" + with assertable_controller.assert_execute_here(path): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == {path[-1]: True} + + def test_read_child1(self, assertable_controller, client): + expect = 0 + path = ["SubController01", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController01", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_child2(self, assertable_controller, client): + expect = 0 + path = ["SubController02", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController02", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect)