This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 736
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add XML loader * Fix unit test, llama_hub link * Fix llama_hub link in readme
- Loading branch information
1 parent
d7b95f7
commit 57af091
Showing
7 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# XML Loader | ||
|
||
This loader extracts the text from a local XML file. A single local file is passed in each time you call `load_data`. | ||
|
||
## Usage | ||
|
||
To use this loader, you need to pass in a `Path` to a local file. | ||
|
||
```python | ||
from pathlib import Path | ||
from llama_index import download_loader | ||
|
||
XMLReader = download_loader("XMLReader") | ||
|
||
loader = XMLReader() | ||
documents = loader.load_data(file=Path('../example.xml')) | ||
``` | ||
|
||
This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/run-llama/llama-hub/tree/main/llama_hub) for examples. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Init file.""" | ||
from llama_hub.file.xml.base import ( | ||
XMLReader, | ||
) | ||
|
||
__all__ = ["XMLReader"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
"""JSON Reader.""" | ||
|
||
import re | ||
from pathlib import Path | ||
from typing import Dict, List, Optional | ||
|
||
from llama_index.readers.base import BaseReader | ||
from llama_index.readers.schema.base import Document | ||
import xml.etree.ElementTree as ET | ||
|
||
|
||
def _get_leaf_nodes_up_to_level(root: ET.Element, level: int) -> List[ET.Element]: | ||
"""Get collection of nodes up to certain level including leaf nodes | ||
Args: | ||
root (ET.Element): XML Root Element | ||
level (int): Levels to traverse in the tree | ||
Returns: | ||
List[ET.Element]: List of target nodes | ||
""" | ||
|
||
def traverse(current_node, current_level): | ||
if len(current_node) == 0 or level == current_level: | ||
# Keep leaf nodes and target level nodes | ||
nodes.append(current_node) | ||
elif current_level < level: | ||
# Move to the next level | ||
for child in current_node: | ||
traverse(child, current_level + 1) | ||
|
||
nodes = [] | ||
traverse(root, 0) | ||
return nodes | ||
|
||
|
||
class XMLReader(BaseReader): | ||
"""XML reader. | ||
Reads XML documents with options to help suss out relationships between nodes. | ||
Args: | ||
tree_level_split (int): From which level in the xml tree we split documents, | ||
the default level is the root which is level 0 | ||
""" | ||
|
||
def __init__(self, tree_level_split: Optional[int] = 0) -> None: | ||
"""Initialize with arguments.""" | ||
super().__init__() | ||
self.tree_level_split = tree_level_split | ||
|
||
def _parse_xmlelt_to_document( | ||
self, root: ET.Element, extra_info: Optional[Dict] = None | ||
) -> List[Document]: | ||
"""Parse the xml object into a list of Documents. | ||
Args: | ||
root: The XML Element to be converted. | ||
extra_info (Optional[Dict]): Additional information. Default is None. | ||
Returns: | ||
Document: The documents. | ||
""" | ||
nodes = _get_leaf_nodes_up_to_level(root, self.tree_level_split) | ||
documents = [] | ||
for node in nodes: | ||
content = ET.tostring(node, encoding="utf8").decode("utf-8") | ||
content = re.sub(r"^<\?xml.*", "", content) | ||
content = content.strip() | ||
documents.append(Document(text=content, extra_info=extra_info or {})) | ||
|
||
return documents | ||
|
||
def load_data( | ||
self, | ||
file: Path, | ||
extra_info: Optional[Dict] = None, | ||
) -> List[Document]: | ||
"""Load data from the input file. | ||
Args: | ||
file (Path): Path to the input file. | ||
extra_info (Optional[Dict]): Additional information. Default is None. | ||
Returns: | ||
List[Document]: List of documents. | ||
""" | ||
if not isinstance(file, Path): | ||
file = Path(file) | ||
|
||
tree = ET.parse(file) | ||
documents = self._parse_xmlelt_to_document(tree.getroot(), extra_info) | ||
|
||
return documents |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
|
||
from llama_hub.file.xml import XMLReader | ||
import xml.etree.ElementTree as ET | ||
|
||
# Sample XML data for testing | ||
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?> | ||
<data> | ||
<item type="fruit"> | ||
<name>Apple</name> | ||
<color>Red</color> | ||
<price>1.20</price> | ||
</item> | ||
<item type="vegetable"> | ||
<name>Carrot</name> | ||
<color>Orange</color> | ||
<price>0.50</price> | ||
</item> | ||
<item type="fruit"> | ||
<name>Banana</name> | ||
<color>Yellow</color> | ||
<price>0.30</price> | ||
</item> | ||
<company> | ||
<name>Fresh Produce Ltd.</name> | ||
<address> | ||
<street>123 Green Lane</street> | ||
<city>Garden City</city> | ||
<state>Harvest</state> | ||
<zip>54321</zip> | ||
</address> | ||
</company> | ||
</data>""" | ||
|
||
|
||
# Fixture to create a temporary XML file | ||
@pytest.fixture | ||
def xml_file(tmp_path): | ||
file = tmp_path / "test.xml" | ||
with open(file, "w") as f: | ||
f.write(SAMPLE_XML) | ||
return file | ||
|
||
|
||
def test_xml_reader_init(): | ||
reader = XMLReader(tree_level_split=2) | ||
assert reader.tree_level_split == 2 | ||
|
||
|
||
def test_parse_xml_to_document(): | ||
reader = XMLReader(1) | ||
root = ET.fromstring(SAMPLE_XML) | ||
documents = reader._parse_xmlelt_to_document(root) | ||
assert "Fresh Produce Ltd." in documents[-1].text | ||
assert "fruit" in documents[0].text | ||
|
||
|
||
def test_load_data_xml(xml_file): | ||
reader = XMLReader() | ||
|
||
documents = reader.load_data(xml_file) | ||
assert len(documents) == 1 | ||
assert "Apple" in documents[0].text | ||
assert "Garden City" in documents[0].text |