diff --git a/NEWS.md b/NEWS.md index e173119a7..df61137e2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,7 @@ * `recipe()`, `prep()`, and `bake()` now work with sparse tibbles. (#1364, #1366) -* `recipe()` now works with sparse matrices. (#1364) +* `recipe()` and `prep()` now work with sparse matrices. (#1364, #1368) # recipes 1.1.0 diff --git a/R/misc.R b/R/misc.R index f11ebb4d3..a86195fb9 100644 --- a/R/misc.R +++ b/R/misc.R @@ -660,6 +660,9 @@ validate_training_data <- function(x, rec, fresh, call = rlang::caller_env()) { } x <- rec$template } else { + if (is_sparse_matrix(x)) { + x <- sparsevctrs::coerce_to_sparse_tibble(x) + } if (!is_tibble(x)) { x <- as_tibble(x) } diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 6db6576a5..9caad9c82 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -110,3 +110,26 @@ test_that("recipe() accepts sparse matrices", { ) }) +test_that("prep() accepts sparse matrices", { + skip_if_not_installed("modeldata") + + hotel_data <- sparse_hotel_rates() + + rec_spec <- recipe(avg_price_per_room ~ ., data = hotel_data) + + expect_no_error( + rec <- prep(rec_spec) + ) + + expect_true( + is_sparse_tibble(rec$template) + ) + + expect_no_error( + rec <- prep(rec_spec, training = hotel_data) + ) + + expect_true( + is_sparse_tibble(rec$template) + ) +})