diff --git a/pkgs/base/swarmauri_base/document_stores/DocumentStoreBase.py b/pkgs/base/swarmauri_base/document_stores/DocumentStoreBase.py index f2710a51..5b463c82 100644 --- a/pkgs/base/swarmauri_base/document_stores/DocumentStoreBase.py +++ b/pkgs/base/swarmauri_base/document_stores/DocumentStoreBase.py @@ -1,11 +1,16 @@ from abc import abstractmethod -from typing import List, Optional +from typing import List, Literal, Optional import json +from pydantic import Field + +from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes from swarmauri_core.documents.IDocument import IDocument from swarmauri_core.document_stores.IDocumentStore import IDocumentStore -class DocumentStoreBase(IDocumentStore): + +@ComponentBase.register_model() +class DocumentStoreBase(IDocumentStore, ComponentBase): """ Abstract base class for document stores, implementing the IDocumentStore interface. @@ -13,6 +18,9 @@ class DocumentStoreBase(IDocumentStore): The specifics of storing (e.g., in a database, in-memory, or file system) are to be implemented by concrete subclasses. """ + resource: Optional[str] = Field(default=ResourceTypes.DOCUMENT_STORE.value, frozen=True) + type: Literal["DocumentStoreBase"] = "DocumentStoreBase" + @abstractmethod def add_document(self, document: IDocument) -> None: """ @@ -76,14 +84,14 @@ def delete_document(self, doc_id: str) -> None: - doc_id (str): The unique identifier of the document to delete. """ pass - + def document_count(self): return len(self.documents) - + def dump(self, file_path): - with open(file_path, 'w') as f: + with open(file_path, "w") as f: json.dumps([each.__dict__ for each in self.documents], f, indent=4) - + def load(self, file_path): - with open(file_path, 'r') as f: - self.documents = json.loads(f) \ No newline at end of file + with open(file_path, "r") as f: + self.documents = json.loads(f) diff --git a/pkgs/base/swarmauri_base/document_stores/DocumentStoreRetrieveBase.py b/pkgs/base/swarmauri_base/document_stores/DocumentStoreRetrieveBase.py index aa94b172..e1e9a2f3 100644 --- a/pkgs/base/swarmauri_base/document_stores/DocumentStoreRetrieveBase.py +++ b/pkgs/base/swarmauri_base/document_stores/DocumentStoreRetrieveBase.py @@ -5,18 +5,19 @@ from swarmauri_core.documents.IDocument import IDocument from swarmauri_base.document_stores.DocumentStoreBase import DocumentStoreBase + class DocumentStoreRetrieveBase(DocumentStoreBase, IDocumentRetrieve): @abstractmethod def retrieve(self, query: str, top_k: int = 5) -> List[IDocument]: """ Retrieve the top_k most relevant documents based on the given query. - + Args: query (str): The query string used for document retrieval. top_k (int): The number of top relevant documents to retrieve. - + Returns: List[IDocument]: A list of the top_k most relevant documents. """ - pass \ No newline at end of file + pass diff --git a/pkgs/base/swarmauri_base/documents/DocumentBase.py b/pkgs/base/swarmauri_base/documents/DocumentBase.py index d44e68c7..0e10ffa6 100644 --- a/pkgs/base/swarmauri_base/documents/DocumentBase.py +++ b/pkgs/base/swarmauri_base/documents/DocumentBase.py @@ -1,4 +1,5 @@ from typing import Dict, Optional, Literal +import numpy as np from pydantic import Field, ConfigDict from swarmauri_core.documents.IDocument import IDocument @@ -7,7 +8,7 @@ @ComponentBase.register_model() class DocumentBase(IDocument, ComponentBase): - content: str + content: str | np.ndarray metadata: Dict = {} embedding: Optional[Vector] = None model_config = ConfigDict(extra='forbid', arbitrary_types_allowed=True) diff --git a/pkgs/community/swarmauri_documentstore_redis/swarmauri_documentstore_redis/RedisDocumentStore.py b/pkgs/community/swarmauri_documentstore_redis/swarmauri_documentstore_redis/RedisDocumentStore.py index f2f84a88..4bc46546 100644 --- a/pkgs/community/swarmauri_documentstore_redis/swarmauri_documentstore_redis/RedisDocumentStore.py +++ b/pkgs/community/swarmauri_documentstore_redis/swarmauri_documentstore_redis/RedisDocumentStore.py @@ -1,18 +1,33 @@ -from typing import List, Optional +from typing import List, Literal, Optional + +from pydantic import PrivateAttr +from swarmauri_core.ComponentBase import ComponentBase from swarmauri_base.document_stores.DocumentStoreBase import DocumentStoreBase -from swarmauri_core.documents.IDocument import IDocument +from swarmauri_standard.documents.Document import Document import redis import json +@ComponentBase.register_type(DocumentStoreBase, "RedisDocumentStore") class RedisDocumentStore(DocumentStoreBase): - def __init__(self, host, password, port, db): - """Store connection details without initializing the Redis client.""" + # Public fields + type: Literal["RedisDocumentStore"] = "RedisDocumentStore" + + # Private attributes + _host: str = PrivateAttr() + _password: str = PrivateAttr() + _port: int = PrivateAttr() + _db: int = PrivateAttr() + _redis_client: Optional[redis.Redis] = PrivateAttr(default=None) + + def __init__( + self, host: str, password: str = "", port: int = 6379, db: int = 0, **data + ): + super().__init__(**data) self._host = host self._password = password self._port = port self._db = db - self._redis_client = None # Delayed initialization @property def redis_client(self): @@ -25,26 +40,27 @@ def redis_client(self): print("there") return self._redis_client - def add_document(self, document: IDocument) -> None: + def add_document(self, document: Document) -> None: - data = document.as_dict() - doc_id = data["id"] - del data["id"] + data = document.model_dump() + doc_id = data.pop("id") # Remove and get id self.redis_client.json().set(doc_id, "$", json.dumps(data)) - def add_documents(self, documents: List[IDocument]) -> None: + def add_documents(self, documents: List[Document]) -> None: with self.redis_client.pipeline() as pipe: for document in documents: - pipe.set(document.doc_id, document) + data = document.model_dump() + doc_id = data.pop("id") + pipe.json().set(doc_id, "$", json.dumps(data)) pipe.execute() - def get_document(self, doc_id: str) -> Optional[IDocument]: + def get_document(self, doc_id: str) -> Optional[Document]: result = self.redis_client.json().get(doc_id) if result: return json.loads(result) return None - def get_all_documents(self) -> List[IDocument]: + def get_all_documents(self) -> List[Document]: keys = self.redis_client.keys("*") documents = [] for key in keys: @@ -53,8 +69,10 @@ def get_all_documents(self) -> List[IDocument]: documents.append(json.loads(document_data)) return documents - def update_document(self, doc_id: str, updated_document: IDocument) -> None: - self.add_document(updated_document) + def update_document(self, doc_id: str, document: Document) -> None: + data = document.model_dump() + data.pop("id") # Remove id from data + self.redis_client.json().set(doc_id, "$", json.dumps(data)) def delete_document(self, doc_id: str) -> None: self.redis_client.delete(doc_id) diff --git a/pkgs/community/swarmauri_documentstore_redis/tests/unit/RedisDocumentStore_unit_test.py b/pkgs/community/swarmauri_documentstore_redis/tests/unit/RedisDocumentStore_unit_test.py new file mode 100644 index 00000000..faa58e74 --- /dev/null +++ b/pkgs/community/swarmauri_documentstore_redis/tests/unit/RedisDocumentStore_unit_test.py @@ -0,0 +1,108 @@ +import pytest +from swarmauri_documentstore_redis.RedisDocumentStore import RedisDocumentStore +from swarmauri_standard.documents.Document import Document +from unittest.mock import MagicMock, patch +import json + + +@pytest.fixture(scope="module") +def redis_document_store(): + return RedisDocumentStore("localhost", "", 6379, 0) + + +@pytest.fixture(scope="module") +def mock_redis(): + with patch("redis.Redis") as mock: + mock_instance = MagicMock() + mock.return_value = mock_instance + mock_instance.json.return_value = MagicMock() + yield mock_instance + + +@pytest.mark.unit +def test_ubc_resource(redis_document_store): + assert redis_document_store.resource == "DocumentStore" + + +@pytest.mark.unit +def test_ubc_typeredis_document_store(redis_document_store): + assert redis_document_store.type == "RedisDocumentStore" + + +@pytest.mark.unit +def test_serialization(redis_document_store): + assert ( + redis_document_store.id + == RedisDocumentStore.model_validate_json( + redis_document_store.model_dump_json() + ).id + ) + + +@pytest.mark.unit +def test_add_document(redis_document_store, mock_redis): + doc = Document(id="test1", content="test content") + redis_document_store.add_document(doc) + expected_data = doc.model_dump() + expected_data.pop("id") + mock_redis.json.return_value.set.assert_called_once_with( + "test1", "$", json.dumps(expected_data) + ) + + +@pytest.mark.unit +def test_get_document(redis_document_store, mock_redis): + mock_redis.json.return_value.get.return_value = json.dumps( + {"content": "test content", "type": "Document"} + ) + doc = redis_document_store.get_document("test1") + assert doc["content"] == "test content" + assert doc["type"] == "Document" + + +@pytest.mark.unit +def test_get_all_documents(redis_document_store, mock_redis): + mock_redis.keys.return_value = ["doc1", "doc2"] + mock_redis.get.side_effect = [ + json.dumps({"content": "content1", "type": "Document"}), + json.dumps({"content": "content2", "type": "Document"}), + ] + docs = redis_document_store.get_all_documents() + assert len(docs) == 2 + assert all(doc["type"] == "Document" for doc in docs) + + +@pytest.mark.unit +def test_update_document(redis_document_store, mock_redis): + updated_doc = Document(id="test1", content="updated content") + redis_document_store.update_document("test1", updated_doc) + expected_data = updated_doc.model_dump() + expected_data.pop("id") + mock_redis.json.return_value.set.assert_called_with( + "test1", "$", json.dumps(expected_data) + ) + + +@pytest.mark.unit +def test_delete_document(redis_document_store, mock_redis): + redis_document_store.delete_document("test1") + mock_redis.delete.assert_called_once_with("test1") + + +@pytest.mark.unit +def test_add_documents(redis_document_store, mock_redis): + docs = [ + Document(id="test1", content="content1"), + Document(id="test2", content="content2"), + ] + redis_document_store.add_documents(docs) + assert mock_redis.pipeline.called + pipeline = mock_redis.pipeline.return_value.__enter__.return_value + assert len(pipeline.json.return_value.set.mock_calls) == 2 + + +@pytest.mark.unit +def test_document_not_found(redis_document_store, mock_redis): + mock_redis.json.return_value.get.return_value = None + result = redis_document_store.get_document("nonexistent") + assert result is None diff --git a/pkgs/community/swarmauri_parser_bertembedding/swarmauri_parser_bertembedding/BERTEmbeddingParser.py b/pkgs/community/swarmauri_parser_bertembedding/swarmauri_parser_bertembedding/BERTEmbeddingParser.py index 18468023..662602df 100644 --- a/pkgs/community/swarmauri_parser_bertembedding/swarmauri_parser_bertembedding/BERTEmbeddingParser.py +++ b/pkgs/community/swarmauri_parser_bertembedding/swarmauri_parser_bertembedding/BERTEmbeddingParser.py @@ -1,12 +1,13 @@ from typing import List, Union, Any, Literal + from swarmauri_core.ComponentBase import ComponentBase from transformers import BertTokenizer, BertModel import torch -from pydantic import PrivateAttr -from swarmauri_core.documents.IDocument import IDocument +from pydantic import ConfigDict, PrivateAttr from swarmauri_standard.documents.Document import Document from swarmauri_base.parsers.ParserBase import ParserBase + @ComponentBase.register_type(ParserBase, "BERTEmbeddingParser") class BERTEmbeddingParser(ParserBase): """ @@ -19,6 +20,8 @@ class BERTEmbeddingParser(ParserBase): parser_model_name: str = "bert-base-uncased" _model: Any = PrivateAttr() type: Literal["BERTEmbeddingParser"] = "BERTEmbeddingParser" + _tokenizer: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__(self, **kwargs): """ @@ -28,11 +31,11 @@ def __init__(self, **kwargs): - model_name (str): The name of the pre-trained BERT model to use. """ super().__init__(**kwargs) - self.tokenizer = BertTokenizer.from_pretrained(self.parser_model_name) + self._tokenizer = BertTokenizer.from_pretrained(self.parser_model_name) self._model = BertModel.from_pretrained(self.parser_model_name) self._model.eval() # Set model to evaluation mode - def parse(self, data: Union[str, Any]) -> List[IDocument]: + def parse(self, data: Union[str, Any]) -> List[Document]: """ Tokenizes input data and generates embeddings using a BERT model. @@ -42,9 +45,11 @@ def parse(self, data: Union[str, Any]) -> List[IDocument]: Returns: - List[IDocument]: A list containing a single IDocument instance with BERT embeddings as content. """ + if data is None or not data: + raise ValueError("Input data cannot be None.") # Tokenization - inputs = self.tokenizer( + inputs = self._tokenizer( data, return_tensors="pt", padding=True, truncation=True, max_length=512 ) @@ -63,9 +68,7 @@ def parse(self, data: Union[str, Any]) -> List[IDocument]: # Creating document object(s) documents = [ - Document( - doc_id=str(i), content=emb, metadata={"source": "BERTEmbeddingParser"} - ) + Document(id=str(i), content=emb, metadata={"source": "BERTEmbeddingParser"}) for i, emb in enumerate(doc_embeddings) ] diff --git a/pkgs/community/swarmauri_parser_bertembedding/tests/unit/BERTEmbeddingParser_unit_test.py b/pkgs/community/swarmauri_parser_bertembedding/tests/unit/BERTEmbeddingParser_unit_test.py new file mode 100644 index 00000000..2653ca7d --- /dev/null +++ b/pkgs/community/swarmauri_parser_bertembedding/tests/unit/BERTEmbeddingParser_unit_test.py @@ -0,0 +1,67 @@ +import pytest +from swarmauri_parser_bertembedding.BERTEmbeddingParser import ( + BERTEmbeddingParser as Parser, +) +import numpy as np + + +@pytest.fixture(scope="module") +def parser(): + """Fixture to provide a parser instance for tests.""" + return Parser() + + +@pytest.mark.unit +def test_ubc_resource(parser): + assert parser.resource == "Parser" + + +@pytest.mark.unit +def test_ubc_type(parser): + assert parser.type == "BERTEmbeddingParser" + + +@pytest.mark.unit +def test_serialization(parser): + assert parser.id == Parser.model_validate_json(parser.model_dump_json()).id + + +@pytest.mark.unit +def test_parse(parser): + # Test basic text parsing + text = "This is a test sentence." + documents = parser.parse(text) + + # Verify basic properties + assert len(documents) == 1 + doc = documents[0] + assert doc.resource == "Document" + assert isinstance(doc.content, np.ndarray) + assert doc.content.shape == (768,) # BERT base embedding size + assert doc.metadata["source"] == "BERTEmbeddingParser" + assert isinstance(doc.id, str) + + +@pytest.mark.unit +def test_parse_empty(parser): + # Test empty input + with pytest.raises(ValueError): + parser.parse("") + + +@pytest.mark.unit +def test_parse_long_text(parser): + # Test text longer than 512 tokens + long_text = " ".join(["word"] * 1000) + documents = parser.parse(long_text) + assert len(documents) == 1 + assert documents[0].content.shape == (768,) + + +@pytest.mark.unit +def test_parse_special_chars(parser): + # Test text with special characters + special_text = "Hello! @#$%^&* World 123" + documents = parser.parse(special_text) + assert len(documents) == 1 + assert documents[0].content.shape == (768,) diff --git a/pkgs/community/swarmauri_parser_entityrecognition/tests/unit/EntityRecognitionParser_unit_test.py b/pkgs/community/swarmauri_parser_entityrecognition/tests/unit/EntityRecognitionParser_unit_test.py new file mode 100644 index 00000000..0cceb87f --- /dev/null +++ b/pkgs/community/swarmauri_parser_entityrecognition/tests/unit/EntityRecognitionParser_unit_test.py @@ -0,0 +1,36 @@ +import pytest +from swarmauri_parser_entityrecognition.EntityRecognitionParser import ( + EntityRecognitionParser as Parser, +) + + +@pytest.fixture(scope="module") +def parser(): + """Fixture to provide a parser instance for tests.""" + return Parser() + + +@pytest.mark.unit +def test_ubc_resource(parser): + assert parser.resource == "Parser" + + +@pytest.mark.unit +def test_ubc_type(parser): + assert parser.type == "BERTEmbeddingParser" + + +@pytest.mark.unit +def test_serialization(parser): + assert parser.id == Parser.model_validate_json(parser.model_dump_json()).id + + +@pytest.mark.unit +def test_parse(parser): + try: + documents = parser.parse("One more large chapula please.") + assert documents[0].resource == "Document" + assert documents[0].content == "One more large chapula please." + assert documents[0].metadata["noun_phrases"] == ["large chapula"] + except Exception as e: + pytest.fail(f"Parser failed with error: {str(e)}") diff --git a/pkgs/core/swarmauri_core/ComponentBase.py b/pkgs/core/swarmauri_core/ComponentBase.py index b1d4da75..dc386efd 100644 --- a/pkgs/core/swarmauri_core/ComponentBase.py +++ b/pkgs/core/swarmauri_core/ComponentBase.py @@ -30,17 +30,21 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound="ComponentBase") + class ResourceType: """ Metadata class to hold resource type information for Annotated fields. """ - def __init__(self, resource_type: Type['ComponentBase']): + + def __init__(self, resource_type: Type["ComponentBase"]): self.resource_type = resource_type + class SubclassUnion(type): """ A generic class to create discriminated unions based on resource types. """ + def __class_getitem__(cls, resource_type: Type[T]) -> type: """ Allows usage of SubclassUnion[ResourceType] to get the corresponding discriminated Union. @@ -51,13 +55,20 @@ def __class_getitem__(cls, resource_type: Type[T]) -> type: Returns: - An Annotated Union of all subclasses registered under the resource_type, with 'type' as the discriminator. """ - registered_classes = list(ComponentBase.TYPE_REGISTRY.get(resource_type, {}).values()) + registered_classes = list( + ComponentBase.TYPE_REGISTRY.get(resource_type, {}).values() + ) if not registered_classes: - logger.debug(f"No subclasses registered for resource type '{resource_type.__name__}'. Using 'Any' as a placeholder.") + logger.debug( + f"No subclasses registered for resource type '{resource_type.__name__}'. Using 'Any' as a placeholder." + ) return Annotated[Any, Field(...), ResourceType(resource_type)] else: union_type = Union[tuple(registered_classes)] - return Annotated[union_type, Field(discriminator='type'), ResourceType(resource_type)] + return Annotated[ + union_type, Field(discriminator="type"), ResourceType(resource_type) + ] + class ResourceTypes(Enum): UNIVERSAL_BASE = "ComponentBase" @@ -98,17 +109,22 @@ class ResourceTypes(Enum): CONTROL_PANEL = "ControlPanel" TASK_MGMT_STRATEGY = "TaskMgmtStrategy" + def generate_id() -> str: return str(uuid4()) + class ComponentBase(BaseModel): """ Base class for all components. """ + # Class-level registry mapping resource types to their type mappings - TYPE_REGISTRY: ClassVar[Dict[Type['ComponentBase'], Dict[str, Type['ComponentBase']]]] = {} + TYPE_REGISTRY: ClassVar[ + Dict[Type["ComponentBase"], Dict[str, Type["ComponentBase"]]] + ] = {} # Model registry mapping models to their resource types - MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], List[Type['ComponentBase']]]] = {} + MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], List[Type["ComponentBase"]]]] = {} _lock: ClassVar[Lock] = Lock() name: Optional[str] = None @@ -180,16 +196,22 @@ def register_type(cls, resource_type: Type[T], type_name: str): """ Decorator to register a component class with a specific type name under a resource type. """ - def decorator(subclass: Type['ComponentBase']): + + def decorator(subclass: Type["ComponentBase"]): if not issubclass(subclass, resource_type): - raise TypeError(f"Registered class '{subclass.__name__}' must be a subclass of {resource_type.__name__}") + raise TypeError( + f"Registered class '{subclass.__name__}' must be a subclass of {resource_type.__name__}" + ) if resource_type not in cls.TYPE_REGISTRY: cls.TYPE_REGISTRY[resource_type] = {} cls.TYPE_REGISTRY[resource_type][type_name] = subclass # Automatically recreate models after registering a new type cls.recreate_models() - logger.info(f"Registered type '{type_name}' for resource '{resource_type.__name__}' with subclass '{subclass.__name__}'") + logger.info( + f"Registered type '{type_name}' for resource '{resource_type.__name__}' with subclass '{subclass.__name__}'" + ) return subclass + return decorator @classmethod @@ -198,13 +220,14 @@ def register_model(cls): Decorator to register a Pydantic model by automatically detecting resource types from fields that use SubclassUnion. """ + def decorator(model_cls: Type[BaseModel]): # Initialize list if not present if model_cls not in cls.MODEL_REGISTRY: cls.MODEL_REGISTRY[model_cls] = [] # Inspect all fields to find SubclassUnion annotations - for field_name, field in model_cls.__fields__.items(): + for field_name, field in model_cls.model_fields.items(): field_annotation = model_cls.__annotations__.get(field_name) if not field_annotation: continue @@ -212,17 +235,24 @@ def decorator(model_cls: Type[BaseModel]): # Check if field uses SubclassUnion if cls.field_contains_subclass_union(field_annotation): # Extract resource types from SubclassUnion - resource_types = cls.extract_resource_types_from_field(field_annotation) + resource_types = cls.extract_resource_types_from_field( + field_annotation + ) for resource_type in resource_types: if resource_type not in cls.MODEL_REGISTRY[model_cls]: cls.MODEL_REGISTRY[model_cls].append(resource_type) - logger.info(f"Registered model '{model_cls.__name__}' for resource '{resource_type.__name__}'") + logger.info( + f"Registered model '{model_cls.__name__}' for resource '{resource_type.__name__}'" + ) cls.recreate_models() return model_cls + return decorator @classmethod - def get_class_by_type(cls, resource_type: Type[T], type_name: str) -> Type['ComponentBase']: + def get_class_by_type( + cls, resource_type: Type[T], type_name: str + ) -> Type["ComponentBase"]: """ Retrieve a component class based on its resource type and type name. @@ -246,7 +276,9 @@ def field_contains_subclass_union(cls, field_annotation) -> bool: Returns: - True if SubclassUnion or ResourceType is present, False otherwise. """ - logger.debug(f"Checking if field annotation '{field_annotation}' contains a SubclassUnion or ResourceType") + logger.debug( + f"Checking if field annotation '{field_annotation}' contains a SubclassUnion or ResourceType" + ) origin = get_origin(field_annotation) args = get_args(field_annotation) @@ -255,7 +287,9 @@ def field_contains_subclass_union(cls, field_annotation) -> bool: if origin is Annotated: for arg in args: if isinstance(arg, ResourceType): - logger.debug(f"Annotated field contains ResourceType metadata for resource '{arg.resource_type.__name__}'") + logger.debug( + f"Annotated field contains ResourceType metadata for resource '{arg.resource_type.__name__}'" + ) return True if cls.field_contains_subclass_union(arg): logger.debug(f"Annotated field contains SubclassUnion in '{arg}'") @@ -265,15 +299,20 @@ def field_contains_subclass_union(cls, field_annotation) -> bool: if origin in {Union, List, Dict, Set, Tuple, Optional}: for arg in args: if cls.field_contains_subclass_union(arg): - logger.debug(f"Container field '{field_annotation}' contains SubclassUnion in its arguments") + logger.debug( + f"Container field '{field_annotation}' contains SubclassUnion in its arguments" + ) return True - logger.debug(f"Field annotation '{field_annotation}' does not contain SubclassUnion or ResourceType") + logger.debug( + f"Field annotation '{field_annotation}' does not contain SubclassUnion or ResourceType" + ) return False - @classmethod - def extract_resource_types_from_field(cls, field_annotation) -> List[Type['ComponentBase']]: + def extract_resource_types_from_field( + cls, field_annotation + ) -> List[Type["ComponentBase"]]: """ Extracts all resource types from a field annotation using SubclassUnion. @@ -283,7 +322,9 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo Returns: - A list of resource type classes. """ - logger.debug(f"Extracting resource types from field annotation '{field_annotation}'") + logger.debug( + f"Extracting resource types from field annotation '{field_annotation}'" + ) resource_types = [] try: origin = get_origin(field_annotation) @@ -293,9 +334,13 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo for arg in args[1:]: # Skip the first argument which is the main type if isinstance(arg, ResourceType): resource_types.append(arg.resource_type) - logger.debug(f"Found ResourceType metadata with resource type '{arg.resource_type.__name__}'") + logger.debug( + f"Found ResourceType metadata with resource type '{arg.resource_type.__name__}'" + ) else: - resource_types.extend(cls.extract_resource_types_from_field(arg)) + resource_types.extend( + cls.extract_resource_types_from_field(arg) + ) elif origin in {Union, List, Dict, Set, Tuple, Optional}: for arg in args: @@ -303,10 +348,11 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo return resource_types except TypeError as e: - logger.error(f"TypeError while extracting resource types from field annotation '{field_annotation}': {e}") + logger.error( + f"TypeError while extracting resource types from field annotation '{field_annotation}': {e}" + ) return resource_types - @classmethod def determine_new_type(cls, field_annotation, resource_type): """ @@ -319,7 +365,9 @@ def determine_new_type(cls, field_annotation, resource_type): Returns: - The updated type annotation incorporating SubclassUnion. """ - logger.debug(f"Determining new type for field annotation '{field_annotation}' with resource type '{resource_type.__name__}'") + logger.debug( + f"Determining new type for field annotation '{field_annotation}' with resource type '{resource_type.__name__}'" + ) try: origin = get_origin(field_annotation) args = get_args(field_annotation) @@ -336,15 +384,21 @@ def determine_new_type(cls, field_annotation, resource_type): is_optional = True else: # Multiple non-None types, complex Union - logger.warning(f"Field annotation '{field_annotation}' has multiple non-None Union types; optionality may not be preserved correctly.") + logger.warning( + f"Field annotation '{field_annotation}' has multiple non-None Union types; optionality may not be preserved correctly." + ) # Handle Annotated if origin is Annotated: base_type = args[0] - metadata = [arg for arg in args[1:] if not isinstance(arg, ResourceType)] + metadata = [ + arg for arg in args[1:] if not isinstance(arg, ResourceType) + ] # Append the new ResourceType metadata.append(ResourceType(resource_type)) - logger.debug(f"Preserving existing metadata and adding ResourceType for resource '{resource_type.__name__}'") + logger.debug( + f"Preserving existing metadata and adding ResourceType for resource '{resource_type.__name__}'" + ) field_annotation = Annotated[tuple([base_type, *metadata])] # Construct the new type with SubclassUnion and discriminated Union @@ -359,11 +413,11 @@ def determine_new_type(cls, field_annotation, resource_type): logger.debug(f"New type for field: {new_type}") return new_type except TypeError as e: - logger.error(f"TypeError while determining new type for field annotation '{field_annotation}': {e}") + logger.error( + f"TypeError while determining new type for field annotation '{field_annotation}': {e}" + ) return field_annotation # Fallback to original type if error occurs - - @classmethod def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]: """ @@ -379,35 +433,44 @@ def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]: logging.debug(f"Processing model: {model_cls.__name__}") models_with_fields[model_cls] = {} - for field_name, field in model_cls.__fields__.items(): + for field_name, field in model_cls.model_fields.items(): logging.debug(f"Processing field: {field_name}") field_annotation = model_cls.__annotations__.get(field_name) if not field_annotation: - logging.debug(f"Field {field_name} in model {model_cls.__name__} has no annotation, skipping.") + logging.debug( + f"Field {field_name} in model {model_cls.__name__} has no annotation, skipping." + ) continue # Check if SubclassUnion is used in the field type if not cls.field_contains_subclass_union(field_annotation): - logging.debug(f"Field {field_name} does not contain SubclassUnion, skipping.") + logging.debug( + f"Field {field_name} does not contain SubclassUnion, skipping." + ) continue # Only process fields that use SubclassUnion # Extract all resource types from the field - field_resource_types = cls.extract_resource_types_from_field(field_annotation) - logging.debug(f"Extracted resource types for field {field_name}: {field_resource_types}") + field_resource_types = cls.extract_resource_types_from_field( + field_annotation + ) + logging.debug( + f"Extracted resource types for field {field_name}: {field_resource_types}" + ) for resource_type in field_resource_types: new_type = cls.determine_new_type(field_annotation, resource_type) - logging.debug(f"Determined new type for resource {resource_type}: {new_type}") + logging.debug( + f"Determined new type for resource {resource_type}: {new_type}" + ) models_with_fields[model_cls][field_name] = new_type logging.info("Completed generation of models_with_fields") return models_with_fields - @classmethod - def recreate_type_models(cls) -> Dict[Type['ComponentBase'], Dict[str, Any]]: + def recreate_type_models(cls) -> Dict[Type["ComponentBase"], Dict[str, Any]]: """ Generate a mapping of component types to their fields and updated type annotations. @@ -423,38 +486,55 @@ def recreate_type_models(cls) -> Dict[Type['ComponentBase'], Dict[str, Any]]: logging.debug(f"Processing component class: {component_cls.__name__}") type_models_with_fields[component_cls] = {} - for field_name, field in component_cls.__fields__.items(): + for field_name, field in component_cls.model_fields.items(): logging.debug(f"Processing field: {field_name}") field_annotation = component_cls.__annotations__.get(field_name) if not field_annotation: - logging.debug(f"Field '{field_name}' in component '{component_cls.__name__}' has no annotation, skipping.") + logging.debug( + f"Field '{field_name}' in component '{component_cls.__name__}' has no annotation, skipping." + ) continue # Check if SubclassUnion is used in the field type if not cls.field_contains_subclass_union(field_annotation): - logging.debug(f"Field '{field_name}' does not contain SubclassUnion, skipping.") + logging.debug( + f"Field '{field_name}' does not contain SubclassUnion, skipping." + ) continue # Only process fields that use SubclassUnion # Extract all resource types from the field - field_resource_types = cls.extract_resource_types_from_field(field_annotation) + field_resource_types = cls.extract_resource_types_from_field( + field_annotation + ) if not field_resource_types: - logging.warning(f"No resource types extracted for field '{field_name}' in component '{component_cls.__name__}'") + logging.warning( + f"No resource types extracted for field '{field_name}' in component '{component_cls.__name__}'" + ) continue - logging.debug(f"Extracted resource types for field '{field_name}': {[rt.__name__ for rt in field_resource_types]}") + logging.debug( + f"Extracted resource types for field '{field_name}': {[rt.__name__ for rt in field_resource_types]}" + ) for resource_type_in_field in field_resource_types: try: - new_type = cls.determine_new_type(field_annotation, resource_type_in_field) - logging.debug(f"Determined new type for resource '{resource_type_in_field.__name__}': {new_type}") - type_models_with_fields[component_cls][field_name] = new_type + new_type = cls.determine_new_type( + field_annotation, resource_type_in_field + ) + logging.debug( + f"Determined new type for resource '{resource_type_in_field.__name__}': {new_type}" + ) + type_models_with_fields[component_cls][ + field_name + ] = new_type except Exception as e: - logging.error(f"Error determining new type for field '{field_name}' in component '{component_cls.__name__}': {e}") + logging.error( + f"Error determining new type for field '{field_name}' in component '{component_cls.__name__}': {e}" + ) continue # Proceed with other resource types and fields logging.info("Completed generation of type models for TYPE_REGISTRY") return type_models_with_fields - @classmethod def recreate_models(cls): @@ -468,7 +548,10 @@ def recreate_models(cls): type_models_with_fields = cls.recreate_type_models() # Combine both dictionaries - combined_models_with_fields = {**models_with_fields, **type_models_with_fields} + combined_models_with_fields = { + **models_with_fields, + **type_models_with_fields, + } for model_class, fields in combined_models_with_fields.items(): for field_name, new_type in fields.items(): @@ -476,20 +559,32 @@ def recreate_models(cls): original_type = model_class.model_fields[field_name].annotation if original_type != new_type: model_class.model_fields[field_name].annotation = new_type - logger.debug(f"Updated field '{field_name}' in model '{model_class.__name__}' from '{original_type}' to '{new_type}'") + logger.debug( + f"Updated field '{field_name}' in model '{model_class.__name__}' from '{original_type}' to '{new_type}'" + ) else: - logger.debug(f"No change for field '{field_name}' in model '{model_class.__name__}'") + logger.debug( + f"No change for field '{field_name}' in model '{model_class.__name__}'" + ) else: - logger.error(f"Field '{field_name}' does not exist in model '{model_class.__name__}'") + logger.error( + f"Field '{field_name}' does not exist in model '{model_class.__name__}'" + ) continue # Skip to next field try: model_class.model_rebuild(force=True) - logger.debug(f"'{model_class.__name__}' has been successfully recreated.") + logger.debug( + f"'{model_class.__name__}' has been successfully recreated." + ) except ValidationError as ve: - logger.error(f"Validation error while rebuilding model '{model_class.__name__}': {ve}") + logger.error( + f"Validation error while rebuilding model '{model_class.__name__}': {ve}" + ) except Exception as e: - logger.error(f"Error while rebuilding model '{model_class.__name__}': {e}") + logger.error( + f"Error while rebuilding model '{model_class.__name__}': {e}" + ) logger.info("All models have been successfully recreated.") @classmethod @@ -497,7 +592,7 @@ def model_validate_yaml(cls, yaml_data: str): try: # Parse YAML into a Python dictionary yaml_content = yaml.safe_load(yaml_data) - + # Convert the dictionary to JSON and validate using Pydantic return cls.model_validate_json(json.dumps(yaml_content)) except yaml.YAMLError as e: @@ -516,8 +611,13 @@ def model_dump_yaml(self, fields_to_exclude=None, api_key_placeholder=None): def process_fields(data, fields_to_exclude): if isinstance(data, dict): return { - key: (api_key_placeholder if key == "api_key" and api_key_placeholder is not None else process_fields(value, fields_to_exclude)) - for key, value in data.items() if key not in fields_to_exclude + key: ( + api_key_placeholder + if key == "api_key" and api_key_placeholder is not None + else process_fields(value, fields_to_exclude) + ) + for key, value in data.items() + if key not in fields_to_exclude } elif isinstance(data, list): return [process_fields(item, fields_to_exclude) for item in data] @@ -528,4 +628,4 @@ def process_fields(data, fields_to_exclude): filtered_data = process_fields(json_data, fields_to_exclude) # Convert the filtered data into YAML - return yaml.dump(filtered_data, default_flow_style=False) \ No newline at end of file + return yaml.dump(filtered_data, default_flow_style=False) diff --git a/pkgs/standards/swarmauri_standard/tests/unit/agents/QAAgent_unit_test.py b/pkgs/standards/swarmauri_standard/tests/unit/agents/QAAgent_unit_test.py index 52865465..1da2450c 100644 --- a/pkgs/standards/swarmauri_standard/tests/unit/agents/QAAgent_unit_test.py +++ b/pkgs/standards/swarmauri_standard/tests/unit/agents/QAAgent_unit_test.py @@ -8,37 +8,35 @@ @pytest.fixture(scope="module") -def groq_model(): +def qa_agent(): API_KEY = os.getenv("GROQ_API_KEY") if not API_KEY: pytest.skip("Skipping due to environment variable not set") + llm = GroqModel(api_key=API_KEY) - return llm + agent = QAAgent(llm=llm) + return agent @pytest.mark.unit -def test_ubc_resource(groq_model): - agent = QAAgent(llm=groq_model) - assert agent.resource == "Agent" +def test_ubc_resource(qa_agent): + assert qa_agent.resource == "Agent" @pytest.mark.unit -def test_ubc_type(groq_model): - agent = QAAgent(llm=groq_model) - assert agent.type == "QAAgent" +def test_ubc_type(qa_agent): + assert qa_agent.type == "QAAgent" @pytest.mark.unit -def test_agent_exec(groq_model): - agent = QAAgent(llm=groq_model) - result = agent.exec("hello") +def test_agent_exec(qa_agent): + result = qa_agent.exec("hello") assert isinstance(result, str) @pytest.mark.unit -def test_serialization(groq_model): - agent = QAAgent(llm=groq_model) - assert agent.id == QAAgent.model_validate_json(agent.model_dump_json()).id +def test_serialization(qa_agent): + assert qa_agent.id == QAAgent.model_validate_json(qa_agent.model_dump_json()).id @pytest.mark.asyncio