Skip to content

Commit

Permalink
core - replace PlaceholderPlugin with Any; use `recreate_models()…
Browse files Browse the repository at this point in the history
…` only.
  • Loading branch information
cobycloud committed Jan 4, 2025
1 parent 60cbb34 commit 9ccc43d
Showing 1 changed file with 19 additions and 64 deletions.
83 changes: 19 additions & 64 deletions pkgs/core/swarmauri_core/ComponentBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def __class_getitem__(cls, resource_type: Type[T]) -> type:
"""
registered_classes = list(ComponentBase.TYPE_REGISTRY.get(resource_type, {}).values())
if not registered_classes:
logger.warning(f"No subclasses registered for resource type '{resource_type.__name__}'. Using 'PlaceholderPlugin' as a placeholder.")
registered_classes = [PlaceholderPlugin]

union_type = Union[tuple(registered_classes)]
logger.warning(f"No subclasses registered for resource type '{resource_type.__name__}'. Using 'Any' as a placeholder.")
union_type = Any
else:
union_type = Union[tuple(registered_classes)]
return Annotated[union_type, Field(discriminator='type')]

class ResourceTypes(Enum):
Expand Down Expand Up @@ -181,7 +181,7 @@ def decorator(subclass: Type['ComponentBase']):
cls.TYPE_REGISTRY[resource_type] = {}
cls.TYPE_REGISTRY[resource_type][type_name] = subclass
# Automatically recreate models after registering a new type
cls.recreate_models_for_resource(resource_type)
cls.recreate_models()
logger.info(f"Registered type '{type_name}' for resource '{resource_type.__name__}' with subclass '{subclass.__name__}'")
return subclass
return decorator
Expand Down Expand Up @@ -211,11 +211,6 @@ def decorator(model_cls: Type[BaseModel]):
if resource_type not in cls.MODEL_REGISTRY[model_cls]:
cls.MODEL_REGISTRY[model_cls].append(resource_type)
logger.info(f"Registered model '{model_cls.__name__}' for resource '{resource_type.__name__}'")

# Recreate models for all associated resource types
for resource_type in cls.MODEL_REGISTRY[model_cls]:
cls.recreate_models_for_resource(resource_type)

return model_cls
return decorator

Expand All @@ -233,16 +228,16 @@ def get_class_by_type(cls, resource_type: Type[T], type_name: str) -> Type['Comp
"""
return cls.TYPE_REGISTRY.get(resource_type, {}).get(type_name)

@classmethod
@classmethod
def field_contains_subclass_union(cls, field_annotation) -> bool:
"""
Check if the field annotation contains a SubclassUnion or the placeholder PlaceholderPlugin.
Check if the field annotation contains a SubclassUnion.
Parameters:
- field_annotation: The type annotation of the field.
Returns:
- True if SubclassUnion or PlaceholderPlugin is present, False otherwise.
- True if SubclassUnion is present, False otherwise.
"""
if isinstance(field_annotation, type(SubclassUnion)):
return True
Expand All @@ -253,15 +248,13 @@ def field_contains_subclass_union(cls, field_annotation) -> bool:
elif origin in {list, List, dict, Dict, Union}:
args = get_args(field_annotation)
return any(cls.field_contains_subclass_union(arg) for arg in args)
elif inspect.isclass(field_annotation) and issubclass(field_annotation, PlaceholderPlugin):
return True
return False


@classmethod
def extract_resource_types_from_field(cls, field_annotation) -> List[Type['ComponentBase']]:
"""
Extracts all resource types from a field annotation that uses SubclassUnion or PlaceholderPlugin.
Extracts all resource types from a field annotation that uses SubclassUnion.
Parameters:
- field_annotation: The type annotation of the field.
Expand Down Expand Up @@ -298,9 +291,6 @@ def extract_resource_types_from_field(cls, field_annotation) -> List[Type['Compo
# Handle Dict[key_type, SubclassUnion[ResourceType]]
value_type = args[1]
resource_types.extend(cls.extract_resource_types_from_field(value_type))
elif inspect.isclass(field_annotation) and issubclass(field_annotation, PlaceholderPlugin):
# Placeholder resource type
resource_types.append(field_annotation)

return resource_types

Expand Down Expand Up @@ -367,10 +357,15 @@ def determine_new_type(cls, field_annotation, resource_type):
# Include None in the Union and maintain the discriminator
registered_classes = list(cls.TYPE_REGISTRY.get(resource_type, {}).values())
if not registered_classes:
# Use PlaceholderPlugin as a placeholder if no subclasses are registered
registered_classes = [PlaceholderPlugin]
union_with_none = Union[tuple(registered_classes + [type(None)])]
new_type = Annotated[union_with_none, Field(discriminator="type")]
# Use Any as a placeholder if no subclasses are registered
union_type = Any
else:
union_type = Union[tuple(registered_classes)]
union_with_none = Union[tuple([union_type, type(None)])]
new_type = Annotated[
union_with_none,
Field(discriminator="type")
]

return new_type

Expand Down Expand Up @@ -417,44 +412,4 @@ def recreate_models(cls):
else:
raise ValueError(f"Field '{field_name}' does not exist in model '{model_class.__name__}'")
model_class.model_rebuild(force=True)
logger.info("All models have been successfully recreated.")

@classmethod
def recreate_models_for_resource(cls, resource_type: Type[T]):
"""
Recreate only the models associated with the given resource type.
"""
with cls._lock:
models_with_fields = {}
for model_cls, res_types in cls.MODEL_REGISTRY.items():
if resource_type in res_types:
models_with_fields[model_cls] = {}
for field_name, field in model_cls.__fields__.items():
field_annotation = model_cls.__annotations__.get(field_name)
if not field_annotation:
continue

# Check if SubclassUnion is used in the field type
if not cls.field_contains_subclass_union(field_annotation):
continue # Only process fields that use SubclassUnion

# Determine the new type based on the annotation and resource_type
new_type = cls.determine_new_type(field_annotation, resource_type)

models_with_fields[model_cls][field_name] = new_type

for model_class, fields in models_with_fields.items():
for field_name, new_type in fields.items():
if field_name in model_class.model_fields:
model_class.model_fields[field_name].annotation = new_type
else:
raise ValueError(f"Field '{field_name}' does not exist in model '{model_class.__name__}'")
model_class.model_rebuild(force=True)
logger.info(f"Models associated with resource '{resource_type.__name__}' have been successfully recreated.")


class PlaceholderPlugin(ComponentBase):
"""
Placeholder base class for plugins when no subclasses are registered.
"""
pass
logger.info("All models have been successfully recreated.")

0 comments on commit 9ccc43d

Please sign in to comment.