From b68938a5ac97f76055ff9d45b11a4ec6c006d8c1 Mon Sep 17 00:00:00 2001 From: MichaelDecent Date: Wed, 8 Jan 2025 11:32:28 +0100 Subject: [PATCH] feat: add NmfEmbedding plugin with model implementation and tests --- .../swarmauri_embedding_nmf/README.md | 1 + .../swarmauri_embedding_nmf/pyproject.toml | 53 +++++++++ .../swarmauri_embedding_nmf/NmfEmbedding.py | 111 ++++++++++++++++++ .../swarmauri_embedding_nmf/__init__.py | 14 +++ .../tests/unit/TfidfEmbedding_unit_test.py | 39 ++++++ pkgs/swarmauri/pyproject.toml | 1 + .../swarmauri/plugin_citizenship_registry.py | 4 +- 7 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 pkgs/standards/swarmauri_embedding_nmf/README.md create mode 100644 pkgs/standards/swarmauri_embedding_nmf/pyproject.toml create mode 100644 pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/NmfEmbedding.py create mode 100644 pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/__init__.py create mode 100644 pkgs/standards/swarmauri_embedding_nmf/tests/unit/TfidfEmbedding_unit_test.py diff --git a/pkgs/standards/swarmauri_embedding_nmf/README.md b/pkgs/standards/swarmauri_embedding_nmf/README.md new file mode 100644 index 00000000..24ded9c4 --- /dev/null +++ b/pkgs/standards/swarmauri_embedding_nmf/README.md @@ -0,0 +1 @@ +# Swarmauri Example Plugin \ No newline at end of file diff --git a/pkgs/standards/swarmauri_embedding_nmf/pyproject.toml b/pkgs/standards/swarmauri_embedding_nmf/pyproject.toml new file mode 100644 index 00000000..fdd82c00 --- /dev/null +++ b/pkgs/standards/swarmauri_embedding_nmf/pyproject.toml @@ -0,0 +1,53 @@ +[tool.poetry] +name = "swarmauri_embedding_nmf" +version = "0.6.0.dev1" +description = "This repository includes an example of a First Class Swarmauri Example." +authors = ["Jacob Stewart "] +license = "Apache-2.0" +readme = "README.md" +repository = "http://github.com/swarmauri/swarmauri-sdk" +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] + +[tool.poetry.dependencies] +python = ">=3.10,<3.13" + +# Swarmauri +swarmauri_core = { path = "../../core" } +swarmauri_base = { path = "../../base" } + +[tool.poetry.group.dev.dependencies] +flake8 = "^7.0" +pytest = "^8.0" +pytest-asyncio = ">=0.24.0" +pytest-xdist = "^3.6.1" +pytest-json-report = "^1.5.0" +python-dotenv = "*" +requests = "^2.32.3" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +norecursedirs = ["combined", "scripts"] + +markers = [ + "test: standard test", + "unit: Unit tests", + "integration: Integration tests", + "acceptance: Acceptance tests", + "experimental: Experimental tests" +] +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" +asyncio_default_fixture_loop_scope = "function" + +[tool.poetry.plugins."swarmauri.embeddings"] +NmfEmbedding = "swarmauri_embedding_nmf:NmfEmbedding" diff --git a/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/NmfEmbedding.py b/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/NmfEmbedding.py new file mode 100644 index 00000000..532de021 --- /dev/null +++ b/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/NmfEmbedding.py @@ -0,0 +1,111 @@ +import joblib +from sklearn.decomposition import NMF +from sklearn.feature_extraction.text import TfidfVectorizer + +from typing import List, Any, Literal +from pydantic import PrivateAttr +from swarmauri_standard.vectors.Vector import Vector +from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase +from swarmauri_core.ComponentBase import ComponentBase + + +@ComponentBase.register_type(EmbeddingBase, "NmfEmbedding") +class NmfEmbedding(EmbeddingBase): + n_components: int = 10 + _tfidf_vectorizer = PrivateAttr() + _model = PrivateAttr() + feature_names: List[Any] = [] + + type: Literal["NmfEmbedding"] = "NmfEmbedding" + + def __init__(self, **kwargs): + + super().__init__(**kwargs) + # Initialize TF-IDF Vectorizer + self._tfidf_vectorizer = TfidfVectorizer() + # Initialize NMF with the desired number of components + self._model = NMF(n_components=self.n_components) + + def fit(self, data): + """ + Fit the NMF model to data. + + Args: + data (Union[str, Any]): The text data to fit. + """ + # Transform data into TF-IDF matrix + tfidf_matrix = self._tfidf_vectorizer.fit_transform(data) + # Fit the NMF model + self._model.fit(tfidf_matrix) + # Store feature names + self.feature_names = self._tfidf_vectorizer.get_feature_names_out() + + def transform(self, data): + """ + Transform new data into NMF feature space. + + Args: + data (Union[str, Any]): Text data to transform. + + Returns: + List[IVector]: A list of vectors representing the transformed data. + """ + # Transform data into TF-IDF matrix + tfidf_matrix = self._tfidf_vectorizer.transform(data) + # Transform TF-IDF matrix into NMF space + nmf_features = self._model.transform(tfidf_matrix) + + # Wrap NMF features in SimpleVector instances and return + return [Vector(value=features.tolist()) for features in nmf_features] + + def fit_transform(self, data): + """ + Fit the model to data and then transform it. + + Args: + data (Union[str, Any]): The text data to fit and transform. + + Returns: + List[IVector]: A list of vectors representing the fitted and transformed data. + """ + self.fit(data) + return self.transform(data) + + def infer_vector(self, data): + """ + Convenience method for transforming a single data point. + + Args: + data (Union[str, Any]): Single text data to transform. + + Returns: + IVector: A vector representing the transformed single data point. + """ + return self.transform([data])[0] + + def extract_features(self): + """ + Extract the feature names from the TF-IDF vectorizer. + + Returns: + The feature names. + """ + return self.feature_names.tolist() + + def save_model(self, path: str) -> None: + """ + Saves the NMF model and TF-IDF vectorizer using joblib. + """ + # It might be necessary to save both tfidf_vectorizer and model + # Consider using a directory for 'path' or appended identifiers for each model file + joblib.dump(self._tfidf_vectorizer, f"{path}_tfidf.joblib") + joblib.dump(self._model, f"{path}_nmf.joblib") + + def load_model(self, path: str) -> None: + """ + Loads the NMF model and TF-IDF vectorizer from paths using joblib. + """ + self._tfidf_vectorizer = joblib.load(f"{path}_tfidf.joblib") + self._model = joblib.load(f"{path}_nmf.joblib") + # Dependending on your implementation, you might need to refresh the feature_names + self.feature_names = self._tfidf_vectorizer.get_feature_names_out() diff --git a/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/__init__.py b/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/__init__.py new file mode 100644 index 00000000..3e6e9219 --- /dev/null +++ b/pkgs/standards/swarmauri_embedding_nmf/swarmauri_embedding_nmf/__init__.py @@ -0,0 +1,14 @@ +from .NmfEmbedding import NmfEmbedding + +__version__ = "0.6.0.dev26" +__long_desc__ = """ + +# Swarmauri Nmf Embedding Plugin + +This repository includes an Nmf Embedding of a Swarmauri Plugin. + +Visit us at: https://swarmauri.com +Follow us at: https://github.com/swarmauri +Star us at: https://github.com/swarmauri/swarmauri-sdk + +""" diff --git a/pkgs/standards/swarmauri_embedding_nmf/tests/unit/TfidfEmbedding_unit_test.py b/pkgs/standards/swarmauri_embedding_nmf/tests/unit/TfidfEmbedding_unit_test.py new file mode 100644 index 00000000..d61563ba --- /dev/null +++ b/pkgs/standards/swarmauri_embedding_nmf/tests/unit/TfidfEmbedding_unit_test.py @@ -0,0 +1,39 @@ +import pytest +from swarmauri.embeddings.concrete.TfidfEmbedding import TfidfEmbedding + + +@pytest.mark.unit +def test_ubc_resource(): + def test(): + assert TfidfEmbedding().resource == "Embedding" + + test() + + +@pytest.mark.unit +def test_ubc_type(): + assert TfidfEmbedding().type == "TfidfEmbedding" + + +@pytest.mark.unit +def test_serialization(): + embedder = TfidfEmbedding() + assert ( + embedder.id == TfidfEmbedding.model_validate_json(embedder.model_dump_json()).id + ) + + +@pytest.mark.unit +def test_fit_transform(): + embedder = TfidfEmbedding() + documents = ["test", "test1", "test2"] + embedder.fit_transform(documents) + assert documents == embedder.extract_features() + + +@pytest.mark.unit +def test_infer_vector(): + embedder = TfidfEmbedding() + documents = ["test", "test1", "test2"] + embedder.fit_transform(documents) + assert embedder.infer_vector("hi", documents).value == [1.0, 0.0, 0.0, 0.0] diff --git a/pkgs/swarmauri/pyproject.toml b/pkgs/swarmauri/pyproject.toml index f5e5f6a1..fbd1e0dd 100644 --- a/pkgs/swarmauri/pyproject.toml +++ b/pkgs/swarmauri/pyproject.toml @@ -53,6 +53,7 @@ doc2vecvectorstore = ["swarmauri_vectorstore_doc2vec"] matplotlib_tool = ["swarmauri_tool_matplotlib"] keywordextractor_parser = ["swarmauri_parser_keywordextractor"] tfidf_vectorstore = ["swarmauri_vectorstore_tfidf"] +nmf_embedding = ["swarmauri_embedding_nmf"] [tool.setuptools] namespace_packages = ["swarmauri"] diff --git a/pkgs/swarmauri/swarmauri/plugin_citizenship_registry.py b/pkgs/swarmauri/swarmauri/plugin_citizenship_registry.py index a857f4c0..c96c0ce3 100644 --- a/pkgs/swarmauri/swarmauri/plugin_citizenship_registry.py +++ b/pkgs/swarmauri/swarmauri/plugin_citizenship_registry.py @@ -81,7 +81,7 @@ class PluginCitizenshipRegistry: "swarmauri.embeddings.CohereEmbedding": "swarmauri_standard.embeddings.CohereEmbedding", "swarmauri.embeddings.GeminiEmbedding": "swarmauri_standard.embeddings.GeminiEmbedding", "swarmauri.embeddings.MistralEmbedding": "swarmauri_standard.embeddings.MistralEmbedding", - "swarmauri.embeddings.NmfEmbedding": "swarmauri_standard.embeddings.NmfEmbedding", + # "swarmauri.embeddings.NmfEmbedding": "swarmauri_standard.embeddings.NmfEmbedding", "swarmauri.embeddings.OpenAIEmbedding": "swarmauri_standard.embeddings.OpenAIEmbedding", # "swarmauri.embeddings.TfidfEmbedding": "swarmauri_standard.embeddings.TfidfEmbedding", "swarmauri.embeddings.VoyageEmbedding": "swarmauri_standard.embeddings.VoyageEmbedding", @@ -229,7 +229,7 @@ class PluginCitizenshipRegistry: "swarmauri.parsers.KeywordExtractorParser": "swarmauri_parser_keywordextractor.KeywordExtractorParser", "swarmauri.vector_stores.TfidfVectorStore": "swarmauri_vectorstore_tfidf.TfidfVectorStore", "swarmauri.embeddings.TfidfEmbedding": "swarmauri_vectorstore_tfidf.TfidfEmbedding", - + "swarmauri.embeddings.NmfEmbedding": "swarmauri_embedding_nmf.NmfEmbedding", } SECOND_CLASS_REGISTRY: Dict[str, str] = {} THIRD_CLASS_REGISTRY: Dict[str, str] = {}