diff --git a/src/chonkie/chunker/semantic.py b/src/chonkie/chunker/semantic.py index 1e1ccf9..64cb35e 100644 --- a/src/chonkie/chunker/semantic.py +++ b/src/chonkie/chunker/semantic.py @@ -94,7 +94,7 @@ def __init__( "Cannot specify both similarity_threshold and similarity_percentile" ) if similarity_threshold is None and similarity_percentile is None: - similarity_percentile = 0.8 + similarity_percentile = 80 raise Warning( "No similarity threshold specified. Defaulting to 80th percentile." ) #TODO: Change this to be a non-blocking warning @@ -211,6 +211,18 @@ def _prepare_sentences(self, text: str) -> List[Sentence]: ) ] + # Get or compute similarity threshold + if self.similarity_threshold is None: + # Compute all pairwise similarities + all_similarities = [ + self._get_semantic_similarity( + sentences[i].embedding, sentences[i + 1].embedding + ) + for i in range(len(sentences) - 1) + ] + self.similarity_threshold = float( + np.percentile(all_similarities, self.similarity_percentile) + ) return sentences def _get_semantic_similarity( @@ -240,21 +252,6 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]: if len(sentences) <= self.initial_sentences: return [sentences] - # Get or compute similarity threshold - if self.similarity_percentile is not None: - # Compute all pairwise similarities - all_similarities = [ - self._get_semantic_similarity( - sentences[i].embedding, sentences[i + 1].embedding - ) - for i in range(len(sentences) - 1) - ] - similarity_threshold = float( - np.percentile(all_similarities, self.similarity_percentile) - ) - else: - similarity_threshold = self.similarity_threshold - groups = [] current_group = sentences[: self.initial_sentences] current_embedding = self._compute_group_embedding(current_group) @@ -265,7 +262,7 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]: current_embedding, sentence.embedding ) - if similarity >= similarity_threshold: + if similarity >= self.similarity_threshold: # Add to current group current_group.append(sentence) # Update mean embedding diff --git a/tests/chunker/test_sdpm_chunker.py b/tests/chunker/test_sdpm_chunker.py index f796f50..4675b17 100644 --- a/tests/chunker/test_sdpm_chunker.py +++ b/tests/chunker/test_sdpm_chunker.py @@ -14,9 +14,10 @@ def sample_text(): def embedding_model(): return SentenceTransformerEmbeddings("all-MiniLM-L6-v2") + @pytest.fixture def sample_complex_markdown_text(): - text = """# Heading 1 + text = """# Heading 1 This is a paragraph with some **bold text** and _italic text_. ## Heading 2 - Bullet point 1 @@ -70,6 +71,25 @@ def test_spdm_chunker_chunking(embedding_model, sample_text): assert all([chunk.sentences is not None for chunk in chunks]) +def test_spdm_chunker_percentile_mode(embedding_model, sample_complex_markdown_text): + """Test the SPDMChunker works with percentile-based similarity.""" + chunker = SDPMChunker( + embedding_model=embedding_model, + chunk_size=512, + similarity_percentile=50, + ) + chunks = chunker.chunk(sample_complex_markdown_text) + + assert len(chunks) > 0 + assert isinstance(chunks[0], SemanticChunk) + assert all([chunk.token_count <= 512 for chunk in chunks]) + assert all([chunk.token_count > 0 for chunk in chunks]) + assert all([chunk.text is not None for chunk in chunks]) + assert all([chunk.start_index is not None for chunk in chunks]) + assert all([chunk.end_index is not None for chunk in chunks]) + assert all([chunk.sentences is not None for chunk in chunks]) + + def test_spdm_chunker_empty_text(embedding_model): """Test that the SPDMChunker can handle empty text input.""" chunker = SDPMChunker(