Skip to content

Commit

Permalink
feat: add NmfEmbedding plugin with model implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelDecent committed Jan 8, 2025
1 parent cf4267c commit b68938a
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 2 deletions.
1 change: 1 addition & 0 deletions pkgs/standards/swarmauri_embedding_nmf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Swarmauri Example Plugin
53 changes: 53 additions & 0 deletions pkgs/standards/swarmauri_embedding_nmf/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[tool.poetry]
name = "swarmauri_embedding_nmf"
version = "0.6.0.dev1"
description = "This repository includes an example of a First Class Swarmauri Example."
authors = ["Jacob Stewart <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
repository = "http://github.com/swarmauri/swarmauri-sdk"
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
]

[tool.poetry.dependencies]
python = ">=3.10,<3.13"

# Swarmauri
swarmauri_core = { path = "../../core" }
swarmauri_base = { path = "../../base" }

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0"
pytest = "^8.0"
pytest-asyncio = ">=0.24.0"
pytest-xdist = "^3.6.1"
pytest-json-report = "^1.5.0"
python-dotenv = "*"
requests = "^2.32.3"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
norecursedirs = ["combined", "scripts"]

markers = [
"test: standard test",
"unit: Unit tests",
"integration: Integration tests",
"acceptance: Acceptance tests",
"experimental: Experimental tests"
]
log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)s] %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_default_fixture_loop_scope = "function"

[tool.poetry.plugins."swarmauri.embeddings"]
NmfEmbedding = "swarmauri_embedding_nmf:NmfEmbedding"
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import joblib
from sklearn.decomposition import NMF
from sklearn.feature_extraction.text import TfidfVectorizer

from typing import List, Any, Literal
from pydantic import PrivateAttr
from swarmauri_standard.vectors.Vector import Vector
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
from swarmauri_core.ComponentBase import ComponentBase


@ComponentBase.register_type(EmbeddingBase, "NmfEmbedding")
class NmfEmbedding(EmbeddingBase):
n_components: int = 10
_tfidf_vectorizer = PrivateAttr()
_model = PrivateAttr()
feature_names: List[Any] = []

type: Literal["NmfEmbedding"] = "NmfEmbedding"

def __init__(self, **kwargs):

super().__init__(**kwargs)
# Initialize TF-IDF Vectorizer
self._tfidf_vectorizer = TfidfVectorizer()
# Initialize NMF with the desired number of components
self._model = NMF(n_components=self.n_components)

def fit(self, data):
"""
Fit the NMF model to data.
Args:
data (Union[str, Any]): The text data to fit.
"""
# Transform data into TF-IDF matrix
tfidf_matrix = self._tfidf_vectorizer.fit_transform(data)
# Fit the NMF model
self._model.fit(tfidf_matrix)
# Store feature names
self.feature_names = self._tfidf_vectorizer.get_feature_names_out()

def transform(self, data):
"""
Transform new data into NMF feature space.
Args:
data (Union[str, Any]): Text data to transform.
Returns:
List[IVector]: A list of vectors representing the transformed data.
"""
# Transform data into TF-IDF matrix
tfidf_matrix = self._tfidf_vectorizer.transform(data)
# Transform TF-IDF matrix into NMF space
nmf_features = self._model.transform(tfidf_matrix)

# Wrap NMF features in SimpleVector instances and return
return [Vector(value=features.tolist()) for features in nmf_features]

def fit_transform(self, data):
"""
Fit the model to data and then transform it.
Args:
data (Union[str, Any]): The text data to fit and transform.
Returns:
List[IVector]: A list of vectors representing the fitted and transformed data.
"""
self.fit(data)
return self.transform(data)

def infer_vector(self, data):
"""
Convenience method for transforming a single data point.
Args:
data (Union[str, Any]): Single text data to transform.
Returns:
IVector: A vector representing the transformed single data point.
"""
return self.transform([data])[0]

def extract_features(self):
"""
Extract the feature names from the TF-IDF vectorizer.
Returns:
The feature names.
"""
return self.feature_names.tolist()

def save_model(self, path: str) -> None:
"""
Saves the NMF model and TF-IDF vectorizer using joblib.
"""
# It might be necessary to save both tfidf_vectorizer and model
# Consider using a directory for 'path' or appended identifiers for each model file
joblib.dump(self._tfidf_vectorizer, f"{path}_tfidf.joblib")
joblib.dump(self._model, f"{path}_nmf.joblib")

