Skip to content

Commit

Permalink
transition from add_tailor(prop) and method (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Oct 1, 2024
1 parent 18442b2 commit f8d734a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
7 changes: 1 addition & 6 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ tune_grid_loop_iter <- function(split,
assessment_rows <- as.integer(split, data = "assessment")
assessment <- vctrs::vec_slice(split$data, assessment_rows)

if (workflows::.should_inner_split(workflow)) {
if (workflows::.workflow_includes_calibration(workflow)) {
# if the workflow has a postprocessor that needs training (i.e. calibration),
# further split the analysis data into an "inner" analysis and
# assessment set.
Expand All @@ -397,11 +397,6 @@ tune_grid_loop_iter <- function(split,
# calibration set
# * the model (including the post-processor) generates predictions on the
# assessment set and those predictions are assessed with performance metrics
# todo: check if workflow's `method` is incompatible with `class(split)`?
# todo: workflow's `method` is currently ignored in favor of the one
# automatically dispatched to from `split`. consider this is combination
# with above todo.
split_args <- c(split_args, list(prop = workflow$post$actions$tailor$prop))
split <- rsample::inner_split(split, split_args = split_args)
analysis <- rsample::analysis(split)

Expand Down
20 changes: 12 additions & 8 deletions tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
parsnip::linear_reg()
) %>%
workflows::add_tailor(
tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"),
prop = 2/3,
method = class(split)
tailor::tailor() %>% tailor::adjust_numeric_calibration("linear")
)

set.seed(1)
Expand All @@ -261,13 +259,21 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
last_fit_preds <- collect_predictions(last_fit_res)

set.seed(1)
wflow_res <- generics::fit(wflow, rsample::analysis(split))
inner_split <- rsample::inner_split(split, split_args = list())

set.seed(1)
wflow_res <-
generics::fit(
wflow,
rsample::analysis(inner_split),
calibration = rsample::assessment(inner_split)
)
wflow_preds <- predict(wflow_res, rsample::assessment(split))

expect_equal(last_fit_preds[".pred"], wflow_preds)
})

test_that("can use `last_fit()` with a workflow - postprocessor (requires training)", {
test_that("can use `last_fit()` with a workflow - postprocessor (does not require training)", {
skip_if_not_installed("tailor")

y <- seq(0, 7, .001)
Expand All @@ -284,9 +290,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
parsnip::linear_reg()
) %>%
workflows::add_tailor(
tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1),
prop = 2/3,
method = class(split)
tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1)
)

set.seed(1)
Expand Down
20 changes: 15 additions & 5 deletions tests/testthat/test-resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t
parsnip::linear_reg()
) %>%
workflows::add_tailor(
tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"),
prop = 2/3,
method = class(folds$splits[[1]])
tailor::tailor() %>% tailor::adjust_numeric_calibration("linear")
)

set.seed(1)
Expand All @@ -178,8 +176,20 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t
seed <- generate_seeds(TRUE, 1)[[1]]
old_kind <- RNGkind()[[1]]
assign(".Random.seed", seed, envir = globalenv())
withr::defer(RNGkind(kind = old_kind))

wflow_res <- generics::fit(wflow, rsample::analysis(folds$splits[[1]]))
inner_split_1 <-
rsample::inner_split(
folds$splits[[1]],
split_args = list(v = 2, repeats = 1, breaks = 4, pool = 0.1)
)

wflow_res <-
generics::fit(
wflow,
rsample::analysis(inner_split_1),
calibration = rsample::assessment(inner_split_1)
)
wflow_preds <- predict(wflow_res, rsample::assessment(folds$splits[[1]]))

tune_wflow$fit$fit$elapsed$elapsed <- wflow_res$fit$fit$elapsed$elapsed
Expand All @@ -201,7 +211,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (no trainin
parsnip::linear_reg()
) %>%
workflows::add_tailor(
tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1)
tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1)
)

set.seed(1)
Expand Down

0 comments on commit f8d734a

Please sign in to comment.