From 8dd5ae01eb80df061ed5783553f204803a979963 Mon Sep 17 00:00:00 2001 From: tomaskontrimas Date: Mon, 19 Feb 2024 14:08:39 -0600 Subject: [PATCH 1/2] Implement fall back to ParameterGrid if no objects were specified and sequence/instance check. --- skyllh/core/parameters.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/skyllh/core/parameters.py b/skyllh/core/parameters.py index 684dc6b028..a1ce25656c 100644 --- a/skyllh/core/parameters.py +++ b/skyllh/core/parameters.py @@ -1545,9 +1545,19 @@ def __init__( The ParameterGrid instances this instance of ParameterGridSet should get initialized with. """ + # Infer `obj_type` from the `param_grids` argument. + if (param_grids is None) or issequenceof(param_grids, type(None)): + # Fall back to the default `ParameterGrid` type. + obj_type = ParameterGrid + else: + if issequence(param_grids): + obj_type = type(param_grids[0]) + else: + obj_type = type(param_grids) + super().__init__( objs=param_grids, - obj_type=type(param_grids[0]), + obj_type=obj_type, **kwargs) @property From 66d1c6fc218cb29c713758f6dd91cd2501d09947 Mon Sep 17 00:00:00 2001 From: Tomas Kontrimas <52071038+tomaskontrimas@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:44:30 +0100 Subject: [PATCH 2/2] Update according to Martin's suggestions --- skyllh/core/parameters.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/skyllh/core/parameters.py b/skyllh/core/parameters.py index a1ce25656c..4548e4d491 100644 --- a/skyllh/core/parameters.py +++ b/skyllh/core/parameters.py @@ -1546,14 +1546,14 @@ def __init__( get initialized with. """ # Infer `obj_type` from the `param_grids` argument. - if (param_grids is None) or issequenceof(param_grids, type(None)): + if issequence(param_grids): + obj_type = type(param_grids[0]) + else: + obj_type = type(param_grids) + + if obj_type is type(None): # Fall back to the default `ParameterGrid` type. obj_type = ParameterGrid - else: - if issequence(param_grids): - obj_type = type(param_grids[0]) - else: - obj_type = type(param_grids) super().__init__( objs=param_grids,