Skip to content

Commit

Permalink
added check for min/max
Browse files Browse the repository at this point in the history
  • Loading branch information
franneck94 committed Sep 12, 2022
1 parent 4f82c35 commit c8d077d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
7 changes: 0 additions & 7 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ strict_optional = true
no_implicit_optional = true
warn_no_return = true
warn_unreachable = true
plugins = pydantic.mypy

[pydantic-mypy]
init_forbid_extra = True
init_typed = True
warn_required_dynamic_aliases = True
warn_untyped_fields = True


[pylint.config]
Expand Down
8 changes: 6 additions & 2 deletions tensorcross/model_selection/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _run_search(
kwargs (Any): Keyword arguments for the fit method of the
tf.keras.models.Model or tf.keras.models.Sequential model.
"""

maximize = True
tensorboard_callback = None
tensorboard_log_dir = ""

Expand Down Expand Up @@ -101,12 +101,16 @@ def _run_search(
if len(model.metrics) > 1:
val_score = model.evaluate(val_dataset, verbose=0)[-1]
else:
maximize = False
val_score = model.evaluate(val_dataset, verbose=0)
self.results_["val_scores"].append(val_score)
self.results_["params"].append(grid_combination)

logger.setLevel(tf_log_level) # Issue 30
best_run_idx = np.argmax(self.results_["val_scores"])
if maximize:
best_run_idx = np.argmax(self.results_["val_scores"])
else:
best_run_idx = np.argmin(self.results_["val_scores"])
self.results_["best_score"] = self.results_["val_scores"][best_run_idx]
self.results_["best_params"] = self.results_["params"][best_run_idx]

Expand Down
7 changes: 6 additions & 1 deletion tensorcross/model_selection/search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _run_search(
kwargs (Any): Keyword arguments for the fit method of the
tf.keras.models.Model or tf.keras.models.Sequential model.
"""
maximize = True
tensorboard_callback = None
tensorboard_log_dir = ""

Expand Down Expand Up @@ -118,6 +119,7 @@ def _run_search(
if len(model.metrics) > 1:
val_score = model.evaluate(val_dataset, verbose=0)[-1]
else:
maximize = False
val_score = model.evaluate(val_dataset, verbose=0)
val_scores[fold] = val_score

Expand All @@ -127,7 +129,10 @@ def _run_search(
logger.setLevel(tf_log_level) # Issue 30

mean_val_scores = np.mean(self.results_["val_scores"], axis=0)
best_run_idx = np.argmax(mean_val_scores)
if maximize:
best_run_idx = np.argmax(mean_val_scores)
else:
best_run_idx = np.argmin(mean_val_scores)
self.results_["best_score"] = self.results_["val_scores"][best_run_idx]
self.results_["best_params"] = self.results_["params"][best_run_idx]

Expand Down

0 comments on commit c8d077d

Please sign in to comment.