Skip to content

Commit

Permalink
changes to go along with tidymodels/finetune#88 (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Dec 13, 2023
1 parent fc2a6b1 commit 3aa7075
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ tune_bayes_workflow <-

metrics <- check_metrics_arg(metrics, object, call = call)
opt_metric <- first_metric(metrics)
metrics_name <- opt_metric$metric
opt_metric_name <- opt_metric$metric
maximize <- opt_metric$direction == "maximize"

eval_time <- check_eval_time_arg(eval_time, metrics, call = call)
metrics_time <- first_eval_time(metrics, metrics_name, eval_time)
opt_metric_time <- first_eval_time(metrics, opt_metric_name, eval_time)

if (is.null(param_info)) {
param_info <- hardhat::extract_parameter_set_dials(object)
Expand Down Expand Up @@ -306,7 +306,7 @@ tune_bayes_workflow <-
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
eval_time_target = metrics_time,
eval_time_target = opt_metric_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = NULL
Expand All @@ -332,12 +332,12 @@ tune_bayes_workflow <-

check_time(start_time, control$time_limit)

score_card <- initial_info(mean_stats, metrics_name, maximize, metrics_time)
score_card <- initial_info(mean_stats, opt_metric_name, maximize, opt_metric_time)

if (control$verbose_iter) {
msg <- paste("Optimizing", metrics_name, "using", objective$label)
if (!is.null(metrics_time)) {
msg <- paste(msg, "at evaluation time", format(metrics_time, digits = 3))
msg <- paste("Optimizing", opt_metric_name, "using", objective$label)
if (!is.null(opt_metric_time)) {
msg <- paste(msg, "at evaluation time", format(opt_metric_time, digits = 3))
}
message_wrap(msg)
}
Expand All @@ -361,8 +361,8 @@ tune_bayes_workflow <-
fit_gp(
mean_stats %>% dplyr::select(-.iter),
pset = param_info,
metric = metrics_name,
eval_time = metrics_time,
metric = opt_metric_name,
eval_time = opt_metric_time,
control = control,
...
),
Expand Down Expand Up @@ -451,7 +451,7 @@ tune_bayes_workflow <-
mean_stats <- dplyr::bind_rows(mean_stats, rs_estimate %>% dplyr::mutate(.iter = i))
score_card <- update_score_card(score_card, i, tmp_res)
log_progress(control, x = mean_stats, maximize = maximize,
objective = metrics_name, eval_time = metrics_time)
objective = opt_metric_name, eval_time = opt_metric_time)
} else {
if (all_bad) {
tune_log(control, split = NULL, task = "All models failed", type = "danger")
Expand Down Expand Up @@ -480,7 +480,7 @@ tune_bayes_workflow <-
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
eval_time_target = metrics_time,
eval_time_target = opt_metric_time,
outcomes = outcomes,
rset_info = rset_info,
workflow = workflow_output
Expand Down Expand Up @@ -686,7 +686,7 @@ update_score_card <- function(info, iter, results, control) {
# ------------------------------------------------------------------------------


# save metrics_name and maximize to simplify!!!!!!!!!!!!!!!
# save opt_metric_name and maximize to simplify!!!!!!!!!!!!!!!
initial_info <- function(stats, metrics, maximize, eval_time) {
best_res <-
stats %>%
Expand Down

0 comments on commit 3aa7075

Please sign in to comment.