-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1058 from MichaelDecent/pkg3
Add Swarmauri Tfidf Vector Store plugins
- Loading branch information
Showing
10 changed files
with
293 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Swarmauri Example Plugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
[tool.poetry] | ||
name = "swarmauri_vectorstore_tfidf" | ||
version = "0.6.0.dev1" | ||
description = "This repository includes an example of a First Class Swarmauri Example." | ||
authors = ["Jacob Stewart <[email protected]>"] | ||
license = "Apache-2.0" | ||
readme = "README.md" | ||
repository = "http://github.com/swarmauri/swarmauri-sdk" | ||
classifiers = [ | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12" | ||
] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<3.13" | ||
|
||
# Swarmauri | ||
swarmauri_core = { path = "../../core" } | ||
swarmauri_base = { path = "../../base" } | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
flake8 = "^7.0" | ||
pytest = "^8.0" | ||
pytest-asyncio = ">=0.24.0" | ||
pytest-xdist = "^3.6.1" | ||
pytest-json-report = "^1.5.0" | ||
python-dotenv = "*" | ||
requests = "^2.32.3" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.pytest.ini_options] | ||
norecursedirs = ["combined", "scripts"] | ||
|
||
markers = [ | ||
"test: standard test", | ||
"unit: Unit tests", | ||
"integration: Integration tests", | ||
"acceptance: Acceptance tests", | ||
"experimental: Experimental tests" | ||
] | ||
log_cli = true | ||
log_cli_level = "INFO" | ||
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" | ||
log_cli_date_format = "%Y-%m-%d %H:%M:%S" | ||
asyncio_default_fixture_loop_scope = "function" | ||
|
||
[tool.poetry.plugins."swarmauri.vector_stores"] | ||
TfidfVectorStore = "swarmauri_vectorstore_tfidf:TfidfVectorStore" | ||
|
||
[tool.poetry.plugins."swarmauri.embeddings"] | ||
TfidfEmbedding = "swarmauri_vectorstore_tfidf:TfidfEmbedding" | ||
|
54 changes: 54 additions & 0 deletions
54
pkgs/standards/swarmauri_vectorstore_tfidf/swarmauri_vectorstore_tfidf/TfidfEmbedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import List, Union, Any, Literal | ||
import joblib | ||
from pydantic import PrivateAttr | ||
from sklearn.feature_extraction.text import TfidfVectorizer as SklearnTfidfVectorizer | ||
|
||
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase | ||
from swarmauri_standard.vectors.Vector import Vector | ||
from swarmauri_core.ComponentBase import ComponentBase | ||
|
||
|
||
@ComponentBase.register_type(EmbeddingBase, "TfidfEmbedding") | ||
class TfidfEmbedding(EmbeddingBase): | ||
_model = PrivateAttr() | ||
_fit_matrix = PrivateAttr() | ||
type: Literal["TfidfEmbedding"] = "TfidfEmbedding" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._model = SklearnTfidfVectorizer() | ||
|
||
def extract_features(self): | ||
return self._model.get_feature_names_out().tolist() | ||
|
||
def fit(self, documents: List[str]) -> None: | ||
self._fit_matrix = self._model.fit_transform(documents) | ||
|
||
def fit_transform(self, documents: List[str]) -> List[Vector]: | ||
self._fit_matrix = self._model.fit_transform(documents) | ||
# Convert the sparse matrix rows into Vector instances | ||
vectors = [ | ||
Vector(value=vector.toarray().flatten()) for vector in self._fit_matrix | ||
] | ||
return vectors | ||
|
||
def transform(self, data: Union[str, Any], documents: List[str]) -> List[Vector]: | ||
raise NotImplementedError("Transform not implemented on TFIDFVectorizer.") | ||
|
||
def infer_vector(self, data: str, documents: List[str]) -> Vector: | ||
documents.append(data) | ||
tmp_tfidf_matrix = self.fit_transform(documents) | ||
query_vector = tmp_tfidf_matrix[-1] | ||
return query_vector | ||
|
||
def save_model(self, path: str) -> None: | ||
""" | ||
Saves the TF-IDF model to the specified path using joblib. | ||
""" | ||
joblib.dump(self._model, path) | ||
|
||
def load_model(self, path: str) -> None: | ||
""" | ||
Loads a TF-IDF model from the specified path using joblib. | ||
""" | ||
self._model = joblib.load(path) |
76 changes: 76 additions & 0 deletions
76
pkgs/standards/swarmauri_vectorstore_tfidf/swarmauri_vectorstore_tfidf/TfidfVectorStore.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import List, Union, Literal | ||
from swarmauri_standard.documents.Document import Document | ||
from swarmauri_standard.embeddings.TfidfEmbedding import TfidfEmbedding | ||
from swarmauri_standard.distances.CosineDistance import CosineDistance | ||
|
||
from swarmauri_base.vector_stores.VectorStoreBase import VectorStoreBase | ||
from swarmauri_base.vector_stores.VectorStoreRetrieveMixin import ( | ||
VectorStoreRetrieveMixin, | ||
) | ||
from swarmauri_base.vector_stores.VectorStoreSaveLoadMixin import ( | ||
VectorStoreSaveLoadMixin, | ||
) | ||
from swarmauri_core.ComponentBase import ComponentBase | ||
|
||
|
||
@ComponentBase.register_type(VectorStoreBase, "TfidfVectorStore") | ||
class TfidfVectorStore( | ||
VectorStoreSaveLoadMixin, VectorStoreRetrieveMixin, VectorStoreBase | ||
): | ||
type: Literal["TfidfVectorStore"] = "TfidfVectorStore" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._embedder = TfidfEmbedding() | ||
self._distance = CosineDistance() | ||
self.documents = [] | ||
|
||
def add_document(self, document: Document) -> None: | ||
self.documents.append(document) | ||
# Recalculate TF-IDF matrix for the current set of documents | ||
self._embedder.fit([doc.content for doc in self.documents]) | ||
|
||
def add_documents(self, documents: List[Document]) -> None: | ||
self.documents.extend(documents) | ||
# Recalculate TF-IDF matrix for the current set of documents | ||
self._embedder.fit([doc.content for doc in self.documents]) | ||
|
||
def get_document(self, id: str) -> Union[Document, None]: | ||
for document in self.documents: | ||
if document.id == id: | ||
return document | ||
return None | ||
|
||
def get_all_documents(self) -> List[Document]: | ||
return self.documents | ||
|
||
def delete_document(self, id: str) -> None: | ||
self.documents = [doc for doc in self.documents if doc.id != id] | ||
# Recalculate TF-IDF matrix for the current set of documents | ||
self._embedder.fit([doc.content for doc in self.documents]) | ||
|
||
def update_document(self, id: str, updated_document: Document) -> None: | ||
for i, document in enumerate(self.documents): | ||
if document.id == id: | ||
self.documents[i] = updated_document | ||
break | ||
|
||
# Recalculate TF-IDF matrix for the current set of documents | ||
self._embedder.fit([doc.content for doc in self.documents]) | ||
|
||
def retrieve(self, query: str, top_k: int = 5) -> List[Document]: | ||
documents = [query] | ||
documents.extend([doc.content for doc in self.documents]) | ||
transform_matrix = self._embedder.fit_transform(documents) | ||
|
||
# The inferred vector is the last vector in the transformed_matrix | ||
# The rest of the matrix is what we are comparing | ||
distances = self._distance.distances( | ||
transform_matrix[-1], transform_matrix[:-1] | ||
) | ||
|
||
# Get the indices of the top_k most similar (least distant) documents | ||
top_k_indices = sorted(range(len(distances)), key=lambda i: distances[i])[ | ||
:top_k | ||
] | ||
return [self.documents[i] for i in top_k_indices] |
15 changes: 15 additions & 0 deletions
15
pkgs/standards/swarmauri_vectorstore_tfidf/swarmauri_vectorstore_tfidf/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .TfidfEmbedding import TfidfEmbedding | ||
from .TfidfVectorStore import TfidfVectorStore | ||
|
||
__version__ = "0.6.0.dev26" | ||
__long_desc__ = """ | ||
# Swarmauri Tfidf Plugin | ||
This repository includes an Tfidf of a Swarmauri Plugin. | ||
Visit us at: https://swarmauri.com | ||
Follow us at: https://github.com/swarmauri | ||
Star us at: https://github.com/swarmauri/swarmauri-sdk | ||
""" |
39 changes: 39 additions & 0 deletions
39
pkgs/standards/swarmauri_vectorstore_tfidf/tests/unit/TfidfEmbedding_unit_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from swarmauri_vectorstore_tfidf.TfidfEmbedding import TfidfEmbedding | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(): | ||
def test(): | ||
assert TfidfEmbedding().resource == "Embedding" | ||
|
||
test() | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(): | ||
assert TfidfEmbedding().type == "TfidfEmbedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(): | ||
embedder = TfidfEmbedding() | ||
assert ( | ||
embedder.id == TfidfEmbedding.model_validate_json(embedder.model_dump_json()).id | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_fit_transform(): | ||
embedder = TfidfEmbedding() | ||
documents = ["test", "test1", "test2"] | ||
embedder.fit_transform(documents) | ||
assert documents == embedder.extract_features() | ||
|
||
|
||
@pytest.mark.unit | ||
def test_infer_vector(): | ||
embedder = TfidfEmbedding() | ||
documents = ["test", "test1", "test2"] | ||
embedder.fit_transform(documents) | ||
assert embedder.infer_vector("hi", documents).value == [1.0, 0.0, 0.0, 0.0] |
36 changes: 36 additions & 0 deletions
36
pkgs/standards/swarmauri_vectorstore_tfidf/tests/unit/TfidfVectorStore_unit_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pytest | ||
from swarmauri.documents.concrete.Document import Document | ||
from swarmauri_vectorstore_tfidf.TfidfVectorStore import TfidfVectorStore | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(): | ||
vs = TfidfVectorStore() | ||
assert vs.resource == "VectorStore" | ||
assert vs.embedder.resource == "Embedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(): | ||
vs = TfidfVectorStore() | ||
assert vs.type == "TfidfVectorStore" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(): | ||
vs = TfidfVectorStore() | ||
assert vs.id == TfidfVectorStore.model_validate_json(vs.model_dump_json()).id | ||
|
||
|
||
@pytest.mark.unit | ||
def test_top_k(): | ||
vs = TfidfVectorStore() | ||
documents = [ | ||
Document(content="test"), | ||
Document(content="test1"), | ||
Document(content="test2"), | ||
Document(content="test3"), | ||
] | ||
|
||
vs.add_documents(documents) | ||
assert len(vs.retrieve(query="test", top_k=2)) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters