Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
GDYendell committed Dec 9, 2024
1 parent ef2c5c4 commit f5e231a
Showing 1 changed file with 72 additions and 121 deletions.
193 changes: 72 additions & 121 deletions src/fastcs/transport/graphQL/graphQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,18 @@
from strawberry.types.field import StrawberryField

from fastcs.attributes import AttrR, AttrRW, AttrW, T
from fastcs.controller import BaseController, Controller
from fastcs.controller import BaseController, Controller, _get_single_mapping

from .options import GraphQLServerOptions


class GraphQLServer:
def __init__(self, controller: Controller):
self._controller = controller
self._fields_tree: FieldTree = FieldTree("")
self._api: GraphQLAPI = GraphQLAPI(controller)
self._app = self._create_app()

def _create_app(self) -> GraphQL:
_add_attribute_operations(self._fields_tree, self._controller)
_add_command_mutations(self._fields_tree, self._controller)

schema_kwargs = {}
for key in ["query", "mutation"]:
if s_type := self._fields_tree.create_strawberry_type(key):
schema_kwargs[key] = s_type
schema = strawberry.Schema(**schema_kwargs) # type: ignore
schema = strawberry.Schema(query=self._api.query, mutation=self._api.mutation)
app = GraphQL(schema)

return app
Expand All @@ -44,9 +36,72 @@ def run(self, options: GraphQLServerOptions | None = None) -> None:
)


class GraphQLAPI:
"""A Strawberry API built dynamically from a Controller"""

def __init__(self, controller: BaseController):
self.queries: list[StrawberryField] = []
self.mutations: list[StrawberryField] = []

api = _get_single_mapping(controller)

for attr_name, attribute in api.attributes.items():
match attribute:
# mutation for server changes https://graphql.org/learn/queries/
case AttrRW():
self.queries.append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
self.mutations.append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)
case AttrR():
self.queries.append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
case AttrW():
self.mutations.append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)

for cmd_name, method in api.command_methods.items():
self.mutations.append(
strawberry.mutation(_wrap_command(cmd_name, method.fn, controller))
)

for sub_controller in controller.get_sub_controllers().values():
name = "".join(sub_controller.path)
child_tree = GraphQLAPI(sub_controller)
if child_tree.queries:
self.queries.append(
_wrap_as_field(
name, create_type(f"{name}Query", child_tree.queries)
)
)
if child_tree.mutations:
self.mutations.append(
_wrap_as_field(
name, create_type(f"{name}Mutation", child_tree.mutations)
)
)

@property
def query(self) -> type:
if not self.queries:
raise ValueError(
"Can't create GraphQL transport from Controller with no read attributes"
)

return create_type("Query", self.queries)

@property
def mutation(self) -> type | None:
if self.mutations:
return create_type("Mutation", self.mutations)


def _wrap_attr_set(
attr_name: str,
attribute: AttrW[T],
attr_name: str, attribute: AttrW[T]
) -> Callable[[T], Coroutine[Any, Any, None]]:
async def _dynamic_f(value):
await attribute.process(value)
Expand All @@ -61,8 +116,7 @@ async def _dynamic_f(value):


def _wrap_attr_get(
attr_name: str,
attribute: AttrR[T],
attr_name: str, attribute: AttrR[T]
) -> Callable[[], Coroutine[Any, Any, Any]]:
async def _dynamic_f() -> Any:
return attribute.get()
Expand All @@ -73,98 +127,16 @@ async def _dynamic_f() -> Any:
return _dynamic_f


def _wrap_as_field(
field_name: str,
strawberry_type: type,
) -> StrawberryField:
def _wrap_as_field(field_name: str, operation: type) -> StrawberryField:
def _dynamic_field():
return strawberry_type()
return operation()

_dynamic_field.__name__ = field_name
_dynamic_field.__annotations__["return"] = strawberry_type
_dynamic_field.__annotations__["return"] = operation

return strawberry.field(_dynamic_field)


class FieldTree:
def __init__(self, name: str):
self.name = name
self.children: dict[str, FieldTree] = {}
self.fields: dict[str, list[StrawberryField]] = {
"query": [],
"mutation": [],
}

def insert(self, path: list[str]) -> "FieldTree":
# Create child if not exist
name = path.pop(0)
if child := self.get_child(name):
pass
else:
child = FieldTree(name)
self.children[name] = child

# Recurse if needed
if path:
return child.insert(path)
else:
return child

def get_child(self, name: str) -> "FieldTree | None":
if name in self.children:
return self.children[name]
else:
return None

def create_strawberry_type(self, strawberry_type: str) -> type | None:
for child in self.children.values():
if new_type := child.create_strawberry_type(strawberry_type):
child_field = _wrap_as_field(
child.name,
new_type,
)
self.fields[strawberry_type].append(child_field)

if self.fields[strawberry_type]:
return create_type(
f"{self.name}{strawberry_type}", self.fields[strawberry_type]
)
else:
return None


def _add_attribute_operations(
fields_tree: FieldTree,
controller: Controller,
) -> None:
for single_mapping in controller.get_controller_mappings():
path = single_mapping.controller.path
if path:
node = fields_tree.insert(path)
else:
node = fields_tree

if node is not None:
for attr_name, attribute in single_mapping.attributes.items():
match attribute:
# mutation for server changes https://graphql.org/learn/queries/
case AttrRW():
node.fields["query"].append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
node.fields["mutation"].append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)
case AttrR():
node.fields["query"].append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
case AttrW():
node.fields["mutation"].append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)


def _wrap_command(
method_name: str, method: Callable, controller: BaseController
) -> Callable[..., Awaitable[bool]]:
Expand All @@ -175,24 +147,3 @@ async def _dynamic_f() -> bool:
_dynamic_f.__name__ = method_name

return _dynamic_f


def _add_command_mutations(fields_tree: FieldTree, controller: Controller) -> None:
for single_mapping in controller.get_controller_mappings():
path = single_mapping.controller.path
if path:
node = fields_tree.insert(path)
else:
node = fields_tree

if node is not None:
for cmd_name, method in single_mapping.command_methods.items():
node.fields["mutation"].append(
strawberry.mutation(
_wrap_command(
cmd_name,
method.fn,
single_mapping.controller,
)
)
)

0 comments on commit f5e231a

Please sign in to comment.