diff --git a/NAMESPACE b/NAMESPACE index 4464ddd5c..7e7fedc6a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -187,6 +187,8 @@ export(finalize_model) export(finalize_recipe) export(finalize_workflow) export(finalize_workflow_preprocessor) +export(first_eval_time) +export(first_metric) export(fit_best) export(fit_max_value) export(fit_resamples) diff --git a/R/metric-selection.R b/R/metric-selection.R new file mode 100644 index 000000000..880bc0b7f --- /dev/null +++ b/R/metric-selection.R @@ -0,0 +1,51 @@ +#' Tools for selecting metrics and evaluation times +#' +#' @param mtr_set A [yardstick::metric_set()]. +#' @param metric A character value for which metric is being used. +#' @param eval_time An optional vector of times to compute dynamic and/or +#' integrated metrics. +#' @keywords internal +#' @export +first_metric <- function(mtr_set) { + tibble::as_tibble(mtr_set)[1,] +} + +#' @rdname first_metric +#' @keywords internal +#' @export +first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) { + num_times <- length(eval_time) + + if (is.null(metric)) { + mtr_info <- first_metric(mtr_set) + metric <- mtr_info$metric + } else { + mtr_info <- tibble::as_tibble(mtr_set) + mtr_info <- mtr_info[mtr_info$metric == metric,] + } + + # Not a survival metric + if (!any(grepl("_survival_", mtr_info$class))) { + return(NULL) + } + + # Not a metric that requires an eval_time + no_time_req <- c("static_survival_metric", "integrated_survival_metric") + if (mtr_info$class %in% no_time_req) { + if (num_times > 0) { + cli::cli_warn("Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric.") + } + return(NULL) + } + + # checks for dynamic metrics + if (num_times == 0) { + cli::cli_abort("A single evaluation time is required to use this metric.") + } else if ( num_times > 1 ) { + eval_time <- eval_time[1] + print_time <- format(eval_time, digits = 3) + cli::cli_warn("{num_times} evaluation times were available; the first ({print_time}) will be used.") + } + + eval_time +} diff --git a/R/select_best.R b/R/select_best.R index a9771b1e6..85305667d 100644 --- a/R/select_best.R +++ b/R/select_best.R @@ -76,6 +76,8 @@ show_best.default <- function(x, ...) { #' @export #' @rdname show_best show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ...) { + # TODO should return the as_tibble(metric_set) results to get the class etc. + # TODO new function start metric <- choose_metric(metric, x) dots <- rlang::enquos(...) @@ -92,8 +94,12 @@ show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, .. metric <- metrics } + # TODO new function stop + # get estimates/summarise summary_res <- summary_res %>% dplyr::filter(.metric == metric) + + # TODO split selecting the req time and seeing if it is in the data summary_res <- choose_eval_time(summary_res, x, eval_time) if (nrow(summary_res) == 0) { @@ -349,7 +355,8 @@ middle_eval_time <- function(x) { eval_time } - +# NOTE this chooses the time and subsets the data; break it up to only select +# time choose_eval_time <- function(x, object, eval_time) { mtrs <- .get_tune_metrics(object) mtrs <- tibble::as_tibble(mtrs) diff --git a/man/first_metric.Rd b/man/first_metric.Rd new file mode 100644 index 000000000..169d9b423 --- /dev/null +++ b/man/first_metric.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-selection.R +\name{first_metric} +\alias{first_metric} +\alias{first_eval_time} +\title{Tools for selecting metrics and evaluation times} +\usage{ +first_metric(mtr_set) + +first_eval_time(mtr_set, metric = NULL, eval_time = NULL) +} +\arguments{ +\item{mtr_set}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}}.} + +\item{metric}{A character value for which metric is being used.} + +\item{eval_time}{An optional vector of times to compute dynamic and/or +integrated metrics.} +} +\description{ +Tools for selecting metrics and evaluation times +} +\keyword{internal} diff --git a/tests/testthat/_snaps/eval-time-single-selection.md b/tests/testthat/_snaps/eval-time-single-selection.md new file mode 100644 index 000000000..5d122bb3c --- /dev/null +++ b/tests/testthat/_snaps/eval-time-single-selection.md @@ -0,0 +1,156 @@ +# selecting single eval time - pure metric sets + + Code + stc_one <- first_eval_time(met_stc, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_multi <- first_eval_time(met_stc, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + first_eval_time(met_dyn, eval_time = NULL) + Condition + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. + +--- + + Code + first_eval_time(met_dyn, "brier_survival", eval_time = NULL) + Condition + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. + +--- + + Code + dyn_multi <- first_eval_time(met_dyn, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + +--- + + Code + int_1 <- first_eval_time(met_int, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + int_multi <- first_eval_time(met_int, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +# selecting single eval time - mixed metric sets - static first + + Code + stc_1 <- first_eval_time(met_mix_stc, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_multi <- first_eval_time(met_mix_stc, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_1 <- first_eval_time(met_mix_stc_all, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + stc_multi <- first_eval_time(met_mix_stc_all, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +# selecting single eval time - mixed metric sets - dynamic first + + Code + first_eval_time(met_mix_dyn, eval_time = NULL) + Condition + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. + +--- + + Code + dyn_multi <- first_eval_time(met_mix_dyn, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + +--- + + Code + first_eval_time(met_mix_dyn_all, eval_time = NULL) + Condition + Error in `first_eval_time()`: + ! A single evaluation time is required to use this metric. + +--- + + Code + dyn_multi <- first_eval_time(met_mix_dyn_all, eval_time = times_2) + Condition + Warning: + 2 evaluation times were available; the first (0.714) will be used. + +# selecting single eval time - mixed metric sets - integrated first + + Code + first_eval_time(met_mix_int, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Output + NULL + +--- + + Code + int_multi <- first_eval_time(met_mix_int, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + +--- + + Code + first_eval_time(met_mix_int_all, eval_time = times_1) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + Output + NULL + +--- + + Code + int_multi <- first_eval_time(met_mix_int_all, eval_time = times_2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric. + diff --git a/tests/testthat/test-eval-time-single-selection.R b/tests/testthat/test-eval-time-single-selection.R new file mode 100644 index 000000000..c40621b18 --- /dev/null +++ b/tests/testthat/test-eval-time-single-selection.R @@ -0,0 +1,214 @@ +library(yardstick) + +test_that("selecting single eval time - non-survival case", { + met_reg <- metric_set(rmse) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # eval time is not applicable outside of survival models; return null + + expect_null(first_eval_time(met_reg, eval_time = NULL)) + expect_null(first_eval_time(met_reg, eval_time = times_1)) + expect_null(first_eval_time(met_reg, eval_time = times_2)) + +}) + +test_that("selecting single eval time - pure metric sets", { + met_int <- metric_set(brier_survival_integrated) + met_dyn <- metric_set(brier_survival) + met_stc <- metric_set(concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # all static; return NULL and add warning if times are given + + expect_null(first_eval_time(met_stc, eval_time = NULL)) + expect_null(first_eval_time(met_stc, "concordance_survival", eval_time = NULL)) + + expect_snapshot( + stc_one <- first_eval_time(met_stc, eval_time = times_1) + ) + expect_null(stc_one) + + expect_snapshot( + stc_multi <- first_eval_time(met_stc, eval_time = times_2) + ) + expect_null(stc_multi) + + # all dynamic; return a single time and warn if there are more and error if + # there are none + + expect_snapshot( + first_eval_time(met_dyn, eval_time = NULL), + error = TRUE + ) + expect_snapshot( + first_eval_time(met_dyn, "brier_survival", eval_time = NULL), + error = TRUE + ) + + expect_equal( + first_eval_time(met_dyn, eval_time = times_1), + times_1 + ) + + expect_snapshot( + dyn_multi <- first_eval_time(met_dyn, eval_time = times_2) + ) + expect_equal(dyn_multi, times_2[1]) + + # all integrated; return NULL and warn if there 1+ times + + expect_null(first_eval_time(met_int, eval_time = NULL)) + expect_null( + first_eval_time(met_int, "brier_survival_integrated", eval_time = NULL) + ) + + expect_snapshot( + int_1 <- first_eval_time(met_int, eval_time = times_1) + ) + expect_null(int_1) + + expect_snapshot( + int_multi <- first_eval_time(met_int, eval_time = times_2) + ) + expect_null(int_multi) + +}) + +test_that("selecting single eval time - mixed metric sets - static first", { + met_mix_stc <- metric_set(concordance_survival, brier_survival) + met_mix_stc_all <- metric_set(concordance_survival, brier_survival, brier_survival_integrated) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # static is first but includes dynamic. Should return NULL and add warning + # if times are given + + expect_null( + first_eval_time(met_mix_stc, eval_time = NULL) + ) + + expect_snapshot( + stc_1 <- first_eval_time(met_mix_stc, eval_time = times_1) + ) + expect_null(stc_1) + + expect_snapshot( + stc_multi <- first_eval_time(met_mix_stc, eval_time = times_2) + ) + expect_null(stc_multi) + + # static is first but includes dynamic and integrated. Should return NULL and + # add warning if times are given + + expect_null( + first_eval_time(met_mix_stc_all, eval_time = NULL) + ) + + expect_snapshot( + stc_1 <- first_eval_time(met_mix_stc_all, eval_time = times_1) + ) + expect_null(stc_1) + + expect_snapshot( + stc_multi <- first_eval_time(met_mix_stc_all, eval_time = times_2) + ) + expect_null(stc_multi) +}) + +test_that("selecting single eval time - mixed metric sets - dynamic first", { + met_mix_dyn <- metric_set(brier_survival, concordance_survival) + met_mix_dyn_all <- + metric_set(brier_survival, + brier_survival_integrated, + concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # dynamic is first but includes static. Should return single time and add warning + # if 2+ times are given + + expect_snapshot( + first_eval_time(met_mix_dyn, eval_time = NULL), + error = TRUE + ) + expect_equal( + first_eval_time(met_mix_dyn, eval_time = times_1), + times_1 + ) + expect_snapshot( + dyn_multi <- first_eval_time(met_mix_dyn, eval_time = times_2) + ) + expect_equal(dyn_multi, times_2[1]) + + # dynamic is first but includes static and integrated. Should return single + # time and add warning if 2+ times are given + + expect_snapshot( + first_eval_time(met_mix_dyn_all, eval_time = NULL), + error = TRUE + ) + expect_equal( + first_eval_time(met_mix_dyn_all, eval_time = times_1), + times_1 + ) + expect_snapshot( + dyn_multi <- first_eval_time(met_mix_dyn_all, eval_time = times_2) + ) + expect_equal(dyn_multi, times_2[1]) + +}) + + +test_that("selecting single eval time - mixed metric sets - integrated first", { + met_mix_int <- metric_set(brier_survival_integrated, concordance_survival) + met_mix_int_all <- + metric_set(brier_survival_integrated, + brier_survival, + concordance_survival) + + times_1 <- 1 / 3 + times_2 <- as.numeric(5:4) / 7 + + # integrated is first but includes static. Should return NULL and add warning + # if 1+ times are given + + expect_null(first_eval_time(met_mix_int, eval_time = NULL)) + + expect_snapshot( + first_eval_time(met_mix_int, eval_time = times_1) + ) + expect_snapshot( + int_multi <- first_eval_time(met_mix_int, eval_time = times_2) + ) + expect_null(int_multi) + + # integrated is first but includes static and dynamic. Should return NULL and + # add warning if 1+ times are given + + expect_null(first_eval_time(met_mix_int_all, eval_time = NULL)) + + expect_snapshot( + first_eval_time(met_mix_int_all, eval_time = times_1) + ) + expect_snapshot( + int_multi <- first_eval_time(met_mix_int_all, eval_time = times_2) + ) + expect_null(int_multi) +}) + + +test_that("selecting the first metric", { + met_1 <- metric_set(rmse) + tbl_1 <- tibble::as_tibble(met_1)[1,] + met_2 <- metric_set(rmse, ccc) + tbl_2 <- tibble::as_tibble(met_2)[1,] + + expect_equal(first_metric(met_1), tbl_1) + expect_equal(first_metric(met_2), tbl_2) +})