Skip to content

Commit

Permalink
Merge pull request #77 from bhavnicksm/refinery
Browse files Browse the repository at this point in the history
[FEAT] Add BaseRefinery and OverlapRefinery support
  • Loading branch information
bhavnicksm authored Dec 4, 2024
2 parents aa82668 + aa1fe0a commit 71a9d5d
Show file tree
Hide file tree
Showing 12 changed files with 769 additions and 69 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ pythonpath = "src"
package-dir = {"" = "src"}
packages = ["chonkie",
"chonkie.chunker",
"chonkie.embeddings"]
"chonkie.embeddings",
"chonkie.refinery"]

[tool.ruff]
select = ["F", "I", "D", "DOC"]
16 changes: 16 additions & 0 deletions src/chonkie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Main package for Chonkie."""

from .context import Context

from .chunker import (
BaseChunker,
Chunk,
Expand All @@ -19,6 +23,11 @@
SentenceTransformerEmbeddings,
)

from .refinery import (
BaseRefinery,
OverlapRefinery,
)

__version__ = "0.2.1.post1"
__name__ = "chonkie"
__author__ = "Bhavnick Minhas"
Expand All @@ -32,6 +41,7 @@

# Add all data classes to __all__
__all__ += [
"Context",
"Chunk",
"SentenceChunk",
"SemanticChunk",
Expand All @@ -57,3 +67,9 @@
"OpenAIEmbeddings",
"AutoEmbeddings",
]

# Add all refinery classes to __all__
__all__ += [
"BaseRefinery",
"OverlapRefinery",
]
76 changes: 68 additions & 8 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Base classes for chunking text."""

import importlib

import inspect
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from multiprocessing import Pool, cpu_count
from typing import Any, Callable, List, Union
from typing import Any, Callable, List, Optional, Union

import inspect
from chonkie.context import Context


@dataclass
Expand All @@ -19,14 +23,55 @@ class Chunk:
start_index: The starting index of the chunk in the original text
end_index: The ending index of the chunk in the original text
token_count: The number of tokens in the chunk
context: The context of the chunk, useful for refinery classes
"""

text: str
start_index: int
end_index: int
token_count: int
__slots__ = ["text", "start_index", "end_index", "token_count"]
context: Optional[Context] = None

def __str__(self) -> str:
"""Return string representation of the chunk."""
return self.text

def __len__(self) -> int:
"""Return the length of the chunk."""
return len(self.text)

def __repr__(self) -> str:
"""Return string representation of the chunk."""
if self.context is not None:
return (
f"Chunk(text={self.text}, start_index={self.start_index}, "
f"end_index={self.end_index}, token_count={self.token_count})"
)
else:
return (
f"Chunk(text={self.text}, start_index={self.start_index}, "
f"end_index={self.end_index}, token_count={self.token_count}, "
f"context={self.context})"
)

def __iter__(self):
"""Return an iterator over the chunk."""
return iter(self.text)

def __getitem__(self, index: int):
"""Return the item at the given index."""
return self.text[index]

def copy(self) -> "Chunk":
"""Return a deep copy of the chunk."""
return Chunk(
text=self.text,
start_index=self.start_index,
end_index=self.end_index,
token_count=self.token_count,
)



class BaseChunker(ABC):
Expand Down Expand Up @@ -67,7 +112,10 @@ def _get_tokenizer_backend(self):
elif "tiktoken" in str(type(self.tokenizer)):
return "tiktoken"
else:
raise ValueError(f"Tokenizer backend {str(type(self.tokenizer))} not supported")
raise ValueError(
f"Tokenizer backend {str(type(self.tokenizer))} not supported"
)


def _load_tokenizer(self, tokenizer_name: str):
"""Load a tokenizer based on the backend."""
Expand Down Expand Up @@ -134,7 +182,10 @@ def _encode(self, text: str):
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode(text)
else:
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _encode_batch(self, texts: List[str]):
"""Encode a batch of texts using the backend tokenizer."""
Expand All @@ -150,7 +201,10 @@ def _encode_batch(self, texts: List[str]):
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode_batch(texts)
else:
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _decode(self, tokens) -> str:
"""Decode tokens using the backend tokenizer."""
Expand All @@ -161,7 +215,10 @@ def _decode(self, tokens) -> str:
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.decode(tokens)
else:
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
"""Decode a batch of token lists using the backend tokenizer."""
Expand All @@ -172,7 +229,10 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
elif self._tokenizer_backend == "tiktoken":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
else:
raise ValueError(f"Tokenizer backend {self._tokenizer_backend} not supported.")
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _count_tokens(self, text: str) -> int:
"""Count tokens in text using the backend tokenizer."""
Expand Down
50 changes: 11 additions & 39 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,7 @@ class SemanticSentence(Sentence):
"""

embedding: Optional[np.ndarray]

# Only define new slots, not the ones inherited from Sentence
__slots__ = [
"embedding",
]

def __init__(
self,
text: str,
start_index: int,
end_index: int,
token_count: int,
embedding: Optional[np.ndarray] = None,
):
super().__init__(text, start_index, end_index, token_count)
object.__setattr__(
self, "embedding", embedding if embedding is not None else None
)
embedding: Optional[np.ndarray] = field(default=None)


