Skip to content

Commit

Permalink
refactor: Use tree-sitter rather than pyparsing
Browse files Browse the repository at this point in the history
  • Loading branch information
caksoylar committed Nov 16, 2024
1 parent 0943c95 commit 20fd314
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 115 deletions.
223 changes: 136 additions & 87 deletions keymap_drawer/dts.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""
Helper module to parse ZMK keymap-like DT syntax into a tree,
while keeping track of "compatible" values and utilities to parse
bindings fields.
Helper module to parse ZMK keymap-like DT syntax to extract nodes with
given "compatible" values, and utilities to extract their properties and
child nodes.
The implementation is based on a nested expression parser for curly braces
through pyparsing with some additions on top to clean up comments and run the
C preprocessor using pcpp.
The implementation is based on pcpp to run the C preprocessor, and then
tree-sitter-devicetree to run queries to find compatible nodes and extract properties.
Node overrides via node references are supported in a limited capacity.
"""

import re
from collections import defaultdict
from io import StringIO
from itertools import chain

import pyparsing as pp
import tree_sitter_devicetree as ts
from pcpp.preprocessor import Action, OutputDirective, Preprocessor # type: ignore
from tree_sitter import Language, Node, Parser, Tree

TS_LANG = Language(ts.language())


class DTNode:
Expand All @@ -24,45 +26,56 @@ class DTNode:
label: str | None
content: str
children: list["DTNode"]
label_refs: dict[str, "DTNode"]

def __init__(self, name: str, parse: pp.ParseResults):
def __init__(self, node: Node, text_buf: bytes, override_nodes: list["DTNode"] | None = None):
"""
Initialize a node from its name (which may be in the form of `label:name`)
and `parse` which contains the node itself.
"""

if ":" in name:
self.label, self.name = name.split(":", maxsplit=1)
else:
self.label, self.name = None, name

self.content = " ".join(elt for elt in parse if isinstance(elt, str))
self.children = [
DTNode(name=elt_p, parse=elt_n)
for elt_p, elt_n in zip(parse[:-1], parse[1:])
if isinstance(elt_p, str) and isinstance(elt_n, pp.ParseResults)
]

# keep track of labeled nodes
self.label_refs = {self.label: self} if self.label else {}
for child in self.children:
self.label_refs |= child.label_refs
child.label_refs = {}
self.node = node
self.text_buf = text_buf
name_node = node.child_by_field_name("name")
assert name_node is not None
self.name = self._get_content(name_node)
self.label = self._get_content(v) if (v := node.child_by_field_name("label")) is not None else None
self.children = sorted(
(DTNode(child, text_buf, override_nodes) for child in node.children if child.type == "node"),
key=lambda x: x.node.start_byte,
)
self.overrides = []
if override_nodes and self.label is not None:
# consider pre-compiling nodes by label for performance
self.overrides = [node for node in override_nodes if self.label == node.name.lstrip("&")]

def _get_content(self, node: Node) -> str:
return self.text_buf[node.start_byte : node.end_byte].decode("utf-8").replace("\n", " ")

def _get_property(self, property_re: str) -> list[Node] | None:
children = [node for node in self.node.children if node.type == "property"]
for override_node in self.overrides:
children += [node for node in override_node.node.children if node.type == "property"]
for child in children[::-1]:
name_node = child.child_by_field_name("name")
assert name_node is not None
if re.match(property_re, self._get_content(name_node)):
return child.children_by_field_name("value")
return None

def get_string(self, property_re: str) -> str | None:
"""Extract last defined value for a `string` type property matching the `property_re` regex."""
out = None
for m in re.finditer(rf'(?:^|\s)({property_re}) = "(.*?)"', self.content):
out = m.group(2)
return out
if (nodes := self._get_property(property_re)) is None:
return None
return self._get_content(nodes[0]).strip('"')

def get_array(self, property_re: str) -> list[str] | None:
"""Extract last defined values for a `array` type property matching the `property_re` regex."""
matches = list(re.finditer(rf"(?:^|\s){property_re} = (<.*?>( ?, ?<.*?>)*) ?;", self.content))
if not matches:
if (nodes := self._get_property(property_re)) is None:
return None
return list(chain.from_iterable(content.split(" ") for content in re.findall(r"<(.*?)>", matches[-1].group(1))))
return list(
chain.from_iterable(
self._get_content(node).strip("<>").split() for node in nodes if node.type == "integer_cells"
)
)

def get_phandle_array(self, property_re: str) -> list[str] | None:
"""Extract last defined values for a `phandle-array` type property matching the `property_re` regex."""
Expand All @@ -79,26 +92,24 @@ def get_path(self, property_re: str) -> str | None:
Extract last defined value for a `path` type property matching the `property_re` regex.
Only supports phandle paths `&p` rather than path types `"/a/b"` right now.
"""
out = None
for m in re.finditer(rf"(?:^|\s){property_re} = &(.*?);", self.content):
out = m.group(1)
return out
if (nodes := self._get_property(property_re)) is None:
return None
return self._get_content(nodes[0]).lstrip("&")

