-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bdab4fa
commit 8d61dde
Showing
6 changed files
with
434 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from fastcs.backend import Backend | ||
from fastcs.controller import Controller | ||
|
||
from .graphQL import GraphQLServer | ||
|
||
|
||
class GraphQLBackend(Backend): | ||
def __init__(self, controller: Controller): | ||
super().__init__(controller) | ||
|
||
self._server = GraphQLServer(self._mapping) | ||
|
||
def _run(self): | ||
self._server.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from collections.abc import Awaitable, Callable, Coroutine | ||
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
import strawberry | ||
import uvicorn | ||
from fastapi import FastAPI | ||
from strawberry.asgi import GraphQL | ||
from strawberry.tools import create_type | ||
from strawberry.types.field import StrawberryField | ||
|
||
from fastcs.attributes import AttrR, AttrRW, AttrW, T | ||
from fastcs.controller import BaseController | ||
from fastcs.mapping import Mapping | ||
|
||
|
||
@dataclass | ||
class GraphQLServerOptions: | ||
host: str = "localhost" | ||
port: int = 8080 | ||
log_level: str = "info" | ||
|
||
|
||
class GraphQLServer: | ||
def __init__(self, mapping: Mapping): | ||
self._mapping = mapping | ||
self._fields_tree: FieldsTree = FieldsTree("") | ||
self._app = self._create_app() | ||
|
||
def _create_app(self) -> FastAPI: | ||
_add_dev_attributes(self._fields_tree, self._mapping) | ||
_add_dev_commands(self._fields_tree, self._mapping) | ||
|
||
schema_kwargs = {} | ||
for key in ["query", "mutation"]: | ||
if s_type := self._fields_tree.create_type(key): | ||
schema_kwargs[key] = s_type | ||
schema = strawberry.Schema(**schema_kwargs) # type: ignore | ||
graphql_app: GraphQL = GraphQL(schema) | ||
|
||
app = FastAPI() | ||
app.add_route("/graphql", graphql_app) # type: ignore | ||
app.add_websocket_route("/graphql", graphql_app) # type: ignore | ||
|
||
return app | ||
|
||
def run(self, options: GraphQLServerOptions | None = None) -> None: | ||
if options is None: | ||
options = GraphQLServerOptions() | ||
|
||
uvicorn.run( | ||
self._app, | ||
host=options.host, | ||
port=options.port, | ||
log_level=options.log_level, | ||
) | ||
|
||
|
||
def _wrap_attr_set( | ||
d_attr_name: str, | ||
attribute: AttrW[T], | ||
) -> Callable[[T], Coroutine[Any, Any, None]]: | ||
async def _dynamic_f(value): | ||
await attribute.process(value) | ||
return value | ||
|
||
# Add type annotations for validation, schema, conversions | ||
_dynamic_f.__name__ = d_attr_name | ||
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype | ||
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype | ||
|
||
return _dynamic_f | ||
|
||
|
||
def _wrap_attr_get( | ||
d_attr_name: str, | ||
attribute: AttrR[T], | ||
) -> Callable[[], Coroutine[Any, Any, Any]]: | ||
async def _dynamic_f() -> Any: | ||
return attribute.get() | ||
|
||
_dynamic_f.__name__ = d_attr_name | ||
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype | ||
|
||
return _dynamic_f | ||
|
||
|
||
def _wrap_as_field( | ||
field_name: str, | ||
strawberry_type: type, | ||
) -> StrawberryField: | ||
def _dynamic_field(): | ||
return strawberry_type() | ||
|
||
_dynamic_field.__name__ = field_name | ||
_dynamic_field.__annotations__["return"] = strawberry_type | ||
|
||
return strawberry.field(_dynamic_field) | ||
|
||
|
||
class NodeNotFoundError(Exception): | ||
pass | ||
|
||
|
||
class FieldsTree: | ||
def __init__(self, name: str): | ||
self.name = name | ||
self.children: list[FieldsTree] = [] | ||
self.fields_dict: dict[str, list[StrawberryField]] = { | ||
"query": [], | ||
"mutation": [], | ||
} | ||
|
||
def insert(self, path: list[str]) -> "FieldsTree": | ||
# Create child if not exist | ||
name = path.pop(0) | ||
if self.is_child(name): | ||
child = self.get_child(name) | ||
else: | ||
child = FieldsTree(name) | ||
self.children.append(child) | ||
|
||
# Recurse if needed | ||
if path: | ||
return child.insert(path) # type: ignore | ||
else: | ||
return child | ||
|
||
def is_child(self, name: str) -> bool: | ||
for child in self.children: | ||
if child.name == name: | ||
return True | ||
return False | ||
|
||
def get_child(self, name: str) -> "FieldsTree": | ||
for child in self.children: | ||
if child.name == name: | ||
return child | ||
raise NodeNotFoundError | ||
|
||
def create_type(self, strawberry_type: str) -> type | None: | ||
for child in self.children: | ||
if new_type := child.create_type(strawberry_type): | ||
child_field = _wrap_as_field( | ||
child.name, | ||
new_type, | ||
) | ||
self.fields_dict[strawberry_type].append(child_field) | ||
|
||
if self.fields_dict[strawberry_type]: | ||
return create_type( | ||
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type] | ||
) | ||
else: | ||
return None | ||
|
||
|
||
def _add_dev_attributes( | ||
fields_tree: FieldsTree, | ||
mapping: Mapping, | ||
) -> None: | ||
for single_mapping in mapping.get_controller_mappings(): | ||
path = single_mapping.controller.path | ||
if path: | ||
node = fields_tree.insert(path) | ||
else: | ||
node = fields_tree | ||
|
||
if node is not None: | ||
for attr_name, attribute in single_mapping.attributes.items(): | ||
attr_name = attr_name.title().replace("_", "") | ||
|
||
match attribute: | ||
# mutation for server changes https://graphql.org/learn/queries/ | ||
case AttrRW(): | ||
node.fields_dict["query"].append( | ||
strawberry.field(_wrap_attr_get(attr_name, attribute)) | ||
) | ||
node.fields_dict["mutation"].append( | ||
strawberry.mutation(_wrap_attr_set(attr_name, attribute)) | ||
) | ||
case AttrR(): | ||
node.fields_dict["query"].append( | ||
strawberry.field(_wrap_attr_get(attr_name, attribute)) | ||
) | ||
case AttrW(): | ||
node.fields_dict["mutation"].append( | ||
strawberry.mutation(_wrap_attr_set(attr_name, attribute)) | ||
) | ||
|
||
|
||
def _wrap_command( | ||
method_name: str, method: Callable, controller: BaseController | ||
) -> Callable[..., Awaitable[bool]]: | ||
async def _dynamic_f() -> bool: | ||
await getattr(controller, method.__name__)() | ||
return True | ||
|
||
_dynamic_f.__name__ = method_name | ||
|
||
return _dynamic_f | ||
|
||
|
||
def _add_dev_commands( | ||
fields_tree: FieldsTree, | ||
mapping: Mapping, | ||
) -> None: | ||
for single_mapping in mapping.get_controller_mappings(): | ||
path = single_mapping.controller.path | ||
if path: | ||
node = fields_tree.insert(path) | ||
else: | ||
node = fields_tree | ||
|
||
if node is not None: | ||
for name, method in single_mapping.command_methods.items(): | ||
cmd_name = name.title().replace("_", "") | ||
node.fields_dict["mutation"].append( | ||
strawberry.mutation( | ||
_wrap_command( | ||
cmd_name, | ||
method.fn, | ||
single_mapping.controller, | ||
) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import copy | ||
import json | ||
import re | ||
from typing import Any | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
|
||
from fastcs.attributes import AttrR | ||
from fastcs.backends.graphQL.backend import GraphQLBackend | ||
from fastcs.datatypes import Bool, Float, Int | ||
|
||
|
||
def pascal_2_snake(input: list[str]) -> list[str]: | ||
snake_list = copy.deepcopy(input) | ||
snake_list[-1] = re.sub(r"(?<!^)(?=[A-Z])", "_", snake_list[-1]).lower() | ||
return snake_list | ||
|
||
|
||
def nest_query(path: list[str]) -> str: | ||
queue = copy.deepcopy(path) | ||
field = queue.pop(0) | ||
|
||
if queue: | ||
nesting = nest_query(queue) | ||
return f"{field} {{ {nesting} }} " | ||
else: | ||
return field | ||
|
||
|
||
def nest_mutation(path: list[str], value: Any) -> str: | ||
queue = copy.deepcopy(path) | ||
field = queue.pop(0) | ||
|
||
if queue: | ||
nesting = nest_query(queue) | ||
return f"{field} {{ {nesting} }} " | ||
else: | ||
return f"{field}(value: {json.dumps(value)})" | ||
|
||
|
||
def nest_responce(path: list[str], value: Any) -> dict: | ||
queue = copy.deepcopy(path) | ||
field = queue.pop(0) | ||
|
||
if queue: | ||
nesting = nest_responce(queue, value) | ||
return {field: nesting} | ||
else: | ||
return {field: value} | ||
|
||
|
||
class TestGraphQLServer: | ||
@pytest.fixture(scope="class", autouse=True) | ||
def setup_class(self, assertable_controller): | ||
self.controller = assertable_controller | ||
|
||
@pytest.fixture(scope="class") | ||
def client(self): | ||
app = GraphQLBackend(self.controller)._server._app | ||
return TestClient(app) | ||
|
||
@pytest.fixture(scope="class") | ||
def client_read(self, client): | ||
def _client_read(path: list[str], expected: Any): | ||
query = f"query {{ {nest_query(path)} }}" | ||
with self.controller.assertPerformed(pascal_2_snake(path), "READ"): | ||
response = client.post("/graphql", json={"query": query}) | ||
assert response.status_code == 200 | ||
assert response.json()["data"] == nest_responce(path, expected) | ||
|
||
return _client_read | ||
|
||
@pytest.fixture(scope="class") | ||
def client_write(self, client): | ||
def _client_write(path: list[str], value: Any): | ||
mutation = f"mutation {{ {nest_mutation(path, value)} }}" | ||
with self.controller.assertPerformed(pascal_2_snake(path), "WRITE"): | ||
response = client.post("/graphql", json={"query": mutation}) | ||
assert response.status_code == 200 | ||
assert response.json()["data"] == nest_responce(path, value) | ||
|
||
return _client_write | ||
|
||
@pytest.fixture(scope="class") | ||
def client_exec(self, client): | ||
def _client_exec(path: list[str]): | ||
mutation = f"mutation {{ {nest_query(path)} }}" | ||
with self.controller.assertPerformed(pascal_2_snake(path), "EXECUTE"): | ||
response = client.post("/graphql", json={"query": mutation}) | ||
assert response.status_code == 200 | ||
assert response.json()["data"] == {path[-1]: True} | ||
|
||
return _client_exec | ||
|
||
def test_read_int(self, client_read): | ||
client_read(["ReadInt"], AttrR(Int())._value) | ||
|
||
def test_read_write_int(self, client_read, client_write): | ||
client_read(["ReadWriteInt"], AttrR(Int())._value) | ||
client_write(["ReadWriteInt"], AttrR(Int())._value) | ||
|
||
def test_read_write_float(self, client_read, client_write): | ||
client_read(["ReadWriteFloat"], AttrR(Float())._value) | ||
client_write(["ReadWriteFloat"], AttrR(Float())._value) | ||
|
||
def test_read_bool(self, client_read): | ||
client_read(["ReadBool"], AttrR(Bool())._value) | ||
|
||
def test_write_bool(self, client_write): | ||
client_write(["WriteBool"], AttrR(Bool())._value) | ||
|
||
# # We need to discuss enums | ||
# def test_string_enum(self, client_read, client_write): | ||
|
||
def test_big_enum(self, client_read): | ||
client_read(["BigEnum"], AttrR(Int(), allowed_values=list(range(1, 18)))._value) | ||
|
||
def test_go(self, client_exec): | ||
client_exec(["Go"]) | ||
|
||
def test_read_child1(self, client_read): | ||
client_read(["SubController01", "ReadInt"], AttrR(Int())._value) | ||
|
||
def test_read_child2(self, client_read): | ||
client_read(["SubController02", "ReadInt"], AttrR(Int())._value) |
Oops, something went wrong.