-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add NmfEmbedding plugin with model implementation and tests
- Loading branch information
1 parent
cf4267c
commit b68938a
Showing
7 changed files
with
221 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Swarmauri Example Plugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
[tool.poetry] | ||
name = "swarmauri_embedding_nmf" | ||
version = "0.6.0.dev1" | ||
description = "This repository includes an example of a First Class Swarmauri Example." | ||
authors = ["Jacob Stewart <[email protected]>"] | ||
license = "Apache-2.0" | ||
readme = "README.md" | ||
repository = "http://github.com/swarmauri/swarmauri-sdk" | ||
classifiers = [ | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12" | ||
] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<3.13" | ||
|
||
# Swarmauri | ||
swarmauri_core = { path = "../../core" } | ||
swarmauri_base = { path = "../../base" } | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
flake8 = "^7.0" | ||
pytest = "^8.0" | ||
pytest-asyncio = ">=0.24.0" | ||
pytest-xdist = "^3.6.1" | ||
pytest-json-report = "^1.5.0" | ||
python-dotenv = "*" | ||
requests = "^2.32.3" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.pytest.ini_options] | ||
norecursedirs = ["combined", "scripts"] | ||
|
||
markers = [ | ||
"test: standard test", | ||
"unit: Unit tests", | ||
"integration: Integration tests", | ||
"acceptance: Acceptance tests", | ||
"experimental: Experimental tests" | ||
] | ||
log_cli = true | ||
log_cli_level = "INFO" | ||
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" | ||
log_cli_date_format = "%Y-%m-%d %H:%M:%S" | ||
asyncio_default_fixture_loop_scope = "function" | ||
|
||
[tool.poetry.plugins."swarmauri.embeddings"] | ||
NmfEmbedding = "swarmauri_embedding_nmf:NmfEmbedding" |
111 changes: 111 additions & 0 deletions
111
pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/NmfEmbedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import joblib | ||
from sklearn.decomposition import NMF | ||
from sklearn.feature_extraction.text import TfidfVectorizer | ||
|
||
from typing import List, Any, Literal | ||
from pydantic import PrivateAttr | ||
from swarmauri_standard.vectors.Vector import Vector | ||
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase | ||
from swarmauri_core.ComponentBase import ComponentBase | ||
|
||
|
||
@ComponentBase.register_type(EmbeddingBase, "NmfEmbedding") | ||
class NmfEmbedding(EmbeddingBase): | ||
n_components: int = 10 | ||
_tfidf_vectorizer = PrivateAttr() | ||
_model = PrivateAttr() | ||
feature_names: List[Any] = [] | ||
|
||
type: Literal["NmfEmbedding"] = "NmfEmbedding" | ||
|
||
def __init__(self, **kwargs): | ||
|
||
super().__init__(**kwargs) | ||
# Initialize TF-IDF Vectorizer | ||
self._tfidf_vectorizer = TfidfVectorizer() | ||
# Initialize NMF with the desired number of components | ||
self._model = NMF(n_components=self.n_components) | ||
|
||
def fit(self, data): | ||
""" | ||
Fit the NMF model to data. | ||
Args: | ||
data (Union[str, Any]): The text data to fit. | ||
""" | ||
# Transform data into TF-IDF matrix | ||
tfidf_matrix = self._tfidf_vectorizer.fit_transform(data) | ||
# Fit the NMF model | ||
self._model.fit(tfidf_matrix) | ||
# Store feature names | ||
self.feature_names = self._tfidf_vectorizer.get_feature_names_out() | ||
|
||
def transform(self, data): | ||
""" | ||
Transform new data into NMF feature space. | ||
Args: | ||
data (Union[str, Any]): Text data to transform. | ||
Returns: | ||
List[IVector]: A list of vectors representing the transformed data. | ||
""" | ||
# Transform data into TF-IDF matrix | ||
tfidf_matrix = self._tfidf_vectorizer.transform(data) | ||
# Transform TF-IDF matrix into NMF space | ||
nmf_features = self._model.transform(tfidf_matrix) | ||
|
||
# Wrap NMF features in SimpleVector instances and return | ||
return [Vector(value=features.tolist()) for features in nmf_features] | ||
|
||
def fit_transform(self, data): | ||
""" | ||
Fit the model to data and then transform it. | ||
Args: | ||
data (Union[str, Any]): The text data to fit and transform. | ||
Returns: | ||
List[IVector]: A list of vectors representing the fitted and transformed data. | ||
""" | ||
self.fit(data) | ||
return self.transform(data) | ||
|
||
def infer_vector(self, data): | ||
""" | ||
Convenience method for transforming a single data point. | ||
Args: | ||
data (Union[str, Any]): Single text data to transform. | ||
Returns: | ||
IVector: A vector representing the transformed single data point. | ||
""" | ||
return self.transform([data])[0] | ||
|
||
def extract_features(self): | ||
""" | ||
Extract the feature names from the TF-IDF vectorizer. | ||
Returns: | ||
The feature names. | ||
""" | ||
return self.feature_names.tolist() | ||
|
||
def save_model(self, path: str) -> None: | ||
""" | ||
Saves the NMF model and TF-IDF vectorizer using joblib. | ||
""" | ||
# It might be necessary to save both tfidf_vectorizer and model | ||
# Consider using a directory for 'path' or appended identifiers for each model file | ||
joblib.dump(self._tfidf_vectorizer, f"{path}_tfidf.joblib") | ||
joblib.dump(self._model, f"{path}_nmf.joblib") | ||
|
||
def load_model(self, path: str) -> None: | ||
""" | ||
Loads the NMF model and TF-IDF vectorizer from paths using joblib. | ||
""" | ||
self._tfidf_vectorizer = joblib.load(f"{path}_tfidf.joblib") | ||
self._model = joblib.load(f"{path}_nmf.joblib") | ||
# Dependending on your implementation, you might need to refresh the feature_names | ||
self.feature_names = self._tfidf_vectorizer.get_feature_names_out() |
14 changes: 14 additions & 0 deletions
14
pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .NmfEmbedding import NmfEmbedding | ||
|
||
__version__ = "0.6.0.dev26" | ||
__long_desc__ = """ | ||
# Swarmauri Nmf Embedding Plugin | ||
This repository includes an Nmf Embedding of a Swarmauri Plugin. | ||
Visit us at: https://swarmauri.com | ||
Follow us at: https://github.com/swarmauri | ||
Star us at: https://github.com/swarmauri/swarmauri-sdk | ||
""" |
39 changes: 39 additions & 0 deletions
39
pkgs/standards/swarmauri_embedding_nmf/tests/unit/TfidfEmbedding_unit_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from swarmauri.embeddings.concrete.TfidfEmbedding import TfidfEmbedding | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(): | ||
def test(): | ||
assert TfidfEmbedding().resource == "Embedding" | ||
|
||
test() | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(): | ||
assert TfidfEmbedding().type == "TfidfEmbedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(): | ||
embedder = TfidfEmbedding() | ||
assert ( | ||
embedder.id == TfidfEmbedding.model_validate_json(embedder.model_dump_json()).id | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_fit_transform(): | ||
embedder = TfidfEmbedding() | ||
documents = ["test", "test1", "test2"] | ||
embedder.fit_transform(documents) | ||
assert documents == embedder.extract_features() | ||
|
||
|
||
@pytest.mark.unit | ||
def test_infer_vector(): | ||
embedder = TfidfEmbedding() | ||
documents = ["test", "test1", "test2"] | ||
embedder.fit_transform(documents) | ||
assert embedder.infer_vector("hi", documents).value == [1.0, 0.0, 0.0, 0.0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters