diff --git a/rosdistro_reviewer/yaml_lines.py b/rosdistro_reviewer/yaml_lines.py new file mode 100644 index 0000000..076a156 --- /dev/null +++ b/rosdistro_reviewer/yaml_lines.py @@ -0,0 +1,98 @@ +# Copyright 2024 Open Source Robotics Foundation, Inc. +# Licensed under the Apache License, Version 2.0 + +import yaml + + +class AnnotatedSafeLoader(yaml.SafeLoader): + """ + YAML loader that adds '__lines__' attributes to some of the parsed data. + + This extension of the PyYAML SafeLoader replaces some basic types with + derived types that include a '__lines__' attribute to determine where + the deserialized data can be found in the YAML file it was parsed from. + """ + + class AnnotatedDict(dict): + """Implementation of 'dict' with '__lines__' attribute.""" + + __slots__ = ('__lines__',) + + def __init__(self, *args, **kwargs): # noqa: D107 + return super().__init__(*args, **kwargs) + + class AnnotatedList(list): + """Implementation of 'list' with '__lines__' attribute.""" + + __slots__ = ('__lines__',) + + def __init__(self, *args, **kwargs): # noqa: D107 + return super().__init__(*args, **kwargs) + + class AnnotatedStr(str): + """Implementation of 'str' with '__lines__' attribute.""" + + __slots__ = ('__lines__',) + + def __new__(cls, *args, **kwargs): # noqa: D102 + return str.__new__(cls, *args, **kwargs) + + def compose_node(self, parent, index): # noqa: D102 + event = self.peek_event() + start_line = event.start_mark.line + 1 + end_line = event.end_mark.line + 1 + if end_line <= start_line: + end_line = start_line + 1 + node = super().compose_node(parent, index) + node.__lines__ = range(start_line, end_line) + return node + + def construct_annotated_map(self, node): # noqa: D102 + data = AnnotatedSafeLoader.AnnotatedDict() + data.__lines__ = node.__lines__ + yield data + value = self.construct_mapping(node, deep=True) + for k, v in reversed(value.items()): + k_lines = getattr(k, '__lines__', None) + if k_lines is not None and k_lines.stop > data.__lines__.stop: + data.__lines__ = range(data.__lines__.start, k_lines.stop) + + v_lines = getattr(v, '__lines__', None) + if v_lines is not None and v_lines.stop > data.__lines__.stop: + data.__lines__ = range(data.__lines__.start, v_lines.stop) + data.update(value) + + def construct_annotated_seq(self, node): # noqa: D102 + data = AnnotatedSafeLoader.AnnotatedList() + data.__lines__ = node.__lines__ + yield data + value = self.construct_sequence(node, deep=True) + for v in reversed(value): + v_lines = getattr(v, '__lines__', None) + if v_lines is not None and v_lines.stop > data.__lines__.stop: + data.__lines__ = range(data.__lines__.start, v_lines.stop) + data.extend(value) + + def construct_annotated_str(self, node): # noqa: D102 + data = self.construct_yaml_str(node) + data = AnnotatedSafeLoader.AnnotatedStr(data) + data.__lines__ = node.__lines__ + return data + + +AnnotatedSafeLoader.add_constructor( + 'tag:yaml.org,2002:map', AnnotatedSafeLoader.construct_annotated_map) +AnnotatedSafeLoader.add_constructor( + 'tag:yaml.org,2002:seq', AnnotatedSafeLoader.construct_annotated_seq) +AnnotatedSafeLoader.add_constructor( + 'tag:yaml.org,2002:str', AnnotatedSafeLoader.construct_annotated_str) + +yaml.add_representer( + AnnotatedSafeLoader.AnnotatedDict, + yaml.representer.SafeRepresenter.represent_dict) +yaml.add_representer( + AnnotatedSafeLoader.AnnotatedList, + yaml.representer.SafeRepresenter.represent_list) +yaml.add_representer( + AnnotatedSafeLoader.AnnotatedStr, + yaml.representer.SafeRepresenter.represent_str) diff --git a/setup.cfg b/setup.cfg index 1aa83a1..3dbbef5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,7 @@ python_requires = >=3.6 install_requires = GitPython unidiff + PyYAML packages = find: zip_safe = true @@ -31,6 +32,7 @@ test = mypy pytest scspell3k>=2.2 + types-PyYAML [options.packages.find] exclude = diff --git a/test/resources/simple.yaml b/test/resources/simple.yaml new file mode 100644 index 0000000..56fb47b --- /dev/null +++ b/test/resources/simple.yaml @@ -0,0 +1,13 @@ +--- +foo: + bar: baz + qux: [quux] + corge: + - grault + - garply + waldo: >- + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod + tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim + veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea + commodo consequat. + fred: null diff --git a/test/spell_check.words b/test/spell_check.words index ddc1c4f..9e66353 100644 --- a/test/spell_check.words +++ b/test/spell_check.words @@ -1,19 +1,26 @@ addfinalizer apache +corge +deserialized diffs +fred https iterdir linter mktemp mypy +noqa patchset pathlib pycqa pytest rangeify +representer returncode rosdistro scspell setuptools thomas unidiff +waldo +yaml diff --git a/test/test_yaml_lines.py b/test/test_yaml_lines.py new file mode 100644 index 0000000..1001f32 --- /dev/null +++ b/test/test_yaml_lines.py @@ -0,0 +1,51 @@ +# Copyright 2024 Open Source Robotics Foundation, Inc. +# Licensed under the Apache License, Version 2.0 + +from pathlib import Path + +from rosdistro_reviewer.yaml_lines import AnnotatedSafeLoader +import yaml + + +def _get_key_and_val(data, key): + for k, v in data.items(): + if k == key: + return k, v + return None, None + + +def test_line_numbers() -> None: + test_resources = Path(__file__).parent / 'resources' + test_yaml = test_resources / 'simple.yaml' + with test_yaml.open('r') as f: + test_data = yaml.load(f, Loader=AnnotatedSafeLoader) + + foo, foo_val = _get_key_and_val(test_data, 'foo') + assert foo and foo.__lines__ == range(2, 3) + assert hasattr(foo_val, '__getitem__') and \ + foo_val.__lines__ == range(3, 14) + + bar, bar_val = _get_key_and_val(foo_val, 'bar') + assert bar and bar.__lines__ == range(3, 4) + assert bar_val == 'baz' and bar_val.__lines__ == range(3, 4) + + qux, qux_val = _get_key_and_val(foo_val, 'qux') + assert qux and qux.__lines__ == range(4, 5) + assert hasattr(qux_val, '__iter__') and qux_val.__lines__ == range(4, 5) + for item in qux_val: + assert item.__lines__ == range(4, 5) + + corge, corge_val = _get_key_and_val(foo_val, 'corge') + assert corge and corge.__lines__ == range(5, 6) + assert hasattr(corge_val, '__iter__') and \ + corge_val.__lines__ == range(6, 8) + for idx, item in enumerate(corge_val): + assert item.__lines__ == range(6 + idx, 7 + idx) + + waldo, waldo_val = _get_key_and_val(foo_val, 'waldo') + assert waldo and waldo.__lines__ == range(8, 9) + assert len(waldo_val) == 231 and waldo_val.__lines__ == range(8, 13) + + fred, fred_val = _get_key_and_val(foo_val, 'fred') + assert fred and fred.__lines__ == range(13, 14) + assert fred_val is None