Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for BERTEmbeddingParser #1112

Open
wants to merge 3 commits into
base: mono/0.6.0.dev1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions pkgs/base/swarmauri_base/document_stores/DocumentStoreBase.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from abc import abstractmethod
from typing import List, Optional
from typing import List, Literal, Optional
import json

from pydantic import Field

from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes
from swarmauri_core.documents.IDocument import IDocument
from swarmauri_core.document_stores.IDocumentStore import IDocumentStore

class DocumentStoreBase(IDocumentStore):

@ComponentBase.register_model()
class DocumentStoreBase(IDocumentStore, ComponentBase):
"""
Abstract base class for document stores, implementing the IDocumentStore interface.

This class provides a standard API for adding, updating, getting, and deleting documents in a store.
The specifics of storing (e.g., in a database, in-memory, or file system) are to be implemented by concrete subclasses.
"""

resource: Optional[str] = Field(default=ResourceTypes.DOCUMENT_STORE.value, frozen=True)
type: Literal["DocumentStoreBase"] = "DocumentStoreBase"

@abstractmethod
def add_document(self, document: IDocument) -> None:
"""
Expand Down Expand Up @@ -76,14 +84,14 @@ def delete_document(self, doc_id: str) -> None:
- doc_id (str): The unique identifier of the document to delete.
"""
pass

def document_count(self):
return len(self.documents)

def dump(self, file_path):
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
json.dumps([each.__dict__ for each in self.documents], f, indent=4)

def load(self, file_path):
with open(file_path, 'r') as f:
self.documents = json.loads(f)
with open(file_path, "r") as f:
self.documents = json.loads(f)
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
from swarmauri_core.documents.IDocument import IDocument
from swarmauri_base.document_stores.DocumentStoreBase import DocumentStoreBase


class DocumentStoreRetrieveBase(DocumentStoreBase, IDocumentRetrieve):

@abstractmethod
def retrieve(self, query: str, top_k: int = 5) -> List[IDocument]:
"""
Retrieve the top_k most relevant documents based on the given query.

Args:
query (str): The query string used for document retrieval.
top_k (int): The number of top relevant documents to retrieve.

Returns:
List[IDocument]: A list of the top_k most relevant documents.
"""
pass
pass
3 changes: 2 additions & 1 deletion pkgs/base/swarmauri_base/documents/DocumentBase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Optional, Literal
import numpy as np
from pydantic import Field, ConfigDict

from swarmauri_core.documents.IDocument import IDocument
Expand All @@ -7,7 +8,7 @@

@ComponentBase.register_model()
class DocumentBase(IDocument, ComponentBase):
content: str
content: str | np.ndarray
metadata: Dict = {}
embedding: Optional[Vector] = None
model_config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
from typing import List, Optional
from typing import List, Literal, Optional

from pydantic import PrivateAttr
from swarmauri_core.ComponentBase import ComponentBase
from swarmauri_base.document_stores.DocumentStoreBase import DocumentStoreBase
from swarmauri_core.documents.IDocument import IDocument
from swarmauri_standard.documents.Document import Document
import redis
import json


@ComponentBase.register_type(DocumentStoreBase, "RedisDocumentStore")
class RedisDocumentStore(DocumentStoreBase):
def __init__(self, host, password, port, db):
"""Store connection details without initializing the Redis client."""
# Public fields
type: Literal["RedisDocumentStore"] = "RedisDocumentStore"

# Private attributes
_host: str = PrivateAttr()
_password: str = PrivateAttr()
_port: int = PrivateAttr()
_db: int = PrivateAttr()
_redis_client: Optional[redis.Redis] = PrivateAttr(default=None)

def __init__(
self, host: str, password: str = "", port: int = 6379, db: int = 0, **data
):
super().__init__(**data)
self._host = host
self._password = password
self._port = port
self._db = db
self._redis_client = None # Delayed initialization

@property
def redis_client(self):
Expand All @@ -25,26 +40,27 @@ def redis_client(self):
print("there")
return self._redis_client

def add_document(self, document: IDocument) -> None:
def add_document(self, document: Document) -> None:

data = document.as_dict()
doc_id = data["id"]
del data["id"]
data = document.model_dump()
doc_id = data.pop("id") # Remove and get id
self.redis_client.json().set(doc_id, "$", json.dumps(data))

def add_documents(self, documents: List[IDocument]) -> None:
def add_documents(self, documents: List[Document]) -> None:
with self.redis_client.pipeline() as pipe:
for document in documents:
pipe.set(document.doc_id, document)
data = document.model_dump()
doc_id = data.pop("id")
pipe.json().set(doc_id, "$", json.dumps(data))
pipe.execute()

def get_document(self, doc_id: str) -> Optional[IDocument]:
def get_document(self, doc_id: str) -> Optional[Document]:
result = self.redis_client.json().get(doc_id)
if result:
return json.loads(result)
return None

def get_all_documents(self) -> List[IDocument]:
def get_all_documents(self) -> List[Document]:
keys = self.redis_client.keys("*")
documents = []
for key in keys:
Expand All @@ -53,8 +69,10 @@ def get_all_documents(self) -> List[IDocument]:
documents.append(json.loads(document_data))
return documents

def update_document(self, doc_id: str, updated_document: IDocument) -> None:
self.add_document(updated_document)
def update_document(self, doc_id: str, document: Document) -> None:
data = document.model_dump()
data.pop("id") # Remove id from data
self.redis_client.json().set(doc_id, "$", json.dumps(data))

