Skip to content

Commit

Permalink
add plot_pred compatibility with mids object and add test
Browse files Browse the repository at this point in the history
Also edit verify_data function to work with this.
  • Loading branch information
pepijnvink committed Dec 7, 2023
1 parent 6a69f13 commit e367963
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
40 changes: 25 additions & 15 deletions R/plot_pred.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Plot the predictor matrix of an imputation model
#'
#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred].
#' @param data A predictor matrix for `mice`, typically generated with [mice::make.predictorMatrix] or [mice::quickpred], or an object of class [`mice::mids`].

Check warning on line 3 in R/plot_pred.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_pred.R,line=3,col=151,[line_length_linter] Lines should not be more than 150 characters. This line is 159 characters.
#' @param vrb String, vector, or unquoted expression with variable name(s), default is "all".
#' @param method Character string or vector with imputation methods.
#' @param label Logical indicating whether predictor matrix values should be displayed.
Expand All @@ -20,28 +20,38 @@ plot_pred <-
label = TRUE,
square = TRUE,
rotate = FALSE) {
verify_data(data, pred = TRUE)
verify_data(data, pred = TRUE, imp = TRUE)
if (mice::is.mids(data)) {
method <- data$method
data <- data$predictorMatrix
}
p <- nrow(data)
if (!is.null(method) && is.character(method)) {
if (length(method) == 1) {
method <- rep(method, p)
if (!mice::is.mids(data)) {
if (!is.null(method) && is.character(method)) {
if (length(method) == 1) {
method <- rep(method, p)
}
if (length(method) == p) {
ylabel <- "Imputation method"
}
}
if (length(method) == p) {
ylabel <- "Imputation method"
if (is.null(method)) {
method <- rep("", p)
ylabel <- ""
}
if (!is.character(method) || length(method) != p) {
cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).")
}
}
if (is.null(method)) {
method <- rep("", p)
ylabel <- ""
}
if (!is.character(method) || length(method) != p) {
cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).")
}
vrb <- substitute(vrb)
if (vrb[1] == "all") {
vrb <- names(data)
} else {
vrb <- names(dplyr::select(as.data.frame(data), {{vrb}}))
vrb <- names(dplyr::select(as.data.frame(data), {
{

Check warning on line 51 in R/plot_pred.R

View workflow job for this annotation

GitHub Actions / lint

file=R/plot_pred.R,line=51,col=9,[brace_linter] Opening curly braces should never go on their own line and should always be followed by a new line.
vrb
}
}))
}
vrbs <- row.names(data)
long <- data.frame(
Expand Down
15 changes: 13 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ verify_data <- function(data,
)
}
}
if (imp && !df) {
if (imp && !df && !pred) {
if (!mice::is.mids(data)) {
cli::cli_abort(
c(
Expand All @@ -73,7 +73,18 @@ verify_data <- function(data,
)
}
}
if (pred) {
if (imp && pred){

Check warning on line 76 in R/utils.R

View workflow job for this annotation

GitHub Actions / lint

file=R/utils.R,line=76,col=19,[brace_linter] There should be a space before an opening curly brace.
if (!(is.matrix(data) || mice::is.mids(data))) {
cli::cli_abort(
c(
"The 'data' argument requires an object of class 'matrix', or 'mids'.",
"i" = "Input object is of class {class(data)}."
),
call. = FALSE
)
}
}
if (pred && !imp) {
if (!is.matrix(data)) {
cli::cli_abort(
c(
Expand Down
2 changes: 1 addition & 1 deletion man/plot_pred.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/testthat/test-plot_pred.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# create test objects
dat <- mice::nhanes
pred <- mice::quickpred(dat)
imp <- mice::mice(dat, printFlag = FALSE)

# tests
test_that("plot_pred creates ggplot object", {
Expand All @@ -18,6 +19,7 @@ test_that("plot_pred creates ggplot object", {
expect_s3_class(plot_pred(rbind(
cbind(pred, "with space" = 0), "with space" = 0
)), "ggplot")
expect_s3_class(plot_pred(imp, vrb = c("age", "bmi")), "ggplot")
})

test_that("plot_pred with incorrect argument(s)", {
Expand Down

0 comments on commit e367963

Please sign in to comment.