From 66f46fa02d018f012ad91c1885943bf0d32be59f Mon Sep 17 00:00:00 2001 From: Marcell Nagy Date: Mon, 4 Nov 2024 13:34:18 +0000 Subject: [PATCH 1/2] Attempt graphql backend Apply upstream feedback Remove fastapi layer from gql Process review comments Rebase create_type -> create_strawberry_type --- pyproject.toml | 3 +- src/fastcs/launch.py | 10 +- src/fastcs/transport/__init__.py | 2 + src/fastcs/transport/graphQL/__init__.py | 0 src/fastcs/transport/graphQL/adapter.py | 24 +++ src/fastcs/transport/graphQL/graphQL.py | 198 +++++++++++++++++++++++ src/fastcs/transport/graphQL/options.py | 13 ++ tests/transport/graphQL/test_graphQL.py | 158 ++++++++++++++++++ 8 files changed, 406 insertions(+), 2 deletions(-) create mode 100644 src/fastcs/transport/graphQL/__init__.py create mode 100644 src/fastcs/transport/graphQL/adapter.py create mode 100644 src/fastcs/transport/graphQL/graphQL.py create mode 100644 src/fastcs/transport/graphQL/options.py create mode 100644 tests/transport/graphQL/test_graphQL.py diff --git a/pyproject.toml b/pyproject.toml index 2b7722ca..c1663382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "pvi~=0.10.0", "pytango", "softioc>=4.5.0", + "strawberry-graphql", ] dynamic = ["version"] license.file = "LICENSE" @@ -63,7 +64,7 @@ version_file = "src/fastcs/_version.py" [tool.pyright] typeCheckingMode = "standard" -reportMissingImports = false # Ignore missing stubs in imported modules +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index 04828578..d1eb12dd 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -14,11 +14,12 @@ from .exceptions import LaunchError from .transport.adapter import TransportAdapter from .transport.epics.options import EpicsOptions +from .transport.graphQL.options import GraphQLOptions from .transport.rest.options import RestOptions from .transport.tango.options import TangoOptions # Define a type alias for transport options -TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions +TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions | GraphQLOptions class FastCS: @@ -38,6 +39,13 @@ def __init__( self._backend.dispatcher, transport_options, ) + case GraphQLOptions(): + from .transport.graphQL.adapter import GraphQLTransport + + self._transport = GraphQLTransport( + controller, + transport_options, + ) case TangoOptions(): from .transport.tango.adapter import TangoTransport diff --git a/src/fastcs/transport/__init__.py b/src/fastcs/transport/__init__.py index 0ca90d43..36f3470e 100644 --- a/src/fastcs/transport/__init__.py +++ b/src/fastcs/transport/__init__.py @@ -2,6 +2,8 @@ from .epics.options import EpicsGUIOptions as EpicsGUIOptions from .epics.options import EpicsIOCOptions as EpicsIOCOptions from .epics.options import EpicsOptions as EpicsOptions +from .graphQL.options import GraphQLOptions as GraphQLOptions +from .graphQL.options import GraphQLServerOptions as GraphQLServerOptions from .rest.options import RestOptions as RestOptions from .rest.options import RestServerOptions as RestServerOptions from .tango.options import TangoDSROptions as TangoDSROptions diff --git a/src/fastcs/transport/graphQL/__init__.py b/src/fastcs/transport/graphQL/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/fastcs/transport/graphQL/adapter.py b/src/fastcs/transport/graphQL/adapter.py new file mode 100644 index 00000000..5b573c02 --- /dev/null +++ b/src/fastcs/transport/graphQL/adapter.py @@ -0,0 +1,24 @@ +from fastcs.controller import Controller +from fastcs.transport.adapter import TransportAdapter + +from .graphQL import GraphQLServer +from .options import GraphQLOptions + + +class GraphQLTransport(TransportAdapter): + def __init__( + self, + controller: Controller, + options: GraphQLOptions | None = None, + ): + self.options = options or GraphQLOptions() + self._server = GraphQLServer(controller) + + def create_docs(self) -> None: + raise NotImplementedError + + def create_gui(self) -> None: + raise NotImplementedError + + def run(self) -> None: + self._server.run(self.options.gql) diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py new file mode 100644 index 00000000..f1a6777d --- /dev/null +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -0,0 +1,198 @@ +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any + +import strawberry +import uvicorn +from strawberry.asgi import GraphQL +from strawberry.tools import create_type +from strawberry.types.field import StrawberryField + +from fastcs.attributes import AttrR, AttrRW, AttrW, T +from fastcs.controller import BaseController, Controller + +from .options import GraphQLServerOptions + + +class GraphQLServer: + def __init__(self, controller: Controller): + self._controller = controller + self._fields_tree: FieldTree = FieldTree("") + self._app = self._create_app() + + def _create_app(self) -> GraphQL: + _add_attribute_operations(self._fields_tree, self._controller) + _add_command_mutations(self._fields_tree, self._controller) + + schema_kwargs = {} + for key in ["query", "mutation"]: + if s_type := self._fields_tree.create_strawberry_type(key): + schema_kwargs[key] = s_type + schema = strawberry.Schema(**schema_kwargs) # type: ignore + app = GraphQL(schema) + + return app + + def run(self, options: GraphQLServerOptions | None = None) -> None: + if options is None: + options = GraphQLServerOptions() + + uvicorn.run( + self._app, + host=options.host, + port=options.port, + log_level=options.log_level, + ) + + +def _wrap_attr_set( + attr_name: str, + attribute: AttrW[T], +) -> Callable[[T], Coroutine[Any, Any, None]]: + async def _dynamic_f(value): + await attribute.process(value) + return value + + # Add type annotations for validation, schema, conversions + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["value"] = attribute.datatype.dtype + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_attr_get( + attr_name: str, + attribute: AttrR[T], +) -> Callable[[], Coroutine[Any, Any, Any]]: + async def _dynamic_f() -> Any: + return attribute.get() + + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_as_field( + field_name: str, + strawberry_type: type, +) -> StrawberryField: + def _dynamic_field(): + return strawberry_type() + + _dynamic_field.__name__ = field_name + _dynamic_field.__annotations__["return"] = strawberry_type + + return strawberry.field(_dynamic_field) + + +class FieldTree: + def __init__(self, name: str): + self.name = name + self.children: dict[str, FieldTree] = {} + self.fields: dict[str, list[StrawberryField]] = { + "query": [], + "mutation": [], + } + + def insert(self, path: list[str]) -> "FieldTree": + # Create child if not exist + name = path.pop(0) + if child := self.get_child(name): + pass + else: + child = FieldTree(name) + self.children[name] = child + + # Recurse if needed + if path: + return child.insert(path) + else: + return child + + def get_child(self, name: str) -> "FieldTree | None": + if name in self.children: + return self.children[name] + else: + return None + + def create_strawberry_type(self, strawberry_type: str) -> type | None: + for child in self.children.values(): + if new_type := child.create_strawberry_type(strawberry_type): + child_field = _wrap_as_field( + child.name, + new_type, + ) + self.fields[strawberry_type].append(child_field) + + if self.fields[strawberry_type]: + return create_type( + f"{self.name}{strawberry_type}", self.fields[strawberry_type] + ) + else: + return None + + +def _add_attribute_operations( + fields_tree: FieldTree, + controller: Controller, +) -> None: + for single_mapping in controller.get_controller_mappings(): + path = single_mapping.controller.path + if path: + node = fields_tree.insert(path) + else: + node = fields_tree + + if node is not None: + for attr_name, attribute in single_mapping.attributes.items(): + match attribute: + # mutation for server changes https://graphql.org/learn/queries/ + case AttrRW(): + node.fields["query"].append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + node.fields["mutation"].append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + case AttrR(): + node.fields["query"].append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + case AttrW(): + node.fields["mutation"].append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + + +def _wrap_command( + method_name: str, method: Callable, controller: BaseController +) -> Callable[..., Awaitable[bool]]: + async def _dynamic_f() -> bool: + await getattr(controller, method.__name__)() + return True + + _dynamic_f.__name__ = method_name + + return _dynamic_f + + +def _add_command_mutations(fields_tree: FieldTree, controller: Controller) -> None: + for single_mapping in controller.get_controller_mappings(): + path = single_mapping.controller.path + if path: + node = fields_tree.insert(path) + else: + node = fields_tree + + if node is not None: + for cmd_name, method in single_mapping.command_methods.items(): + node.fields["mutation"].append( + strawberry.mutation( + _wrap_command( + cmd_name, + method.fn, + single_mapping.controller, + ) + ) + ) diff --git a/src/fastcs/transport/graphQL/options.py b/src/fastcs/transport/graphQL/options.py new file mode 100644 index 00000000..b1ce2e83 --- /dev/null +++ b/src/fastcs/transport/graphQL/options.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass, field + + +@dataclass +class GraphQLServerOptions: + host: str = "localhost" + port: int = 8080 + log_level: str = "info" + + +@dataclass +class GraphQLOptions: + gql: GraphQLServerOptions = field(default_factory=GraphQLServerOptions) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py new file mode 100644 index 00000000..05ba00e0 --- /dev/null +++ b/tests/transport/graphQL/test_graphQL.py @@ -0,0 +1,158 @@ +import copy +import json +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from fastcs.transport.graphQL.adapter import GraphQLTransport + + +def nest_query(path: list[str]) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return field + + +def nest_mutation(path: list[str], value: Any) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return f"{field}(value: {json.dumps(value)})" + + +def nest_responce(path: list[str], value: Any) -> dict: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_responce(queue, value) + return {field: nesting} + else: + return {field: value} + + +class TestGraphQLServer: + @pytest.fixture(scope="class") + def client(self, assertable_controller): + app = GraphQLTransport(assertable_controller)._server._app + return TestClient(app) + + def test_read_int(self, assertable_controller, client): + expect = 0 + path = ["readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_write_int(self, assertable_controller, client): + expect = 0 + path = ["readWriteInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 9 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_int"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_write_float(self, assertable_controller, client): + expect = 0 + path = ["readWriteFloat"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_float"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 0.5 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_float"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_bool(self, assertable_controller, client): + expect = False + path = ["readBool"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_bool"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_write_bool(self, assertable_controller, client): + value = True + path = ["writeBool"] + mutation = f"mutation {{ {nest_mutation(path, value)} }}" + with assertable_controller.assert_write_here(["write_bool"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, value) + + def test_string_enum(self, assertable_controller, client): + expect = "" + path = ["stringEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["string_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = "new" + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["string_enum"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_big_enum(self, assertable_controller, client): + expect = 0 + path = ["bigEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["big_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_go(self, assertable_controller, client): + path = ["go"] + mutation = f"mutation {{ {nest_query(path)} }}" + with assertable_controller.assert_execute_here(path): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == {path[-1]: True} + + def test_read_child1(self, assertable_controller, client): + expect = 0 + path = ["SubController01", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController01", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_child2(self, assertable_controller, client): + expect = 0 + path = ["SubController02", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController02", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) From ef5c8e3b919171de5d93742f4f0497c83193e181 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Mon, 9 Dec 2024 12:52:56 +0000 Subject: [PATCH 2/2] Refactor --- src/fastcs/transport/graphQL/graphQL.py | 214 +++++++++++------------- 1 file changed, 94 insertions(+), 120 deletions(-) diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py index f1a6777d..85bde07b 100644 --- a/src/fastcs/transport/graphQL/graphQL.py +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -8,7 +8,13 @@ from strawberry.types.field import StrawberryField from fastcs.attributes import AttrR, AttrRW, AttrW, T -from fastcs.controller import BaseController, Controller +from fastcs.controller import ( + BaseController, + Controller, + SingleMapping, + _get_single_mapping, +) +from fastcs.exceptions import FastCSException from .options import GraphQLServerOptions @@ -16,18 +22,11 @@ class GraphQLServer: def __init__(self, controller: Controller): self._controller = controller - self._fields_tree: FieldTree = FieldTree("") self._app = self._create_app() def _create_app(self) -> GraphQL: - _add_attribute_operations(self._fields_tree, self._controller) - _add_command_mutations(self._fields_tree, self._controller) - - schema_kwargs = {} - for key in ["query", "mutation"]: - if s_type := self._fields_tree.create_strawberry_type(key): - schema_kwargs[key] = s_type - schema = strawberry.Schema(**schema_kwargs) # type: ignore + api = GraphQLAPI(self._controller) + schema = api.create_schema() app = GraphQL(schema) return app @@ -44,10 +43,83 @@ def run(self, options: GraphQLServerOptions | None = None) -> None: ) +class GraphQLAPI: + """A Strawberry API built dynamically from a Controller""" + + def __init__(self, controller: BaseController): + self.queries: list[StrawberryField] = [] + self.mutations: list[StrawberryField] = [] + + api = _get_single_mapping(controller) + + self._process_attributes(api) + self._process_commands(api) + self._process_sub_controllers(api) + + def _process_attributes(self, api: SingleMapping): + """Create queries and mutations from api attributes.""" + for attr_name, attribute in api.attributes.items(): + match attribute: + # mutation for server changes https://graphql.org/learn/queries/ + case AttrRW(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + case AttrR(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + case AttrW(): + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + + def _process_commands(self, api: SingleMapping): + """Create mutations from api commands""" + for cmd_name, method in api.command_methods.items(): + self.mutations.append( + strawberry.mutation(_wrap_command(cmd_name, method.fn, api.controller)) + ) + + def _process_sub_controllers(self, api: SingleMapping): + """Recursively add fields from the queries and mutations of sub controllers""" + for sub_controller in api.controller.get_sub_controllers().values(): + name = "".join(sub_controller.path) + child_tree = GraphQLAPI(sub_controller) + if child_tree.queries: + self.queries.append( + _wrap_as_field( + name, create_type(f"{name}Query", child_tree.queries) + ) + ) + if child_tree.mutations: + self.mutations.append( + _wrap_as_field( + name, create_type(f"{name}Mutation", child_tree.mutations) + ) + ) + + def create_schema(self) -> strawberry.Schema: + """Create a Strawberry Schema to load into a GraphQL application.""" + if not self.queries: + raise FastCSException( + "Can't create GraphQL transport from Controller with no read attributes" + ) + + query = create_type("Query", self.queries) + mutation = create_type("Mutation", self.mutations) if self.mutations else None + + return strawberry.Schema(query=query, mutation=mutation) + + def _wrap_attr_set( - attr_name: str, - attribute: AttrW[T], + attr_name: str, attribute: AttrW[T] ) -> Callable[[T], Coroutine[Any, Any, None]]: + """Wrap an attribute in a function with annotations for strawberry""" + async def _dynamic_f(value): await attribute.process(value) return value @@ -61,9 +133,10 @@ async def _dynamic_f(value): def _wrap_attr_get( - attr_name: str, - attribute: AttrR[T], + attr_name: str, attribute: AttrR[T] ) -> Callable[[], Coroutine[Any, Any, Any]]: + """Wrap an attribute in a function with annotations for strawberry""" + async def _dynamic_f() -> Any: return attribute.get() @@ -73,101 +146,23 @@ async def _dynamic_f() -> Any: return _dynamic_f -def _wrap_as_field( - field_name: str, - strawberry_type: type, -) -> StrawberryField: +def _wrap_as_field(field_name: str, operation: type) -> StrawberryField: + """Wrap a strawberry type as a field of a parent type""" + def _dynamic_field(): - return strawberry_type() + return operation() _dynamic_field.__name__ = field_name - _dynamic_field.__annotations__["return"] = strawberry_type + _dynamic_field.__annotations__["return"] = operation return strawberry.field(_dynamic_field) -class FieldTree: - def __init__(self, name: str): - self.name = name - self.children: dict[str, FieldTree] = {} - self.fields: dict[str, list[StrawberryField]] = { - "query": [], - "mutation": [], - } - - def insert(self, path: list[str]) -> "FieldTree": - # Create child if not exist - name = path.pop(0) - if child := self.get_child(name): - pass - else: - child = FieldTree(name) - self.children[name] = child - - # Recurse if needed - if path: - return child.insert(path) - else: - return child - - def get_child(self, name: str) -> "FieldTree | None": - if name in self.children: - return self.children[name] - else: - return None - - def create_strawberry_type(self, strawberry_type: str) -> type | None: - for child in self.children.values(): - if new_type := child.create_strawberry_type(strawberry_type): - child_field = _wrap_as_field( - child.name, - new_type, - ) - self.fields[strawberry_type].append(child_field) - - if self.fields[strawberry_type]: - return create_type( - f"{self.name}{strawberry_type}", self.fields[strawberry_type] - ) - else: - return None - - -def _add_attribute_operations( - fields_tree: FieldTree, - controller: Controller, -) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - if path: - node = fields_tree.insert(path) - else: - node = fields_tree - - if node is not None: - for attr_name, attribute in single_mapping.attributes.items(): - match attribute: - # mutation for server changes https://graphql.org/learn/queries/ - case AttrRW(): - node.fields["query"].append( - strawberry.field(_wrap_attr_get(attr_name, attribute)) - ) - node.fields["mutation"].append( - strawberry.mutation(_wrap_attr_set(attr_name, attribute)) - ) - case AttrR(): - node.fields["query"].append( - strawberry.field(_wrap_attr_get(attr_name, attribute)) - ) - case AttrW(): - node.fields["mutation"].append( - strawberry.mutation(_wrap_attr_set(attr_name, attribute)) - ) - - def _wrap_command( method_name: str, method: Callable, controller: BaseController ) -> Callable[..., Awaitable[bool]]: + """Wrap a command in a function with annotations for strawberry""" + async def _dynamic_f() -> bool: await getattr(controller, method.__name__)() return True @@ -175,24 +170,3 @@ async def _dynamic_f() -> bool: _dynamic_f.__name__ = method_name return _dynamic_f - - -def _add_command_mutations(fields_tree: FieldTree, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - if path: - node = fields_tree.insert(path) - else: - node = fields_tree - - if node is not None: - for cmd_name, method in single_mapping.command_methods.items(): - node.fields["mutation"].append( - strawberry.mutation( - _wrap_command( - cmd_name, - method.fn, - single_mapping.controller, - ) - ) - )