Skip to content

Commit

Permalink
add unit tests for Doc2VecEmbedding and Doc2VecVectorStore
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelDecent committed Jan 7, 2025
1 parent 061ab84 commit 9580b05
Show file tree
Hide file tree
Showing 4 changed files with 365 additions and 217 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Any, Optional, Literal
from typing import List, Any, Literal
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from swarmauri_standard.vectors.Vector import Vector
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
from swarmauri_core.ComponentBase import ComponentBase

@ComponentBase.register_type(EmbeddingBase, 'Doc2VecEmbedding')

@ComponentBase.register_type(EmbeddingBase, "Doc2VecEmbedding")
class Doc2VecEmbedding(EmbeddingBase):
_model: Doc2Vec
type: Literal["Doc2VecEmbedding"] = "Doc2VecEmbedding"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from swarmauri_doc2vec_vectorstore.Doc2VecEmbedding import Doc2VecEmbedding


@pytest.mark.unit
def test_ubc_resource():
assert Doc2VecEmbedding().resource == "Embedding"


@pytest.mark.unit
def test_ubc_type():
assert Doc2VecEmbedding().type == "Doc2VecEmbedding"


@pytest.mark.unit
def test_serialization():
embedder = Doc2VecEmbedding()
assert (
embedder.id
== Doc2VecEmbedding.model_validate_json(embedder.model_dump_json()).id
)


@pytest.mark.unit
def test_fit_transform():
embedder = Doc2VecEmbedding()
documents = ["test", "cat", "banana"]
embedder.fit_transform(documents)
assert ["banana", "cat", "test"] == embedder.extract_features()
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from swarmauri_standard.documents.Document import Document
from swarmauri_doc2vec_vectorstore.Doc2VecVectorStore import Doc2VecVectorStore


@pytest.mark.unit
def test_ubc_resource():
vs = Doc2VecVectorStore()
assert vs.resource == "VectorStore"
assert vs.embedder.resource == "Embedding"


@pytest.mark.unit
def test_ubc_type():
vs = Doc2VecVectorStore()
assert vs.type == "Doc2VecVectorStore"


@pytest.mark.unit
def test_serialization():
vs = Doc2VecVectorStore()
assert vs.id == Doc2VecVectorStore.model_validate_json(vs.model_dump_json()).id


@pytest.mark.unit
def test_top_k():
vs = Doc2VecVectorStore()
documents = [
Document(content="test"),
Document(content="test1"),
Document(content="test2"),
Document(content="test3"),
]

vs.add_documents(documents)
assert len(vs.retrieve(query="test", top_k=2)) == 2


@pytest.mark.unit
def test_adding_more_doc():
vs = Doc2VecVectorStore()
documents_batch_1 = [
Document(content="test"),
Document(content="test1"),
Document(content="test2"),
Document(content="test3"),
]
documents_batch_2 = [
Document(content="This is a test. Test number 4"),
Document(content="This is a test. Test number 5"),
Document(content="This is a test. Test number 6"),
Document(content="This is a test. Test number 7"),
]
doc_count = len(documents_batch_1) + len(documents_batch_2)

vs.add_documents(documents_batch_1)
vs.add_documents(documents_batch_2)
assert len(vs.retrieve(query="test", top_k=doc_count)) == doc_count


@pytest.mark.unit
def test_oov():
"""Test for Out Of Vocabulary (OOV) words"""
vs = Doc2VecVectorStore()
documents = [
Document(content="test"),
Document(content="test1"),
Document(content="test2"),
Document(content="test3"),
]
vs.add_documents(documents)
assert len(vs.retrieve(query="what is test 4", top_k=2)) == 2
Loading

0 comments on commit 9580b05

Please sign in to comment.