From 83ab507bbac181ce7e308a4256a7ecf88cf0e5d7 Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Tue, 31 Dec 2024 22:45:37 +0530 Subject: [PATCH 1/6] [fix] #93: Support passing a callable which may be a function or a method for token counting --- src/chonkie/chunker/base.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index 5028b97..693eb41 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -30,12 +30,6 @@ def __init__( if isinstance(tokenizer_or_token_counter, str): self.tokenizer = self._load_tokenizer(tokenizer_or_token_counter) self.token_counter = self._get_tokenizer_counter() - # Then check if the tokenizer_or_token_counter is a function via inspect - elif inspect.isfunction(tokenizer_or_token_counter): - self.tokenizer = None - self._tokenizer_backend = "callable" - self.token_counter = tokenizer_or_token_counter - # If not function or string, then assume it's a tokenizer object else: self.tokenizer = tokenizer_or_token_counter self._tokenizer_backend = self._get_tokenizer_backend() @@ -49,6 +43,12 @@ def _get_tokenizer_backend(self): return "tokenizers" elif "tiktoken" in str(type(self.tokenizer)): return "tiktoken" + elif ( + callable(self.tokenizer) + or inspect.isfunction(self.tokenizer) + or inspect.ismethod(self.tokenizer) + ): + return "callable" else: raise ValueError( f"Tokenizer backend {str(type(self.tokenizer))} not supported" @@ -107,6 +107,8 @@ def _get_tokenizer_counter(self) -> Callable[[str], int]: return self._tokenizers_token_counter elif self._tokenizer_backend == "tiktoken": return self._tiktoken_token_counter + elif self._tokenizer_backend == "callable": + return self.tokenizer else: raise ValueError("Tokenizer backend not supported for token counting") @@ -130,6 +132,10 @@ def _encode(self, text: str) -> List[int]: return self.tokenizer.encode(text, add_special_tokens=False).ids elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode(text) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support encoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -148,6 +154,10 @@ def _encode_batch(self, texts: List[str]) -> List[List[int]]: ] elif self._tokenizer_backend == "tiktoken": return self.tokenizer.encode_batch(texts) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support batch encoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -161,6 +171,10 @@ def _decode(self, tokens) -> str: return self.tokenizer.decode(tokens) elif self._tokenizer_backend == "tiktoken": return self.tokenizer.decode(tokens) + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support decoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." @@ -174,6 +188,10 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: return [self.tokenizer.decode(tokens) for tokens in token_lists] elif self._tokenizer_backend == "tiktoken": return [self.tokenizer.decode(tokens) for tokens in token_lists] + elif self._tokenizer_backend == "callable": + raise NotImplementedError( + "Callable tokenizer backend does not support batch decoding." + ) else: raise ValueError( f"Tokenizer backend {self._tokenizer_backend} not supported." From 513e1e315ff7beb349afc6e3dedc0e37545081a6 Mon Sep 17 00:00:00 2001 From: Shreyash Nigam <33201914+shreyashnigam@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:49:02 +0530 Subject: [PATCH 2/6] Update CONTRIBUTING.md Hyperlink good first issue --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 33a9340..b93cad4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -125,7 +125,7 @@ src/ ### 1. Good First Issues -Look for issues labeled `good-first-issue`. These are great starting points for new contributors. +Look for issues labeled [`good-first-issue`](https://github.com/chonkie-ai/chonkie/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). These are great starting points for new contributors. ### 2. Documentation From 3f0432326282f32086f67d2e871825145622fcc4 Mon Sep 17 00:00:00 2001 From: Shreyash Nigam Date: Thu, 2 Jan 2025 15:04:56 +0530 Subject: [PATCH 3/6] Fix numbering --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b93cad4..a4fe131 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,7 +19,7 @@ git clone https://github.com/your-username/chonkie.git cd chonkie ``` -1. Create a virtual environment and install dependencies: +2. Create a virtual environment and install dependencies: ```bash python -m venv venv @@ -105,7 +105,7 @@ feat: add batch processing to WordChunker - Update documentation ``` -1. **Dependencies**: If adding new dependencies: +3. **Dependencies**: If adding new dependencies: - Core dependencies go in `project.dependencies` - Optional features go in `project.optional-dependencies` - Development tools go in the `dev` optional dependency group From dd49161dffc9cc1115b7b21d7f1c948a435d7226 Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Thu, 2 Jan 2025 15:51:32 +0530 Subject: [PATCH 4/6] [fix] Add tests for Py38 and Py313 --- .github/workflows/python-test-push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-test-push.yml b/.github/workflows/python-test-push.yml index 01d356f..2717793 100644 --- a/.github/workflows/python-test-push.yml +++ b/.github/workflows/python-test-push.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 From 59b246ef1ec26a8e6ed0660f7b155fdaaf73a26e Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Thu, 2 Jan 2025 18:05:23 +0530 Subject: [PATCH 5/6] Refactor BaseEmbeddings class and add CustomEmbeddings tests - Removed abstractmethod decorators from `embed_batch`, `count_tokens_batch`, and `similarity` methods in the BaseEmbeddings class. - Updated the `similarity` method to ensure it returns a float. - Introduced a new CustomEmbeddings class implementing the abstract methods with functionality for embedding, token counting, and similarity calculation. - Added comprehensive unit tests for the CustomEmbeddings class, covering initialization, embedding single and batch texts, token counting, and similarity calculation. These changes enhance the flexibility and test coverage of the embeddings functionality. --- src/chonkie/embeddings/base.py | 7 +- tests/embeddings/test_custom_embeddings.py | 93 ++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 tests/embeddings/test_custom_embeddings.py diff --git a/src/chonkie/embeddings/base.py b/src/chonkie/embeddings/base.py index 1e3763f..a6b69e6 100644 --- a/src/chonkie/embeddings/base.py +++ b/src/chonkie/embeddings/base.py @@ -43,7 +43,6 @@ def embed(self, text: str) -> "np.ndarray": """ raise NotImplementedError - @abstractmethod def embed_batch(self, texts: List[str]) -> List["np.ndarray"]: """Embed a list of text strings into vector representations. @@ -76,7 +75,6 @@ def count_tokens(self, text: str) -> int: """ raise NotImplementedError - @abstractmethod def count_tokens_batch(self, texts: List[str]) -> List[int]: """Count the number of tokens in a list of text strings. @@ -89,7 +87,6 @@ def count_tokens_batch(self, texts: List[str]) -> List[int]: """ return [self.count_tokens(text) for text in texts] - @abstractmethod def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: """Compute the similarity between two embeddings. @@ -106,9 +103,7 @@ def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: float: Similarity score between the two embeddings """ - return np.dot(u, v) / ( - np.linalg.norm(u) * np.linalg.norm(v) - ) # cosine similarity + return float(np.dot(u, v.T) / (np.linalg.norm(u) * np.linalg.norm(v))) # cosine similarity @property @abstractmethod diff --git a/tests/embeddings/test_custom_embeddings.py b/tests/embeddings/test_custom_embeddings.py new file mode 100644 index 0000000..6bc2b84 --- /dev/null +++ b/tests/embeddings/test_custom_embeddings.py @@ -0,0 +1,93 @@ +"""Contains test cases for the CustomEmbeddings class. + +The tests verify: + +- Initialization with a specified dimension +- Embedding a single text string +- Embedding a batch of text strings +- Token counting +- Similarity calculation +""" +import numpy as np +import pytest + +from chonkie.embeddings.base import BaseEmbeddings + +# 1. Define a custom embeddings class implementing the abstract methods. +class CustomEmbeddings(BaseEmbeddings): + """Custom embeddings class.""" + + def __init__(self, dimension=4): + """Initialize the CustomEmbeddings class.""" + super().__init__() + self._dimension = dimension + + def embed(self, text: str) -> "np.ndarray": + """Embed a single text string into a vector representation.""" + # For demonstration, returns a random vector + return np.random.rand(self._dimension) + + def embed_batch(self, texts): + """Embed a batch of text strings into a list of vector representations.""" + # Reuse the single-text embed for batch embeddings + return [self.embed(text) for text in texts] + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a text string.""" + # Very naive token counting—split by whitespace + return len(text.split()) + + def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float: + """Calculate the cosine similarity between two vectors.""" + return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))) + + @property + def dimension(self) -> int: + """Return the dimension of the embeddings.""" + return self._dimension + + +# 2. Write tests to validate the custom embeddings. +def test_custom_embeddings_initialization(): + """Test the initialization of the CustomEmbeddings class.""" + embeddings = CustomEmbeddings(dimension=4) + assert isinstance(embeddings, BaseEmbeddings) + assert embeddings.dimension == 4 + +def test_custom_embeddings_single_text(): + """Test the embedding of a single text string.""" + embeddings = CustomEmbeddings(dimension=4) + text = "Test string" + vector = embeddings.embed(text) + assert isinstance(vector, np.ndarray) + assert vector.shape == (4, ) + +def test_custom_embeddings_batch_text(): + """Test the embedding of a batch of text strings.""" + embeddings = CustomEmbeddings(dimension=4) + texts = ["Test string one", "Test string two"] + vectors = embeddings.embed_batch(texts) + assert len(vectors) == 2 + for vec in vectors: + assert isinstance(vec, np.ndarray) + assert vec.shape == (4,) + +def test_custom_embeddings_token_count(): + """Test the token counting functionality.""" + embeddings = CustomEmbeddings() + text = "Test string for counting tokens" + count = embeddings.count_tokens(text) + assert isinstance(count, int) + assert count == len(text.split()) + +def test_custom_embeddings_similarity(): + """Test the similarity calculation.""" + embeddings = CustomEmbeddings(dimension=4) + vec1 = embeddings.embed("Text A") + vec2 = embeddings.embed("Text B") + sim = embeddings.similarity(vec1, vec2) + # Cosine similarity is in [-1, 1]—random vectors often produce a small positive or negative value + assert -1.0 <= sim <= 1.0 + +if __name__ == "__main__": + pytest.main() \ No newline at end of file From 242e1fb5b2ba2ac1d2f2d44a53c3f1b1eccd9a58 Mon Sep 17 00:00:00 2001 From: bhavnicksm Date: Thu, 2 Jan 2025 18:19:55 +0530 Subject: [PATCH 6/6] [chore] Run ruff format fixing --- src/chonkie/chunker/recursive.py | 4 +--- src/chonkie/chunker/token.py | 1 - tests/chunker/test_recursive_chunker.py | 3 ++- tests/chunker/test_word_chunker.py | 3 +-- tests/embeddings/test_custom_embeddings.py | 9 +-------- tests/refinery/test_overlap_refinery.py | 2 +- 6 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/chonkie/chunker/recursive.py b/src/chonkie/chunker/recursive.py index 320b8d8..c390596 100644 --- a/src/chonkie/chunker/recursive.py +++ b/src/chonkie/chunker/recursive.py @@ -1,13 +1,11 @@ """Recursive chunker.""" from bisect import bisect_left -from dataclasses import dataclass from functools import lru_cache from itertools import accumulate from typing import Any, List, Optional, Union from chonkie.chunker.base import BaseChunker -from chonkie.types import Chunk, RecursiveChunk, RecursiveRules, RecursiveLevel - +from chonkie.types import Chunk, RecursiveChunk, RecursiveLevel, RecursiveRules class RecursiveChunker(BaseChunker): diff --git a/src/chonkie/chunker/token.py b/src/chonkie/chunker/token.py index 1912163..defb05c 100644 --- a/src/chonkie/chunker/token.py +++ b/src/chonkie/chunker/token.py @@ -1,6 +1,5 @@ """Token-based chunking.""" -from itertools import accumulate from typing import Any, Generator, List, Tuple, Union from chonkie.types import Chunk diff --git a/tests/chunker/test_recursive_chunker.py b/tests/chunker/test_recursive_chunker.py index ffda7f7..b06a089 100644 --- a/tests/chunker/test_recursive_chunker.py +++ b/tests/chunker/test_recursive_chunker.py @@ -12,7 +12,8 @@ """ import pytest -from chonkie.chunker.recursive import RecursiveChunker, RecursiveRules, RecursiveLevel + +from chonkie.chunker.recursive import RecursiveChunker, RecursiveLevel, RecursiveRules from chonkie.types import Chunk diff --git a/tests/chunker/test_word_chunker.py b/tests/chunker/test_word_chunker.py index 420f7fa..15bc836 100644 --- a/tests/chunker/test_word_chunker.py +++ b/tests/chunker/test_word_chunker.py @@ -12,9 +12,9 @@ """ from typing import List -from datasets import load_dataset import pytest +from datasets import load_dataset from tokenizers import Tokenizer from chonkie import WordChunker @@ -126,7 +126,6 @@ def test_word_chunker_single_chunk_text(tokenizer): def test_word_chunker_batch_chunking(tokenizer, sample_batch): """Test that the WordChunker can chunk a batch of texts.""" - # this is to avoid the following # DeprecationWarning: This process (pid=) is multi-threaded, # use of fork() may lead to deadlocks in the child. diff --git a/tests/embeddings/test_custom_embeddings.py b/tests/embeddings/test_custom_embeddings.py index 6bc2b84..8a75a27 100644 --- a/tests/embeddings/test_custom_embeddings.py +++ b/tests/embeddings/test_custom_embeddings.py @@ -13,7 +13,7 @@ from chonkie.embeddings.base import BaseEmbeddings -# 1. Define a custom embeddings class implementing the abstract methods. + class CustomEmbeddings(BaseEmbeddings): """Custom embeddings class.""" @@ -27,11 +27,6 @@ def embed(self, text: str) -> "np.ndarray": # For demonstration, returns a random vector return np.random.rand(self._dimension) - def embed_batch(self, texts): - """Embed a batch of text strings into a list of vector representations.""" - # Reuse the single-text embed for batch embeddings - return [self.embed(text) for text in texts] - def count_tokens(self, text: str) -> int: """Count the number of tokens in a text string.""" # Very naive token counting—split by whitespace @@ -46,8 +41,6 @@ def dimension(self) -> int: """Return the dimension of the embeddings.""" return self._dimension - -# 2. Write tests to validate the custom embeddings. def test_custom_embeddings_initialization(): """Test the initialization of the CustomEmbeddings class.""" embeddings = CustomEmbeddings(dimension=4) diff --git a/tests/refinery/test_overlap_refinery.py b/tests/refinery/test_overlap_refinery.py index f93850a..3026947 100644 --- a/tests/refinery/test_overlap_refinery.py +++ b/tests/refinery/test_overlap_refinery.py @@ -3,9 +3,9 @@ import pytest from transformers import AutoTokenizer +from chonkie import TokenChunker from chonkie.refinery import OverlapRefinery from chonkie.types import Chunk, Context, Sentence, SentenceChunk -from chonkie import TokenChunker @pytest.fixture