Skip to content

Commit

Permalink
Merge branch 'issue-79' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinsimpson committed Mar 11, 2024
2 parents 4eb66c4 + a12081e commit 4ad2c92
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 20 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ export(load_mgcv)
export(lp_matrix)
export(mh_draws)
export(model_concurvity)
export(model_constant)
export(model_edf)
export(model_vars)
export(n_smooths)
Expand Down
11 changes: 10 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# gratia 0.8.9.7
# gratia 0.8.9.8

## Breaking changes

Expand Down Expand Up @@ -244,6 +244,15 @@ eventual 1.0.0 release. These functions will become defunct by version 0.11.0 or
* `penalty()` has a default method that works with `s()`, `te()`, `t2()`, and
`ti()`, which create a smooth specification.

* `transform_fun()` gains argument `constant` to allow for the addition of a
constant value to objects (e.g. the estimate and confidence interval). This
enables a single `obj |> transform_fun(fun = exp, constant = 5)` instead of
separate calls to `add_constant()` and then `transform_fun()`. Part of the
discussion of #79

* `model_constant()` is a new function that simply extracts the first
coefficient from the estimated model.

## Bug fixes

* `link()`, `inv_link()`, and related family functions for the `ocat()` weren't
Expand Down
108 changes: 96 additions & 12 deletions R/utililties.R
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ vars_from_label <- function(label) {
#'
#' @param object an object to apply the transform function to.
#' @param fun the function to apply.
#' @param constant numeric; a constant to apply before transformation.
#' @param ... additional arguments passed to methods.
#' @param column character; for the `"tbl_df"` method, which column to
#' transform.
Expand All @@ -1134,7 +1135,16 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect any_of
`transform_fun.smooth_estimates` <- function(object, fun = NULL, ...) {
`transform_fun.smooth_estimates` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(any_of(c("est", "lower_ci", "upper_ci",
".estimate", ".upper_ci", ".lower_ci")),
.fns = \(x) x + constant))
}
## If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
Expand All @@ -1156,14 +1166,21 @@ vars_from_label <- function(label) {
#' @rdname transform_fun
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.smooth_samples` <- function(object, fun = NULL, ...) {
## If fun supplied, use it to transform est and the upper and lower interval
#' @importFrom tidyselect all_of any_of
`transform_fun.smooth_samples` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform value
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c(".value")),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform value
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
object,
across(all_of("value"),
across(all_of(".value"),
.fns = fun
)
)
Expand All @@ -1176,7 +1193,15 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.mgcv_smooth` <- function(object, fun = NULL, ...) {
`transform_fun.mgcv_smooth` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c(".estimate", ".upper_ci", ".lower_ci")),
.fns = \(x) x + constant))
}
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand All @@ -1194,7 +1219,15 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.evaluated_parametric_term` <- function(object, fun = NULL, ...) {
`transform_fun.evaluated_parametric_term` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c("est", "lower", "upper")),
.fns = \(x) x + constant))
}
## If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
Expand All @@ -1212,9 +1245,17 @@ vars_from_label <- function(label) {
#' @rdname transform_fun
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.parametric_effects` <- function(object, fun = NULL, ...) {
## If fun supplied, use it to transform est and the upper and lower interval
#' @importFrom tidyselect all_of any_of
`transform_fun.parametric_effects` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(any_of(c(".partial", ".lower_ci", ".upper_ci")),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand All @@ -1232,11 +1273,19 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.tbl_df` <- function(object, fun = NULL, column = NULL, ...) {
`transform_fun.tbl_df` <- function(object, fun = NULL, column = NULL,
constant = NULL, ...) {
if (is.null(column)) {
stop("'column' to modify must be supplied.")
}
## If fun supplied, use it to transform est and the upper and lower interval
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(column),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand Down Expand Up @@ -1765,3 +1814,38 @@ reclass_scam_smooth <- function(smooth) {

sm_vars
}

#' Extract the model constant term
#'
#' Extracts the model constant term, the model intercept, from a fitted model
#' object.
#'
#' @param model a fitted model for which a `coef()` method exists
#'
#' @export
#' @importFrom stats coef
#' @examples
#' \dontshow{
#' op <- options(digits = 4)
#' }
#' load_mgcv()
#'
#' # simulate a small example
#' df <- data_sim("eg1")
#'
#' # fit the GAM
#' m <- gam(y ~ s(x0) + s(x1) + s(x2) + s(x3), data = df, method = "REML")
#'
#' # extract the estimate of the constant term
#' model_constant(m)
#' # same as coef(m)[1L]
#' coef(m)[1L]
#'
#' \dontshow{
#' options(op)
#' }
`model_constant` <- function(model) {
b <- coef(model)
b[1L] |>
unname()
}
36 changes: 36 additions & 0 deletions man/model_constant.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 8 additions & 6 deletions man/transform_fun.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 22 additions & 1 deletion tests/testthat/test-utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,21 @@ test_that("transform_fun works for tbl", {
expect_silent(tbl <- transform_fun(su_eg1, fun = abs, column = "y"))
})

test_that("transform_fun works for smooth_estimates with constant", {
expect_silent(sm <- smooth_estimates(m_gam, smooth = "s(x1)"))
expect_silent(sm <- transform_fun(sm, fun = exp, constant = coef(m_gam)[1]))
})

test_that("transform_fun works for smooth_samples with constant", {
expect_silent(sm <- smooth_samples(m_gam, term = "s(x1)", n = 5))
expect_silent(sm <- transform_fun(sm, fun = exp, constant = coef(m_gam)[1]))
})

test_that("transform_fun works for tbl with constant", {
expect_silent(tbl <- transform_fun(su_eg1, fun = abs, column = "y",
constant = 5))
})

test_that("involves_ranef_smooth works", {
sm <- smooths(su_m_trivar_t2)
expect_false(involves_ranef_smooth(get_smooth(su_m_trivar_t2, sm[1])))
Expand Down Expand Up @@ -480,7 +495,6 @@ test_that("norm_minus_one_to_one works", {
expect_identical(range(x), c(-1, 1))
})


test_that("norm_minus_one_to_one works with NA", {
expect_silent(x <- norm_minus_one_to_one(c(0:10, NA)))
expect_equal(c(seq(-1, 1, by = 0.2), NA), x)
Expand All @@ -489,3 +503,10 @@ test_that("norm_minus_one_to_one works with NA", {
expect_identical(length(x), length(c(0:10, NA)))
expect_identical(range(x, na.rm = TRUE), c(-1, 1))
})

test_that("model_constant returns the intercept estimate", {
expect_silent(b <- model_constant(m_gam))
expect_type(b, "double")
expect_identical(b, unname(coef(m_gam)[1L]))
expect_named(b, expected = NULL)
})

0 comments on commit 4ad2c92

Please sign in to comment.