Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed similarity_percentile with sdpm chunker + added test #65

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
"Cannot specify both similarity_threshold and similarity_percentile"
)
if similarity_threshold is None and similarity_percentile is None:
similarity_percentile = 0.8
similarity_percentile = 80
raise Warning(
"No similarity threshold specified. Defaulting to 80th percentile."
) #TODO: Change this to be a non-blocking warning
Expand Down Expand Up @@ -211,6 +211,18 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:
)
]

# Get or compute similarity threshold
if self.similarity_threshold is None:
# Compute all pairwise similarities
all_similarities = [
self._get_semantic_similarity(
sentences[i].embedding, sentences[i + 1].embedding
)
for i in range(len(sentences) - 1)
]
self.similarity_threshold = float(
np.percentile(all_similarities, self.similarity_percentile)
)
return sentences

def _get_semantic_similarity(
Expand Down Expand Up @@ -240,21 +252,6 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
if len(sentences) <= self.initial_sentences:
return [sentences]

# Get or compute similarity threshold
if self.similarity_percentile is not None:
# Compute all pairwise similarities
all_similarities = [
self._get_semantic_similarity(
sentences[i].embedding, sentences[i + 1].embedding
)
for i in range(len(sentences) - 1)
]
similarity_threshold = float(
np.percentile(all_similarities, self.similarity_percentile)
)
else:
similarity_threshold = self.similarity_threshold

groups = []
current_group = sentences[: self.initial_sentences]
current_embedding = self._compute_group_embedding(current_group)
Expand All @@ -265,7 +262,7 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
current_embedding, sentence.embedding
)

if similarity >= similarity_threshold:
if similarity >= self.similarity_threshold:
# Add to current group
current_group.append(sentence)
# Update mean embedding
Expand Down
22 changes: 21 additions & 1 deletion tests/chunker/test_sdpm_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ def sample_text():
def embedding_model():
return SentenceTransformerEmbeddings("all-MiniLM-L6-v2")


@pytest.fixture
def sample_complex_markdown_text():
text = """# Heading 1
text = """# Heading 1
This is a paragraph with some **bold text** and _italic text_.
## Heading 2
- Bullet point 1
Expand Down Expand Up @@ -70,6 +71,25 @@ def test_spdm_chunker_chunking(embedding_model, sample_text):
assert all([chunk.sentences is not None for chunk in chunks])


def test_spdm_chunker_percentile_mode(embedding_model, sample_complex_markdown_text):
"""Test the SPDMChunker works with percentile-based similarity."""
chunker = SDPMChunker(
embedding_model=embedding_model,
chunk_size=512,
similarity_percentile=50,
)
chunks = chunker.chunk(sample_complex_markdown_text)

assert len(chunks) > 0
assert isinstance(chunks[0], SemanticChunk)
assert all([chunk.token_count <= 512 for chunk in chunks])
assert all([chunk.token_count > 0 for chunk in chunks])
assert all([chunk.text is not None for chunk in chunks])
assert all([chunk.start_index is not None for chunk in chunks])
assert all([chunk.end_index is not None for chunk in chunks])
assert all([chunk.sentences is not None for chunk in chunks])


def test_spdm_chunker_empty_text(embedding_model):
"""Test that the SPDMChunker can handle empty text input."""
chunker = SDPMChunker(
Expand Down