Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add fix for #92: Support class.method as a Tokenizer for CustomEmbedding +. minor changes #128

Merged
merged 9 commits into from
Jan 2, 2025
2 changes: 1 addition & 1 deletion .github/workflows/python-test-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
Expand Down
30 changes: 24 additions & 6 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def __init__(
if isinstance(tokenizer_or_token_counter, str):
self.tokenizer = self._load_tokenizer(tokenizer_or_token_counter)
self.token_counter = self._get_tokenizer_counter()
# Then check if the tokenizer_or_token_counter is a function via inspect
elif inspect.isfunction(tokenizer_or_token_counter):
self.tokenizer = None
self._tokenizer_backend = "callable"
self.token_counter = tokenizer_or_token_counter
# If not function or string, then assume it's a tokenizer object
else:
self.tokenizer = tokenizer_or_token_counter
self._tokenizer_backend = self._get_tokenizer_backend()
Expand All @@ -49,6 +43,12 @@ def _get_tokenizer_backend(self):
return "tokenizers"
elif "tiktoken" in str(type(self.tokenizer)):
return "tiktoken"
elif (
callable(self.tokenizer)
or inspect.isfunction(self.tokenizer)
or inspect.ismethod(self.tokenizer)
):
return "callable"
else:
raise ValueError(
f"Tokenizer backend {str(type(self.tokenizer))} not supported"
Expand Down Expand Up @@ -107,6 +107,8 @@ def _get_tokenizer_counter(self) -> Callable[[str], int]:
return self._tokenizers_token_counter
elif self._tokenizer_backend == "tiktoken":
return self._tiktoken_token_counter
elif self._tokenizer_backend == "callable":
return self.tokenizer
else:
raise ValueError("Tokenizer backend not supported for token counting")

Expand All @@ -130,6 +132,10 @@ def _encode(self, text: str) -> List[int]:
return self.tokenizer.encode(text, add_special_tokens=False).ids
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode(text)
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support encoding."
)
else:
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
Expand All @@ -148,6 +154,10 @@ def _encode_batch(self, texts: List[str]) -> List[List[int]]:
]
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode_batch(texts)
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support batch encoding."
)
else:
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
Expand All @@ -161,6 +171,10 @@ def _decode(self, tokens) -> str:
return self.tokenizer.decode(tokens)
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.decode(tokens)
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support decoding."
)
else:
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
Expand All @@ -174,6 +188,10 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
return [self.tokenizer.decode(tokens) for tokens in token_lists]
elif self._tokenizer_backend == "tiktoken":
return [self.tokenizer.decode(tokens) for tokens in token_lists]
elif self._tokenizer_backend == "callable":
raise NotImplementedError(
"Callable tokenizer backend does not support batch decoding."
)
else:
raise ValueError(
f"Tokenizer backend {self._tokenizer_backend} not supported."
Expand Down
4 changes: 1 addition & 3 deletions src/chonkie/chunker/recursive.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Recursive chunker."""
from bisect import bisect_left
from dataclasses import dataclass
from functools import lru_cache
from itertools import accumulate
from typing import Any, List, Optional, Union

from chonkie.chunker.base import BaseChunker
from chonkie.types import Chunk, RecursiveChunk, RecursiveRules, RecursiveLevel

from chonkie.types import Chunk, RecursiveChunk, RecursiveLevel, RecursiveRules


class RecursiveChunker(BaseChunker):
Expand Down
1 change: 0 additions & 1 deletion src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Token-based chunking."""

from itertools import accumulate
from typing import Any, Generator, List, Tuple, Union

from chonkie.types import Chunk
Expand Down
7 changes: 1 addition & 6 deletions src/chonkie/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def embed(self, text: str) -> "np.ndarray":
"""
raise NotImplementedError

@abstractmethod
def embed_batch(self, texts: List[str]) -> List["np.ndarray"]:
"""Embed a list of text strings into vector representations.

Expand Down Expand Up @@ -76,7 +75,6 @@ def count_tokens(self, text: str) -> int:
"""
raise NotImplementedError

@abstractmethod
def count_tokens_batch(self, texts: List[str]) -> List[int]:
"""Count the number of tokens in a list of text strings.

Expand All @@ -89,7 +87,6 @@ def count_tokens_batch(self, texts: List[str]) -> List[int]:
"""
return [self.count_tokens(text) for text in texts]

@abstractmethod
def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float:
"""Compute the similarity between two embeddings.

Expand All @@ -106,9 +103,7 @@ def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float:
float: Similarity score between the two embeddings

"""
return np.dot(u, v) / (
np.linalg.norm(u) * np.linalg.norm(v)
) # cosine similarity
return float(np.dot(u, v.T) / (np.linalg.norm(u) * np.linalg.norm(v))) # cosine similarity

@property
@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/chunker/test_recursive_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"""

import pytest
from chonkie.chunker.recursive import RecursiveChunker, RecursiveRules, RecursiveLevel

from chonkie.chunker.recursive import RecursiveChunker, RecursiveLevel, RecursiveRules
from chonkie.types import Chunk


Expand Down
3 changes: 1 addition & 2 deletions tests/chunker/test_word_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
"""

from typing import List
from datasets import load_dataset

import pytest
from datasets import load_dataset
from tokenizers import Tokenizer

from chonkie import WordChunker
Expand Down Expand Up @@ -126,7 +126,6 @@ def test_word_chunker_single_chunk_text(tokenizer):

def test_word_chunker_batch_chunking(tokenizer, sample_batch):
"""Test that the WordChunker can chunk a batch of texts."""

# this is to avoid the following
# DeprecationWarning: This process (pid=<SOME-PID>) is multi-threaded,
# use of fork() may lead to deadlocks in the child.
Expand Down
86 changes: 86 additions & 0 deletions tests/embeddings/test_custom_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Contains test cases for the CustomEmbeddings class.

The tests verify:

- Initialization with a specified dimension
- Embedding a single text string
- Embedding a batch of text strings
- Token counting
- Similarity calculation
"""
import numpy as np
import pytest

from chonkie.embeddings.base import BaseEmbeddings


class CustomEmbeddings(BaseEmbeddings):
"""Custom embeddings class."""

def __init__(self, dimension=4):
"""Initialize the CustomEmbeddings class."""
super().__init__()
self._dimension = dimension

def embed(self, text: str) -> "np.ndarray":
"""Embed a single text string into a vector representation."""
# For demonstration, returns a random vector
return np.random.rand(self._dimension)

def count_tokens(self, text: str) -> int:
"""Count the number of tokens in a text string."""
# Very naive token counting—split by whitespace
return len(text.split())

def similarity(self, u: "np.ndarray", v: "np.ndarray") -> float:
"""Calculate the cosine similarity between two vectors."""
return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)))

