diff --git a/R/plot_pred.R b/R/plot_pred.R index 2cc425c7..826b0156 100644 --- a/R/plot_pred.R +++ b/R/plot_pred.R @@ -1,6 +1,6 @@ #' Plot the predictor matrix of an imputation model #' -#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred]. +#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred], or an object of class [`mice::mids`]. #' @param vrb String, vector, or unquoted expression with variable name(s), default is "all". #' @param method Character string or vector with imputation methods. #' @param label Logical indicating whether predictor matrix values should be displayed. @@ -20,28 +20,38 @@ plot_pred <- label = TRUE, square = TRUE, rotate = FALSE) { - verify_data(data, pred = TRUE) + verify_data(data, pred = TRUE, imp = TRUE) + if (mice::is.mids(data)) { + method <- data$method + data <- data$predictorMatrix + } p <- nrow(data) - if (!is.null(method) && is.character(method)) { - if (length(method) == 1) { - method <- rep(method, p) + if (!mice::is.mids(data)) { + if (!is.null(method) && is.character(method)) { + if (length(method) == 1) { + method <- rep(method, p) + } + if (length(method) == p) { + ylabel <- "Imputation method" + } } - if (length(method) == p) { - ylabel <- "Imputation method" + if (is.null(method)) { + method <- rep("", p) + ylabel <- "" + } + if (!is.character(method) || length(method) != p) { + cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).") } - } - if (is.null(method)) { - method <- rep("", p) - ylabel <- "" - } - if (!is.character(method) || length(method) != p) { - cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).") } vrb <- substitute(vrb) if (vrb[1] == "all") { vrb <- names(data) } else { - vrb <- names(dplyr::select(as.data.frame(data), {{vrb}})) + vrb <- names(dplyr::select(as.data.frame(data), { + { + vrb + } + })) } vrbs <- row.names(data) long <- data.frame( diff --git a/R/utils.R b/R/utils.R index 67fcf484..076569ca 100644 --- a/R/utils.R +++ b/R/utils.R @@ -62,7 +62,7 @@ verify_data <- function(data, ) } } - if (imp && !df) { + if (imp && !df && !pred) { if (!mice::is.mids(data)) { cli::cli_abort( c( @@ -73,7 +73,18 @@ verify_data <- function(data, ) } } - if (pred) { + if (imp && pred){ + if (!(is.matrix(data) || mice::is.mids(data))) { + cli::cli_abort( + c( + "The 'data' argument requires an object of class 'matrix', or 'mids'.", + "i" = "Input object is of class {class(data)}." + ), + call. = FALSE + ) + } + } + if (pred && !imp) { if (!is.matrix(data)) { cli::cli_abort( c( diff --git a/man/plot_pred.Rd b/man/plot_pred.Rd index c71970a6..2d5c98f9 100644 --- a/man/plot_pred.Rd +++ b/man/plot_pred.Rd @@ -14,7 +14,7 @@ plot_pred( ) } \arguments{ -\item{data}{A predictor matrix for \code{mice}, typically generated with \link[mice:make.predictorMatrix]{mice::make.predictorMatrix} or \link[mice:quickpred]{mice::quickpred}.} +\item{data}{A predictor matrix for \code{mice}, typically generated with \link[mice:make.predictorMatrix]{mice::make.predictorMatrix} or \link[mice:quickpred]{mice::quickpred}, or an object of class \code{\link[mice:mids-class]{mice::mids}}.} \item{vrb}{String, vector, or unquoted expression with variable name(s), default is "all".} diff --git a/tests/testthat/test-plot_pred.R b/tests/testthat/test-plot_pred.R index 57d72649..ff0c9aa1 100644 --- a/tests/testthat/test-plot_pred.R +++ b/tests/testthat/test-plot_pred.R @@ -1,6 +1,7 @@ # create test objects dat <- mice::nhanes pred <- mice::quickpred(dat) +imp <- mice::mice(dat, printFlag = FALSE) # tests test_that("plot_pred creates ggplot object", { @@ -18,6 +19,7 @@ test_that("plot_pred creates ggplot object", { expect_s3_class(plot_pred(rbind( cbind(pred, "with space" = 0), "with space" = 0 )), "ggplot") + expect_s3_class(plot_pred(imp, vrb = c("age", "bmi")), "ggplot") }) test_that("plot_pred with incorrect argument(s)", {