diff --git a/.github/workflows/python-test-push.yml b/.github/workflows/python-test-push.yml index 01d356f..2717793 100644 --- a/.github/workflows/python-test-push.yml +++ b/.github/workflows/python-test-push.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index 5028b97..693eb41 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -30,12 +30,6 @@ def __init__( if isinstance(tokenizer_or_token_counter, str): self.tokenizer = self._load_tokenizer(tokenizer_or_token_counter) self.token_counter = self._get_tokenizer_counter() - # Then check if the tokenizer_or_token_counter is a function via inspect - elif inspect.isfunction(tokenizer_or_token_counter): - self.tokenizer = None - self._tokenizer_backend = "callable" - self.token_counter = tokenizer_or_token_counter - # If not function or string, then assume it's a tokenizer object else: self.tokenizer = tokenizer_or_token_counter self._tokenizer_backend = self._get_tokenizer_backend() @@ -49,6 +43,12 @@ def _get_tokenizer_backend(self): return "tokenizers" elif "tiktoken" in str(type(self.tokenizer)): return "tiktoken" + elif ( + callable(self.tokenizer) + or inspect.isfunction(self.tokenizer) + or inspect.ismethod(self.tokenizer) + ): + return "callable" else: raise ValueError( f"Tokenizer backend {str(type(self.tokenizer))} not supported" @@ -107,6 +107,8 @@ def _get_tokenizer_counter(self) -> Callable[[str], int]: return self._tokenizers_token_counter elif self._tokenizer_backend == "tiktoken": return self._tiktoken_token_counter + elif self._tokenizer_backend == "callable": + return self.tokenizer else: raise ValueError("Tokenizer backend not supported for token counting") @@ -130,6 +132,10 @@ def _encode(self, text: str) -> List[int]: return self.tokenizer.encode(text, add_special_tokens=False).ids elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode(text) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support encoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -148,6 +154,10 @@ def _encode_batch(self, texts: List[str]) -> List[List[int]]: ] elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode_batch(texts) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support batch encoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -161,6 +171,10 @@ def _decode(self, tokens) -> str: return self.tokenizer.decode(tokens) elif self._tokenizer_backend == "tiktoken": return self.tokenizer.decode(tokens) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support decoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -174,6 +188,10 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: return [self.tokenizer.decode(tokens) for tokens in token_lists] elif self._tokenizer_backend == "tiktoken": return [self.tokenizer.decode(tokens) for tokens in token_lists] + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support batch decoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." diff --git a/src/chonkie/chunker/recursive.py b/src/chonkie/chunker/recursive.py index 320b8d8..c390596 100644 --- a/src/chonkie/chunker/recursive.py +++ b/src/chonkie/chunker/recursive.py @@ -1,13 +1,11 @@ """Recursive chunker.""" from bisect import bisect_left -from dataclasses import dataclass from functools import lru_cache from itertools import accumulate from typing import Any, List, Optional, Union from chonkie.chunker.base import BaseChunker -from chonkie.types import Chunk, RecursiveChunk, RecursiveRules, RecursiveLevel - +from chonkie.types import Chunk, RecursiveChunk, RecursiveLevel, RecursiveRules class RecursiveChunker(BaseChunker): diff --git a/src/chonkie/chunker/token.py b/src/chonkie/chunker/token.py index 1912163..defb05c 100644 --- a/src/chonkie/chunker/token.py +++ b/src/chonkie/chunker/token.py @@ -1,6 +1,5 @@ """Token-based chunking.""" -from itertools import accumulate from typing import Any, Generator, List, Tuple, Union from chonkie.types import Chunk diff --git a/src/chonkie/embeddings/base.py b/src/chonkie/embeddings/base.py index 1e3763f..a6b69e6 100644 --- a/src/chonkie/embeddings/base.py +++ b/src/chonkie/embeddings/base.py @@ -43,7 +43,6 @@ def embed(self, text: str) -> "np.ndarray": """ raise NotImplementedError - @abstractmethod def embed_batch(self, texts: List[str]) -> List["np.ndarray"]: """Embed a list of text strings into vector representations. @@ -76,7 +75,6 @@ def count_tokens(self, text: str) -> int: """ raise NotImplementedError - @abstractmethod def count_tokens_batch(self, texts: List[str]) -> List[int]: """Count the number of tokens in a list of text strings. @@ -89,7 +87,6 @@ def count_tokens_batch(self, texts: List[str]) -> List[int]: """ return [self.count_tokens(text) for text in texts] - @abstractmethod def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: """Compute the similarity between two embeddings. @@ -106,9 +103,7 @@ def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: float: Similarity score between the two embeddings """ - return np.dot(u, v) / ( - np.linalg.norm(u) * np.linalg.norm(v) - ) # cosine similarity + return float(np.dot(u, v.T) / (np.linalg.norm(u) * np.linalg.norm(v))) # cosine similarity @property @abstractmethod diff --git a/tests/chunker/test_recursive_chunker.py b/tests/chunker/test_recursive_chunker.py index ffda7f7..b06a089 100644 --- a/tests/chunker/test_recursive_chunker.py +++ b/tests/chunker/test_recursive_chunker.py @@ -12,7 +12,8 @@ """ import pytest -from chonkie.chunker.recursive import RecursiveChunker, RecursiveRules, RecursiveLevel + +from chonkie.chunker.recursive import RecursiveChunker, RecursiveLevel, RecursiveRules from chonkie.types import Chunk diff --git a/tests/chunker/test_word_chunker.py b/tests/chunker/test_word_chunker.py index 420f7fa..15bc836 100644 --- a/tests/chunker/test_word_chunker.py +++ b/tests/chunker/test_word_chunker.py @@ -12,9 +12,9 @@ """ from typing import List -from datasets import load_dataset import pytest +from datasets import load_dataset from tokenizers import Tokenizer from chonkie import WordChunker @@ -126,7 +126,6 @@ def test_word_chunker_single_chunk_text(tokenizer): def test_word_chunker_batch_chunking(tokenizer, sample_batch): """Test that the WordChunker can chunk a batch of texts.""" - # this is to avoid the following # DeprecationWarning: This process (pid=) is multi-threaded, # use of fork() may lead to deadlocks in the child. diff --git a/tests/embeddings/test_custom_embeddings.py b/tests/embeddings/test_custom_embeddings.py new file mode 100644 index 0000000..8a75a27 --- /dev/null +++ b/tests/embeddings/test_custom_embeddings.py @@ -0,0 +1,86 @@ +"""Contains test cases for the CustomEmbeddings class. + +The tests verify: + +- Initialization with a specified dimension +- Embedding a single text string +- Embedding a batch of text strings +- Token counting +- Similarity calculation +""" +import numpy as np +import pytest + +from chonkie.embeddings.base import BaseEmbeddings + + +class CustomEmbeddings(BaseEmbeddings): + """Custom embeddings class.""" + + def __init__(self, dimension=4): + """Initialize the CustomEmbeddings class.""" + super().__init__() + self._dimension = dimension + + def embed(self, text: str) -> "np.ndarray": + """Embed a single text string into a vector representation.""" + # For demonstration, returns a random vector + return np.random.rand(self._dimension) + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a text string.""" + # Very naive token counting—split by whitespace + return len(text.split()) + + def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: + """Calculate the cosine similarity between two vectors.""" + return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))) + + @property + def dimension(self) -> int: + """Return the dimension of the embeddings.""" + return self._dimension + +def test_custom_embeddings_initialization(): + """Test the initialization of the CustomEmbeddings class.""" + embeddings = CustomEmbeddings(dimension=4) + assert isinstance(embeddings, BaseEmbeddings) + assert embeddings.dimension == 4 + +def test_custom_embeddings_single_text(): + """Test the embedding of a single text string.""" + embeddings = CustomEmbeddings(dimension=4) + text = "Test string" + vector = embeddings.embed(text) + assert isinstance(vector, np.ndarray) + assert vector.shape == (4, ) + +def test_custom_embeddings_batch_text(): + """Test the embedding of a batch of text strings.""" + embeddings = CustomEmbeddings(dimension=4) + texts = ["Test string one", "Test string two"] + vectors = embeddings.embed_batch(texts) + assert len(vectors) == 2 + for vec in vectors: + assert isinstance(vec, np.ndarray) + assert vec.shape == (4,) + +def test_custom_embeddings_token_count(): + """Test the token counting functionality.""" + embeddings = CustomEmbeddings() + text = "Test string for counting tokens" + count = embeddings.count_tokens(text) + assert isinstance(count, int) + assert count == len(text.split()) + +def test_custom_embeddings_similarity(): + """Test the similarity calculation.""" + embeddings = CustomEmbeddings(dimension=4) + vec1 = embeddings.embed("Text A") + vec2 = embeddings.embed("Text B") + sim = embeddings.similarity(vec1, vec2) + # Cosine similarity is in [-1, 1]—random vectors often produce a small positive or negative value + assert -1.0 <= sim <= 1.0 + +if __name__ == "__main__": + pytest.main() \ No newline at end of file diff --git a/tests/refinery/test_overlap_refinery.py b/tests/refinery/test_overlap_refinery.py index f93850a..3026947 100644 --- a/tests/refinery/test_overlap_refinery.py +++ b/tests/refinery/test_overlap_refinery.py @@ -3,9 +3,9 @@ import pytest from transformers import AutoTokenizer +from chonkie import TokenChunker from chonkie.refinery import OverlapRefinery from chonkie.types import Chunk, Context, Sentence, SentenceChunk -from chonkie import TokenChunker @pytest.fixture