Skip to content

Commit

Permalink
add option to make square
Browse files Browse the repository at this point in the history
  • Loading branch information
pepijnvink committed Nov 16, 2023
1 parent d49e9eb commit e8b0d0b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 6 deletions.
11 changes: 9 additions & 2 deletions R/plot_miss.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#' @param border Logical indicating whether borders should be present between tiles.
#' @param row.breaks Optional numeric input specifying the number of breaks to be visualized on the y axis.
#' @param ordered Logical indicating whether rows should be ordered according to their pattern.
#' @param square Logical indicating whether the plot tiles should be squares, defaults to squares.
#'
#' @return An object of class [ggplot2::ggplot].
#'
Expand All @@ -17,6 +18,7 @@ plot_miss <-
vrb = "all",
border = FALSE,
row.breaks = nrow(data),

Check warning on line 20 in R/plot_miss.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_miss.R,line=20,col=12,[object_name_linter] Variable and function name style should match snake_case or symbols.
square = TRUE,
ordered = FALSE) {
# input processing
if (is.matrix(data) && ncol(data) > 1) {
Expand All @@ -40,7 +42,7 @@ plot_miss <-
if(ordered){

Check warning on line 42 in R/plot_miss.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_miss.R,line=42,col=7,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.

Check warning on line 42 in R/plot_miss.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_miss.R,line=42,col=16,[brace_linter] There should be a space before an opening curly brace.

Check warning on line 42 in R/plot_miss.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_miss.R,line=42,col=16,[paren_body_linter] There should be a space between a right parenthesis and a body expression.
# extract md.pattern matrix
mdpat <- mice::md.pattern(data, plot = FALSE) %>%
head(., -1)
utils::head(., -1)

Check warning on line 45 in R/plot_miss.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_miss.R,line=45,col=21,[object_usage_linter] no visible binding for global variable '.'
# save frequency of patterns
freq.pat <- rownames(mdpat) %>%
as.numeric()
Expand Down Expand Up @@ -103,13 +105,18 @@ plot_miss <-
fill = "",
alpha = ""
) +
ggplot2::coord_cartesian(expand = FALSE) +
theme_minimice()
# additional arguments
if(border){
gg <- gg + ggplot2::geom_tile(color = "black")
} else{
gg <- gg + ggplot2::geom_tile()
}
if (square) {
gg <- gg + ggplot2::coord_fixed(expand = FALSE)
} else {
gg <- gg + ggplot2::coord_cartesian(expand = FALSE)
}
if(ordered){
gg <- gg +
ggplot2::theme(axis.text.y = ggplot2::element_blank(),
Expand Down
71 changes: 71 additions & 0 deletions R/plot_variance.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' Plot the scaled between imputation variance for every cell as a heatmap
#'
#' This function plots the cell-level between imputation variance. The function
#' scales the variances column-wise, without centering cf. `base::scale(center = FALSE)`
#' and plots the data image as a heatmap. Darker red cells indicate more variance,
#' lighter cells indicate less variance. White cells represent observed cells or unobserved cells with zero between
#' imputation variance.
#'
#' @param data A package `mice` generated multiply imputed data set of class
#' `mids`. Non-`mids` objects that have not been generated with `mice::mice()`
#' can be converted through a pipeline with `mice::as.mids()`.
#' @param grid Logical indicating whether grid lines should be displayed.
#'
#' @return An object of class `ggplot`.
#' @examples
#' imp <- mice::mice(mice::nhanes, printFlag = FALSE)
#' plot_variance(imp)
#' @export
plot_variance <- function(data, grid = TRUE) {
verify_data(data, imp = TRUE)
if (data$m < 2) {
cli::cli_abort(
c(
"The between imputation variance cannot be computed if there are fewer than two imputations (m < 2).",
"i" = "Please provide an object with 2 or more imputations"
)
)
}
if (grid) {
gridcol <- "black"
} else {
gridcol <- NA
}

gg <- mice::complete(data, "long") %>%
dplyr::mutate(dplyr::across(where(is.factor), as.numeric)) %>%
dplyr::select(-.imp) %>%
dplyr::group_by(.id) %>%
dplyr::summarise(dplyr::across(dplyr::everything(), stats::var)) %>%
dplyr::ungroup() %>%
dplyr::mutate(dplyr::across(.cols = -.id, ~ scale_above_zero(.))) %>%
tidyr::pivot_longer(cols = -.id) %>%
ggplot2::ggplot(ggplot2::aes(name, .id, fill = value)) +
ggplot2::geom_tile(color = gridcol) +
ggplot2::scale_fill_gradient(low = "white", high = mice::mdc(2)) +
ggplot2::labs(
x = "Column name",
y = "Row number",
fill = "Imputation variability*
",
caption = "*scaled cell-level between imputation variance"
) + # "Cell-level between imputation\nvariance (scaled)\n\n"
ggplot2::scale_x_discrete(position = "top", expand = c(0, 0)) +
ggplot2::scale_y_continuous(trans = "reverse", expand = c(0, 0)) +
theme_minimice()

if (!grid) {
gg <-
gg + ggplot2::theme(panel.border = ggplot2::element_rect(fill = NA))
}

# return the ggplot object
return(gg)
}

# function to scale only non-zero values without centering
scale_above_zero <- function(x) {
x <- as.matrix(x)
x[x != 0] <- scale(x[x != 0], center = FALSE)
return(x)
}
3 changes: 3 additions & 0 deletions man/plot_miss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/plot_variance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions tests/testthat/test-plot_miss.R.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dat <- mice::nhanes
# tests
test_that("plot_miss produces plot", {
expect_s3_class(plot_miss(dat), "ggplot")
expect_s3_class(plot_miss(dat), "ggplot")
expect_s3_class(plot_miss(dat, border = TRUE, ordered = T, row.breaks = 25, square = TRUE), "ggplot")
expect_s3_class(plot_miss(cbind(dat, "testvar" = NA)), "ggplot")
})

Expand All @@ -17,10 +17,8 @@ test_that("plot_miss works with different inputs", {


test_that("plot_miss with incorrect argument(s)", {
expect_output(plot_miss(na.omit(dat)))
expect_s3_class(plot_miss(na.omit(dat)), "ggplot")
expect_error(plot_miss("test"))
expect_error(plot_miss(dat, vrb = "test"))
expect_error(plot_miss(dat, cluster = "test"))
expect_error(plot_miss(cbind(dat, .x = NA)))
expect_error(plot_miss(dat, npat = "test"))
})

0 comments on commit e8b0d0b

Please sign in to comment.