Skip to content

Commit

Permalink
[chore] Use ruff for formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavnicksm committed Dec 6, 2024
1 parent 51bee2d commit 2e4bf86
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 57 deletions.
17 changes: 8 additions & 9 deletions src/chonkie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
"""Main package for Chonkie."""

from .types import (
Context,
SemanticSentence,
Sentence,
Chunk,
SentenceChunk,
SemanticChunk,
)

from .chunker import (
BaseChunker,
SDPMChunker,
Expand All @@ -28,6 +19,14 @@
BaseRefinery,
OverlapRefinery,
)
from .types import (
Chunk,
Context,
SemanticChunk,
SemanticSentence,
Sentence,
SentenceChunk,
)

__version__ = "0.2.1.post1"
__name__ = "chonkie"
Expand Down
6 changes: 3 additions & 3 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from multiprocessing import Pool, cpu_count
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Union

from chonkie.types import Chunk

from chonkie.types import Chunk, Context

class BaseChunker(ABC):
"""Abstract base class for all chunker implementations.
Expand Down
8 changes: 5 additions & 3 deletions src/chonkie/chunker/sdpm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Semantic Double Pass Merge chunking using sentence embeddings."""

from typing import Any, List, Union

from .semantic import SemanticChunker
from chonkie.types import SemanticChunk, Sentence

from .semantic import SemanticChunker


class SDPMChunker(SemanticChunker):
"""Chunker implementation using the Semantic Document Partitioning Method (SDPM).
Expand All @@ -24,7 +26,7 @@ class SDPMChunker(SemanticChunker):
Methods:
chunk: Split text into chunks using the SDPM approach.
"""