def load_model(self, path: str) -> None:
"""
Loads the NMF model and TF-IDF vectorizer from paths using joblib.
"""
self._tfidf_vectorizer = joblib.load(f"{path}_tfidf.joblib")
self._model = joblib.load(f"{path}_nmf.joblib")
# Dependending on your implementation, you might need to refresh the feature_names
self.feature_names = self._tfidf_vectorizer.get_feature_names_out()
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .NmfEmbedding import NmfEmbedding

__version__ = "0.6.0.dev26"
__long_desc__ = """
# Swarmauri Nmf Embedding Plugin
This repository includes an Nmf Embedding of a Swarmauri Plugin.
Visit us at: https://swarmauri.com
Follow us at: https://github.com/swarmauri
Star us at: https://github.com/swarmauri/swarmauri-sdk
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from swarmauri.embeddings.concrete.TfidfEmbedding import TfidfEmbedding


@pytest.mark.unit
def test_ubc_resource():
def test():
assert TfidfEmbedding().resource == "Embedding"

test()


@pytest.mark.unit
def test_ubc_type():
assert TfidfEmbedding().type == "TfidfEmbedding"


@pytest.mark.unit
def test_serialization():
embedder = TfidfEmbedding()
assert (
embedder.id == TfidfEmbedding.model_validate_json(embedder.model_dump_json()).id
)


@pytest.mark.unit
def test_fit_transform():
embedder = TfidfEmbedding()
documents = ["test", "test1", "test2"]
embedder.fit_transform(documents)
assert documents == embedder.extract_features()


@pytest.mark.unit
def test_infer_vector():
embedder = TfidfEmbedding()
documents = ["test", "test1", "test2"]
embedder.fit_transform(documents)
assert embedder.infer_vector("hi", documents).value == [1.0, 0.0, 0.0, 0.0]
1 change: 1 addition & 0 deletions pkgs/swarmauri/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ doc2vecvectorstore = ["swarmauri_vectorstore_doc2vec"]
matplotlib_tool = ["swarmauri_tool_matplotlib"]
keywordextractor_parser = ["swarmauri_parser_keywordextractor"]
tfidf_vectorstore = ["swarmauri_vectorstore_tfidf"]
nmf_embedding = ["swarmauri_embedding_nmf"]

[tool.setuptools]
namespace_packages = ["swarmauri"]
Expand Down
4 changes: 2 additions & 2 deletions pkgs/swarmauri/swarmauri/plugin_citizenship_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class PluginCitizenshipRegistry:
"swarmauri.embeddings.CohereEmbedding": "swarmauri_standard.embeddings.CohereEmbedding",
"swarmauri.embeddings.GeminiEmbedding": "swarmauri_standard.embeddings.GeminiEmbedding",
"swarmauri.embeddings.MistralEmbedding": "swarmauri_standard.embeddings.MistralEmbedding",
"swarmauri.embeddings.NmfEmbedding": "swarmauri_standard.embeddings.NmfEmbedding",
# "swarmauri.embeddings.NmfEmbedding": "swarmauri_standard.embeddings.NmfEmbedding",
"swarmauri.embeddings.OpenAIEmbedding": "swarmauri_standard.embeddings.OpenAIEmbedding",
# "swarmauri.embeddings.TfidfEmbedding": "swarmauri_standard.embeddings.TfidfEmbedding",
"swarmauri.embeddings.VoyageEmbedding": "swarmauri_standard.embeddings.VoyageEmbedding",
Expand Down Expand Up @@ -229,7 +229,7 @@ class PluginCitizenshipRegistry:
"swarmauri.parsers.KeywordExtractorParser": "swarmauri_parser_keywordextractor.KeywordExtractorParser",
"swarmauri.vector_stores.TfidfVectorStore": "swarmauri_vectorstore_tfidf.TfidfVectorStore",
"swarmauri.embeddings.TfidfEmbedding": "swarmauri_vectorstore_tfidf.TfidfEmbedding",

"swarmauri.embeddings.NmfEmbedding": "swarmauri_embedding_nmf.NmfEmbedding",
}
SECOND_CLASS_REGISTRY: Dict[str, str] = {}
THIRD_CLASS_REGISTRY: Dict[str, str] = {}
Expand Down

0 comments on commit b68938a

Please sign in to comment.