From 2a8510a77a9f3fc82f41fb6d95160cfdf3846598 Mon Sep 17 00:00:00 2001 From: perib Date: Thu, 2 Nov 2023 17:41:38 -0700 Subject: [PATCH] bug fixes --- .../graph_pipeline_individual/individual.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/tpot2/individual_representations/graph_pipeline_individual/individual.py b/tpot2/individual_representations/graph_pipeline_individual/individual.py index 8ea3e0f2..bc2a883a 100644 --- a/tpot2/individual_representations/graph_pipeline_individual/individual.py +++ b/tpot2/individual_representations/graph_pipeline_individual/individual.py @@ -209,6 +209,7 @@ def __init__( self._crossover_swap_branch, ] + if self.inner_config_dict is not None: self.mutate_methods_list.append(self._mutate_insert_inner_node) self.crossover_methods_list.append(self._crossover_take_branch) #this is the only crossover method that can create inner nodes @@ -217,7 +218,7 @@ def __init__( self.mutate_methods_list.append(self._mutate_remove_edge) self.mutate_methods_list.append(self._mutate_add_edge) - if not linear_pipeline: + if not linear_pipeline and (self.leaf_config_dict is not None or self.inner_config_dict is not None): self.mutate_methods_list.append(self._mutate_insert_leaf) @@ -595,20 +596,12 @@ def _mutate_replace_node(self, rng_=None): for node in sorted_nodes_list: if isinstance(node,GraphIndividual): continue - node.method_class = rng.choice(list(self.select_config_dict(node).keys())) - if isinstance(self.select_config_dict(node)[node.method_class], dict): - hyperparameters = self.select_config_dict(node)[node.method_class] - node.hyperparameters = hyperparameters - else: - #hyperparameters = self.select_config_dict(node)[node.method_class](config.hyperparametersuggestor) - #get_hyperparameter(self.select_config_dict(node)[node.method_class], nodelabel=None, alpha=self.hyperparameter_alpha, hyperparameter_probability=self.hyperparameter_probability) - new_node = create_node(self.select_config_dict(node), rng_=rng) - #TODO cleanup - node.hyperparameters = new_node.hyperparameters - node.method_class = new_node.method_class - node.label = new_node.label - - return True + new_node = create_node(self.select_config_dict(node), rng_=rng) + #check if new node and old node are the same + #TODO: add attempts? + if node.method_class != new_node.method_class or node.hyperparameters != new_node.hyperparameters: + nx.relabel_nodes(self.graph, {new_node:node}, copy=False) + return True return False