Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend import functionality #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/hsd/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hsd.common import HSD_ATTRIB_NAME, np, ATTRIB_SUFFIX, HSD_ATTRIB_SUFFIX, HsdError,\
QUOTING_CHARS, SPECIAL_CHARS
from hsd.eventhandler import HsdEventHandler, HsdEventPrinter
from hsd.interrupts import Interrupt

_ItemType = Union[float, complex, int, bool, str]

Expand Down Expand Up @@ -67,7 +68,7 @@ class HsdDictBuilder(HsdEventHandler):

Args:
flatten_data: Whether multiline data in the HSD input should be
flattened into a single list. Othewise a list of lists is created, with one list for
flattened into a single list. Otherwise a list of lists is created, with one list for
every line (default).
lower_tag_names: Whether tag names should be all converted to lower case (to ease case
insensitive processing). Default: False. If set and include_hsd_attribs is also set,
Expand Down Expand Up @@ -161,6 +162,13 @@ def add_text(self, text):
self._data = self._text_to_data(text)


def add_interrupt(self, interrupt):
if self._curblock or self._data is not None:
msg = "Data appeared in an invalid context"
raise HsdError(msg)
self._data = interrupt


def _text_to_data(self, txt: str) -> _DataType:
data = []
for line in txt.split("\n"):
Expand Down Expand Up @@ -242,6 +250,11 @@ def walk(self, dictobj):
self.walk(item)
self._eventhandler.close_tag(key)

elif isinstance(value, Interrupt):
self._eventhandler.open_tag(key, attrib, hsdattrib)
self._eventhandler.add_interrupt(value)
self._eventhandler.close_tag(key)

else:
self._eventhandler.open_tag(key, attrib, hsdattrib)
self._eventhandler.add_text(_to_text(value))
Expand Down
13 changes: 13 additions & 0 deletions src/hsd/eventhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from abc import ABC, abstractmethod
from typing import Optional
from hsd.interrupts import Interrupt


class HsdEventHandler(ABC):
Expand Down Expand Up @@ -43,6 +44,13 @@ def add_text(self, text: str):
text: Text in the current tag.
"""

@abstractmethod
def add_interrupt(self, interrupt: Interrupt):
"""Adds interrupts to the current tag.

Args:
interrupt: Instance of the Interrupt class or its children.
"""


class HsdEventPrinter(HsdEventHandler):
Expand Down Expand Up @@ -75,3 +83,8 @@ def close_tag(self, tagname: str):
def add_text(self, text: str):
indentstr = self._indentlevel * self._indentstr
print(f"{indentstr}Received text: {text}")

def add_interrupt(self, interrupt: Interrupt):
indentstr = self._indentlevel * self._indentstr
print(f"{indentstr}Received interrupt: type '{type(interrupt)}' to "
f"file '{interrupt.file}'")
33 changes: 30 additions & 3 deletions src/hsd/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
"""

from typing import List, TextIO, Union
from hsd.common import HSD_ATTRIB_EQUAL, HSD_ATTRIB_NAME
from hsd.common import HSD_ATTRIB_EQUAL, HSD_ATTRIB_NAME, HsdError
from hsd.eventhandler import HsdEventHandler
from hsd.interrupts import Interrupt, IncludeHsd, IncludeText


_INDENT_STR = " "


class HsdFormatter(HsdEventHandler):
"""Implements an even driven HSD formatter.
"""Implements an event driven HSD formatter.

Args:
fobj: File like object to write the formatted output to.
Expand Down Expand Up @@ -106,10 +107,36 @@ def add_text(self, text: str):
self._followed_by_equal[-1] = True
else:
self._indent_level += 1
indentstr = self._indent_level * _INDENT_STR
indentstr = self._indent_level * _INDENT_STR
self._fobj.write(f" {{\n{indentstr}")
text = text.replace("\n", "\n" + indentstr)

self._fobj.write(text)
self._fobj.write("\n")
self._nr_children[-1] += 1


def add_interrupt(self, interrupt: Interrupt):

if isinstance(interrupt, IncludeHsd):
operator = "<<+"
elif isinstance(interrupt, IncludeText):
operator = "<<<"
else:
msg = ("The 'HsdFormatter' does not support Interrupts of "
f"type '{type(interrupt)}' !")
raise HsdError(msg)

equal = self._followed_by_equal[-1]
if equal:
self._fobj.write(" = ")
self._followed_by_equal[-1] = True
else:
self._indent_level += 1
indentstr = self._indent_level * _INDENT_STR
self._fobj.write(f" {{\n{indentstr}")

text = operator + ' "' + interrupt.file + '"'
self._fobj.write(text)
self._fobj.write("\n")
self._nr_children[-1] += 1
27 changes: 27 additions & 0 deletions src/hsd/interrupts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#--------------------------------------------------------------------------------------------------#
# hsd-python: package for manipulating HSD-formatted data in Python #
# Copyright (C) 2011 - 2023 DFTB+ developers group #
# Licensed under the BSD 2-clause license. #
#--------------------------------------------------------------------------------------------------#
#
"""
Contains hsd interrupts
"""

