diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index 693eb41..a5ef8bc 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -183,11 +183,11 @@ def _decode(self, tokens) -> str: def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: """Decode a batch of token lists using the backend tokenizer.""" if self._tokenizer_backend == "transformers": - return [self.tokenizer.decode(tokens) for tokens in token_lists] + return self.tokenizer.batch_decode(token_lists, skip_special_tokens=True) elif self._tokenizer_backend == "tokenizers": - return [self.tokenizer.decode(tokens) for tokens in token_lists] + return self.tokenizer.decode_batch(token_lists) elif self._tokenizer_backend == "tiktoken": - return [self.tokenizer.decode(tokens) for tokens in token_lists] + return self.tokenizer.decode_batch(token_lists) elif self._tokenizer_backend == "callable": raise NotImplementedError( "Callable tokenizer backend does not support batch decoding." diff --git a/src/chonkie/chunker/token.py b/src/chonkie/chunker/token.py index defb05c..61bf1a3 100644 --- a/src/chonkie/chunker/token.py +++ b/src/chonkie/chunker/token.py @@ -52,17 +52,27 @@ def __init__( def _create_chunks( self, chunk_texts: List[str], - token_counts: List[int], - decoded_text: str, + token_groups: List[List[int]], + token_counts: List[int] ) -> List[Chunk]: """Create chunks from a list of texts.""" - # package everything as Chunk objects and send out the result + # Find the overlap lengths for index calculation + if self.chunk_overlap > 0: + # we get the overlap texts, that gives you the start_index for the next chunk + # if the token group is smaller than the overlap, we just use the whole token group + overlap_texts = self._decode_batch([token_group[-self.chunk_overlap:] + if (len(token_group) > self.chunk_overlap) + else token_group + for token_group in token_groups]) + overlap_lengths = [len(overlap_text) for overlap_text in overlap_texts] + else: + overlap_lengths = [0] * len(token_groups) + + # Create the chunks chunks = [] current_index = 0 - for chunk_text, token_count in zip(chunk_texts, token_counts): - start_index = decoded_text.find( - chunk_text, current_index - ) # Find needs to be run every single time because of unknown overlap length + for chunk_text, overlap_length, token_count in zip(chunk_texts, overlap_lengths, token_counts): + start_index = current_index end_index = start_index + len(chunk_text) chunks.append( Chunk( @@ -72,7 +82,8 @@ def _create_chunks( token_count=token_count, ) ) - current_index = end_index + current_index = end_index - overlap_length + return chunks def chunk(self, text: str) -> List[Chunk]: @@ -91,40 +102,24 @@ def chunk(self, text: str) -> List[Chunk]: # Encode full text text_tokens = self._encode(text) - # We decode the text because the tokenizer might result in a different output than text - decoded_text = self._decode(text_tokens) - # Calculate chunk positions - token_groups = [ - text_tokens[ - start_index : min(start_index + self.chunk_size, len(text_tokens)) - ] - for start_index in range( - 0, len(text_tokens), self.chunk_size - self.chunk_overlap - ) - ] - token_counts = [ - len(toks) for toks in token_groups - ] # get the token counts; it's prolly chunk_size, but len doesn't take too long + token_groups = [text_tokens[start_index : min(start_index + self.chunk_size, len(text_tokens))] + for start_index in range(0, len(text_tokens), self.chunk_size - self.chunk_overlap)] + token_counts = [len(toks) for toks in token_groups] - chunk_texts = self._decode_batch( - token_groups - ) # decrease the time by decoding in one go (?) + # decode the token groups into the chunk texts + chunk_texts = self._decode_batch(token_groups) - chunks = self._create_chunks(chunk_texts, token_counts, decoded_text) + # Create the chunks from the token groups and token counts + chunks = self._create_chunks(chunk_texts, token_groups, token_counts) return chunks - def _chunk_generator( - self, tokens: List[int] - ) -> Generator[Tuple[List[int], int, int], None, None]: + def _token_group_generator(self, tokens: List[int]) -> Generator[List[int], None, None]: """Generate chunks from a list of tokens.""" - stride = self.chunk_size - self.chunk_overlap - for start in range(0, len(tokens), stride): + for start in range(0, len(tokens), self.chunk_size - self.chunk_overlap): end = min(start + self.chunk_size, len(tokens)) - yield tokens[start:end], start, end - if end == len(tokens): - break + yield tokens[start:end] def _process_batch(self, chunks: List[Tuple[List[int], int, int]], @@ -148,22 +143,28 @@ def _process_batch(self, def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]: """Process a batch of texts.""" + # encode the texts into tokens in a batch tokens_list = self._encode_batch(texts) - decoded_texts = self._decode_batch(tokens_list) result = [] - for tokens, text in zip(tokens_list, decoded_texts): + for tokens in tokens_list: if not tokens: result.append([]) continue - chunks = [] - chunk_batch = [] + # get the token groups + token_groups = [] + for token_group in self._token_group_generator(tokens): + token_groups.append(token_group) + + # get the token counts + token_counts = [len(token_group) for token_group in token_groups] - for chunk_data in self._chunk_generator(tokens): - chunk_batch.append(chunk_data) + # decode the token groups into the chunk texts + chunk_texts = self._decode_batch(token_groups) - chunks.extend(self._process_batch(chunk_batch, text)) + # create the chunks from the token groups and token counts + chunks = self._create_chunks(chunk_texts, token_groups, token_counts) result.append(chunks) return result @@ -181,6 +182,7 @@ def chunk_batch( List of lists of Chunk objects containing the chunked text and metadata """ + # if batch_size is not None, we process the texts in mini-batches to avoid memory issues if batch_size is not None: chunks = [] for i in range(0, len(texts), batch_size): @@ -193,6 +195,7 @@ def chunk_batch( def __repr__(self) -> str: """Return a string representation of the TokenChunker.""" return ( - f"TokenChunker(chunk_size={self.chunk_size}, " + f"TokenChunker(tokenizer={self.tokenizer}, " + f"chunk_size={self.chunk_size}, " f"chunk_overlap={self.chunk_overlap})" ) diff --git a/tests/chunker/test_token_chunker.py b/tests/chunker/test_token_chunker.py index b990bcc..f0c9b5e 100644 --- a/tests/chunker/test_token_chunker.py +++ b/tests/chunker/test_token_chunker.py @@ -152,9 +152,9 @@ def test_token_chunker_initialization_tik(tiktokenizer): assert chunker.chunk_overlap == 128 -def test_token_chunker_chunking(tokenizer, sample_text): +def test_token_chunker_chunking(tiktokenizer, sample_text): """Test that the TokenChunker can chunk a sample text into tokens.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert len(chunks) > 0 @@ -196,9 +196,9 @@ def test_token_chunker_chunking_tik(tiktokenizer, sample_text): assert all([chunk.end_index is not None for chunk in chunks]) -def test_token_chunker_empty_text(tokenizer): +def test_token_chunker_empty_text(tiktokenizer): """Test that the TokenChunker can handle empty text input.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk("") assert len(chunks) == 0 @@ -246,9 +246,9 @@ def test_token_chunker_single_chunk_text(tokenizer): assert chunks[0].text == "Hello, how are you?" -def test_token_chunker_batch_chunking(tokenizer, sample_batch): +def test_token_chunker_batch_chunking(tiktokenizer, sample_batch): """Test that the TokenChunker can chunk a batch of texts into tokens.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk_batch(sample_batch) assert len(chunks) > 0 @@ -267,16 +267,16 @@ def test_token_chunker_batch_chunking(tokenizer, sample_batch): ) -def test_token_chunker_repr(tokenizer): +def test_token_chunker_repr(tiktokenizer): """Test that the TokenChunker has a string representation.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) - assert repr(chunker) == "TokenChunker(chunk_size=512, chunk_overlap=128)" + assert repr(chunker) == "TokenChunker(tokenizer=, chunk_size=512, chunk_overlap=128)" -def test_token_chunker_call(tokenizer, sample_text): +def test_token_chunker_call(tiktokenizer, sample_text): """Test that the TokenChunker can be called directly.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker(sample_text) assert len(chunks) > 0 @@ -305,7 +305,7 @@ def verify_chunk_indices(chunks: List[Chunk], original_text: str): ) -def test_token_chunker_indices(sample_text): +def test_token_chunker_indices(tiktokenizer, sample_text): """Test that TokenChunker's indices correctly map to original text.""" tokenizer = Tokenizer.from_pretrained("gpt2") chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) @@ -321,19 +321,19 @@ def test_token_chunker_indices_complex_md(sample_complex_markdown_text): verify_chunk_indices(chunks, sample_complex_markdown_text) -def test_token_chunker_token_counts(tokenizer, sample_text): +def test_token_chunker_token_counts(tiktokenizer, sample_text): """Test that the TokenChunker correctly calculates token counts.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk(sample_text) assert all([chunk.token_count > 0 for chunk in chunks]), "All chunks must have a positive token count" assert all([chunk.token_count <= 512 for chunk in chunks]), "All chunks must have a token count less than or equal to 512" - token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks] + token_counts = [len(tiktokenizer.encode(chunk.text)) for chunk in chunks] assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text" -def test_token_chunker_indices_batch(tokenizer, sample_text): +def test_token_chunker_indices_batch(tiktokenizer, sample_text): """Test that TokenChunker's indices correctly map to original text.""" - chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunker = TokenChunker(tokenizer=tiktokenizer, chunk_size=512, chunk_overlap=128) chunks = chunker.chunk_batch([sample_text]*10)[-1] verify_chunk_indices(chunks, sample_text)