diff --git a/tests/functional/test_chunkers.py b/tests/functional/test_chunkers.py index d364ac48..aecbf9ce 100644 --- a/tests/functional/test_chunkers.py +++ b/tests/functional/test_chunkers.py @@ -1,6 +1,6 @@ # Standard -from pathlib import Path import os +from pathlib import Path # Third Party import pytest @@ -8,56 +8,43 @@ # First Party from instructlab.sdg.utils.chunkers import DocumentChunker +# Constants for Test Directory and Test Documents TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "testdata") +TEST_DOCUMENTS = { + "pdf": "sample_documents/phoenix.pdf", + "md": "sample_documents/phoenix.md" +} +@pytest.fixture(scope="module") +def test_paths(): + """Fixture to return paths to test documents.""" + return {doc_type: Path(os.path.join(TEST_DATA_DIR, path)) for doc_type, path in TEST_DOCUMENTS.items()} @pytest.fixture def tokenizer_model_name(): + """Fixture to return the path to the tokenizer model.""" return os.path.join(TEST_DATA_DIR, "models/instructlab/granite-7b-lab") - -def test_chunk_pdf(tmp_path, tokenizer_model_name): - pdf_path = Path(os.path.join(TEST_DATA_DIR, "sample_documents", "phoenix.pdf")) - leaf_node = [ - { - "documents": ["Lorem ipsum"], - "filepaths": [pdf_path], - "taxonomy_path": "knowledge", - } - ] +@pytest.mark.parametrize("doc_type, expected_chunks, contains_text", [ + ("pdf", 9, "Phoenix is a minor constellation"), + ("md", 7, None) # Assuming there's no specific text to check in Markdown +]) +def test_chunk_documents(tmp_path, tokenizer_model_name, test_paths, doc_type, expected_chunks, contains_text): + """ + Generalized test function for chunking documents. + """ + document_path = test_paths[doc_type] chunker = DocumentChunker( - document_paths=[pdf_path], + document_paths=[document_path], output_dir=tmp_path, tokenizer_model_name=tokenizer_model_name, server_ctx_size=4096, chunk_word_count=500, ) chunks = chunker.chunk_documents() - assert len(chunks) > 9 - assert "Phoenix is a minor constellation" in chunks[0] + assert len(chunks) > expected_chunks + if contains_text: + assert contains_text in chunks[0] for chunk in chunks: - # inexact sanity-checking of chunk max length assert len(chunk) < 2500 - -def test_chunk_md(tmp_path, tokenizer_model_name): - markdown_path = Path(os.path.join(TEST_DATA_DIR, "sample_documents", "phoenix.md")) - leaf_node = [ - { - "documents": [markdown_path.read_text(encoding="utf-8")], - "filepaths": [markdown_path], - "taxonomy_path": "knowledge", - } - ] - chunker = DocumentChunker( - document_paths=[markdown_path], - output_dir=tmp_path, - tokenizer_model_name=tokenizer_model_name, - server_ctx_size=4096, - chunk_word_count=500, - ) - chunks = chunker.chunk_documents() - assert len(chunks) > 7 - for chunk in chunks: - # inexact sanity-checking of chunk max length - assert len(chunk) < 2500