diff --git a/DESCRIPTION b/DESCRIPTION index 1ce588d1..00cd6cae 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,6 +21,7 @@ URL: https://github.com/amices/ggmice, https://amices.org/, https://amices.org/ggmice/ BugReports: https://github.com/amices/ggmice Imports: + broom, cli, dplyr, ggplot2, diff --git a/R/plot_variance.R b/R/plot_variance.R new file mode 100644 index 00000000..f332944d --- /dev/null +++ b/R/plot_variance.R @@ -0,0 +1,114 @@ +#' Plot the scaled between imputation variance for every cell as a heatmap +#' +#' This function plots the cell-level between imputation variance. The function +#' scales the variances column-wise, without centering cf. `base::scale(center = FALSE)` +#' and plots the data image as a heatmap. Darker red cells indicate more variance, +#' lighter cells indicate less variance. White cells represent observed cells or unobserved cells with zero between +#' imputation variance. +#' +#' @param data A multiply imputed object of class [`mice::mids`] or [`mice::mira`]. +#' @param grid Logical indicating whether grid lines should be displayed. +#' +#' @return An object of class `ggplot`. +#' @examples +#' imp <- mice::mice(mice::nhanes, printFlag = FALSE) +#' plot_variance(imp) +#' @export +plot_variance <- function(data, grid = TRUE) { + if (mice::is.mids(data)) { + if (data$m < 2) { + stop("The between inmputation variance cannot be computed if there are fewer than two imputations (m < 2).") + }} else if (mice::is.mira(data)) { + if(length(data$analyses) < 2) { + stop("The between inmputation variance cannot be computed if there are fewer than two imputations (m < 2).") + }} else { + if (!is.list(data)) stop("Argument 'data' not a list", call. = FALSE) + data <- mice::as.mira(data) + # stop( + # "Input is not a Multiply Imputed Data Set of class `mids`/ `mira`. \n + # Perhaps function mice::as.mids() can be of use?") + } + if (grid) { + gridcol <- "black" + } else { + gridcol <- NA + } + + if (mice::is.mids(data)) { + long <- mice::complete(data, "long") %>% + dplyr::mutate(dplyr::across(where(is.factor), as.numeric)) %>% + dplyr::select(-.imp) %>% + dplyr::group_by(.id) %>% + dplyr::summarise(dplyr::across(dplyr::everything(), stats::var)) %>% + dplyr::ungroup() %>% + dplyr::mutate(dplyr::across(.cols = -.id, ~ scale_above_zero(.))) %>% + tidyr::pivot_longer(cols = -.id) + + legend <- "Imputation variability*\n " + caption <- "*scaled cell-level between imputation variance" + + gg <- + ggplot2::ggplot(long, ggplot2::aes(name, .id, fill = value)) + + ggplot2::geom_tile(color = gridcol) + + ggplot2::scale_fill_gradient(low = "white", high = mice::mdc(2)) + + ggplot2::labs( + x = "Column name", + y = "Row number", + fill = legend, + caption = caption + ) + + ggplot2::scale_x_discrete(position = "top", expand = c(0, 0)) + + ggplot2::scale_y_continuous(trans = "reverse", expand = c(0, 0)) + + theme_minimice() + } + + if (mice::is.mira(data)) { + dv <- data[["analyses"]][[1]][["terms"]][[2]] + long <- purrr::map_dfr(1:length(data$analyses), ~ { + broom::augment(data$analyses[[.x]]) + }, .id="m") %>% + ### extract row numbers :/ :/ :/ + dplyr::mutate(row = rep(1:(nrow(.)/dplyr::n_distinct(m)), + dplyr::n_distinct(m))) %>% + tidyr::pivot_wider(id_cols = row, names_from = m, + values_from = c(dv, .fitted, .resid)) %>% + dplyr::rowwise() %>% + dplyr::summarize( + ### extract observed dat :/ :/ :/ + observed = ifelse(dplyr::n_distinct(dplyr::c_across( + dplyr::starts_with(rlang::as_string(dv))))==1, get(paste0(dv,"_1")), NA), + avg = mean(dplyr::c_across(dplyr::starts_with(".fitted"))), + vrn = stats::var(dplyr::c_across(dplyr::starts_with(".fitted"))) + ) + + legend <- "Imputation variability*\n " + caption <- + "*absolute prediction-level between imputation variance" + gg <- ggplot2::ggplot(long, ggplot2::aes(x = avg, y = observed, fill = vrn, size = vrn)) + + ggplot2::geom_point(color = gridcol, shape = 21) + + ggplot2::scale_fill_gradient(low = "white", high = mice::mdc(2), guide = "legend") + + ggplot2::labs( + x = paste("Average predicted", dv), + y = paste("Observed", dv), + fill = legend, + caption = caption + ) + + ggplot2::guides(size = FALSE, fill = "colorbar") + + theme_minimice() + } + + gg <- + gg + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA)) + + + # return the ggplot object + return(gg) +} + +# function to scale only non-zero values without centering +scale_above_zero <- function(x) { + x <- as.matrix(x) + x[x != 0] <- scale(x[x != 0], center = FALSE) + return(x) +} + diff --git a/man/plot_variance.Rd b/man/plot_variance.Rd new file mode 100644 index 00000000..1f245fca --- /dev/null +++ b/man/plot_variance.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot_variance.R +\name{plot_variance} +\alias{plot_variance} +\title{Plot the scaled between imputation variance for every cell as a heatmap} +\usage{ +plot_variance(data, grid = TRUE) +} +\arguments{ +\item{data}{A multiply imputed object of class \code{\link[mice:mids-class]{mice::mids}} or \code{\link[mice:mira-class]{mice::mira}}.} + +\item{grid}{Logical indicating whether grid lines should be displayed.} +} +\value{ +An object of class \code{ggplot}. +} +\description{ +This function plots the cell-level between imputation variance. The function +scales the variances column-wise, without centering cf. \code{base::scale(center = FALSE)} +and plots the data image as a heatmap. Darker red cells indicate more variance, +lighter cells indicate less variance. White cells represent observed cells or unobserved cells with zero between +imputation variance. +} +\examples{ +imp <- mice::mice(mice::nhanes, printFlag = FALSE) +plot_variance(imp) +}