diff --git a/NEWS.md b/NEWS.md index ab9e4d46e..76375c8e0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,8 @@ * Improves documentation related to the hyperparameters associated with extracted objects that are generated from submodels. See the "Extracting with submodels" section of `?collect_extracts` to learn more. +* `augment()` methods to `tune_results`, `resample_results`, and `last_fit` objects now always returns tibbles (#759). + # tune 1.1.2 * `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701). diff --git a/R/augment.R b/R/augment.R index 5a249f294..d1d57a137 100644 --- a/R/augment.R +++ b/R/augment.R @@ -112,6 +112,11 @@ merge_pred <- function(dat, pred, y) { ) ) } + + if (!tibble::is_tibble(dat)) { + dat <- tibble::as_tibble(dat) + } + dat$.row <- 1:nrow(dat) dat <- dplyr::left_join(dat, pred, by = ".row") dat$.row <- NULL diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index ff170397b..cebc08521 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -6,6 +6,7 @@ test_that("augment fit_resamples", { lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm") set.seed(1) + two_class_dat <- as.data.frame(two_class_dat) bt1 <- rsample::bootstraps(two_class_dat, times = 30) set.seed(1) @@ -23,6 +24,7 @@ test_that("augment fit_resamples", { expect_true(sum(names(aug_1) == ".pred_Class1") == 1) expect_true(sum(names(aug_1) == ".pred_Class2") == 1) expect_true(sum(names(aug_1) == ".resid") == 0) + expect_s3_class_bare_tibble(aug_1) expect_snapshot(error = TRUE, augment(fit_1, hey = "you")) }) @@ -33,6 +35,7 @@ test_that("augment fit_resamples", { lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm") set.seed(1) + two_class_dat <- as.data.frame(two_class_dat) bt2 <- rsample::bootstraps(two_class_dat, times = 3) set.seed(1) @@ -50,6 +53,7 @@ test_that("augment fit_resamples", { expect_true(sum(names(aug_2) == ".pred_class") == 1) expect_true(sum(names(aug_2) == ".pred_Class1") == 1) expect_true(sum(names(aug_2) == ".pred_Class2") == 1) + expect_s3_class_bare_tibble(aug_2) }) # ------------------------------------------------------------------------------ @@ -77,6 +81,7 @@ test_that("augment tune_grid", { expect_true(sum(!is.na(aug_1$.pred)) == nrow(mtcars)) expect_true(sum(names(aug_1) == ".pred") == 1) expect_true(sum(names(aug_1) == ".resid") == 1) + expect_s3_class_bare_tibble(aug_1) aug_2 <- augment(fit_1, parameters = data.frame(cost = 3)) expect_true(any(abs(aug_1$.pred - aug_2$.pred) > 1)) @@ -113,6 +118,7 @@ test_that("augment tune_grid", { expect_true(sum(!is.na(aug_3$.pred)) == nrow(mtcars)) expect_true(sum(names(aug_3) == ".pred") == 1) expect_true(sum(names(aug_3) == ".resid") == 1) + expect_s3_class_bare_tibble(aug_3) }) @@ -131,6 +137,7 @@ test_that("augment last_fit", { expect_true(sum(names(aug_1) == ".pred_class") == 1) expect_true(sum(names(aug_1) == ".pred_Class1") == 1) expect_true(sum(names(aug_1) == ".pred_Class2") == 1) + expect_s3_class_bare_tibble(aug_1) expect_snapshot(error = TRUE, augment(fit_1, potato = TRUE)) })