@property
def dimension(self) -> int:
"""Return the dimension of the embeddings."""
return self._dimension

def test_custom_embeddings_initialization():
"""Test the initialization of the CustomEmbeddings class."""
embeddings = CustomEmbeddings(dimension=4)
assert isinstance(embeddings, BaseEmbeddings)
assert embeddings.dimension == 4

def test_custom_embeddings_single_text():
"""Test the embedding of a single text string."""
embeddings = CustomEmbeddings(dimension=4)
text = "Test string"
vector = embeddings.embed(text)
assert isinstance(vector, np.ndarray)
assert vector.shape == (4, )

def test_custom_embeddings_batch_text():
"""Test the embedding of a batch of text strings."""
embeddings = CustomEmbeddings(dimension=4)
texts = ["Test string one", "Test string two"]
vectors = embeddings.embed_batch(texts)
assert len(vectors) == 2
for vec in vectors:
assert isinstance(vec, np.ndarray)
assert vec.shape == (4,)

def test_custom_embeddings_token_count():
"""Test the token counting functionality."""
embeddings = CustomEmbeddings()
text = "Test string for counting tokens"
count = embeddings.count_tokens(text)
assert isinstance(count, int)
assert count == len(text.split())

def test_custom_embeddings_similarity():
"""Test the similarity calculation."""
embeddings = CustomEmbeddings(dimension=4)
vec1 = embeddings.embed("Text A")
vec2 = embeddings.embed("Text B")
sim = embeddings.similarity(vec1, vec2)
# Cosine similarity is in [-1, 1]—random vectors often produce a small positive or negative value
assert -1.0 <= sim <= 1.0

if __name__ == "__main__":
pytest.main()
2 changes: 1 addition & 1 deletion tests/refinery/test_overlap_refinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
from transformers import AutoTokenizer

from chonkie import TokenChunker
from chonkie.refinery import OverlapRefinery
from chonkie.types import Chunk, Context, Sentence, SentenceChunk
from chonkie import TokenChunker


@pytest.fixture
Expand Down
Loading