Skip to content

Commit

Permalink
Update ComponentBase.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cobycloud committed Jan 4, 2025
1 parent 77e91f1 commit 74c8081
Showing 1 changed file with 46 additions and 17 deletions.
63 changes: 46 additions & 17 deletions pkgs/core/swarmauri_core/ComponentBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="ComponentBase")

class SubclassUnion(Generic[T], type):
class SubclassUnion(type):
"""
A generic class to create discriminated unions based on resource types.
"""
Expand Down Expand Up @@ -99,7 +99,7 @@ class ComponentBase(BaseModel):
# Class-level registry mapping resource types to their type mappings
TYPE_REGISTRY: ClassVar[Dict[Type['ComponentBase'], Dict[str, Type['ComponentBase']]]] = {}
# Model registry mapping models to their resource types
MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], Type['ComponentBase']]] = {}
MODEL_REGISTRY: ClassVar[Dict[Type[BaseModel], List[Type['ComponentBase']]]] = {}
_lock: ClassVar[Lock] = Lock()

name: Optional[str] = None
Expand Down Expand Up @@ -194,9 +194,9 @@ def decorator(model_cls: Type[BaseModel]):
if model_cls not in cls.MODEL_REGISTRY:
cls.MODEL_REGISTRY[model_cls] = []

# Inspect all fields to find SubclassUnion annotations
# Inspect all fields to find SubclassUnion annotations, including inherited fields
for field_name, field in model_cls.__fields__.items():
field_annotation = model_cls.__annotations__.get(field_name)
field_annotation = cls.get_field_annotation(model_cls, field_name)
if not field_annotation:
continue

Expand All @@ -208,10 +208,29 @@ def decorator(model_cls: Type[BaseModel]):
if resource_type not in cls.MODEL_REGISTRY[model_cls]:
cls.MODEL_REGISTRY[model_cls].append(resource_type)
logger.info(f"Registered model '{model_cls.__name__}' for resource '{resource_type.__name__}'")
cls.recreate_models()
return model_cls
return decorator

@classmethod
def get_field_annotation(cls, model_cls: Type[BaseModel], field_name: str):
"""
Retrieve the field annotation for a given field, including inherited fields.
Parameters:
- model_cls: The Pydantic model class.
- field_name: The name of the field.
Returns:
- The type annotation of the field, or None if not found.
"""
for base in inspect.getmro(model_cls):
if base == object:
continue
annotations = getattr(base, '__annotations__', {})
if field_name in annotations:
return annotations[field_name]
return None

@classmethod
def get_class_by_type(cls, resource_type: Type[T], type_name: str) -> Type['ComponentBase']:
"""
Expand Down Expand Up @@ -248,7 +267,6 @@ def field_contains_subclass_union(cls, field_annotation) -> bool:
return any(cls.field_contains_subclass_union(arg) for arg in args)
return False


@classmethod
def extract_resource_types_from_field(cls, field_annotation) -> List[Type['ComponentBase']]:
"""
Expand Down Expand Up @@ -292,7 +310,6 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo

return resource_types


@classmethod
def determine_new_type(cls, field_annotation, resource_type):
"""
Expand Down Expand Up @@ -367,20 +384,28 @@ def determine_new_type(cls, field_annotation, resource_type):

return new_type


@classmethod
def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]:
"""
Automatically generate the models_with_fields dictionary based on registered models.
Automatically generate the models_with_fields dictionary based on registered models,
including subclasses that inherit fields using SubclassUnion.
Returns:
- A dictionary mapping model classes to their fields and corresponding resource types.
"""
models_with_fields = {}
for model_cls, resource_types in cls.MODEL_REGISTRY.items():

# Collect all models, including subclasses of registered models
all_models = set(cls.MODEL_REGISTRY.keys())
for model_cls in list(all_models):
for subclass in model_cls.__subclasses__():
if issubclass(subclass, BaseModel):
all_models.add(subclass)

for model_cls in all_models:
models_with_fields[model_cls] = {}
for field_name, field in model_cls.__fields__.items():
field_annotation = model_cls.__annotations__.get(field_name)
for field_name in model_cls.__fields__:
field_annotation = cls.get_field_annotation(model_cls, field_name)
if not field_annotation:
continue

Expand All @@ -396,6 +421,7 @@ def generate_models_with_fields(cls) -> Dict[Type[BaseModel], Dict[str, Any]]:

return models_with_fields


@classmethod
def recreate_models(cls):
"""
Expand All @@ -409,8 +435,11 @@ def recreate_models(cls):
model_class.model_fields[field_name].annotation = new_type
else:
raise ValueError(f"Field '{field_name}' does not exist in model '{model_class.__name__}'")
if model_class.model_rebuild(force=True):
try:
model_class.model_rebuild(force=True)
logger.debug(f"'{model_class}' has been successfully recreated.")
else:
logger.debug(f"'{model_class}' recreation has failed.")
logger.info("All models have been successfully recreated.")
except ValidationError as ve:
logger.error(f"Validation error while rebuilding model '{model_class.__name__}': {ve}")
except Exception as e:
logger.error(f"Error while rebuilding model '{model_class.__name__}': {e}")
logger.info("All models have been successfully recreated.")

0 comments on commit 74c8081

Please sign in to comment.