def __init__(
Expand Down Expand Up @@ -134,7 +136,7 @@ def chunk(self, text: str) -> List[SemanticChunk]:
sentences = self._prepare_sentences(text)
if len(sentences) <= self.min_sentences:
return [self._create_chunk(sentences)]

# Calculate similarity threshold
self.similarity_threshold = self._calculate_similarity_threshold(sentences)

Expand Down
78 changes: 52 additions & 26 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Semantic chunking using sentence embeddings."""

import warnings
from typing import List, Union

import numpy as np

from chonkie.chunker.base import BaseChunker
from chonkie.embeddings.base import BaseEmbeddings
from chonkie.types import Sentence, SemanticSentence, SemanticChunk
from chonkie.types import SemanticChunk, SemanticSentence, Sentence


class SemanticChunker(BaseChunker):
Expand Down Expand Up @@ -77,10 +78,10 @@ def __init__(
raise ValueError("threshold (float) must be between 0 and 1")
elif type(threshold) == int and (threshold < 1 or threshold > 100):
raise ValueError("threshold (int) must be between 1 and 100")

self.mode = mode
self.chunk_size = chunk_size
self.threshold = threshold
self.threshold = threshold
self.similarity_window = similarity_window
self.min_sentences = min_sentences
self.min_chunk_size = min_chunk_size
Expand Down Expand Up @@ -197,7 +198,7 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:

# Batch compute embeddings for all sentences
# The embeddings are computed assuming a similarity window is applied
# There should be len(raw_sentences) number of similarity groups
# There should be len(raw_sentences) number of similarity groups
sentence_groups = []
for i in range(len(raw_sentences)):
group = []
Expand All @@ -206,8 +207,9 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:
if j >= 0 and j < len(raw_sentences):
group.append(raw_sentences[j])
sentence_groups.append("".join(group))
assert len(sentence_groups) == len(raw_sentences),\
(f"Number of sentence groups ({len(sentence_groups)}) does not match number of raw sentences ({len(raw_sentences)})")
assert (
len(sentence_groups) == len(raw_sentences)
), f"Number of sentence groups ({len(sentence_groups)}) does not match number of raw sentences ({len(raw_sentences)})"
embeddings = self.embedding_model.embed_batch(sentence_groups)

# Batch compute token counts
Expand Down Expand Up @@ -243,41 +245,55 @@ def _compute_group_embedding(self, sentences: List[Sentence]) -> np.ndarray:
np.sum([sent.token_count for sent in sentences]),
dtype=np.float32,
)

def _compute_pairwise_similarities(self, sentences: List[Sentence]) -> List[float]:
"""Compute all pairwise similarities between sentences."""
return [
self._get_semantic_similarity(sentences[i].embedding, sentences[i + 1].embedding)
self._get_semantic_similarity(
sentences[i].embedding, sentences[i + 1].embedding
)
for i in range(len(sentences) - 1)
]

def _get_split_indices(self, similarities: List[float], threshold: float = None) -> List[int]:

def _get_split_indices(
self, similarities: List[float], threshold: float = None
) -> List[int]:
"""Get indices of sentences to split at."""
if threshold is None:
threshold = self.similarity_threshold if self.similarity_threshold is not None else 0.5
threshold = (
self.similarity_threshold
if self.similarity_threshold is not None
else 0.5
)

# get the indices of the sentences that are below the threshold
splits = [i+1 for i, s in enumerate(similarities) if s <= threshold and i+1 < len(similarities)]
splits = [
i + 1
for i, s in enumerate(similarities)
if s <= threshold and i + 1 < len(similarities)
]
# add the start and end of the text
splits = [0] + splits + [len(similarities)]
# check if the splits are valid (i.e. there are enough sentences between them)
i = 0
while i < len(splits) - 1:
if splits[i+1] - splits[i] < self.min_sentences:
splits.pop(i+1)
if splits[i + 1] - splits[i] < self.min_sentences:
splits.pop(i + 1)
else:
i += 1
return splits

def _calculate_threshold_via_binary_search(self, sentences: List[Sentence]) -> float:
def _calculate_threshold_via_binary_search(
self, sentences: List[Sentence]
) -> float:
"""Calculate similarity threshold via binary search."""
# Get the token counts and cumulative token counts
token_counts = [sent.token_count for sent in sentences]
cumulative_token_counts = np.cumsum(token_counts)

# Compute all pairwise similarities
similarities = self._compute_pairwise_similarities(sentences)

# get the median and the std for the similarities
median = np.median(similarities)
std = np.std(similarities)
Expand All @@ -299,7 +315,7 @@ def _calculate_threshold_via_binary_search(self, sentences: List[Sentence]) -> f
# median_split_token_count = np.median(split_token_counts)
# Check if the split respects the chunk size
# if self.min_chunk_size * 1.1 <= median_split_token_count <= 0.95 * self.chunk_size:
# break
# break
# elif median_split_token_count > 0.95 * self.chunk_size:
# The code is calculating the median of a list of token counts stored in the variable
# `split_token_counts` using the `np.median()` function from the NumPy library in Python.
Expand All @@ -316,16 +332,19 @@ def _calculate_threshold_via_binary_search(self, sentences: List[Sentence]) -> f
# check if any of the split token counts are less than the min chunk size
else:
high = threshold - self.threshold_step

iterations += 1
if iterations > 10:
warnings.warn("Too many iterations in threshold calculation, stopping...", stacklevel=2)
warnings.warn(
"Too many iterations in threshold calculation, stopping...",
stacklevel=2,
)
break

return threshold

def _calculate_threshold_via_percentile(self, sentences: List[Sentence]) -> float:
"""Calculate similarity threshold via percentile."""
"""Calculate similarity threshold via percentile."""
# Compute all pairwise similarities, since the embeddings are already computed
# The embeddings are computed assuming a similarity window is applied
all_similarities = self._compute_pairwise_similarities(sentences)
Expand All @@ -340,7 +359,9 @@ def _calculate_similarity_threshold(self, sentences: List[Sentence]) -> float:
else:
return self._calculate_threshold_via_binary_search(sentences)

def _group_sentences_cumulative(self, sentences: List[Sentence]) -> List[List[Sentence]]:
def _group_sentences_cumulative(
self, sentences: List[Sentence]
) -> List[List[Sentence]]:
"""Group sentences based on semantic similarity, ignoring token count limits.
Args:
Expand Down Expand Up @@ -377,12 +398,17 @@ def _group_sentences_cumulative(self, sentences: List[Sentence]) -> List[List[Se
groups.append(current_group)

return groups

def _group_sentences_window(self, sentences: List[Sentence]) -> List[List[Sentence]]:

def _group_sentences_window(
self, sentences: List[Sentence]
) -> List[List[Sentence]]:
"""Group sentences based on semantic similarity, respecting the similarity window."""
similarities = self._compute_pairwise_similarities(sentences)
split_indices = self._get_split_indices(similarities, self.similarity_threshold)
groups = [sentences[split_indices[i]:split_indices[i+1]] for i in range(len(split_indices) - 1)]
groups = [
sentences[split_indices[i] : split_indices[i + 1]]
for i in range(len(split_indices) - 1)
]
return groups

def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
Expand Down Expand Up @@ -477,7 +503,7 @@ def chunk(self, text: str) -> List[SemanticChunk]:
sentences = self._prepare_sentences(text)
if len(sentences) <= self.min_sentences:
return [self._create_chunk(sentences)]

# Calculate similarity threshold
self.similarity_threshold = self._calculate_similarity_threshold(sentences)

Expand Down
5 changes: 3 additions & 2 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from bisect import bisect_left
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, List, Union

from .base import BaseChunker
from chonkie.types import Chunk, Sentence, SentenceChunk

from .base import BaseChunker


class SentenceChunker(BaseChunker):
"""SentenceChunker splits the sentences in a text based on token limits and sentence boundaries.
Expand Down
6 changes: 4 additions & 2 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Token-based chunking."""

from typing import Any, Generator, List, Tuple, Union

from .base import BaseChunker
from chonkie.types import Chunk

from .base import BaseChunker


class TokenChunker(BaseChunker):
"""Chunker that splits text into chunks of a specified token size.
Args:
tokenizer: The tokenizer instance to use for encoding/decoding
chunk_size: Maximum number of tokens per chunk
Expand Down
3 changes: 2 additions & 1 deletion src/chonkie/chunker/word.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from typing import Any, List, Tuple, Union

from .base import BaseChunker
from chonkie.types import Chunk

from .base import BaseChunker


class WordChunker(BaseChunker):
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion src/chonkie/refinery/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from typing import Any, List, Optional

from chonkie.types import Chunk, Context, SentenceChunk, SemanticChunk
from chonkie.refinery.base import BaseRefinery
from chonkie.types import Chunk, Context, SemanticChunk, SentenceChunk


class OverlapRefinery(BaseRefinery):
"""Refinery class which adds overlap as context to chunks.
Expand Down
8 changes: 4 additions & 4 deletions src/chonkie/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Dataclasses for Chonkie."""

from dataclasses import dataclass, field
from typing import List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
import numpy as np


@dataclass
class Context:
"""A dataclass representing contextual information for chunk refinement.
Expand Down Expand Up @@ -124,7 +126,7 @@ def copy(self) -> "Chunk":
end_index=self.end_index,
token_count=self.token_count,
)


@dataclass
class Sentence:
Expand Down Expand Up @@ -165,7 +167,6 @@ class SentenceChunk(Chunk):
sentences: List[Sentence] = field(default_factory=list)



@dataclass
class SemanticSentence(Sentence):
"""Dataclass representing a semantic sentence with metadata.
Expand Down Expand Up @@ -200,4 +201,3 @@ class SemanticChunk(SentenceChunk):
"""

sentences: List[SemanticSentence] = field(default_factory=list)

1 change: 1 addition & 0 deletions tests/chunker/test_sdpm_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from chonkie.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from chonkie.types import SemanticChunk


@pytest.fixture
def sample_text():
"""Sample text for testing the SDPMChunker."""
Expand Down
1 change: 1 addition & 0 deletions tests/chunker/test_semantic_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chonkie.embeddings import Model2VecEmbeddings, OpenAIEmbeddings
from chonkie.types import Chunk, SemanticChunk


@pytest.fixture
def sample_text():
"""Sample text for testing the SemanticChunker.
Expand Down
2 changes: 1 addition & 1 deletion tests/embeddings/test_auto_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the AutoEmbeddings class."""

import pytest

from chonkie import AutoEmbeddings
Expand All @@ -25,7 +26,6 @@ def sentence_transformer_identifier_small():
return "all-minilm-l6-v2"



@pytest.fixture
def openai_identifier():
"""Fixture providing an OpenAI identifier."""
Expand Down
Loading

0 comments on commit 2e4bf86

Please sign in to comment.