diff --git a/skyllh/core/parameters.py b/skyllh/core/parameters.py index 684dc6b028..4548e4d491 100644 --- a/skyllh/core/parameters.py +++ b/skyllh/core/parameters.py @@ -1545,9 +1545,19 @@ def __init__( The ParameterGrid instances this instance of ParameterGridSet should get initialized with. """ + # Infer `obj_type` from the `param_grids` argument. + if issequence(param_grids): + obj_type = type(param_grids[0]) + else: + obj_type = type(param_grids) + + if obj_type is type(None): + # Fall back to the default `ParameterGrid` type. + obj_type = ParameterGrid + super().__init__( objs=param_grids, - obj_type=type(param_grids[0]), + obj_type=obj_type, **kwargs) @property