diff --git a/pkgs/core/swarmauri_core/ComponentBase.py b/pkgs/core/swarmauri_core/ComponentBase.py index df3feb1a..c8d4fd0c 100644 --- a/pkgs/core/swarmauri_core/ComponentBase.py +++ b/pkgs/core/swarmauri_core/ComponentBase.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound="ComponentBase") -class SubclassUnion(Generic[T], type): +class SubclassUnion(type): """ A generic class to create discriminated unions based on resource types. """ @@ -99,7 +99,7 @@ class ComponentBase(BaseModel): # Class-level registry mapping resource types to their type mappings TYPE_REGISTRY: ClassVar[Dict[Type['ComponentBase'], Dict[str, Type['ComponentBase']]]] = {} # Model registry mapping models to their resource types - MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], Type['ComponentBase']]] = {} + MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], List[Type['ComponentBase']]]] = {} _lock: ClassVar[Lock] = Lock() name: Optional[str] = None @@ -194,9 +194,9 @@ def decorator(model_cls: Type[BaseModel]): if model_cls not in cls.MODEL_REGISTRY: cls.MODEL_REGISTRY[model_cls] = [] - # Inspect all fields to find SubclassUnion annotations + # Inspect all fields to find SubclassUnion annotations, including inherited fields for field_name, field in model_cls.__fields__.items(): - field_annotation = model_cls.__annotations__.get(field_name) + field_annotation = cls.get_field_annotation(model_cls, field_name) if not field_annotation: continue @@ -208,10 +208,29 @@ def decorator(model_cls: Type[BaseModel]): if resource_type not in cls.MODEL_REGISTRY[model_cls]: cls.MODEL_REGISTRY[model_cls].append(resource_type) logger.info(f"Registered model '{model_cls.__name__}' for resource '{resource_type.__name__}'") - cls.recreate_models() return model_cls return decorator + @classmethod + def get_field_annotation(cls, model_cls: Type[BaseModel], field_name: str): + """ + Retrieve the field annotation for a given field, including inherited fields. + + Parameters: + - model_cls: The Pydantic model class. + - field_name: The name of the field. + + Returns: + - The type annotation of the field, or None if not found. + """ + for base in inspect.getmro(model_cls): + if base == object: + continue + annotations = getattr(base, '__annotations__', {}) + if field_name in annotations: + return annotations[field_name] + return None + @classmethod def get_class_by_type(cls, resource_type: Type[T], type_name: str) -> Type['ComponentBase']: """ @@ -248,7 +267,6 @@ def field_contains_subclass_union(cls, field_annotation) -> bool: return any(cls.field_contains_subclass_union(arg) for arg in args) return False - @classmethod def extract_resource_types_from_field(cls, field_annotation) -> List[Type['ComponentBase']]: """ @@ -292,7 +310,6 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo return resource_types - @classmethod def determine_new_type(cls, field_annotation, resource_type): """ @@ -367,20 +384,28 @@ def determine_new_type(cls, field_annotation, resource_type): return new_type - @classmethod def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]: """ - Automatically generate the models_with_fields dictionary based on registered models. - + Automatically generate the models_with_fields dictionary based on registered models, + including subclasses that inherit fields using SubclassUnion. + Returns: - A dictionary mapping model classes to their fields and corresponding resource types. """ models_with_fields = {} - for model_cls, resource_types in cls.MODEL_REGISTRY.items(): + + # Collect all models, including subclasses of registered models + all_models = set(cls.MODEL_REGISTRY.keys()) + for model_cls in list(all_models): + for subclass in model_cls.__subclasses__(): + if issubclass(subclass, BaseModel): + all_models.add(subclass) + + for model_cls in all_models: models_with_fields[model_cls] = {} - for field_name, field in model_cls.__fields__.items(): - field_annotation = model_cls.__annotations__.get(field_name) + for field_name in model_cls.__fields__: + field_annotation = cls.get_field_annotation(model_cls, field_name) if not field_annotation: continue @@ -396,6 +421,7 @@ def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]: return models_with_fields + @classmethod def recreate_models(cls): """ @@ -409,8 +435,11 @@ def recreate_models(cls): model_class.model_fields[field_name].annotation = new_type else: raise ValueError(f"Field '{field_name}' does not exist in model '{model_class.__name__}'") - if model_class.model_rebuild(force=True): + try: + model_class.model_rebuild(force=True) logger.debug(f"'{model_class}' has been successfully recreated.") - else: - logger.debug(f"'{model_class}' recreation has failed.") - logger.info("All models have been successfully recreated.") \ No newline at end of file + except ValidationError as ve: + logger.error(f"Validation error while rebuilding model '{model_class.__name__}': {ve}") + except Exception as e: + logger.error(f"Error while rebuilding model '{model_class.__name__}': {e}") + logger.info("All models have been successfully recreated.")