diff --git a/README.md b/README.md index 82021da..ab91fff 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Documentation](https://img.shields.io/badge/docs-chonkie.ai-blue.svg)](https://docs.chonkie.ai) ![Package size](https://img.shields.io/badge/size-11.2MB-blue) [![Downloads](https://static.pepy.tech/badge/chonkie)](https://pepy.tech/project/chonkie) -[![Discord](https://dcbadge.limes.pink/api/server/https://discord.gg/nMYNVyuB5Y?style=flat)](https://discord.gg/rYYp6DC4cv) +[![Discord](https://dcbadge.limes.pink/api/server/https://discord.gg/rYYp6DC4cv?style=flat)](https://discord.gg/rYYp6DC4cv) [![GitHub stars](https://img.shields.io/github/stars/bhavnicksm/chonkie.svg)](https://github.com/bhavnicksm/chonkie/stargazers) _The no-nonsense RAG chunking library that's lightweight, lightning-fast, and ready to CHONK your texts_ diff --git a/pyproject.toml b/pyproject.toml index 286c740..5926426 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "chonkie" -version = "0.4.0" +version = "0.4.1" description = "🦛 CHONK your texts with Chonkie ✨ - The no-nonsense RAG chunking library" readme = "README.md" requires-python = ">=3.9" diff --git a/src/chonkie/__init__.py b/src/chonkie/__init__.py index 49a46e1..c166c07 100644 --- a/src/chonkie/__init__.py +++ b/src/chonkie/__init__.py @@ -34,9 +34,9 @@ SentenceChunk, ) -__version__ = "0.4.0" +__version__ = "0.4.1" __name__ = "chonkie" -__author__ = "Bhavnick Minhas" +__author__ = "Chonkie AI" # Add basic package metadata to __all__ __all__ = [ diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index 35ab3ab..73c8be0 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -7,9 +7,11 @@ from multiprocessing import Pool, cpu_count from typing import Any, Callable, List, Union +from tqdm import tqdm + from chonkie.types import Chunk -from tqdm import tqdm + class BaseChunker(ABC): """Abstract base class for all chunker implementations. @@ -246,11 +248,11 @@ def _process_batch_sequential(self, return [ self.chunk(t) for t in tqdm( texts, - desc="🦛 CHONKING", + desc="🦛", disable=not show_progress_bar, - unit="texts", - bar_format="{desc}: [{bar:20}] {percentage:3.0f}% • {n_fmt}/{total_fmt} texts chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", - ascii=' >=') + unit="doc", + bar_format="{desc} ch{bar:20}nk {percentage:3.0f}% • {n_fmt}/{total_fmt} docs chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", + ascii=' o') ] def _process_batch_multiprocessing(self, @@ -264,12 +266,12 @@ def _process_batch_multiprocessing(self, with Pool(processes=num_workers) as pool: results = [] with tqdm(total=total, - desc="🦛 CHONKING", + desc="🦛", disable=not show_progress_bar, - unit="texts", - bar_format="{desc}: [{bar:20}] {percentage:3.0f}% • {n_fmt}/{total_fmt} texts chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", - ascii=' >=') as pbar: - for result in pool.imap_unordered(self.chunk, texts, chunksize=chunksize): + unit="doc", + bar_format="{desc} ch{bar:20}nk {percentage:3.0f}% • {n_fmt}/{total_fmt} docs chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", + ascii=' o') as pbar: + for result in pool.imap(self.chunk, texts, chunksize=chunksize): results.append(result) pbar.update() return results diff --git a/src/chonkie/chunker/recursive.py b/src/chonkie/chunker/recursive.py index 7bc6437..107d0bb 100644 --- a/src/chonkie/chunker/recursive.py +++ b/src/chonkie/chunker/recursive.py @@ -2,7 +2,7 @@ from bisect import bisect_left from functools import lru_cache from itertools import accumulate -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, Literal from chonkie.chunker.base import BaseChunker from chonkie.types import Chunk, RecursiveChunk, RecursiveLevel, RecursiveRules @@ -18,21 +18,35 @@ class RecursiveChunker(BaseChunker): """ def __init__(self, - tokenizer: Union[str, Any] = "gpt2", + tokenizer_or_token_counter: Union[str, Callable, Any] = "gpt2", chunk_size: int = 512, + min_characters_per_chunk: int = 12, rules: RecursiveRules = RecursiveRules(), - min_characters_per_chunk: int = 12 + return_type: Literal["chunks", "texts"] = "chunks" ) -> None: """Initialize the recursive chunker. Args: - tokenizer: The tokenizer to use for encoding/decoding. + tokenizer_or_token_counter: The tokenizer or token counter to use for encoding/decoding. chunk_size: The size of the chunks to return. - rules: The rules to use for chunking. min_characters_per_chunk: The minimum number of characters per chunk. - + rules: The rules to use for chunking. + return_type: Whether to return chunks or texts. + + Raises: + ValueError: If parameters are invalid. + """ - super().__init__(tokenizer) + super().__init__(tokenizer_or_token_counter=tokenizer_or_token_counter) + + if chunk_size <= 0: + raise ValueError("chunk_size must be positive") + if min_characters_per_chunk < 1: + raise ValueError("min_characters_per_chunk must be at least 1") + if return_type not in ["chunks", "texts"]: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") + + self.return_type = return_type self.rules = rules self.chunk_size = chunk_size self.min_characters_per_chunk = min_characters_per_chunk @@ -194,7 +208,10 @@ def _recursive_chunk(self, # If level is out of bounds, return the text as a chunk if level >= len(self.rules): - return [self._create_chunk(text, self._get_token_count(text), level, full_text)] + if self.return_type == "chunks": + return [self._create_chunk(text, self._get_token_count(text), level, full_text)] + elif self.return_type == "texts": + return [text] # If full_text is not provided, use the text if full_text is None: @@ -227,30 +244,32 @@ def _recursive_chunk(self, if token_count > self.chunk_size: chunks.extend(self._recursive_chunk(split, level + 1, full_text)) else: - if rule.delimiters is None and not rule.whitespace: - # NOTE: This is a hack to get the decoded text, since merged = splits = token_splits - # And we don't want to encode/decode the text again, that would be inefficient - decoded_text = "".join(merged) - chunks.append(self._create_chunk(split, token_count, level, decoded_text)) - else: - chunks.append(self._create_chunk(split, token_count, level, full_text)) - + if self.return_type == "chunks": + if rule.delimiters is None and not rule.whitespace: + # NOTE: This is a hack to get the decoded text, since merged = splits = token_splits + # And we don't want to encode/decode the text again, that would be inefficient + decoded_text = "".join(merged) + chunks.append(self._create_chunk(split, token_count, level, decoded_text)) + else: + chunks.append(self._create_chunk(split, token_count, level, full_text)) + elif self.return_type == "texts": + chunks.append(split) return chunks - def chunk(self, text: str) -> List[Chunk]: """Chunk the text.""" return self._recursive_chunk(text, level=0, full_text=text) - def __repr__(self) -> str: """Get a string representation of the recursive chunker.""" return (f"RecursiveChunker(rules={self.rules}, " f"chunk_size={self.chunk_size}, " - f"min_characters_per_chunk={self.min_characters_per_chunk})") + f"min_characters_per_chunk={self.min_characters_per_chunk}, " + f"return_type={self.return_type})") def __str__(self) -> str: """Get a string representation of the recursive chunker.""" return (f"RecursiveChunker(rules={self.rules}, " f"chunk_size={self.chunk_size}, " - f"min_characters_per_chunk={self.min_characters_per_chunk})") + f"min_characters_per_chunk={self.min_characters_per_chunk}, " + f"return_type={self.return_type})") diff --git a/src/chonkie/chunker/sdpm.py b/src/chonkie/chunker/sdpm.py index 11eb19b..3203424 100644 --- a/src/chonkie/chunker/sdpm.py +++ b/src/chonkie/chunker/sdpm.py @@ -1,6 +1,6 @@ """Semantic Double Pass Merge chunking using sentence embeddings.""" -from typing import Any, List, Union +from typing import Any, List, Union, Literal from chonkie.types import SemanticChunk, Sentence @@ -17,15 +17,17 @@ class SDPMChunker(SemanticChunker): Args: embedding_model: Sentence embedding model to use - similarity_threshold: Minimum similarity score to consider sentences similar - similarity_percentile: Minimum similarity percentile to consider sentences similar + mode: Mode for grouping sentences, either "cumulative" or "window" + threshold: Threshold for semantic similarity (0-1) or percentile (1-100), defaults to "auto" chunk_size: Maximum token count for a chunk - initial_sentences: Number of sentences to consider for initial grouping - skip_window: Number of chunks to skip when looking for similarities + similarity_window: Number of sentences to consider for similarity threshold calculation + min_sentences: Minimum number of sentences per chunk min_chunk_size: Minimum number of tokens per sentence - - Methods: - chunk: Split text into chunks using the SDPM approach. + min_characters_per_sentence: Minimum number of characters per sentence + threshold_step: Step size for similarity threshold calculation + delim: Delimiters to split sentences on + skip_window: Number of chunks to skip when looking for similarities + return_type: Whether to return chunks or texts """ @@ -42,6 +44,7 @@ def __init__( threshold_step: float = 0.01, delim: Union[str, List[str]] = [".", "!", "?", "\n"], skip_window: int = 1, + return_type: Literal["chunks", "texts"] = "chunks", **kwargs ): """Initialize the SDPMChunker. @@ -58,6 +61,7 @@ def __init__( threshold_step: Step size for similarity threshold calculation delim: Delimiters to split sentences on skip_window: Number of chunks to skip when looking for similarities + return_type: Whether to return chunks or texts **kwargs: Additional keyword arguments """ @@ -72,6 +76,7 @@ def __init__( min_characters_per_sentence=min_characters_per_sentence, threshold_step=threshold_step, delim=delim, + return_type=return_type, **kwargs ) self.skip_window = skip_window diff --git a/src/chonkie/chunker/semantic.py b/src/chonkie/chunker/semantic.py index 9288372..db0c514 100644 --- a/src/chonkie/chunker/semantic.py +++ b/src/chonkie/chunker/semantic.py @@ -1,7 +1,7 @@ """Semantic chunking using sentence embeddings.""" import warnings -from typing import List, Union +from typing import List, Union, Literal import numpy as np @@ -24,7 +24,10 @@ class SemanticChunker(BaseChunker): min_chunk_size: Minimum number of tokens per sentence (defaults to 2) threshold_step: Step size for similarity threshold calculation delim: Delimiters to split sentences on + return_type: Whether to return chunks or texts + Raises: + ValueError: If parameters are invalid """ def __init__( @@ -39,6 +42,7 @@ def __init__( min_characters_per_sentence: int = 12, threshold_step: float = 0.01, delim: Union[str, List[str]] = [".", "!", "?", "\n"], + return_type: Literal["chunks", "texts"] = "chunks", **kwargs ): """Initialize the SemanticChunker. @@ -56,6 +60,7 @@ def __init__( min_chunk_size: Minimum number of tokens per chunk (and sentence, defaults to 2) threshold_step: Step size for similarity threshold calculation delim: Delimiters to split sentences on + return_type: Whether to return chunks or texts **kwargs: Additional keyword arguments Raises: @@ -85,6 +90,8 @@ def __init__( raise ValueError("threshold (float) must be between 0 and 1") elif type(threshold) == int and (threshold < 1 or threshold > 100): raise ValueError("threshold (int) must be between 1 and 100") + if return_type not in ["chunks", "texts"]: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") self.mode = mode self.chunk_size = chunk_size @@ -96,6 +103,7 @@ def __init__( self.threshold_step = threshold_step self.delim = delim self.sep = "🦛" + self.return_type = return_type if isinstance(threshold, float): self.similarity_threshold = threshold @@ -115,13 +123,13 @@ def __init__( self.embedding_model = AutoEmbeddings.get_embeddings(embedding_model, **kwargs) else: raise ValueError( - "embedding_model must be a string or BaseEmbeddings instance" + f"{embedding_model} is not a valid embedding model" ) # Probably the dependency is not installed if self.embedding_model is None: raise ImportError( - "embedding_model is not a valid embedding model", + f"{embedding_model} is not a valid embedding model", "Please install the `semantic` extra to use this feature", ) @@ -453,24 +461,27 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]: return self._group_sentences_window(sentences) def _create_chunk( - self, sentences: List[Sentence], similarity_scores: List[float] = None + self, sentences: List[Sentence] ) -> SemanticChunk: """Create a chunk from a list of sentences.""" if not sentences: raise ValueError("Cannot create chunk from empty sentence list") - - # Compute chunk text and token count from sentences - text = "".join(sent.text for sent in sentences) - token_count = sum(sent.token_count for sent in sentences) - - return SemanticChunk( - text=text, - start_index=sentences[0].start_index, - end_index=sentences[-1].end_index, - token_count=token_count, - sentences=sentences, - ) - + if self.return_type == "chunks": + # Compute chunk text and token count from sentences + text = "".join(sent.text for sent in sentences) + token_count = sum(sent.token_count for sent in sentences) + return SemanticChunk( + text=text, + start_index=sentences[0].start_index, + end_index=sentences[-1].end_index, + token_count=token_count, + sentences=sentences, + ) + elif self.return_type == "texts": + return "".join(sent.text for sent in sentences) + else: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") + def _split_chunks( self, sentence_groups: List[List[Sentence]] ) -> List[SemanticChunk]: diff --git a/src/chonkie/chunker/sentence.py b/src/chonkie/chunker/sentence.py index 86b8964..c526df8 100644 --- a/src/chonkie/chunker/sentence.py +++ b/src/chonkie/chunker/sentence.py @@ -1,7 +1,7 @@ """Sentence chunker.""" from bisect import bisect_left from itertools import accumulate -from typing import Any, List, Union +from typing import Any, Callable, List, Union, Literal from chonkie.types import Chunk, Sentence, SentenceChunk @@ -26,14 +26,14 @@ class SentenceChunker(BaseChunker): def __init__( self, - tokenizer: Union[str, Any] = "gpt2", + tokenizer_or_token_counter: Union[str, Callable, Any] = "gpt2", chunk_size: int = 512, chunk_overlap: int = 128, min_sentences_per_chunk: int = 1, min_characters_per_sentence: int = 12, approximate: bool = True, delim: Union[str, List[str]] = [".", "!", "?", "\n"], - **kwargs + return_type: Literal["chunks", "texts"] = "chunks" ): """Initialize the SentenceChunker with configuration parameters. @@ -48,11 +48,13 @@ def __init__( min_characters_per_sentence: Minimum number of characters per sentence approximate: Whether to use approximate token counting (defaults to True) delim: Delimiters to split sentences on + return_type: Whether to return chunks or texts + Raises: ValueError: If parameters are invalid """ - super().__init__(tokenizer) + super().__init__(tokenizer_or_token_counter=tokenizer_or_token_counter) if chunk_size <= 0: raise ValueError("chunk_size must be positive") @@ -62,6 +64,8 @@ def __init__( raise ValueError("min_sentences_per_chunk must be at least 1") if min_characters_per_sentence < 1: raise ValueError("min_characters_per_sentence must be at least 1") + if return_type not in ["chunks", "texts"]: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap @@ -70,6 +74,7 @@ def __init__( self.approximate = approximate self.delim = delim self.sep = "🦛" + self.return_type = return_type # TODO: This is a older method of sentence splitting that uses Regex # but since Regex in python via re is super slooooow we use a different method @@ -187,20 +192,6 @@ def _split_sentences(self, text: str) -> List[str]: return sentences - def _get_token_counts(self, sentences: List[str]) -> List[int]: - """Get token counts for a list of sentences in batch. - - Args: - sentences: List of sentences - - Returns: - List of token counts for each sentence - - """ - # Batch encode all sentences at once - encoded_sentences = self._encode_batch(sentences) - return [len(encoded) for encoded in encoded_sentences] - def _estimate_token_counts(self, sentences: List[str]) -> int: """Estimate token count using character length.""" CHARS_PER_TOKEN = 6.0 # Avg. char per token for llama3 is b/w 6-7 @@ -243,7 +234,7 @@ def _prepare_sentences(self, text: str) -> List[Sentence]: if not self.approximate: # Get accurate token counts in batch - token_counts = self._get_token_counts(sentence_texts) + token_counts = self._count_tokens_batch(sentence_texts) else: # Estimate token counts using character length token_counts = self._estimate_token_counts(sentence_texts) @@ -297,13 +288,16 @@ def _create_chunk(self, sentences: List[Sentence], token_count: int) -> Chunk: """ chunk_text = "".join([sentence.text for sentence in sentences]) - return SentenceChunk( - text=chunk_text, - start_index=sentences[0].start_index, - end_index=sentences[-1].end_index, - token_count=token_count, - sentences=sentences, - ) + if self.return_type == "texts": + return chunk_text + else: + return SentenceChunk( + text=chunk_text, + start_index=sentences[0].start_index, + end_index=sentences[-1].end_index, + token_count=token_count, + sentences=sentences, + ) def chunk(self, text: str) -> List[Chunk]: """Split text into overlapping chunks based on sentences while respecting token limits. @@ -357,7 +351,7 @@ def chunk(self, text: str) -> List[Chunk]: # Get candidate sentences and verify actual token count chunk_sentences = sentences[pos:split_idx] chunk_text = "".join(s.text for s in chunk_sentences) - actual = len(self._encode(chunk_text)) + actual = self._count_tokens(chunk_text) # Given the actual token_count and the estimate, get a feedback value for the next loop feedback = self._get_feedback(estimate, actual) @@ -371,8 +365,8 @@ def chunk(self, text: str) -> List[Chunk]: split_idx -= 1 chunk_sentences = sentences[pos:split_idx] chunk_text = "".join(s.text for s in chunk_sentences) - actual = len(self._encode(chunk_text)) - + actual = self._count_tokens(chunk_text) + chunks.append(self._create_chunk(chunk_sentences, actual)) # Calculate next position with overlap diff --git a/src/chonkie/chunker/token.py b/src/chonkie/chunker/token.py index 3816402..186d43f 100644 --- a/src/chonkie/chunker/token.py +++ b/src/chonkie/chunker/token.py @@ -1,12 +1,14 @@ """Token-based chunking.""" -from typing import Any, Generator, List, Tuple, Union +from typing import Any, Generator, List, Tuple, Union, Literal + +from tqdm import trange from chonkie.types import Chunk from .base import BaseChunker -from tqdm import trange + class TokenChunker(BaseChunker): """Chunker that splits text into chunks of a specified token size. @@ -22,6 +24,7 @@ def __init__( tokenizer: Union[str, Any] = "gpt2", chunk_size: int = 512, chunk_overlap: Union[int, float] = 128, + return_type: Literal["chunks", "texts"] = "chunks" ) -> None: """Initialize the TokenChunker with configuration parameters. @@ -29,6 +32,7 @@ def __init__( tokenizer: The tokenizer instance to use for encoding/decoding chunk_size: Maximum number of tokens per chunk chunk_overlap: Number of tokens to overlap between chunks + return_type: Whether to return chunks or texts Raises: ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size @@ -41,7 +45,10 @@ def __init__( raise ValueError("chunk_overlap must be less than chunk_size") if isinstance(chunk_overlap, float) and chunk_overlap >= 1: raise ValueError("chunk_overlap must be less than 1") + if return_type not in ["chunks", "texts"]: + raise ValueError("return_type must be either 'chunks' or 'texts'") + self.return_type = return_type self.chunk_size = chunk_size self.chunk_overlap = ( chunk_overlap @@ -114,37 +121,24 @@ def chunk(self, text: str) -> List[Chunk]: # Calculate token groups and counts token_groups = list(self._token_group_generator(text_tokens)) - token_counts = [len(toks) for toks in token_groups] - # decode the token groups into the chunk texts - chunk_texts = self._decode_batch(token_groups) + # if return_type is chunks, we need to decode the token groups into the chunk texts + if self.return_type == "chunks": + token_counts = [len(toks) for toks in token_groups] - # Create the chunks from the token groups and token counts - chunks = self._create_chunks(chunk_texts, token_groups, token_counts) + # decode the token groups into the chunk texts + chunk_texts = self._decode_batch(token_groups) - return chunks + # Create the chunks from the token groups and token counts + chunks = self._create_chunks(chunk_texts, token_groups, token_counts) - def _process_batch(self, - chunks: List[Tuple[List[int], int, int]], - full_text: str) -> List[Chunk]: - """Process a batch of chunks.""" - token_lists = [tokens for tokens, _, _ in chunks] - texts = self._decode_batch(token_lists) + return chunks + # if return_type is texts, we can just return the decoded token groups + elif self.return_type == "texts": + return self._decode_batch(token_groups) - index_pairs = [] - current_index = 0 - for text in texts: - start_index = full_text.find(text, current_index) - end_index = start_index + len(text) - index_pairs.append((start_index, end_index)) - current_index = end_index - - return [ - Chunk(text=text, start_index=start, end_index=end, token_count=len(tokens)) - for text, (start, end), tokens in zip(texts, index_pairs, token_lists) - ] - def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]: + def _process_batch(self, texts: List[str]) -> List[List[Chunk]]: """Process a batch of texts.""" # encode the texts into tokens in a batch tokens_list = self._encode_batch(texts) @@ -158,15 +152,20 @@ def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]: # get the token groups token_groups = list(self._token_group_generator(tokens)) - # get the token counts - token_counts = [len(token_group) for token_group in token_groups] + if self.return_type == "chunks": + # get the token counts + token_counts = [len(token_group) for token_group in token_groups] - # decode the token groups into the chunk texts - chunk_texts = self._decode_batch(token_groups) + # decode the token groups into the chunk texts + chunk_texts = self._decode_batch(token_groups) - # create the chunks from the token groups and token counts - chunks = self._create_chunks(chunk_texts, token_groups, token_counts) - result.append(chunks) + # create the chunks from the token groups and token counts + chunks = self._create_chunks(chunk_texts, token_groups, token_counts) + result.append(chunks) + elif self.return_type == "texts": + result.append(self._decode_batch(token_groups)) + else: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") return result @@ -191,13 +190,13 @@ def chunk_batch( for i in trange(0, len(texts), batch_size, - desc="🦛 CHONKING", + desc="🦛", disable=not show_progress_bar, - unit="batches", - bar_format="{desc}: [{bar:20}] {percentage:3.0f}% • {n_fmt}/{total_fmt} batches chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", - ascii=' >='): + unit="batch", + bar_format="{desc} ch{bar:20}nk {percentage:3.0f}% • {n_fmt}/{total_fmt} batches chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱", + ascii=' o'): batch_texts = texts[i : min(i + batch_size, len(texts))] - chunks.extend(self._process_text_batch(batch_texts)) + chunks.extend(self._process_batch(batch_texts)) return chunks def __call__(self, diff --git a/src/chonkie/chunker/word.py b/src/chonkie/chunker/word.py index 656b2f9..e44ad5d 100644 --- a/src/chonkie/chunker/word.py +++ b/src/chonkie/chunker/word.py @@ -1,6 +1,6 @@ """Word-based chunker.""" import re -from typing import Any, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union, Literal from chonkie.types import Chunk @@ -22,31 +22,35 @@ class WordChunker(BaseChunker): def __init__( self, - tokenizer: Union[str, Any] = "gpt2", + tokenizer_or_token_counter: Union[str, Callable, Any] = "gpt2", chunk_size: int = 512, chunk_overlap: int = 128, + return_type: Literal["chunks", "texts"] = "chunks" ): """Initialize the WordChunker with configuration parameters. Args: - tokenizer: The tokenizer instance to use for encoding/decoding + tokenizer_or_token_counter: The tokenizer or token counter to use for encoding/decoding chunk_size: Maximum number of tokens per chunk chunk_overlap: Maximum number of tokens to overlap between chunks - mode: Tokenization mode - "heuristic" (space-based) or "advanced" (handles punctuation) + return_type: Whether to return chunks or texts Raises: - ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size or invalid mode + ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size or invalid return_type """ - super().__init__(tokenizer) + super().__init__(tokenizer_or_token_counter=tokenizer_or_token_counter) if chunk_size <= 0: raise ValueError("chunk_size must be positive") if chunk_overlap >= chunk_size: raise ValueError("chunk_overlap must be less than chunk_size") + if return_type not in ["chunks", "texts"]: + raise ValueError("Invalid return_type. Must be either 'chunks' or 'texts'.") self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap + self.return_type = return_type def _split_into_words(self, text: str) -> List[str]: """Split text into words while preserving whitespace.""" @@ -104,8 +108,7 @@ def _get_word_list_token_counts(self, words: List[str]) -> List[int]: words = [ word for word in words if word != "" ] # Add space in the beginning because tokenizers usually split that differently - encodings = self._encode_batch(words) - return [len(encoding) for encoding in encodings] + return [self._count_tokens(word) for word in words] def chunk(self, text: str) -> List[Chunk]: """Split text into overlapping chunks based on words while respecting token limits. @@ -135,13 +138,18 @@ def chunk(self, text: str) -> List[Chunk]: current_chunk.append(word) current_chunk_length += length else: - chunk = self._create_chunk( - current_chunk, - text, - current_chunk_length, - current_index, - ) - chunks.append(chunk) + if self.return_type == "chunks": + chunk = self._create_chunk( + current_chunk, + text, + current_chunk_length, + current_index, + ) + chunks.append(chunk) + elif self.return_type == "texts": + chunks.append("".join(current_chunk)) + + # update the current_chunk and previous chunk previous_chunk_length = current_chunk_length current_index = chunk.end_index @@ -167,8 +175,11 @@ def chunk(self, text: str) -> List[Chunk]: # Add the final chunk if it has any words if current_chunk: - chunk = self._create_chunk(current_chunk, text, current_chunk_length) - chunks.append(chunk) + if self.return_type == "chunks": + chunk = self._create_chunk(current_chunk, text, current_chunk_length) + chunks.append(chunk) + elif self.return_type == "texts": + chunks.append("".join(current_chunk)) return chunks def __repr__(self) -> str: diff --git a/src/chonkie/embeddings/auto.py b/src/chonkie/embeddings/auto.py index 7863ee5..2abfdb0 100644 --- a/src/chonkie/embeddings/auto.py +++ b/src/chonkie/embeddings/auto.py @@ -1,4 +1,5 @@ -import warnings +"""AutoEmbeddings is a factory class for automatically loading embeddings.""" + from typing import Any, Union from .base import BaseEmbeddings @@ -63,7 +64,7 @@ def get_embeddings( try: return embeddings_cls(model, **kwargs) except Exception as e: - warnings.warn(f"Failed to load {embeddings_cls.__name__}: {e}") + raise ValueError(f"Failed to load {embeddings_cls.__name__}: {e}") except Exception: # Fall back to SentenceTransformerEmbeddings if no matching implementation is found from .sentence_transformer import SentenceTransformerEmbeddings diff --git a/src/chonkie/embeddings/base.py b/src/chonkie/embeddings/base.py index a6b69e6..1d7d911 100644 --- a/src/chonkie/embeddings/base.py +++ b/src/chonkie/embeddings/base.py @@ -1,6 +1,9 @@ +"""Base class for all embeddings implementations.""" from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, List, Union +import numpy as np + # for type checking if TYPE_CHECKING: import numpy as np diff --git a/tests/chunker/test_sdpm_chunker.py b/tests/chunker/test_sdpm_chunker.py index 3b29838..d75eab2 100644 --- a/tests/chunker/test_sdpm_chunker.py +++ b/tests/chunker/test_sdpm_chunker.py @@ -175,5 +175,13 @@ def test_spdm_chunker_token_counts(embedding_model, sample_text): token_counts = [chunker._count_tokens(chunk.text) for chunk in chunks] assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text" +def test_sdpm_chunker_return_type(embedding_model, sample_text): + """Test that SDPMChunker's return type is correctly set.""" + chunker = SDPMChunker(embedding_model=embedding_model, chunk_size=512, threshold=0.5, return_type="texts") + chunks = chunker.chunk(sample_text) + tokenizer = embedding_model.get_tokenizer_or_token_counter() + assert all([type(chunk) is str for chunk in chunks]) + assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks]) + if __name__ == "__main__": - pytest.main() + pytest.main() \ No newline at end of file diff --git a/tests/chunker/test_semantic_chunker.py b/tests/chunker/test_semantic_chunker.py index e9142f1..63c866d 100644 --- a/tests/chunker/test_semantic_chunker.py +++ b/tests/chunker/test_semantic_chunker.py @@ -287,6 +287,13 @@ def test_semantic_chunker_reconstruction_batch(embedding_model, sample_text): chunks = chunker.chunk_batch([sample_text]*10)[-1] assert sample_text == "".join([chunk.text for chunk in chunks]) +def test_semantic_chunker_return_type(embedding_model, sample_text): + """Test that SemanticChunker's return type is correctly set.""" + chunker = SemanticChunker(embedding_model=embedding_model, chunk_size=512, threshold=0.5, return_type="texts") + chunks = chunker.chunk(sample_text) + tokenizer = embedding_model.get_tokenizer_or_token_counter() + assert all([type(chunk) is str for chunk in chunks]) + assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks]) if __name__ == "__main__": - pytest.main() + pytest.main() \ No newline at end of file diff --git a/tests/chunker/test_sentence_chunker.py b/tests/chunker/test_sentence_chunker.py index b58c467..d7c61e2 100644 --- a/tests/chunker/test_sentence_chunker.py +++ b/tests/chunker/test_sentence_chunker.py @@ -45,7 +45,7 @@ def hello_world(): def test_sentence_chunker_initialization(tokenizer): """Test that the SentenceChunker can be initialized with a tokenizer.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) assert chunker is not None assert chunker.tokenizer == tokenizer @@ -56,7 +56,7 @@ def test_sentence_chunker_initialization(tokenizer): def test_sentence_chunker_chunking(tokenizer, sample_text): """Test that the SentenceChunker can chunk a sample text into sentences.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert len(chunks) > 0 @@ -70,7 +70,7 @@ def test_sentence_chunker_chunking(tokenizer, sample_text): def test_sentence_chunker_empty_text(tokenizer): """Test that the SentenceChunker can handle empty text input.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("") assert len(chunks) == 0 @@ -78,7 +78,7 @@ def test_sentence_chunker_empty_text(tokenizer): def test_sentence_chunker_single_sentence(tokenizer): """Test that the SentenceChunker can handle text with a single sentence.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("This is a single sentence.") assert len(chunks) == 1 @@ -87,7 +87,7 @@ def test_sentence_chunker_single_sentence(tokenizer): def test_sentence_chunker_single_chunk_text(tokenizer): """Test that the SentenceChunker can handle text that fits within a single chunk.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("Hello, how are you? I am doing well.") assert len(chunks) == 1 @@ -96,7 +96,7 @@ def test_sentence_chunker_single_chunk_text(tokenizer): def test_sentence_chunker_repr(tokenizer): """Test that the SentenceChunker has a string representation.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) assert ( repr(chunker) @@ -106,7 +106,7 @@ def test_sentence_chunker_repr(tokenizer): def test_sentence_chunker_overlap(tokenizer, sample_text): """Test that the SentenceChunker creates overlapping chunks correctly.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) for i in range(1, len(chunks)): @@ -116,7 +116,7 @@ def test_sentence_chunker_overlap(tokenizer, sample_text): def test_sentence_chunker_min_sentences(tokenizer): """Test that the SentenceChunker respects minimum sentences per chunk.""" chunker = SentenceChunker( - tokenizer=tokenizer, + tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128, min_sentences_per_chunk=2, @@ -150,20 +150,20 @@ def verify_chunk_indices(chunks: List[Chunk], original_text: str): def test_sentence_chunker_indices(tokenizer, sample_text): """Test that the SentenceChunker correctly maps chunk indices to the original text.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) verify_chunk_indices(chunks, sample_text) def test_sentence_chunker_indices_complex_md(tokenizer, sample_complex_markdown_text): """Test that the SentenceChunker correctly maps chunk indices to the original text.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_complex_markdown_text) verify_chunk_indices(chunks, sample_complex_markdown_text) def test_sentence_chunker_token_counts(tokenizer, sample_text): """Test that the SentenceChunker correctly calculates token counts.""" - chunker = SentenceChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert all([chunk.token_count > 0 for chunk in chunks]), "All chunks must have a positive token count" assert all([chunk.token_count <= 512 for chunk in chunks]), "All chunks must have a token count less than or equal to 512" @@ -171,6 +171,12 @@ def test_sentence_chunker_token_counts(tokenizer, sample_text): token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks] assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text" +def test_sentence_chunker_return_type(tokenizer, sample_text): + """Test that SentenceChunker's return type is correctly set.""" + chunker = SentenceChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128, return_type="texts") + chunks = chunker.chunk(sample_text) + assert all([type(chunk) is str for chunk in chunks]) + assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks]) if __name__ == "__main__": pytest.main() diff --git a/tests/chunker/test_token_chunker.py b/tests/chunker/test_token_chunker.py index f0c9b5e..967bda0 100644 --- a/tests/chunker/test_token_chunker.py +++ b/tests/chunker/test_token_chunker.py @@ -337,5 +337,12 @@ def test_token_chunker_indices_batch(tiktokenizer, sample_text): chunks = chunker.chunk_batch([sample_text]*10)[-1] verify_chunk_indices(chunks, sample_text) +def test_token_chunker_return_type(tiktokenizer, sample_text): + """Test that TokenChunker's return type is correctly set.""" + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128, return_type="texts") + chunks = chunker.chunk(sample_text) + assert all([type(chunk) is str for chunk in chunks]) + assert all([len(tiktokenizer.encode(chunk)) <= 512 for chunk in chunks]) + if __name__ == "__main__": pytest.main() diff --git a/tests/chunker/test_word_chunker.py b/tests/chunker/test_word_chunker.py index 15bc836..b0c239a 100644 --- a/tests/chunker/test_word_chunker.py +++ b/tests/chunker/test_word_chunker.py @@ -74,7 +74,7 @@ def hello_world(): def test_word_chunker_initialization(tokenizer): """Test that the WordChunker can be initialized with a tokenizer.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) assert chunker is not None assert chunker.tokenizer == tokenizer @@ -84,7 +84,7 @@ def test_word_chunker_initialization(tokenizer): def test_word_chunker_chunking(tokenizer, sample_text): """Test that the WordChunker can chunk a sample text into words.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert len(chunks) > 0, print(f"Chunks: {chunks}") @@ -98,7 +98,7 @@ def test_word_chunker_chunking(tokenizer, sample_text): def test_word_chunker_empty_text(tokenizer): """Test that the WordChunker can handle empty text input.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("") assert len(chunks) == 0 @@ -106,7 +106,7 @@ def test_word_chunker_empty_text(tokenizer): def test_word_chunker_single_word_text(tokenizer): """Test that the WordChunker can handle text with a single word.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("Hello") assert len(chunks) == 1 @@ -116,7 +116,7 @@ def test_word_chunker_single_word_text(tokenizer): def test_word_chunker_single_chunk_text(tokenizer): """Test that the WordChunker can handle text that fits within a single chunk.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("Hello, how are you?") assert len(chunks) == 1, print(f"Chunks: {chunks}") @@ -133,7 +133,7 @@ def test_word_chunker_batch_chunking(tokenizer, sample_batch): multiprocessing.set_start_method("spawn", force=True) - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk_batch(sample_batch) assert len(chunks) == len(sample_batch) @@ -145,14 +145,14 @@ def test_word_chunker_batch_chunking(tokenizer, sample_batch): def test_word_chunker_repr(tokenizer): """Test that the WordChunker has a string representation.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) assert repr(chunker) == "WordChunker(chunk_size=512, chunk_overlap=128)" def test_word_chunker_call(tokenizer, sample_text): """Test that the WordChunker can be called directly.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker(sample_text) assert len(chunks) > 0 @@ -166,7 +166,7 @@ def test_word_chunker_call(tokenizer, sample_text): def test_word_chunker_overlap(tokenizer, sample_text): """Test that the WordChunker creates overlapping chunks correctly.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) for i in range(1, len(chunks)): @@ -193,7 +193,7 @@ def verify_chunk_indices(chunks: List[Chunk], original_text: str): def test_word_chunker_indices(sample_text): """Test that WordChunker's indices correctly map to original text.""" tokenizer = Tokenizer.from_pretrained("gpt2") - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) verify_chunk_indices(chunks, sample_text) @@ -206,7 +206,7 @@ def test_word_chunker_indices_complex_markdown(sample_complex_markdown_text): def test_word_chunker_token_counts(tokenizer, sample_text): """Test that the WordChunker correctly calculates token counts.""" - chunker = WordChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert all([chunk.token_count > 0 for chunk in chunks]), "All chunks must have a positive token count" assert all([chunk.token_count <= 512 for chunk in chunks]), "All chunks must have a token count less than or equal to 512" @@ -214,5 +214,12 @@ def test_word_chunker_token_counts(tokenizer, sample_text): token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks] assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text" +def test_word_chunker_return_type(tokenizer, sample_text): + """Test that WordChunker's return type is correctly set.""" + chunker = WordChunker(tokenizer_or_token_counter=tokenizer, chunk_size=512, chunk_overlap=128, return_type="texts") + chunks = chunker.chunk(sample_text) + assert all([type(chunk) is str for chunk in chunks]) + assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks]) + if __name__ == "__main__": pytest.main() diff --git a/tests/embeddings/test_openai_embeddings.py b/tests/embeddings/test_openai_embeddings.py index 7ae2010..5898ca8 100644 --- a/tests/embeddings/test_openai_embeddings.py +++ b/tests/embeddings/test_openai_embeddings.py @@ -1,3 +1,4 @@ +"""Test suite for OpenAIEmbeddings.""" import os import numpy as np @@ -8,17 +9,20 @@ @pytest.fixture def embedding_model(): + """Fixture to create an OpenAIEmbeddings instance.""" api_key = os.environ.get("OPENAI_API_KEY") return OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key) @pytest.fixture def sample_text(): + """Fixture to create a sample text for testing.""" return "This is a sample text for testing." @pytest.fixture def sample_texts(): + """Fixture to create a list of sample texts for testing.""" return [ "This is the first sample text.", "Here is another example sentence.", @@ -31,6 +35,7 @@ def sample_texts(): reason="Skipping test because OPENAI_API_KEY is not defined", ) def test_initialization_with_model_name(): + """Test that OpenAIEmbeddings initializes with a model name.""" embeddings = OpenAIEmbeddings(model="text-embedding-3-small") assert embeddings.model == "text-embedding-3-small" assert embeddings.client is not None @@ -41,6 +46,7 @@ def test_initialization_with_model_name(): reason="Skipping test because OPENAI_API_KEY is not defined", ) def test_embed_single_text(embedding_model, sample_text): + """Test that OpenAIEmbeddings correctly embeds a single text.""" embedding = embedding_model.embed(sample_text) assert isinstance(embedding, np.ndarray) assert embedding.shape == (embedding_model.dimension,) @@ -87,6 +93,7 @@ def test_count_tokens_batch_texts(embedding_model, sample_texts): reason="Skipping test because OPENAI_API_KEY is not defined", ) def test_similarity(embedding_model, sample_texts): + """Test that OpenAIEmbeddings correctly calculates similarity between two embeddings.""" embeddings = embedding_model.embed_batch(sample_texts) similarity_score = embedding_model.similarity(embeddings[0], embeddings[1]) assert isinstance(similarity_score, float) @@ -98,11 +105,13 @@ def test_similarity(embedding_model, sample_texts): reason="Skipping test because OPENAI_API_KEY is not defined", ) def test_dimension_property(embedding_model): + """Test that OpenAIEmbeddings correctly calculates the dimension property.""" assert isinstance(embedding_model.dimension, int) assert embedding_model.dimension > 0 def test_is_available(): + """Test that OpenAIEmbeddings correctly checks if it is available.""" assert OpenAIEmbeddings.is_available() is True @@ -111,6 +120,7 @@ def test_is_available(): reason="Skipping test because OPENAI_API_KEY is not defined", ) def test_repr(embedding_model): + """Test that OpenAIEmbeddings correctly returns a string representation.""" repr_str = repr(embedding_model) assert isinstance(repr_str, str) assert repr_str.startswith("OpenAIEmbeddings")