from hsd.common import unquote


class Interrupt:
"""General class for interrupts"""

def __init__(self, file):
self.file = unquote(file.strip())


class IncludeText(Interrupt):
"""class for dealing with text interrupts"""
pass

class IncludeHsd(Interrupt):
"""class for dealing with hsd interrupts"""
pass
14 changes: 9 additions & 5 deletions src/hsd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
include_hsd_attribs: bool = False, flatten_data: bool = False) -> dict:
include_hsd_attribs: bool = False, flatten_data: bool = False,
include_file: bool = True) -> dict:
"""Loads a file with HSD-formatted data into a Python dictionary
Args:
Expand All @@ -36,6 +37,7 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
flatten_data: Whether multiline data in the HSD input should be
flattened into a single list. Othewise a list of lists is created,
with one list for every line (default).
include_file: Whether files via "<<<"/"<<+" should be included or not
Returns:
Dictionary representing the HSD data.
Expand All @@ -45,7 +47,7 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,
"""
dictbuilder = HsdDictBuilder(lower_tag_names=lower_tag_names, flatten_data=flatten_data,
include_hsd_attribs=include_hsd_attribs)
parser = HsdParser(eventhandler=dictbuilder)
parser = HsdParser(eventhandler=dictbuilder, include_file=include_file)
if isinstance(hsdfile, str):
with open(hsdfile, "r") as hsddescr:
parser.parse(hsddescr)
Expand All @@ -56,8 +58,8 @@ def load(hsdfile: Union[TextIO, str], lower_tag_names: bool = False,

def load_string(
hsdstr: str, lower_tag_names: bool = False,
include_hsd_attribs: bool = False, flatten_data: bool = False
) -> dict:
include_hsd_attribs: bool = False, flatten_data: bool = False,
include_file: bool = True) -> dict:
"""Loads a string with HSD-formatted data into a Python dictionary.
Args:
Expand All @@ -75,6 +77,7 @@ def load_string(
flatten_data: Whether multiline data in the HSD input should be
flattened into a single list. Othewise a list of lists is created,
with one list for every line (default).
include_file: Whether files via "<<<"/"<<+" should be included or not
Returns:
Dictionary representing the HSD data.
Expand Down Expand Up @@ -130,7 +133,8 @@ def load_string(
"""
fobj = io.StringIO(hsdstr)
return load(fobj, lower_tag_names, include_hsd_attribs, flatten_data)
return load(fobj, lower_tag_names, include_hsd_attribs, flatten_data,
include_file)


def dump(data: dict, hsdfile: Union[TextIO, str],
Expand Down
20 changes: 16 additions & 4 deletions src/hsd/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Optional, TextIO, Union
from hsd import common
from hsd.eventhandler import HsdEventHandler, HsdEventPrinter

from hsd.interrupts import IncludeHsd, IncludeText

SYNTAX_ERROR = 1
UNCLOSED_TAG_ERROR = 2
Expand Down Expand Up @@ -50,11 +50,13 @@ class HsdParser:
{'Temperature': 100, 'Temperature.attrib': 'Kelvin'}}}}}
"""

def __init__(self, eventhandler: Optional[HsdEventHandler] = None):
def __init__(self, eventhandler: Optional[HsdEventHandler] = None,
include_file: bool = True):
"""Initializes the parser.

Args:
eventhandler: Instance of the HsdEventHandler class or its children.
include_file: Whether files via "<<<"/"<<+" should be included or not
"""
if eventhandler is None:
self._eventhandler = HsdEventPrinter()
Expand All @@ -75,6 +77,7 @@ def __init__(self, eventhandler: Optional[HsdEventHandler] = None):
self._has_child = True # Whether current node has a child already
self._has_text = False # whether current node contains text already
self._oldbefore = "" # buffer for tagname
self._include_file = include_file # Whether files via "<<<"/"<<+" should be included or not


def parse(self, fobj: Union[TextIO, str]):
Expand Down Expand Up @@ -216,10 +219,19 @@ def _parse(self, line):
if txtinc:
self._text("".join(self._buffer) + before)
self._buffer = []
self._eventhandler.add_text(self._include_txt(after[2:]))
if self._include_file:
text = self._include_txt(after[2:])
self._eventhandler.add_text(text)
else:
interrupt = IncludeText(after[2:])
self._eventhandler.add_interrupt(interrupt)
break
if hsdinc:
self._include_hsd(after[2:])
if self._include_file:
self._include_hsd(after[2:])
else:
interrupt = IncludeHsd(after[2:])
self._eventhandler.add_interrupt(interrupt)
break
self._buffer.append(before + sign)

Expand Down
2 changes: 2 additions & 0 deletions test/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def close_tag(self, tagname):
def add_text(self, text):
self.events.append((_ADD_TEXT_EVENT, text))

def add_interrupt(self, interrupt):
pass

@pytest.mark.parametrize(
"hsd_input,expected_events",
Expand Down