Skip to content

Commit

Permalink
[fix] Infinite loop issue
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavnicksm committed Dec 26, 2024
1 parent cd253f2 commit 968dc11
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
57 changes: 39 additions & 18 deletions src/chonkie/chunker/recursive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Recursive chunker."""
from bisect import bisect_left
from dataclasses import dataclass
from typing import Any, List, Union, Optional
from functools import lru_cache
from itertools import accumulate
from typing import Any, List, Optional, Union

from chonkie.chunker.base import BaseChunker
from chonkie.types import Chunk
from bisect import bisect_left
from functools import lru_cache

@dataclass
class RecursiveLevel:
"""Configuration for a single level of recursive chunking.
Expand Down Expand Up @@ -42,7 +46,7 @@ def __post_init__(self):
sentence_level = RecursiveLevel(delimiters=[".", "?", "!"],
whitespace=False)
# Third level should be words
word_level = RecursiveLevel(delimiters=[" "],
word_level = RecursiveLevel(delimiters=None,
whitespace=True)
# Fourth level should be tokens
# NOTE: When delimiters is None, the level will use tokens to determine chunk boundaries.
Expand Down Expand Up @@ -93,7 +97,7 @@ def __str__(self) -> str:
f"token_count={self.token_count}, "
f"level={self.level})")

class RecursiveChunker:
class RecursiveChunker(BaseChunker):
"""Chunker that uses recursive rules to chunk text.
Attributes:
Expand All @@ -104,20 +108,23 @@ class RecursiveChunker:

def __init__(self,
tokenizer: Union[str, Any] = "gpt2",
rules: RecursiveRules = RecursiveRules(),
chunk_size: int = 512
chunk_size: int = 512,
rules: RecursiveRules = RecursiveRules(),
min_characters_per_chunk: int = 12
) -> None:
"""Initialize the recursive chunker.
Args:
tokenizer: The tokenizer to use for encoding/decoding.
rules: The rules to use for chunking.
chunk_size: The size of the chunks to return.
rules: The rules to use for chunking.
min_characters_per_chunk: The minimum number of characters per chunk.
"""
super().__init__(tokenizer)
self.rules = rules
self.chunk_size = chunk_size
self.min_characters_per_chunk = min_characters_per_chunk

def _split_text(self,
text: str,
Expand All @@ -131,6 +138,17 @@ def _split_text(self,

# Split the text at the sep
splits = [s for s in text.split(sep) if s != ""]

# Usually a good idea to check if there are any splits that are too short in characters
# and then merge them
merged_splits = []
for split in splits:
if len(split) < self.min_characters_per_chunk:
merged_splits[-1] += split
else:
merged_splits.append(split)
splits = merged_splits

elif rule.whitespace:
splits = self._split_at_whitespace(text)
else:
Expand Down Expand Up @@ -164,7 +182,7 @@ def _merge_splits(self,
"""Merge splits that are too short."""
# If there are no splits or token counts, return an empty list
if not splits or not token_counts:
return []
return [], []

# If the number of splits and token counts does not match, raise an error
if len(splits) != len(token_counts):
Expand All @@ -173,25 +191,28 @@ def _merge_splits(self,
# Usually the splits can be smaller than the chunk size; if not,
# we can just return the splits
if all(tc > self.chunk_size for tc in token_counts):
return splits
return splits, token_counts

# If the splits are too short, merge them
merged = []

if not combine_with_whitespace:
cumulative_token_counts = list(accumulate(token_counts, lambda x, y: x + y))
else:
cumulative_token_counts = list(accumulate(token_counts, lambda x, y: x + y + 1)) # Add 1 for the whitespace
cumulative_token_counts = list(accumulate([0] + token_counts, lambda x, y: x + y))
else:
cumulative_token_counts = list(accumulate([0] + token_counts, lambda x, y: x + y + 1)) # Add 1 for the whitespace

current_index = 0
merged_token_counts = []

# Use bisect_left to find the index to merge at
while current_index < len(splits):
current_token_count = cumulative_token_counts[current_index]
required_token_count = current_token_count + self.chunk_size

# print(current_index, current_token_count, required_token_count)

# Find the index to merge at
index = bisect_left(cumulative_token_counts, required_token_count, lo=current_index)
index = min(bisect_left(cumulative_token_counts, required_token_count, lo=current_index) - 1, len(splits))
# print(f"index: {index}\n")

# Merge the splits at the index
if combine_with_whitespace:
Expand All @@ -200,7 +221,8 @@ def _merge_splits(self,
merged.append("".join(splits[current_index:index]))

# Add the token count of the merged split
merged_token_counts.append(cumulative_token_counts[index] - current_token_count)
merged_token_counts.append(cumulative_token_counts[min(index, len(splits))] - current_token_count)
# print(f"merged_token_counts: {merged_token_counts}\n")

# Update the current index
current_index = index
Expand Down Expand Up @@ -291,4 +313,3 @@ def _recursive_chunk(self,
def chunk(self, text: str) -> List[Chunk]:
"""Chunk the text."""
return self._recursive_chunk(text)

13 changes: 7 additions & 6 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Sentence chunker."""
from bisect import bisect_left
from itertools import accumulate
from typing import Any, List, Union
Expand Down Expand Up @@ -200,16 +201,16 @@ def _get_token_counts(self, sentences: List[str]) -> List[int]:
encoded_sentences = self._encode_batch(sentences)
return [len(encoded) for encoded in encoded_sentences]

def _estimate_token_counts(self, text: str) -> int:
def _estimate_token_counts(self, sentences: List[str]) -> int:
"""Estimate token count using character length."""
CHARS_PER_TOKEN = 6.0 # Avg. char per token for llama3 is b/w 6-7
if type(text) is str:
return max(1, len(text) // CHARS_PER_TOKEN)
elif type(text) is list and type(text[0]) is str:
return [max(1, len(t) // CHARS_PER_TOKEN) for t in text]
if type(sentences) is str:
return max(1, len(sentences) // CHARS_PER_TOKEN)
elif type(sentences) is list and type(sentences[0]) is str:
return [max(1, len(t) // CHARS_PER_TOKEN) for t in sentences]
else:
raise ValueError(
f"Unknown type passed to _estimate_token_count: {type(text)}"
f"Unknown type passed to _estimate_token_count: {type(sentences)}"
)

def _get_feedback(self, estimate: int, actual: int) -> float:
Expand Down

0 comments on commit 968dc11

Please sign in to comment.