-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1055 from MichaelDecent/doc2vec
Doc2vec pkg update
- Loading branch information
Showing
4 changed files
with
365 additions
and
217 deletions.
There are no files selected for viewing
5 changes: 3 additions & 2 deletions
5
...standards/swarmauri_doc2vec_vectorstore/swarmauri_doc2vec_vectorstore/Doc2VecEmbedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
pkgs/standards/swarmauri_doc2vec_vectorstore/tests/unit/Doc2VecEmbedding_unit_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
from swarmauri_doc2vec_vectorstore.Doc2VecEmbedding import Doc2VecEmbedding | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(): | ||
assert Doc2VecEmbedding().resource == "Embedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(): | ||
assert Doc2VecEmbedding().type == "Doc2VecEmbedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(): | ||
embedder = Doc2VecEmbedding() | ||
assert ( | ||
embedder.id | ||
== Doc2VecEmbedding.model_validate_json(embedder.model_dump_json()).id | ||
) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_fit_transform(): | ||
embedder = Doc2VecEmbedding() | ||
documents = ["test", "cat", "banana"] | ||
embedder.fit_transform(documents) | ||
assert ["banana", "cat", "test"] == embedder.extract_features() |
72 changes: 72 additions & 0 deletions
72
pkgs/standards/swarmauri_doc2vec_vectorstore/tests/unit/Doc2VecVectorStore_unit_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pytest | ||
from swarmauri_standard.documents.Document import Document | ||
from swarmauri_doc2vec_vectorstore.Doc2VecVectorStore import Doc2VecVectorStore | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_resource(): | ||
vs = Doc2VecVectorStore() | ||
assert vs.resource == "VectorStore" | ||
assert vs.embedder.resource == "Embedding" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_ubc_type(): | ||
vs = Doc2VecVectorStore() | ||
assert vs.type == "Doc2VecVectorStore" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_serialization(): | ||
vs = Doc2VecVectorStore() | ||
assert vs.id == Doc2VecVectorStore.model_validate_json(vs.model_dump_json()).id | ||
|
||
|
||
@pytest.mark.unit | ||
def test_top_k(): | ||
vs = Doc2VecVectorStore() | ||
documents = [ | ||
Document(content="test"), | ||
Document(content="test1"), | ||
Document(content="test2"), | ||
Document(content="test3"), | ||
] | ||
|
||
vs.add_documents(documents) | ||
assert len(vs.retrieve(query="test", top_k=2)) == 2 | ||
|
||
|
||
@pytest.mark.unit | ||
def test_adding_more_doc(): | ||
vs = Doc2VecVectorStore() | ||
documents_batch_1 = [ | ||
Document(content="test"), | ||
Document(content="test1"), | ||
Document(content="test2"), | ||
Document(content="test3"), | ||
] | ||
documents_batch_2 = [ | ||
Document(content="This is a test. Test number 4"), | ||
Document(content="This is a test. Test number 5"), | ||
Document(content="This is a test. Test number 6"), | ||
Document(content="This is a test. Test number 7"), | ||
] | ||
doc_count = len(documents_batch_1) + len(documents_batch_2) | ||
|
||
vs.add_documents(documents_batch_1) | ||
vs.add_documents(documents_batch_2) | ||
assert len(vs.retrieve(query="test", top_k=doc_count)) == doc_count | ||
|
||
|
||
@pytest.mark.unit | ||
def test_oov(): | ||
"""Test for Out Of Vocabulary (OOV) words""" | ||
vs = Doc2VecVectorStore() | ||
documents = [ | ||
Document(content="test"), | ||
Document(content="test1"), | ||
Document(content="test2"), | ||
Document(content="test3"), | ||
] | ||
vs.add_documents(documents) | ||
assert len(vs.retrieve(query="what is test 4", top_k=2)) == 2 |
Oops, something went wrong.