def __repr__(self):
def __repr__(self) -> str:
content = " ".join(self._get_content(node) for node in self.node.children if node.type != "node")
return (
f"DTNode(name={self.name!r}, label={self.label!r}, content={self.content!r}, "
f"children={[node.name for node in self.children]})\n"
f"DTNode(name={self.name!r}, label={self.label!r}, content={content!r}, "
f"children={[node.name for node in self.children]})"
)


class DeviceTree:
"""
Class that parses a DTS file (optionally preprocessed by the C preprocessor)
and represents it as a DT tree, with some helpful methods.
and provides methods to extract `compatible` and `chosen` nodes as DTNode's.
"""

_nodelabel_re = re.compile(r"([\w-]+) *: *([\w-]+) *{")
_compatible_re = re.compile(r'compatible = "(.*?)"')
_custom_data_header = "__keymap_drawer_data__"

def __init__(
Expand All @@ -110,8 +121,9 @@ def __init__(
additional_includes: list[str] | None = None,
):
"""
Given an input DTS string `in_str` and `file_name` it is read from, parse it into an internap
tree representation and track what "compatible" value each node has.
Given an input DTS string `in_str` and `file_name` it is read from, parse it to be
able to get `compatible` and `chosen` nodes.
For performance reasons, the whole tree isn't parsed into DTNode's.
If `preamble` is set to a non-empty string, prepend it to the read buffer.
"""
Expand All @@ -123,41 +135,65 @@ def __init__(

prepped = self._preprocess(self.raw_buffer, file_name, self.additional_includes) if preprocess else in_str

# make sure node labels and names are glued together and comments are removed,
# then parse with nested curly braces
self.root = DTNode(
"ROOT",
pp.nested_expr("{", "};")
.ignore("//" + pp.SkipTo(pp.lineEnd))
.ignore(pp.c_style_comment)
.parse_string("{ " + self._nodelabel_re.sub(r"\1:\2 {", prepped) + " };")[0],
)

# handle all node label-based overrides by appending their contents to the referred node's
override_nodes = [node for node in self.root.children if node.name.startswith("&")]
regular_nodes = [node for node in self.root.children if not node.name.startswith("&")]
for node in override_nodes:
if (label := node.name.removeprefix("&")) in self.root.label_refs:
self.root.label_refs[label].content += " " + node.content
self.root.children = regular_nodes
self.ts_buffer = prepped.encode("utf-8")
tree = Parser(TS_LANG).parse(self.ts_buffer)
self.root_nodes = self._find_root_ts_nodes(tree)
self.override_nodes = [DTNode(node, self.ts_buffer) for node in self._find_override_ts_nodes(tree)]
self.chosen_nodes = [DTNode(node, self.ts_buffer) for node in self._find_chosen_ts_nodes(tree)]

# parse through all nodes and hash according to "compatible" values
self.compatibles: defaultdict[str, list[DTNode]] = defaultdict(list)

def assign_compatibles(node: DTNode) -> None:
if m := self._compatible_re.search(node.content):
self.compatibles[m.group(1)].append(node)
for child in node.children:
assign_compatibles(child)
@staticmethod
def _find_root_ts_nodes(tree: Tree) -> list[Node]:
return sorted(
TS_LANG.query(
"""
(document
(node
name: (identifier) @nodename
(#eq? @nodename "/")
) @rootnode
)
"""
)
.captures(tree.root_node)
.get("rootnode", []),
key=lambda node: node.start_byte,
)

assign_compatibles(self.root)
@staticmethod
def _find_override_ts_nodes(tree: Tree) -> list[Node]:
return sorted(
TS_LANG.query(
"""
(document
(node
name: (reference
label: (identifier)
)
) @overridenode
)
"""
)
.captures(tree.root_node)
.get("overridenode", []),
key=lambda node: node.start_byte,
)

# find all chosen nodes and concatenate their content
self.chosen = DTNode("__chosen__", pp.ParseResults())
for root_child in self.root.children:
for node in root_child.children:
if node.name == "chosen":
self.chosen.content += " " + node.content
@staticmethod
def _find_chosen_ts_nodes(tree: Tree) -> list[Node]:
return sorted(
TS_LANG.query(
"""
(node
name: (identifier) @nodename
(#eq? @nodename "chosen")
) @chosennode
"""
)
.set_max_start_depth(2)
.captures(tree.root_node)
.get("chosennode", []),
key=lambda node: node.start_byte,
)

@staticmethod
def _preprocess(in_str: str, file_name: str | None = None, additional_includes: list[str] | None = None) -> str:
Expand All @@ -179,11 +215,30 @@ def include_handler(*args): # type: ignore

def get_compatible_nodes(self, compatible_value: str) -> list[DTNode]:
"""Return a list of nodes that have the given compatible value."""
return self.compatibles[compatible_value]
nodes = chain.from_iterable(
TS_LANG.query(
rf"""
(node
(property name: (identifier) @prop value: (string_literal) @propval)
(#eq? @prop "compatible") (#eq? @propval "\"{compatible_value}\"")
) @node
"""
)
.captures(node)
.get("node", [])
for node in self.root_nodes
)
return sorted(
(DTNode(node, self.ts_buffer, self.override_nodes) for node in nodes), key=lambda x: x.node.start_byte
)

def get_chosen_property(self, property_name: str) -> str | None:
"""Return phandle for a given property in the /chosen node."""
return self.chosen.get_path(re.escape(property_name))
phandle = None
for node in self.chosen_nodes:
if (val := node.get_path(re.escape(property_name))) is not None:
phandle = val
return phandle

def preprocess_extra_data(self, data: str) -> str:
"""
Expand All @@ -202,9 +257,3 @@ def preprocess_extra_data(self, data: str) -> str:
"does not get modified by #define's"
)
return out[data_pos + len(self._custom_data_header) + 2 :]

def __repr__(self):
def recursive_repr(node):
return repr(node) + "".join(recursive_repr(child) for child in node.children)

return recursive_repr(self.root)
8 changes: 5 additions & 3 deletions keymap_drawer/parse/zmk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import yaml

from keymap_drawer.config import ParseConfig
from keymap_drawer.keymap import ComboSpec, KeymapData, LayoutKey
from keymap_drawer.dts import DeviceTree
from keymap_drawer.keymap import ComboSpec, KeymapData, LayoutKey
from keymap_drawer.parse.parse import KeymapParser, ParseError

ZMK_LAYOUTS_PATH = Path(__file__).parent.parent.parent / "resources" / "zmk_keyboard_layouts.yaml"


class ZmkKeymapParser(KeymapParser):
"""Parser for ZMK devicetree keymaps, using C preprocessor and hacky pyparsing-based parsers."""
"""Parser for ZMK devicetree keymaps, using C preprocessor and tree-sitter-devicetree."""

_numbers_re = re.compile(r"N(UM(BER)?_)?(\d)")
_modifier_fn_to_std = {
Expand Down Expand Up @@ -142,6 +142,8 @@ def get_behavior_bindings(compatible_value: str, n_bindings: int) -> dict[str, l
raise ParseError(f'Cannot parse bindings for behavior "{node.name}"')
if node.label is None:
raise ParseError(f'Cannot find label for behavior "{node.name}"')
if len(bindings) < n_bindings:
raise ParseError(f'Could not find {n_bindings} bindings in definition of behavior "{node.name}"')
out[f"&{node.label}"] = bindings[:n_bindings]
return out

Expand Down Expand Up @@ -180,7 +182,7 @@ def _get_layers(self, dts: DeviceTree) -> dict[str, list[LayoutKey]]:
layers: dict[str, list[LayoutKey]] = {}
for layer_ind, node in enumerate(layer_nodes):
layer_name = self.layer_names[layer_ind]
if bindings := node.get_phandle_array(r"bindings"):
if bindings := node.get_phandle_array("bindings"):
layers[layer_name] = []
for ind, binding in enumerate(bindings):
try:
Expand Down
6 changes: 1 addition & 5 deletions keymap_drawer/physical_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,8 @@ def parse_binding_params(bindings):
defined_layouts: dict[str | None, list[str] | None]
if nodes := dts.get_compatible_nodes("zmk,physical-layout"):
defined_layouts = {node.label or node.name: node.get_phandle_array("keys") for node in nodes}
elif keys_array := dts.root.get_phandle_array("keys"):
defined_layouts = {None: keys_array}
else:
raise ValueError(
'No `compatible = "zmk,physical-layout"` nodes nor a single `keys` property found in DTS layout'
)
raise ValueError('No `compatible = "zmk,physical-layout"` nodes found in DTS layout')

layouts = {}
for layout_name, position_bindings in defined_layouts.items():
Expand Down
Loading

0 comments on commit 20fd314

Please sign in to comment.