diff --git a/pyproject.toml b/pyproject.toml index d897d7d..a713649 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ pythonpath = "src" package-dir = {"" = "src"} packages = ["chonkie", "chonkie.chunker", - "chonkie.embeddings"] + "chonkie.embeddings", + "chonkie.refinery"] [tool.ruff] select = ["F", "I", "D", "DOC"] diff --git a/src/chonkie/__init__.py b/src/chonkie/__init__.py index ea4eba5..d69ee29 100644 --- a/src/chonkie/__init__.py +++ b/src/chonkie/__init__.py @@ -1,3 +1,7 @@ +"""Main package for Chonkie.""" + +from .context import Context + from .chunker import ( BaseChunker, Chunk, @@ -19,6 +23,11 @@ SentenceTransformerEmbeddings, ) +from .refinery import ( + BaseRefinery, + OverlapRefinery, +) + __version__ = "0.2.1.post1" __name__ = "chonkie" __author__ = "Bhavnick Minhas" @@ -32,6 +41,7 @@ # Add all data classes to __all__ __all__ += [ + "Context", "Chunk", "SentenceChunk", "SemanticChunk", @@ -57,3 +67,9 @@ "OpenAIEmbeddings", "AutoEmbeddings", ] + +# Add all refinery classes to __all__ +__all__ += [ + "BaseRefinery", + "OverlapRefinery", +] diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index f27f3cd..af0fe3a 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -1,11 +1,15 @@ +"""Base classes for chunking text.""" + import importlib + +import inspect import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from multiprocessing import Pool, cpu_count -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Union -import inspect +from chonkie.context import Context @dataclass @@ -19,6 +23,7 @@ class Chunk: start_index: The starting index of the chunk in the original text end_index: The ending index of the chunk in the original text token_count: The number of tokens in the chunk + context: The context of the chunk, useful for refinery classes """ @@ -26,7 +31,47 @@ class Chunk: start_index: int end_index: int token_count: int - __slots__ = ["text", "start_index", "end_index", "token_count"] + context: Optional[Context] = None + + def __str__(self) -> str: + """Return string representation of the chunk.""" + return self.text + + def __len__(self) -> int: + """Return the length of the chunk.""" + return len(self.text) + + def __repr__(self) -> str: + """Return string representation of the chunk.""" + if self.context is not None: + return ( + f"Chunk(text={self.text}, start_index={self.start_index}, " + f"end_index={self.end_index}, token_count={self.token_count})" + ) + else: + return ( + f"Chunk(text={self.text}, start_index={self.start_index}, " + f"end_index={self.end_index}, token_count={self.token_count}, " + f"context={self.context})" + ) + + def __iter__(self): + """Return an iterator over the chunk.""" + return iter(self.text) + + def __getitem__(self, index: int): + """Return the item at the given index.""" + return self.text[index] + + def copy(self) -> "Chunk": + """Return a deep copy of the chunk.""" + return Chunk( + text=self.text, + start_index=self.start_index, + end_index=self.end_index, + token_count=self.token_count, + ) + class BaseChunker(ABC): @@ -67,7 +112,10 @@ def _get_tokenizer_backend(self): elif "tiktoken" in str(type(self.tokenizer)): return "tiktoken" else: - raise ValueError(f"Tokenizer backend {str(type(self.tokenizer))} not supported") + raise ValueError( + f"Tokenizer backend {str(type(self.tokenizer))} not supported" + ) + def _load_tokenizer(self, tokenizer_name: str): """Load a tokenizer based on the backend.""" @@ -134,7 +182,10 @@ def _encode(self, text: str): elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode(text) else: - raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.") + raise ValueError( + f"Tokenizer backend {self._tokenizer_backend} not supported." + ) + def _encode_batch(self, texts: List[str]): """Encode a batch of texts using the backend tokenizer.""" @@ -150,7 +201,10 @@ def _encode_batch(self, texts: List[str]): elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode_batch(texts) else: - raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.") + raise ValueError( + f"Tokenizer backend {self._tokenizer_backend} not supported." + ) + def _decode(self, tokens) -> str: """Decode tokens using the backend tokenizer.""" @@ -161,7 +215,10 @@ def _decode(self, tokens) -> str: elif self._tokenizer_backend == "tiktoken": return self.tokenizer.decode(tokens) else: - raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.") + raise ValueError( + f"Tokenizer backend {self._tokenizer_backend} not supported." + ) + def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: """Decode a batch of token lists using the backend tokenizer.""" @@ -172,7 +229,10 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: elif self._tokenizer_backend == "tiktoken": return [self.tokenizer.decode(tokens) for tokens in token_lists] else: - raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.") + raise ValueError( + f"Tokenizer backend {self._tokenizer_backend} not supported." + ) + def _count_tokens(self, text: str) -> int: """Count tokens in text using the backend tokenizer.""" diff --git a/src/chonkie/chunker/semantic.py b/src/chonkie/chunker/semantic.py index 8a7ee15..025bea6 100644 --- a/src/chonkie/chunker/semantic.py +++ b/src/chonkie/chunker/semantic.py @@ -24,25 +24,7 @@ class SemanticSentence(Sentence): """ - embedding: Optional[np.ndarray] - - # Only define new slots, not the ones inherited from Sentence - __slots__ = [ - "embedding", - ] - - def __init__( - self, - text: str, - start_index: int, - end_index: int, - token_count: int, - embedding: Optional[np.ndarray] = None, - ): - super().__init__(text, start_index, end_index, token_count) - object.__setattr__( - self, "embedding", embedding if embedding is not None else None - ) + embedding: Optional[np.ndarray] = field(default=None) @dataclass @@ -59,25 +41,9 @@ class SemanticChunk(SentenceChunk): sentences: List of SemanticSentence objects in the chunk """ - + sentences: List[SemanticSentence] = field(default_factory=list) - # No new slots needed since we're just overriding the sentences field - __slots__ = [] - - def __init__( - self, - text: str, - start_index: int, - end_index: int, - token_count: int, - sentences: List[SemanticSentence] = None, - ): - super().__init__(text, start_index, end_index, token_count) - object.__setattr__( - self, "sentences", sentences if sentences is not None else [] - ) - class SemanticChunker(BaseChunker): """Chunker that splits text into semantically coherent chunks using embeddings. @@ -160,6 +126,12 @@ def __init__( "embedding_model must be a string or BaseEmbeddings instance" ) + # Probably the dependency is not installed + if self.embedding_model is None: + raise ImportError("embedding_model is not a valid embedding model", + "Please install the `semantic` extra to use this feature") + + # Keeping the tokenizer the same as the sentence model is important # for the group semantic meaning to be calculated properly super().__init__(self.embedding_model.get_tokenizer_or_token_counter()) @@ -299,11 +271,11 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]: ) for i in range(len(sentences) - 1) ] - similarity_threshold = float( + self.similarity_threshold = float( np.percentile(all_similarities, self.similarity_percentile) ) else: - similarity_threshold = self.similarity_threshold + self.similarity_threshold = self.similarity_threshold groups = [] current_group = sentences[: self.initial_sentences] @@ -315,7 +287,7 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]: current_embedding, sentence.embedding ) - if similarity >= similarity_threshold: + if similarity >= self.similarity_threshold: # Add to current group current_group.append(sentence) # Update mean embedding diff --git a/src/chonkie/chunker/sentence.py b/src/chonkie/chunker/sentence.py index 13e1a8d..174044a 100644 --- a/src/chonkie/chunker/sentence.py +++ b/src/chonkie/chunker/sentence.py @@ -1,5 +1,6 @@ from bisect import bisect_left -from dataclasses import dataclass +from dataclasses import dataclass, field + from itertools import accumulate from typing import Any, List, Union @@ -24,8 +25,6 @@ class Sentence: start_index: int end_index: int token_count: int - __slots__ = ["text", "start_index", "end_index", "token_count"] - @dataclass class SentenceChunk(Chunk): @@ -41,25 +40,8 @@ class SentenceChunk(Chunk): sentences: List of Sentence objects in the chunk """ - # Don't redeclare inherited fields - sentences: List[Sentence] - - __slots__ = ["sentences"] - - def __init__( - self, - text: str, - start_index: int, - end_index: int, - token_count: int, - sentences: List[Sentence] = None, - ): - super().__init__(text, start_index, end_index, token_count) - object.__setattr__( - self, "sentences", sentences if sentences is not None else [] - ) - + sentences: List[Sentence] = field(default_factory=list) class SentenceChunker(BaseChunker): """SentenceChunker splits the sentences in a text based on token limits and sentence boundaries. diff --git a/src/chonkie/context.py b/src/chonkie/context.py new file mode 100644 index 0000000..d532b72 --- /dev/null +++ b/src/chonkie/context.py @@ -0,0 +1,67 @@ +"""Context class for storing contextual information for chunk refinement. + +This class is used to store contextual information for chunk refinement. +It can represent context that comes before a chunk at the moment. + +By default, the context has no start and end indices, meaning it is not +bound to any specific text. The start and end indices are only set if the +context is part of the same text as the chunk. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class Context: + """A dataclass representing contextual information for chunk refinement. + + This class stores text and token count information that can be used to add + context to chunks during the refinement process. It can represent context + that comes before or after a chunk. + + Attributes: + text (str): The context text + token_count (int): Number of tokens in the context text + start_index (Optional[int]): Starting position of context in original text + end_index (Optional[int]): Ending position of context in original text + + Example: + context = Context( + text="This is some context.", + token_count=5, + start_index=0, + end_index=20 + ) + + """ + + text: str + token_count: int + start_index: Optional[int] = None + end_index: Optional[int] = None + + def __post_init__(self): + """Validate the Context attributes after initialization.""" + if not isinstance(self.text, str): + raise ValueError("text must be a string") + if not isinstance(self.token_count, int): + raise ValueError("token_count must be an integer") + if self.token_count < 0: + raise ValueError("token_count must be non-negative") + if (self.start_index is not None and self.end_index is not None and + self.start_index > self.end_index): + raise ValueError("start_index must be less than or equal to end_index") + + def __len__(self) -> int: + """Return the length of the context text.""" + return len(self.text) + + def __str__(self) -> str: + """Return a string representation of the Context.""" + return self.text + + def __repr__(self) -> str: + """Return a detailed string representation of the Context.""" + return (f"Context(text='{self.text}', token_count={self.token_count}, " + f"start_index={self.start_index}, end_index={self.end_index})") diff --git a/src/chonkie/refinery/__init__.py b/src/chonkie/refinery/__init__.py new file mode 100644 index 0000000..dae44df --- /dev/null +++ b/src/chonkie/refinery/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseRefinery +from .overlap import OverlapRefinery + +# Include all the refinery classes in the __all__ list +__all__ = ["BaseRefinery", "OverlapRefinery"] + diff --git a/src/chonkie/refinery/base.py b/src/chonkie/refinery/base.py new file mode 100644 index 0000000..b96b462 --- /dev/null +++ b/src/chonkie/refinery/base.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Any, List + +from chonkie.chunker import Chunk + +class BaseRefinery(ABC): + """Base class for all Refinery classes. + + Refinery classes are used to refine the Chunks generated from the + Chunkers. These classes take in chunks and return refined chunks. + Most refinery classes would be used to add additional context to the + chunks generated by the chunkers. + """ + + def __init__(self, context_size: int = 0) -> None: + """Initialize the Refinery.""" + if context_size < 0: + raise ValueError("context_size must be non-negative") + self.context_size = context_size + + @abstractmethod + def refine(self, chunks: List[Chunk]) -> List[Chunk]: + """Refine the given list of chunks and return the refined list.""" + pass + + @classmethod + @abstractmethod + def is_available(cls) -> bool: + """Check if the Refinery is available.""" + pass + + def __repr__(self) -> str: + """Representation of the Refinery.""" + return f"{self.__class__.__name__}(context_size={self.context_size})" + + def __call__(self, chunks: List[Chunk]) -> List[Chunk]: + """Call the Refinery.""" + return self.refine(chunks) \ No newline at end of file diff --git a/src/chonkie/refinery/overlap.py b/src/chonkie/refinery/overlap.py new file mode 100644 index 0000000..0e1d8a6 --- /dev/null +++ b/src/chonkie/refinery/overlap.py @@ -0,0 +1,298 @@ +from typing import Any, List, Optional +from dataclasses import dataclass + +from chonkie.chunker import Chunk, SentenceChunk, SemanticChunk +from chonkie.refinery.base import BaseRefinery + +from chonkie.context import Context + +class OverlapRefinery(BaseRefinery): + """Refinery class which adds overlap as context to chunks. + + This refinery provides two methods for calculating overlap: + 1. Exact: Uses a tokenizer to precisely determine token boundaries + 2. Approximate: Estimates tokens based on text length ratios + + It can handle different types of chunks (basic Chunks, SentenceChunks, + and SemanticChunks) and can optionally update the chunk text to include + the overlap content. + """ + + def __init__( + self, + context_size: int = 128, + tokenizer: Any = None, + merge_context: bool = True, + inplace: bool = True, + approximate: bool = True + ) -> None: + """Initialize the OverlapRefinery class. + + Args: + context_size: Number of tokens to include in context + tokenizer: Optional tokenizer for exact token counting + merge_context: Whether to merge context with chunk text + inplace: Whether to update chunks in place + approximate: Whether to use approximate token counting + + """ + super().__init__(context_size) + self.merge_context = merge_context + self.inplace = inplace + + # If tokenizer provided, we can do exact token counting + if tokenizer is not None: + self.tokenizer = tokenizer + self.approximate = approximate + else: + # Without tokenizer, must use approximate method + self.approximate = True + + def _get_refined_chunks(self, chunks: List[Chunk], inplace: bool = True) -> List[Chunk]: + """Convert regular chunks to refined chunks with progressive memory cleanup. + + This method takes regular chunks and converts them to RefinedChunks one at a + time. When inplace is True, it progressively removes chunks from the input + list to minimize memory usage. + + The conversion preserves all relevant information from the original chunks, + including sentences and embeddings if they exist. This allows us to maintain + the full capabilities of semantic chunks while adding refinement features. + + Args: + chunks: List of original chunks to convert + inplace: Whether to modify the input list during conversion + + Returns: + List of RefinedChunks without any context (context is added later) + + Example: + For memory efficiency with large datasets: + ``` + chunks = load_large_dataset() # Many chunks + refined = refinery._get_refined_chunks(chunks, inplace=True) + # chunks is now empty, memory is freed + ``` + + """ + if not chunks: + return [] + + refined_chunks = [] + + # Use enumerate to track position without modifying list during iteration + for i in range(len(chunks)): + if inplace: + # Get and remove the first chunk + chunk = chunks.pop(0) + else: + # Just get a reference if not modifying in place + chunk = chunks[i] + + # Create refined version preserving appropriate attributes + refined_chunk = SemanticChunk( + text=chunk.text, + start_index=chunk.start_index, + end_index=chunk.end_index, + token_count=chunk.token_count, + # Preserve sentences and embeddings if they exist + sentences=chunk.sentences if isinstance(chunk, (SentenceChunk, SemanticChunk)) else None, + embedding=chunk.embedding if isinstance(chunk, SemanticChunk) else None, + context=None # Context is added later in the refinement process + ) + + refined_chunks.append(refined_chunk) + + if inplace: + # Clear the input list to free memory + chunks.clear() + chunks += refined_chunks + + return refined_chunks + + def _overlap_token_exact(self, chunk: Chunk) -> Optional[Context]: + """Calculate precise token-based overlap context using tokenizer. + + Takes a larger window of text from the chunk end, tokenizes it, + and selects exactly context_size tokens worth of text. + + Args: + chunk: Chunk to extract context from + + Returns: + Context object with precise token boundaries, or None if no tokenizer + + """ + if not hasattr(self, 'tokenizer'): + return None + + # Take 6x context_size characters to ensure enough tokens + char_window = min(len(chunk.text), self.context_size * 6) + text_portion = chunk.text[-char_window:] + + # Get exact token boundaries + tokens = self.tokenizer.encode(text_portion) + context_tokens = min(self.context_size, len(tokens)) + context_tokens_ids = tokens[-context_tokens:] + context_text = self.tokenizer.decode(context_tokens_ids) + + # Find where context text starts in chunk + try: + context_start = chunk.text.rindex(context_text) + start_index = chunk.start_index + context_start + + return Context( + text=context_text, + token_count=context_tokens, + start_index=start_index, + end_index=chunk.end_index + ) + except ValueError: + # If context text can't be found (e.g., due to special tokens), fall back to approximate + return self._overlap_token_approximate(chunk) + + def _overlap_token_approximate(self, chunk: Chunk) -> Optional[Context]: + """Calculate approximate token-based overlap context. + + Estimates token positions based on character length ratios. + + Args: + chunk: Chunk to extract context from + + Returns: + Context object with estimated token boundaries + + """ + # Calculate desired context size + context_tokens = min(self.context_size, chunk.token_count) + + # Estimate text length based on token ratio + context_ratio = context_tokens / chunk.token_count + char_length = int(len(chunk.text) * context_ratio) + + # Extract context text from end + context_text = chunk.text[-char_length:] + + return Context( + text=context_text, + token_count=context_tokens, + start_index=chunk.end_index - char_length, + end_index=chunk.end_index + ) + + + def _overlap_token(self, chunk: Chunk) -> Optional[Context]: + """Choose between exact or approximate token overlap calculation. + + Args: + chunk: Chunk to process + + Returns: + Context object from either exact or approximate calculation + + """ + if self.approximate: + return self._overlap_token_approximate(chunk) + return self._overlap_token_exact(chunk) + + def _overlap_sentence(self, chunk: SentenceChunk) -> Optional[Context]: + """Calculate overlap context based on sentences. + + Takes sentences from the end of the chunk up to context_size tokens. + + Args: + chunk: SentenceChunk to process + + Returns: + Context object containing complete sentences + + """ + if not chunk.sentences: + return None + + context_sentences = [] + total_tokens = 0 + + # Add sentences from the end until we hit context_size + for sentence in reversed(chunk.sentences): + if total_tokens + sentence.token_count <= self.context_size: + context_sentences.insert(0, sentence) + total_tokens += sentence.token_count + else: + break + # If no sentences were added, add the last sentence + if not context_sentences: + context_sentences.append(chunk.sentences[-1]) + total_tokens = chunk.sentences[-1].token_count + + return Context( + text=" ".join(s.text for s in context_sentences), + token_count=total_tokens, + start_index=context_sentences[0].start_index, + end_index=context_sentences[-1].end_index + ) + + def _get_overlap_context(self, chunk: Chunk) -> Optional[Context]: + """Get appropriate overlap context based on chunk type.""" + if isinstance(chunk, SemanticChunk): + return self._overlap_sentence(chunk) + elif isinstance(chunk, SentenceChunk): + return self._overlap_sentence(chunk) + elif isinstance(chunk, Chunk): + return self._overlap_token(chunk) + else: + raise ValueError(f"Unsupported chunk type: {type(chunk)}") + + def refine(self, chunks: List[Chunk]) -> List[Chunk]: + """Refine chunks by adding overlap context. + + For each chunk after the first, adds context from the previous chunk. + Can optionally update the chunk text to include the context. + + Args: + chunks: List of chunks to refine + + Returns: + List of refined chunks with added context + + """ + if not chunks: + return chunks + + # Validate chunk types + if len(set(type(chunk) for chunk in chunks)) > 1: + raise ValueError("All chunks must be of the same type") + + if not self.inplace: + refined_chunks = [chunk.copy() for chunk in chunks] + else: + refined_chunks = chunks + + # Process remaining chunks + for i in range(1, len(refined_chunks)): + # Get context from previous chunk + context = self._get_overlap_context(chunks[i-1]) + setattr(refined_chunks[i], 'context', context) + + # Optionally update chunk text to include context + if self.merge_context and context: + refined_chunks[i].text = context.text + refined_chunks[i].text + refined_chunks[i].start_index = context.start_index + # Update token count to include context and space + # Calculate new token count + if hasattr(self, 'tokenizer') and not self.approximate: + # Use exact token count if we have a tokenizer + refined_chunks[i].token_count = len(self.tokenizer.encode(refined_chunks[i].text)) + else: + # Otherwise use approximate by adding context tokens plus one for space + refined_chunks[i].token_count = refined_chunks[i].token_count + context.token_count + 1 + + return refined_chunks + + @classmethod + def is_available(cls) -> bool: + """Check if the OverlapRefinery is available. + + Always returns True as this refinery has no external dependencies. + """ + return True \ No newline at end of file diff --git a/src/chonkie/token_factory.py b/src/chonkie/token_factory.py new file mode 100644 index 0000000..100230d --- /dev/null +++ b/src/chonkie/token_factory.py @@ -0,0 +1,21 @@ +"""Factory class for creating and managing tokenizers. + +This factory class is used to create and manage tokenizers for the Chonkie +package. It provides a simple interface for initializing, encoding, decoding, +and counting tokens using different tokenizer backends. + +This is used in the Chunker and Refinery classes to ensure consistent tokenization +across different parts of the pipeline. +""" + +from typing import Callable, List, TYPE_CHECKING + + +if TYPE_CHECKING: + import tiktoken + from transformers import AutoTokenizer + from tokenizers import Tokenizer + +class TokenFactory: + """Factory class for creating and managing tokenizers.""" + pass \ No newline at end of file diff --git a/tests/chunker/test_sdpm_chunker.py b/tests/chunker/test_sdpm_chunker.py index 6dee6ce..1ef0b7a 100644 --- a/tests/chunker/test_sdpm_chunker.py +++ b/tests/chunker/test_sdpm_chunker.py @@ -140,6 +140,24 @@ def test_spdm_chunker_repr(embedding_model): ) assert repr(chunker) == expected +def test_spdm_chunker_percentile_mode(embedding_model, sample_complex_markdown_text): + """Test the SPDMChunker works with percentile-based similarity.""" + chunker = SDPMChunker( + embedding_model=embedding_model, + chunk_size=512, + similarity_percentile=50, + ) + chunks = chunker.chunk(sample_complex_markdown_text) + + assert len(chunks) > 0 + assert isinstance(chunks[0], SemanticChunk) + assert all([chunk.token_count <= 512 for chunk in chunks]) + assert all([chunk.token_count > 0 for chunk in chunks]) + assert all([chunk.text is not None for chunk in chunks]) + assert all([chunk.start_index is not None for chunk in chunks]) + assert all([chunk.end_index is not None for chunk in chunks]) + assert all([chunk.sentences is not None for chunk in chunks]) + if __name__ == "__main__": pytest.main() diff --git a/tests/refinery/test_overlap_refinery.py b/tests/refinery/test_overlap_refinery.py new file mode 100644 index 0000000..b0371cd --- /dev/null +++ b/tests/refinery/test_overlap_refinery.py @@ -0,0 +1,221 @@ +import pytest +from typing import List +from dataclasses import dataclass +from transformers import AutoTokenizer + +from chonkie.chunker import Chunk, SentenceChunk, SemanticChunk, Sentence +from chonkie.refinery import OverlapRefinery +from chonkie.context import Context + +@pytest.fixture +def tokenizer(): + """Fixture providing a GPT-2 tokenizer for testing.""" + return AutoTokenizer.from_pretrained("gpt2") + +@pytest.fixture +def basic_chunks() -> List[Chunk]: + """Fixture providing a list of basic Chunks for testing.""" + return [ + Chunk( + text="This is the first chunk of text.", + start_index=0, + end_index=30, + token_count=8 + ), + Chunk( + text="This is the second chunk of text.", + start_index=31, + end_index=62, + token_count=8 + ), + Chunk( + text="This is the third chunk of text.", + start_index=63, + end_index=93, + token_count=8 + ) + ] + +@pytest.fixture +def sentence_chunks() -> List[SentenceChunk]: + """Fixture providing a list of SentenceChunks for testing.""" + sentences1 = [ + Sentence(text="First sentence.", start_index=0, end_index=14, token_count=3), + Sentence(text="Second sentence.", start_index=15, end_index=30, token_count=3) + ] + sentences2 = [ + Sentence(text="Third sentence.", start_index=31, end_index=45, token_count=3), + Sentence(text="Fourth sentence.", start_index=46, end_index=62, token_count=3) + ] + return [ + SentenceChunk( + text="First sentence. Second sentence.", + start_index=0, + end_index=30, + token_count=6, + sentences=sentences1 + ), + SentenceChunk( + text="Third sentence. Fourth sentence.", + start_index=31, + end_index=62, + token_count=6, + sentences=sentences2 + ) + ] + +def test_overlap_refinery_initialization(): + """Test that OverlapRefinery initializes correctly with different parameters.""" + # Test default initialization + refinery = OverlapRefinery() + assert refinery.context_size == 128 + assert refinery.merge_context is True + assert refinery.approximate is True + assert not hasattr(refinery, 'tokenizer') + + # Test initialization with tokenizer + tokenizer = AutoTokenizer.from_pretrained("gpt2") + refinery = OverlapRefinery( + context_size=64, + tokenizer=tokenizer, + merge_context=False, + approximate=False + ) + assert refinery.context_size == 64 + assert refinery.merge_context is False + assert refinery.approximate is False + assert hasattr(refinery, 'tokenizer') + assert refinery.tokenizer == tokenizer + +def test_overlap_refinery_empty_input(): + """Test that OverlapRefinery handles empty input correctly.""" + refinery = OverlapRefinery() + assert refinery.refine([]) == [] + +def test_overlap_refinery_single_chunk(): + """Test that OverlapRefinery handles single chunk input correctly.""" + refinery = OverlapRefinery() + chunk = Chunk(text="Single chunk.", start_index=0, end_index=12, token_count=3) + refined = refinery.refine([chunk]) + assert len(refined) == 1 + assert refined[0].context is None + +def test_overlap_refinery_basic_chunks_approximate(basic_chunks): + """Test approximate overlap calculation with basic Chunks.""" + refinery = OverlapRefinery(context_size=4) # Small context for testing + refined = refinery.refine(basic_chunks) + + # First chunk should have no context + assert refined[0].context is None + + # Subsequent chunks should have context from previous chunks + for i in range(1, len(refined)): + assert refined[i].context is not None + assert isinstance(refined[i].context, Context) + assert refined[i].context.token_count <= 4 + +def test_overlap_refinery_basic_chunks_exact(basic_chunks, tokenizer): + """Test exact overlap calculation with basic Chunks using tokenizer.""" + refinery = OverlapRefinery( + context_size=4, + tokenizer=tokenizer, + approximate=False + ) + refined = refinery.refine(basic_chunks) + + # Check context for subsequent chunks + for i in range(1, len(refined)): + assert refined[i].context is not None + assert isinstance(refined[i].context, Context) + # Verify exact token count using tokenizer + actual_tokens = len(tokenizer.encode(refined[i].context.text)) + assert actual_tokens <= 4 + +def test_overlap_refinery_sentence_chunks(sentence_chunks): + """Test overlap calculation with SentenceChunks.""" + refinery = OverlapRefinery(context_size=4) + refined = refinery.refine(sentence_chunks) + + # Check context for second chunk + assert refined[1].context is not None + assert isinstance(refined[1].context, Context) + assert refined[1].context.token_count <= 4 + +def test_overlap_refinery_no_merge_context(basic_chunks): + """Test behavior when merge_context is False.""" + refinery = OverlapRefinery(context_size=4, merge_context=False) + refined = refinery.refine(basic_chunks) + + # Chunks should maintain original text + for i in range(len(refined)): + assert refined[i].text == basic_chunks[i].text + assert refined[i].token_count == basic_chunks[i].token_count + +def test_overlap_refinery_context_size_limits(basic_chunks): + """Test that context size limits are respected.""" + refinery = OverlapRefinery(context_size=2) # Very small context + refined = refinery.refine(basic_chunks) + + # Check that no context exceeds size limit + for chunk in refined[1:]: # Skip first chunk + assert chunk.context.token_count <= 2 + +def test_overlap_refinery_merge_context(basic_chunks, tokenizer): + """Test merging context into chunk text.""" + refinery = OverlapRefinery( + context_size=4, + tokenizer=tokenizer, + merge_context=True, + approximate=False + ) + + # Create a deep copy to preserve originals + chunks_copy = [ + Chunk( + text=chunk.text, + start_index=chunk.start_index, + end_index=chunk.end_index, + token_count=chunk.token_count + ) for chunk in basic_chunks + ] + + refined = refinery.refine(chunks_copy) + + # First chunk should be unchanged + assert refined[0].text == basic_chunks[0].text + assert refined[0].token_count == basic_chunks[0].token_count + + # Subsequent chunks should have context prepended + for i in range(1, len(refined)): + assert refined[i].context is not None + assert refined[i].text.startswith(refined[i].context.text) + # Verify token count increase + original_tokens = len(tokenizer.encode(basic_chunks[i].text)) + new_tokens = len(tokenizer.encode(refined[i].text)) + assert new_tokens > original_tokens + +def test_overlap_refinery_mixed_chunk_types(): + """Test that refinery raises error for mixed chunk types.""" + # Create chunks of different types + chunks = [ + Chunk( + text="Basic chunk.", + start_index=0, + end_index=12, + token_count=3 + ), + SentenceChunk( + text="Sentence chunk.", + start_index=13, + end_index=27, + token_count=3, + sentences=[] + ) + ] + + refinery = OverlapRefinery() + with pytest.raises(ValueError, match="All chunks must be of the same type"): + refinery.refine(chunks) + +if __name__ == "__main__": + pytest.main() \ No newline at end of file