Skip to content

Commit

Permalink
implement support for twlss models in fitted_values and fix some issu…
Browse files Browse the repository at this point in the history
…es with general LSS family support in that function
  • Loading branch information
gavinsimpson committed Oct 17, 2023
1 parent bc3202d commit 73cd3ba
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 22 deletions.
6 changes: 4 additions & 2 deletions R/family-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,16 @@ family_type.family <- function(object, ...) {
} else {
# linfo is ordered; 1: location; 2: scale or sigma, 3: shape, power, etc
# (check pi is right greek letter for zero-inflation! - YES)
# but for twlss, eta2 is actually for p (power)
lobj <- switch(parameter,
location = linfo[[1L]],
mu = linfo[[1L]],
scale = linfo[[2L]],
sigma = linfo[[2L]],
phi = linfo[[2L]], # scale parameter for twlss() gammals()
phi = if (family_name(family) == "twlss") {linfo[[3L]]} else
{linfo[[2L]]}, # scale parameter for twlss() gammals()
shape = linfo[[3L]],
power = linfo[[3L]], # power for twlss()
power = linfo[[2L]], # power for twlss()
xi = linfo[[3L]], # xi for gevlss()
pi = linfo[[2L]], # pi for zero-inflation
epsilon = linfo[[3L]], # skewness for shash
Expand Down
107 changes: 87 additions & 20 deletions R/fitted_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
# Handle everything up to and including the extended families, but not more
fn <- family_type(object)
if (inherits(family(object), "general.family")) {
allowed <- c("gaulss", "gammals", "gumbls", "gevlss", "shash", "ziplss")
allowed <- c("gaulss", "gammals", "gumbls", "gevlss", "shash", "ziplss",
"twlss")
if (!fn %in% allowed) {
stop("General likelihood GAMs not yet supported.")
}
Expand All @@ -85,6 +86,7 @@
"shash" = post_link_funs(scale = exp, kurtosis = exp),
"ziplss" = post_link_funs(location = exp,
pi = inv_link(binomial("cloglog"))),
"twlss" = post_link_funs(power = twlss_theta_2_power, scale = exp),
post_link_funs())
# compute fitted values
fit <- fit_vals_fun(object, data = data, ci_level = ci_level,
Expand Down Expand Up @@ -169,19 +171,18 @@

# convert to the response scale if requested
if (identical(scale, "response")) {
ilink_loc <- inv_link(object, parameter = "location")
ilink_scl <- inv_link(object, parameter = "scale")
il <- lss_links(object, inverse = TRUE)

fit <- fit |>
mutate(across(all_of(c(".fitted", ".lower_ci", ".upper_ci")),
.fns = ~ case_match(.data$.parameter,
"location" ~ extra_fns[["location"]](ilink_loc(.x)),
"scale" ~ extra_fns[["scale"]](ilink_scl(.x)),
"shape" ~ extra_fns[["shape"]](ilink_scl(.x)),
"skewness" ~ extra_fns[["skewness"]](ilink_scl(.x)),
"kurtosis" ~ extra_fns[["scale"]](ilink_scl(.x)),
"power" ~ extra_fns[["scale"]](ilink_scl(.x)),
"pi" ~ extra_fns[["pi"]](ilink_scl(.x)))))
"location" ~ extra_fns[["location"]](il[["location"]](.x)),
"scale" ~ extra_fns[["scale"]](il[["scale"]](.x)),
"shape" ~ extra_fns[["shape"]](il[["shape"]](.x)),
"skewness" ~ extra_fns[["skewness"]](il[["skewness"]](.x)),
"kurtosis" ~ extra_fns[["kurtosis"]](il[["kurtosis"]](.x)),
"power" ~ extra_fns[["power"]](il[["power"]](.x)),
"pi" ~ extra_fns[["pi"]](il[["pi"]](.x)))))
}

fit
Expand All @@ -199,7 +200,7 @@ post_link_funs <- function(location = identity_fun,
pi = identity_fun) {

list(location = location, scale = scale, shape = shape, skewness = skewness,
power = power, pi = pi)
kurtosis = kurtosis, power = power, pi = pi)
}

#' General names of LSS parameters for each GAM family
Expand All @@ -208,16 +209,33 @@ post_link_funs <- function(location = identity_fun,
lss_parameters <- function(object) {
fn <- family_type(object)
par_names <- switch(fn,
"gaulss" = c("location", "scale"),
"gammals" = c("location", "scale"),
"gumbls" = c("location", "scale"),
"gevlss" = c("location", "scale", "shape"),
"shash" = c("location", "scale", "skewness", "kurtosis"),
"ziplss" = c("location", "pi"),
"location") # <- default, for most GAM families that's all there is
"gaulss" = c("location", "scale"),
"gammals" = c("location", "scale"),
"gumbls" = c("location", "scale"),
"gevlss" = c("location", "scale", "shape"),
"shash" = c("location", "scale", "skewness", "kurtosis"),
"ziplss" = c("location", "pi"),
"twlss" = c("location", "power", "scale"),
"location") # <- default, for most GAM families that's all there is
par_names
}

