Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] start_index incorrect when chunk_overlap is not 0 (#116) #132

Merged
merged 11 commits into from
Jan 4, 2025
6 changes: 3 additions & 3 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def _decode(self, tokens) -> str:
def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
"""Decode a batch of token lists using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.batch_decode(token_lists, skip_special_tokens=True)
elif self._tokenizer_backend == "tokenizers":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.decode_batch(token_lists)
elif self._tokenizer_backend == "tiktoken":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
return self.tokenizer.decode_batch(token_lists)
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support batch decoding."
Expand Down
87 changes: 45 additions & 42 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,27 @@ def __init__(
def _create_chunks(
self,
chunk_texts: List[str],
token_counts: List[int],
decoded_text: str,
token_groups: List[List[int]],
token_counts: List[int]
) -> List[Chunk]:
"""Create chunks from a list of texts."""
# package everything as Chunk objects and send out the result
# Find the overlap lengths for index calculation
if self.chunk_overlap > 0:
# we get the overlap texts, that gives you the start_index for the next chunk
# if the token group is smaller than the overlap, we just use the whole token group
overlap_texts = self._decode_batch([token_group[-self.chunk_overlap:]
if (len(token_group) > self.chunk_overlap)
else token_group
for token_group in token_groups])
overlap_lengths = [len(overlap_text) for overlap_text in overlap_texts]
else:
overlap_lengths = [0] * len(token_groups)

# Create the chunks
chunks = []
current_index = 0
for chunk_text, token_count in zip(chunk_texts, token_counts):
start_index = decoded_text.find(
chunk_text, current_index
) # Find needs to be run every single time because of unknown overlap length
for chunk_text, overlap_length, token_count in zip(chunk_texts, overlap_lengths, token_counts):
start_index = current_index
end_index = start_index + len(chunk_text)
chunks.append(
Chunk(
Expand All @@ -72,7 +82,8 @@ def _create_chunks(
token_count=token_count,
)
)
current_index = end_index
current_index = end_index - overlap_length

return chunks

def chunk(self, text: str) -> List[Chunk]:
Expand All @@ -91,40 +102,24 @@ def chunk(self, text: str) -> List[Chunk]:
# Encode full text
text_tokens = self._encode(text)

# We decode the text because the tokenizer might result in a different output than text
decoded_text = self._decode(text_tokens)

# Calculate chunk positions
token_groups = [
text_tokens[
start_index : min(start_index + self.chunk_size, len(text_tokens))
]
for start_index in range(
0, len(text_tokens), self.chunk_size - self.chunk_overlap
)
]
token_counts = [
len(toks) for toks in token_groups
] # get the token counts; it's prolly chunk_size, but len doesn't take too long
token_groups = [text_tokens[start_index : min(start_index + self.chunk_size, len(text_tokens))]
for start_index in range(0, len(text_tokens), self.chunk_size - self.chunk_overlap)]
token_counts = [len(toks) for toks in token_groups]

chunk_texts = self._decode_batch(
token_groups
) # decrease the time by decoding in one go (?)
# decode the token groups into the chunk texts
chunk_texts = self._decode_batch(token_groups)

chunks = self._create_chunks(chunk_texts, token_counts, decoded_text)
# Create the chunks from the token groups and token counts
chunks = self._create_chunks(chunk_texts, token_groups, token_counts)

return chunks

def _chunk_generator(
self, tokens: List[int]
) -> Generator[Tuple[List[int], int, int], None, None]:
def _token_group_generator(self, tokens: List[int]) -> Generator[List[int], None, None]:
"""Generate chunks from a list of tokens."""
stride = self.chunk_size - self.chunk_overlap
for start in range(0, len(tokens), stride):
for start in range(0, len(tokens), self.chunk_size - self.chunk_overlap):
end = min(start + self.chunk_size, len(tokens))
yield tokens[start:end], start, end
if end == len(tokens):
break
yield tokens[start:end]

def _process_batch(self,
chunks: List[Tuple[List[int], int, int]],
Expand All @@ -148,22 +143,28 @@ def _process_batch(self,

def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]:
"""Process a batch of texts."""
# encode the texts into tokens in a batch
tokens_list = self._encode_batch(texts)
decoded_texts = self._decode_batch(tokens_list)
result = []

for tokens, text in zip(tokens_list, decoded_texts):
for tokens in tokens_list:
if not tokens:
result.append([])
continue

chunks = []
chunk_batch = []
# get the token groups
token_groups = []
for token_group in self._token_group_generator(tokens):
token_groups.append(token_group)

# get the token counts
token_counts = [len(token_group) for token_group in token_groups]

for chunk_data in self._chunk_generator(tokens):
chunk_batch.append(chunk_data)
# decode the token groups into the chunk texts
chunk_texts = self._decode_batch(token_groups)

chunks.extend(self._process_batch(chunk_batch, text))
# create the chunks from the token groups and token counts
chunks = self._create_chunks(chunk_texts, token_groups, token_counts)
result.append(chunks)

return result
Expand All @@ -181,6 +182,7 @@ def chunk_batch(
List of lists of Chunk objects containing the chunked text and metadata

"""
# if batch_size is not None, we process the texts in mini-batches to avoid memory issues
if batch_size is not None:
chunks = []
for i in range(0, len(texts), batch_size):
Expand All @@ -193,6 +195,7 @@ def chunk_batch(
def __repr__(self) -> str:
"""Return a string representation of the TokenChunker."""
return (
f"TokenChunker(chunk_size={self.chunk_size}, "
f"TokenChunker(tokenizer={self.tokenizer}, "
f"chunk_size={self.chunk_size}, "
f"chunk_overlap={self.chunk_overlap})"
)
34 changes: 17 additions & 17 deletions tests/chunker/test_token_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def test_token_chunker_initialization_tik(tiktokenizer):
assert chunker.chunk_overlap == 128


def test_token_chunker_chunking(tokenizer, sample_text):
def test_token_chunker_chunking(tiktokenizer, sample_text):
"""Test that the TokenChunker can chunk a sample text into tokens."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)

assert len(chunks) > 0
Expand Down Expand Up @@ -196,9 +196,9 @@ def test_token_chunker_chunking_tik(tiktokenizer, sample_text):
assert all([chunk.end_index is not None for chunk in chunks])


def test_token_chunker_empty_text(tokenizer):
def test_token_chunker_empty_text(tiktokenizer):
"""Test that the TokenChunker can handle empty text input."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk("")

assert len(chunks) == 0
Expand Down Expand Up @@ -246,9 +246,9 @@ def test_token_chunker_single_chunk_text(tokenizer):
assert chunks[0].text == "Hello, how are you?"


def test_token_chunker_batch_chunking(tokenizer, sample_batch):
def test_token_chunker_batch_chunking(tiktokenizer, sample_batch):
"""Test that the TokenChunker can chunk a batch of texts into tokens."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk_batch(sample_batch)

assert len(chunks) > 0
Expand All @@ -267,16 +267,16 @@ def test_token_chunker_batch_chunking(tokenizer, sample_batch):
)


def test_token_chunker_repr(tokenizer):
def test_token_chunker_repr(tiktokenizer):
"""Test that the TokenChunker has a string representation."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)

assert repr(chunker) == "TokenChunker(chunk_size=512, chunk_overlap=128)"
assert repr(chunker) == "TokenChunker(tokenizer=<Encoding 'gpt2'>, chunk_size=512, chunk_overlap=128)"


def test_token_chunker_call(tokenizer, sample_text):
def test_token_chunker_call(tiktokenizer, sample_text):
"""Test that the TokenChunker can be called directly."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker(sample_text)

assert len(chunks) > 0
Expand Down Expand Up @@ -305,7 +305,7 @@ def verify_chunk_indices(chunks: List[Chunk], original_text: str):
)


def test_token_chunker_indices(sample_text):
def test_token_chunker_indices(tiktokenizer, sample_text):
"""Test that TokenChunker's indices correctly map to original text."""
tokenizer = Tokenizer.from_pretrained("gpt2")
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
Expand All @@ -321,19 +321,19 @@ def test_token_chunker_indices_complex_md(sample_complex_markdown_text):
verify_chunk_indices(chunks, sample_complex_markdown_text)


def test_token_chunker_token_counts(tokenizer, sample_text):
def test_token_chunker_token_counts(tiktokenizer, sample_text):
"""Test that the TokenChunker correctly calculates token counts."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk(sample_text)
assert all([chunk.token_count > 0 for chunk in chunks]), "All chunks must have a positive token count"
assert all([chunk.token_count <= 512 for chunk in chunks]), "All chunks must have a token count less than or equal to 512"

token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks]
token_counts = [len(tiktokenizer.encode(chunk.text)) for chunk in chunks]
assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text"

def test_token_chunker_indices_batch(tokenizer, sample_text):
def test_token_chunker_indices_batch(tiktokenizer, sample_text):
"""Test that TokenChunker's indices correctly map to original text."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk_batch([sample_text]*10)[-1]
verify_chunk_indices(chunks, sample_text)

Expand Down
Loading