@dataclass
Expand All @@ -59,25 +41,9 @@ class SemanticChunk(SentenceChunk):
sentences: List of SemanticSentence objects in the chunk
"""

sentences: List[SemanticSentence] = field(default_factory=list)

# No new slots needed since we're just overriding the sentences field
__slots__ = []

def __init__(
self,
text: str,
start_index: int,
end_index: int,
token_count: int,
sentences: List[SemanticSentence] = None,
):
super().__init__(text, start_index, end_index, token_count)
object.__setattr__(
self, "sentences", sentences if sentences is not None else []
)


class SemanticChunker(BaseChunker):
"""Chunker that splits text into semantically coherent chunks using embeddings.
Expand Down Expand Up @@ -160,6 +126,12 @@ def __init__(
"embedding_model must be a string or BaseEmbeddings instance"
)

# Probably the dependency is not installed
if self.embedding_model is None:
raise ImportError("embedding_model is not a valid embedding model",
"Please install the `semantic` extra to use this feature")


# Keeping the tokenizer the same as the sentence model is important
# for the group semantic meaning to be calculated properly
super().__init__(self.embedding_model.get_tokenizer_or_token_counter())
Expand Down Expand Up @@ -299,11 +271,11 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
)
for i in range(len(sentences) - 1)
]
similarity_threshold = float(
self.similarity_threshold = float(
np.percentile(all_similarities, self.similarity_percentile)
)
else:
similarity_threshold = self.similarity_threshold
self.similarity_threshold = self.similarity_threshold

groups = []
current_group = sentences[: self.initial_sentences]
Expand All @@ -315,7 +287,7 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
current_embedding, sentence.embedding
)

if similarity >= similarity_threshold:
if similarity >= self.similarity_threshold:
# Add to current group
current_group.append(sentence)
# Update mean embedding
Expand Down
24 changes: 3 additions & 21 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from bisect import bisect_left
from dataclasses import dataclass
from dataclasses import dataclass, field

from itertools import accumulate
from typing import Any, List, Union

Expand All @@ -24,8 +25,6 @@ class Sentence:
start_index: int
end_index: int
token_count: int
__slots__ = ["text", "start_index", "end_index", "token_count"]


@dataclass
class SentenceChunk(Chunk):
Expand All @@ -41,25 +40,8 @@ class SentenceChunk(Chunk):
sentences: List of Sentence objects in the chunk
"""

# Don't redeclare inherited fields
sentences: List[Sentence]

__slots__ = ["sentences"]

def __init__(
self,
text: str,
start_index: int,
end_index: int,
token_count: int,
sentences: List[Sentence] = None,
):
super().__init__(text, start_index, end_index, token_count)
object.__setattr__(
self, "sentences", sentences if sentences is not None else []
)

sentences: List[Sentence] = field(default_factory=list)

class SentenceChunker(BaseChunker):
"""SentenceChunker splits the sentences in a text based on token limits and sentence boundaries.
Expand Down
67 changes: 67 additions & 0 deletions src/chonkie/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Context class for storing contextual information for chunk refinement.
This class is used to store contextual information for chunk refinement.
It can represent context that comes before a chunk at the moment.
By default, the context has no start and end indices, meaning it is not
bound to any specific text. The start and end indices are only set if the
context is part of the same text as the chunk.
"""

from dataclasses import dataclass
from typing import Optional


@dataclass
class Context:
"""A dataclass representing contextual information for chunk refinement.
This class stores text and token count information that can be used to add
context to chunks during the refinement process. It can represent context
that comes before or after a chunk.
Attributes:
text (str): The context text
token_count (int): Number of tokens in the context text
start_index (Optional[int]): Starting position of context in original text
end_index (Optional[int]): Ending position of context in original text
Example:
context = Context(
text="This is some context.",
token_count=5,
start_index=0,
end_index=20
)
"""

text: str
token_count: int
start_index: Optional[int] = None
end_index: Optional[int] = None

def __post_init__(self):
"""Validate the Context attributes after initialization."""
if not isinstance(self.text, str):
raise ValueError("text must be a string")
if not isinstance(self.token_count, int):
raise ValueError("token_count must be an integer")
if self.token_count < 0:
raise ValueError("token_count must be non-negative")
if (self.start_index is not None and self.end_index is not None and
self.start_index > self.end_index):
raise ValueError("start_index must be less than or equal to end_index")

def __len__(self) -> int:
"""Return the length of the context text."""
return len(self.text)

def __str__(self) -> str:
"""Return a string representation of the Context."""
return self.text

def __repr__(self) -> str:
"""Return a detailed string representation of the Context."""
return (f"Context(text='{self.text}', token_count={self.token_count}, "
f"start_index={self.start_index}, end_index={self.end_index})")
6 changes: 6 additions & 0 deletions src/chonkie/refinery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .base import BaseRefinery
from .overlap import OverlapRefinery

# Include all the refinery classes in the __all__ list
__all__ = ["BaseRefinery", "OverlapRefinery"]

Loading

0 comments on commit 71a9d5d

Please sign in to comment.