def delete_document(self, doc_id: str) -> None:
self.redis_client.delete(doc_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
from swarmauri_documentstore_redis.RedisDocumentStore import RedisDocumentStore
from swarmauri_standard.documents.Document import Document
from unittest.mock import MagicMock, patch
import json


@pytest.fixture(scope="module")
def redis_document_store():
return RedisDocumentStore("localhost", "", 6379, 0)


@pytest.fixture(scope="module")
def mock_redis():
with patch("redis.Redis") as mock:
mock_instance = MagicMock()
mock.return_value = mock_instance
mock_instance.json.return_value = MagicMock()
yield mock_instance


@pytest.mark.unit
def test_ubc_resource(redis_document_store):
assert redis_document_store.resource == "DocumentStore"


@pytest.mark.unit
def test_ubc_typeredis_document_store(redis_document_store):
assert redis_document_store.type == "RedisDocumentStore"


@pytest.mark.unit
def test_serialization(redis_document_store):
assert (
redis_document_store.id
== RedisDocumentStore.model_validate_json(
redis_document_store.model_dump_json()
).id
)


@pytest.mark.unit
def test_add_document(redis_document_store, mock_redis):
doc = Document(id="test1", content="test content")
redis_document_store.add_document(doc)
expected_data = doc.model_dump()
expected_data.pop("id")
mock_redis.json.return_value.set.assert_called_once_with(
"test1", "$", json.dumps(expected_data)
)


@pytest.mark.unit
def test_get_document(redis_document_store, mock_redis):
mock_redis.json.return_value.get.return_value = json.dumps(
{"content": "test content", "type": "Document"}
)
doc = redis_document_store.get_document("test1")
assert doc["content"] == "test content"
assert doc["type"] == "Document"


@pytest.mark.unit
def test_get_all_documents(redis_document_store, mock_redis):
mock_redis.keys.return_value = ["doc1", "doc2"]
mock_redis.get.side_effect = [
json.dumps({"content": "content1", "type": "Document"}),
json.dumps({"content": "content2", "type": "Document"}),
]
docs = redis_document_store.get_all_documents()
assert len(docs) == 2
assert all(doc["type"] == "Document" for doc in docs)


@pytest.mark.unit
def test_update_document(redis_document_store, mock_redis):
updated_doc = Document(id="test1", content="updated content")
redis_document_store.update_document("test1", updated_doc)
expected_data = updated_doc.model_dump()
expected_data.pop("id")
mock_redis.json.return_value.set.assert_called_with(
"test1", "$", json.dumps(expected_data)
)


@pytest.mark.unit
def test_delete_document(redis_document_store, mock_redis):
redis_document_store.delete_document("test1")
mock_redis.delete.assert_called_once_with("test1")


@pytest.mark.unit
def test_add_documents(redis_document_store, mock_redis):
docs = [
Document(id="test1", content="content1"),
Document(id="test2", content="content2"),
]
redis_document_store.add_documents(docs)
assert mock_redis.pipeline.called
pipeline = mock_redis.pipeline.return_value.__enter__.return_value
assert len(pipeline.json.return_value.set.mock_calls) == 2


@pytest.mark.unit
def test_document_not_found(redis_document_store, mock_redis):
mock_redis.json.return_value.get.return_value = None
result = redis_document_store.get_document("nonexistent")
assert result is None
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List, Union, Any, Literal

from swarmauri_core.ComponentBase import ComponentBase
from transformers import BertTokenizer, BertModel
import torch
from pydantic import PrivateAttr
from swarmauri_core.documents.IDocument import IDocument
from pydantic import ConfigDict, PrivateAttr
from swarmauri_standard.documents.Document import Document
from swarmauri_base.parsers.ParserBase import ParserBase


@ComponentBase.register_type(ParserBase, "BERTEmbeddingParser")
class BERTEmbeddingParser(ParserBase):
"""
Expand All @@ -19,6 +20,8 @@ class BERTEmbeddingParser(ParserBase):
parser_model_name: str = "bert-base-uncased"
_model: Any = PrivateAttr()
type: Literal["BERTEmbeddingParser"] = "BERTEmbeddingParser"
_tokenizer: Any = PrivateAttr()
model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, **kwargs):
"""
Expand All @@ -28,11 +31,11 @@ def __init__(self, **kwargs):
- model_name (str): The name of the pre-trained BERT model to use.
"""
super().__init__(**kwargs)
self.tokenizer = BertTokenizer.from_pretrained(self.parser_model_name)
self._tokenizer = BertTokenizer.from_pretrained(self.parser_model_name)
self._model = BertModel.from_pretrained(self.parser_model_name)
self._model.eval() # Set model to evaluation mode

def parse(self, data: Union[str, Any]) -> List[IDocument]:
def parse(self, data: Union[str, Any]) -> List[Document]:
"""
Tokenizes input data and generates embeddings using a BERT model.

Expand All @@ -42,9 +45,11 @@ def parse(self, data: Union[str, Any]) -> List[IDocument]:
Returns:
- List[IDocument]: A list containing a single IDocument instance with BERT embeddings as content.
"""
if data is None or not data:
raise ValueError("Input data cannot be None.")

# Tokenization
inputs = self.tokenizer(
inputs = self._tokenizer(
data, return_tensors="pt", padding=True, truncation=True, max_length=512
)

Expand All @@ -63,9 +68,7 @@ def parse(self, data: Union[str, Any]) -> List[IDocument]:

# Creating document object(s)
documents = [
Document(
doc_id=str(i), content=emb, metadata={"source": "BERTEmbeddingParser"}
)
Document(id=str(i), content=emb, metadata={"source": "BERTEmbeddingParser"})
for i, emb in enumerate(doc_embeddings)
]

Expand Down
Loading
Loading