Skip to content

Commit

Permalink
Merge pull request #760 from tidymodels/augment-tibble
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Nov 21, 2023
2 parents ef911aa + d76e58c commit 3509577
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

* Improves documentation related to the hyperparameters associated with extracted objects that are generated from submodels. See the "Extracting with submodels" section of `?collect_extracts` to learn more.

* `augment()` methods to `tune_results`, `resample_results`, and `last_fit` objects now always returns tibbles (#759).

# tune 1.1.2

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).
Expand Down
5 changes: 5 additions & 0 deletions R/augment.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ merge_pred <- function(dat, pred, y) {
)
)
}

if (!tibble::is_tibble(dat)) {
dat <- tibble::as_tibble(dat)
}

dat$.row <- 1:nrow(dat)
dat <- dplyr::left_join(dat, pred, by = ".row")
dat$.row <- NULL
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-augment.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ test_that("augment fit_resamples", {
lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm")

set.seed(1)
two_class_dat <- as.data.frame(two_class_dat)
bt1 <- rsample::bootstraps(two_class_dat, times = 30)

set.seed(1)
Expand All @@ -23,6 +24,7 @@ test_that("augment fit_resamples", {
expect_true(sum(names(aug_1) == ".pred_Class1") == 1)
expect_true(sum(names(aug_1) == ".pred_Class2") == 1)
expect_true(sum(names(aug_1) == ".resid") == 0)
expect_s3_class_bare_tibble(aug_1)

expect_snapshot(error = TRUE, augment(fit_1, hey = "you"))
})
Expand All @@ -33,6 +35,7 @@ test_that("augment fit_resamples", {
lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm")

set.seed(1)
two_class_dat <- as.data.frame(two_class_dat)
bt2 <- rsample::bootstraps(two_class_dat, times = 3)

set.seed(1)
Expand All @@ -50,6 +53,7 @@ test_that("augment fit_resamples", {
expect_true(sum(names(aug_2) == ".pred_class") == 1)
expect_true(sum(names(aug_2) == ".pred_Class1") == 1)
expect_true(sum(names(aug_2) == ".pred_Class2") == 1)
expect_s3_class_bare_tibble(aug_2)
})

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -77,6 +81,7 @@ test_that("augment tune_grid", {
expect_true(sum(!is.na(aug_1$.pred)) == nrow(mtcars))
expect_true(sum(names(aug_1) == ".pred") == 1)
expect_true(sum(names(aug_1) == ".resid") == 1)
expect_s3_class_bare_tibble(aug_1)

aug_2 <- augment(fit_1, parameters = data.frame(cost = 3))
expect_true(any(abs(aug_1$.pred - aug_2$.pred) > 1))
Expand Down Expand Up @@ -113,6 +118,7 @@ test_that("augment tune_grid", {
expect_true(sum(!is.na(aug_3$.pred)) == nrow(mtcars))
expect_true(sum(names(aug_3) == ".pred") == 1)
expect_true(sum(names(aug_3) == ".resid") == 1)
expect_s3_class_bare_tibble(aug_3)
})


Expand All @@ -131,6 +137,7 @@ test_that("augment last_fit", {
expect_true(sum(names(aug_1) == ".pred_class") == 1)
expect_true(sum(names(aug_1) == ".pred_Class1") == 1)
expect_true(sum(names(aug_1) == ".pred_Class2") == 1)
expect_s3_class_bare_tibble(aug_1)

expect_snapshot(error = TRUE, augment(fit_1, potato = TRUE))
})

0 comments on commit 3509577

Please sign in to comment.