#' @importFrom purrr map
lss_links <- function(object, inverse = FALSE, which_eta = NULL) {
params <- lss_parameters(object)
param_nms <- c("location", "scale", "shape", "skewness", "kurtosis",
"power", "pi")
out <- rep(list(identity_fun), length(param_nms)) |>
setNames(param_nms)
funs <- purrr::map(params, .f = function(p, model, inverse, which_eta) {
extract_link(family(model), parameter = p, inverse = inverse,
which_eta = which_eta)
}, model = object, inverse = inverse, which_eta = which_eta) |>
setNames(params)
out[params] <- funs
out
}

# an identity function that simply returns input
identity_fun <- function(eta) {
eta
Expand Down Expand Up @@ -258,17 +276,65 @@ identity_fun <- function(eta) {
# convert to the response scale if requested
if (identical(scale, "response")) {
ilink_loc <- inv_link(object, parameter = "location")
ilink_scl <- inv_link(object, parameter = "scale")
ilink_pi <- inv_link(object, parameter = "pi")

fit <- fit |>
mutate(across(all_of(c(".fitted", ".lower_ci", ".upper_ci")),
.fns = ~ case_match(.data$.parameter,
"location" ~ extra_fns[["location"]](ilink_loc(.x)),
"pi" ~ extra_fns[["pi"]](ilink_scl(.x)))))
"pi" ~ extra_fns[["pi"]](ilink_pi(.x)))))
}

fit
}

#' @importFrom dplyr mutate across case_match row_number
#' @importFrom tidyr pivot_longer
#' @importFrom tibble as_tibble add_column
`fit_vals_twlss` <- function(object, data, ci_level = 0.95,
scale = "response", extra_fns = post_link_funs(), ...) {

crit <- coverage_normal(ci_level)
# get the fitted values for data
fv <- predict(object, newdata = data, ..., type = "link",
se.fit = TRUE)
std_err <- fv[[2L]]
fv <- fv[[1]]
colnames(std_err) <- colnames(fv) <- lss_parameters(object)
# convert fv to tibble then long format
fv <- fv |>
as_tibble() |>
mutate(.row = row_number()) |>
relocate(".row", .before = 1L) |>
tidyr::pivot_longer(!matches("\\.row"), values_to = ".fitted",
names_to = ".parameter")
# convert fv to tibble then long format
std_err <- std_err |>
as_tibble() |>
tidyr::pivot_longer(everything(), values_to = ".std_err",
names_to = ".parameter")
# bind .std_err to fv...
fit <- fv |>
add_column(.se = pull(std_err, ".std_err")) |>
# ...and compute interval
mutate(.lower_ci = .data$.fitted + (crit * .data$.se),
.upper_ci = .data$.fitted - (crit * .data$.se))

# convert to the response scale if requested
if (identical(scale, "response")) {
il <- lss_links(object, inverse = TRUE)
bounds <- get_tw_bounds(object)

fit <- fit |>
mutate(across(all_of(c(".fitted", ".lower_ci", ".upper_ci")),
.fns = ~ case_match(.data$.parameter,
"location" ~ extra_fns[["location"]](il[["location"]](.x)),
"power" ~ extra_fns[["power"]](il[["power"]](.x),
a = bounds[1], b = bounds[2]),
"scale" ~ extra_fns[["scale"]](il[["scale"]](.x)))))
}

fit
}

#' @importFrom dplyr bind_rows relocate
Expand Down Expand Up @@ -345,6 +411,7 @@ identity_fun <- function(eta) {
fam == "gevlss" ~ "general_lss",
fam == "shash" ~ "general_lss",
fam == "ziplss" ~ "ziplss",
fam == "twlss" ~ "twlss",
.default = "default"
)
get(paste0("fit_vals_", fam), mode = "function")
Expand Down
23 changes: 23 additions & 0 deletions R/utililties.R
Original file line number Diff line number Diff line change
Expand Up @@ -1393,3 +1393,26 @@ vars_from_label <- function(label) {
}
model$null.deviance
}

# function to return the vector of boundary points for power parameter
get_tw_bounds <- function(model) {
fam <- family_name(model)
if (fam != "twlss") {
stop("'model' wasn't fitted with 'twlss()' family.",
call. = FALSE)
}
rfun <- family(model)$residuals
a <- get(".a", envir = environment(rfun))
b <- get(".b", envir = environment(rfun))
c(a, b)
}

# function to convert vector of theta values to power values for twlss family
twlss_theta_2_power <- function(theta, a, b) {
i <- theta > 0
exp_theta_pos <- exp(-theta[i])
exp_theta_neg <- exp(theta[!i])
theta[i] <- (b + a * exp_theta_pos) / (1 + exp_theta_pos)
theta[!i] <- (b * exp_theta_neg + a) / (1 + exp_theta_neg)
theta
}

0 comments on commit 73cd3ba

Please sign in to comment.