From da8ef50f85979125c2ca65a196308a95c5a3f47e Mon Sep 17 00:00:00 2001 From: rachaelvp Date: Wed, 22 Sep 2021 10:19:40 -0700 Subject: [PATCH] test with gam formula --- tests/testthat/test-gam.R | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/testthat/test-gam.R b/tests/testthat/test-gam.R index 6e0eeb06..e39df272 100644 --- a/tests/testthat/test-gam.R +++ b/tests/testthat/test-gam.R @@ -46,3 +46,20 @@ test_that("Lrnr_gam without specifying formula gives the predictions ## test equivalence of prediction from Lrnr_svm and svm::svm expect_equal(prd_lrnr_gam, prd_gam) }) + + +test_that("Lrnr_gam specifying complex formula gives the predictions that match those from gam", { + set.seed(256) + dat <- mgcv::gamSim(1, n = 400, dist = "normal", scale = 2) + task <- make_sl3_Task( + data = dat, outcome = "y", + covariates = c("x0", "x1", "x2", "x3", "f", "f0", "f1", "f2", "f3") + ) + lrnr_gam <- Lrnr_gam$new(formula = y ~ te(x0, x1, k = 7) + s(x2) + s(x3), method = "REML") + fit <- lrnr_gam$train(task) + pred_sl3 <- fit$predict(task) + + bt <- mgcv::gam(y ~ te(x0, x1, k = 7) + s(x2) + s(x3), data = dat, method = "REML") + pred_mgcv <- as.numeric(predict(bt)) + expect_equal(pred_sl3, pred_mgcv) +})