diff --git a/.semversioner/next-release/patch-20241220191518597340.json b/.semversioner/next-release/patch-20241220191518597340.json new file mode 100644 index 0000000000..d410fe7ce7 --- /dev/null +++ b/.semversioner/next-release/patch-20241220191518597340.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "unit tests for text_splitting" +} diff --git a/graphrag/index/operations/chunk_text/strategies.py b/graphrag/index/operations/chunk_text/strategies.py index 3fc8fc6f2f..2d6d1d870a 100644 --- a/graphrag/index/operations/chunk_text/strategies.py +++ b/graphrag/index/operations/chunk_text/strategies.py @@ -10,7 +10,10 @@ from graphrag.config.models.chunking_config import ChunkingConfig from graphrag.index.operations.chunk_text.typing import TextChunk -from graphrag.index.text_splitting.text_splitting import Tokenizer +from graphrag.index.text_splitting.text_splitting import ( + Tokenizer, + split_multiple_texts_on_tokens, +) from graphrag.logger.progress import ProgressTicker @@ -31,7 +34,7 @@ def encode(text: str) -> list[int]: def decode(tokens: list[int]) -> str: return enc.decode(tokens) - return _split_text_on_tokens( + return split_multiple_texts_on_tokens( input, Tokenizer( chunk_overlap=chunk_overlap, @@ -43,44 +46,6 @@ def decode(tokens: list[int]) -> str: ) -# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471 -# So we could have better control over the chunking process -def _split_text_on_tokens( - texts: list[str], enc: Tokenizer, tick: ProgressTicker -) -> list[TextChunk]: - """Split incoming text and return chunks.""" - result = [] - mapped_ids = [] - - for source_doc_idx, text in enumerate(texts): - encoded = enc.encode(text) - tick(1) - mapped_ids.append((source_doc_idx, encoded)) - - input_ids: list[tuple[int, int]] = [ - (source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids - ] - - start_idx = 0 - cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) - chunk_ids = input_ids[start_idx:cur_idx] - while start_idx < len(input_ids): - chunk_text = enc.decode([id for _, id in chunk_ids]) - doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) - result.append( - TextChunk( - text_chunk=chunk_text, - source_doc_indices=doc_indices, - n_tokens=len(chunk_ids), - ) - ) - start_idx += enc.tokens_per_chunk - enc.chunk_overlap - cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) - chunk_ids = input_ids[start_idx:cur_idx] - - return result - - def run_sentences( input: list[str], _config: ChunkingConfig, tick: ProgressTicker ) -> Iterable[TextChunk]: diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index 2f6201cab7..1632904637 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -3,19 +3,18 @@ """A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models.""" -import json import logging from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterable from dataclasses import dataclass -from enum import Enum from typing import Any, Literal, cast import pandas as pd import tiktoken import graphrag.config.defaults as defs -from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.index.operations.chunk_text.typing import TextChunk +from graphrag.logger.progress import ProgressTicker EncodedText = list[int] DecodeFn = Callable[[EncodedText], str] @@ -123,10 +122,10 @@ def num_tokens(self, text: str) -> int: def split_text(self, text: str | list[str]) -> list[str]: """Split text method.""" - if cast("bool", pd.isna(text)) or text == "": - return [] if isinstance(text, list): text = " ".join(text) + elif cast("bool", pd.isna(text)) or text == "": + return [] if not isinstance(text, str): msg = f"Attempting to split a non-string value, actual is {type(text)}" raise TypeError(msg) @@ -138,108 +137,57 @@ def split_text(self, text: str | list[str]) -> list[str]: encode=lambda text: self.encode(text), ) - return split_text_on_tokens(text=text, tokenizer=tokenizer) - - -class TextListSplitterType(str, Enum): - """Enum for the type of the TextListSplitter.""" - - DELIMITED_STRING = "delimited_string" - JSON = "json" - - -class TextListSplitter(TextSplitter): - """Text list splitter class definition.""" - - def __init__( - self, - chunk_size: int, - splitter_type: TextListSplitterType = TextListSplitterType.JSON, - input_delimiter: str | None = None, - output_delimiter: str | None = None, - model_name: str | None = None, - encoding_name: str | None = None, - ): - """Initialize the TextListSplitter with a chunk size.""" - # Set the chunk overlap to 0 as we use full strings - super().__init__(chunk_size, chunk_overlap=0) - self._type = splitter_type - self._input_delimiter = input_delimiter - self._output_delimiter = output_delimiter or "\n" - self._length_function = lambda x: num_tokens_from_string( - x, model=model_name, encoding_name=encoding_name - ) - - def split_text(self, text: str | list[str]) -> Iterable[str]: - """Split a string list into a list of strings for a given chunk size.""" - if not text: - return [] - - result: list[str] = [] - current_chunk: list[str] = [] - - # Add the brackets - current_length: int = self._length_function("[]") + return split_single_text_on_tokens(text=text, tokenizer=tokenizer) - # Input should be a string list joined by a delimiter - string_list = self._load_text_list(text) - if len(string_list) == 1: - return string_list - - for item in string_list: - # Count the length of the item and add comma - item_length = self._length_function(f"{item},") +def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]: + """Split a single text and return chunks using the tokenizer.""" + result = [] + input_ids = tokenizer.encode(text) - if current_length + item_length > self._chunk_size: - if current_chunk and len(current_chunk) > 0: - # Add the current chunk to the result - self._append_to_result(result, current_chunk) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] - # Start a new chunk - current_chunk = [item] - # Add 2 for the brackets - current_length = item_length - else: - # Add the item to the current chunk - current_chunk.append(item) - # Add 1 for the comma - current_length += item_length + while start_idx < len(input_ids): + chunk_text = tokenizer.decode(list(chunk_ids)) + result.append(chunk_text) # Append chunked text as string + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] - # Add the last chunk to the result - self._append_to_result(result, current_chunk) + return result - return result - def _load_text_list(self, text: str | list[str]): - """Load the text list based on the type.""" - if isinstance(text, list): - string_list = text - elif self._type == TextListSplitterType.JSON: - string_list = json.loads(text) - else: - string_list = text.split(self._input_delimiter) - return string_list +# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471 +# So we could have better control over the chunking process +def split_multiple_texts_on_tokens( + texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker +) -> list[TextChunk]: + """Split multiple texts and return chunks with metadata using the tokenizer.""" + result = [] + mapped_ids = [] - def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]): - """Append the current chunk to the result.""" - if new_chunk and len(new_chunk) > 0: - if self._type == TextListSplitterType.JSON: - chunk_list.append(json.dumps(new_chunk, ensure_ascii=False)) - else: - chunk_list.append(self._output_delimiter.join(new_chunk)) + for source_doc_idx, text in enumerate(texts): + encoded = tokenizer.encode(text) + if tick: + tick(1) # Track progress if tick callback is provided + mapped_ids.append((source_doc_idx, encoded)) + input_ids = [ + (source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids + ] -def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: - """Split incoming text and return chunks using tokenizer.""" - splits: list[str] = [] - input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): - splits.append(tokenizer.decode(chunk_ids)) + chunk_text = tokenizer.decode([id for _, id in chunk_ids]) + doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) + result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids))) start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] - return splits + + return result diff --git a/tests/unit/indexing/text_splitting/__init__.py b/tests/unit/indexing/text_splitting/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/indexing/text_splitting/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/indexing/text_splitting/test_text_splitting.py b/tests/unit/indexing/text_splitting/test_text_splitting.py new file mode 100644 index 0000000000..833d49fbb1 --- /dev/null +++ b/tests/unit/indexing/text_splitting/test_text_splitting.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from graphrag.index.text_splitting.text_splitting import ( + NoopTextSplitter, + Tokenizer, + TokenTextSplitter, + split_multiple_texts_on_tokens, + split_single_text_on_tokens, +) + + +def test_noop_text_splitter() -> None: + splitter = NoopTextSplitter() + + assert list(splitter.split_text("some text")) == ["some text"] + assert list(splitter.split_text(["some", "text"])) == ["some", "text"] + + +class MockTokenizer: + def encode(self, text): + return [ord(char) for char in text] + + def decode(self, token_ids): + return "".join(chr(id) for id in token_ids) + + +def test_split_text_str_empty(): + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2) + result = splitter.split_text("") + + assert result == [] + + +def test_split_text_str_bool(): + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2) + result = splitter.split_text(None) # type: ignore + + assert result == [] + + +def test_split_text_str_int(): + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2) + with pytest.raises(TypeError): + splitter.split_text(123) # type: ignore + + +@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens") +def test_split_text_large_input(mock_split): + large_text = "a" * 10_000 + mock_split.return_value = ["chunk"] * 2_000 + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2) + + result = splitter.split_text(large_text) + + assert len(result) == 2_000, "Large input was not split correctly" + mock_split.assert_called_once() + + +@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens") +@mock.patch("graphrag.index.text_splitting.text_splitting.Tokenizer") +def test_token_text_splitter(mock_tokenizer, mock_split_text): + text = "chunk1 chunk2 chunk3" + expected_chunks = ["chunk1", "chunk2", "chunk3"] + + mocked_tokenizer = MagicMock() + mock_tokenizer.return_value = mocked_tokenizer + mock_split_text.return_value = expected_chunks + + splitter = TokenTextSplitter() + + splitter.split_text(["chunk1", "chunk2", "chunk3"]) + + mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer) + + +def test_encode_basic(): + splitter = TokenTextSplitter() + result = splitter.encode("abc def") + + assert result == [13997, 711], "Encoding failed to return expected tokens" + + +def test_num_tokens_empty_input(): + splitter = TokenTextSplitter() + result = splitter.num_tokens("") + + assert result == 0, "Token count for empty input should be 0" + + +def test_model_name(): + splitter = TokenTextSplitter(model_name="gpt-4o") + result = splitter.encode("abc def") + + assert result == [26682, 1056], "Encoding failed to return expected tokens" + + +@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError) +@mock.patch("tiktoken.get_encoding") +def test_model_name_exception(mock_get_encoding, mock_encoding_for_model): + mock_get_encoding.return_value = mock.MagicMock() + + TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding") + + mock_get_encoding.assert_called_once_with("mock_encoding") + mock_encoding_for_model.assert_called_once_with("mock_model") + + +def test_split_single_text_on_tokens(): + text = "This is a test text, meaning to be taken seriously by this test only." + mocked_tokenizer = MockTokenizer() + tokenizer = Tokenizer( + chunk_overlap=5, + tokens_per_chunk=10, + decode=mocked_tokenizer.decode, + encode=lambda text: mocked_tokenizer.encode(text), + ) + + expected_splits = [ + "This is a ", + "is a test ", + "test text,", + "text, mean", + " meaning t", + "ing to be ", + "o be taken", + "taken seri", # cspell:disable-line + " seriously", + "ously by t", # cspell:disable-line + " by this t", + "his test o", + "est only.", + "nly.", + ] + + result = split_single_text_on_tokens(text=text, tokenizer=tokenizer) + assert result == expected_splits + + +def test_split_multiple_texts_on_tokens(): + texts = [ + "This is a test text, meaning to be taken seriously by this test only.", + "This is th second text, meaning to be taken seriously by this test only.", + ] + + mocked_tokenizer = MockTokenizer() + mock_tick = MagicMock() + tokenizer = Tokenizer( + chunk_overlap=5, + tokens_per_chunk=10, + decode=mocked_tokenizer.decode, + encode=lambda text: mocked_tokenizer.encode(text), + ) + + split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick) + mock_tick.assert_called()