diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 101b8eae..ebe07a02 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -177,21 +177,27 @@ Lrnr_base <- R6Class( )) } }, - base_predict = function(task = NULL) { + base_predict = function(task = NULL) { self$assert_trained() if (is.null(task)) { task <- private$.training_task } - + assert_that(is(task, "sl3_Task")) task <- self$subset_covariates(task) task <- self$process_formula(task) - + predictions <- private$.predict(task) - ncols <- ncol(predictions) - if (!is.null(ncols) && (ncols == 1)) { - predictions <- as.vector(predictions) + if(!is.null(ncols) && (ncols == 1)) { + if(is.data.table(predictions)) { + # if a data.table of packed predictions, return a matrix. + predictions <- as.matrix(predictions) + } + # if not packed predictions, return vector + if(!inherits(predictions[[1]], "packed_predictions")) { + predictions <- unlist(predictions) + } } return(predictions) }, diff --git a/R/Lrnr_cv.R b/R/Lrnr_cv.R index c626278e..282f8d9e 100644 --- a/R/Lrnr_cv.R +++ b/R/Lrnr_cv.R @@ -174,6 +174,8 @@ Lrnr_cv <- R6Class( predictions <- self$predict_fold(revere_task, fold_number) + # This might not be a matrix + predictions <- as.data.table(predictions) # TODO: make same fixes made to chain here if (nrow(revere_task$data) != nrow(predictions)) { # Gather validation indexes: @@ -371,7 +373,7 @@ Lrnr_cv <- R6Class( list( index = index, fold_index = rep(fold_index(), length(index)), - predictions = data.table(predictions) + predictions = as.data.table(predictions) ) } @@ -392,9 +394,14 @@ Lrnr_cv <- R6Class( predictions <- aorder(preds, order(results$index, results$fold_index)) - # don't convert to vector if learner is stack, as stack won't + + # don't convert to vector if learner is stack, as stack won't if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) { - predictions <- unlist(predictions) + # if packed_predictions dont unlist + if(is.data.table(predictions)) predictions <- as.matrix(predictions) + if(!inherits(predictions[[1]], "packed_predictions")) { + predictions <- as.vector(predictions) + } } return(predictions) }, diff --git a/R/Lrnr_gam.R b/R/Lrnr_gam.R index c52ec44e..24515ca3 100644 --- a/R/Lrnr_gam.R +++ b/R/Lrnr_gam.R @@ -77,7 +77,7 @@ Lrnr_gam <- R6Class( } ), private = list( - .properties = c("continuous", "binomial"), + .properties = c("continuous", "binomial", "weights"), .train = function(task) { # load args args <- self$params @@ -87,6 +87,7 @@ Lrnr_gam <- R6Class( Y <- data.frame(outcome_type$format(task$Y)) colnames(Y) <- task$nodes$outcome args$data <- cbind(task$X, Y) + args$weights <- task$weights ## family if (is.null(args$family)) { if (outcome_type$type == "continuous") { diff --git a/R/survival_utils.R b/R/survival_utils.R index dd11429c..eb36e1ef 100644 --- a/R/survival_utils.R +++ b/R/survival_utils.R @@ -32,11 +32,16 @@ pooled_hazard_task <- function(task, trim = TRUE) { repeated_data <- underlying_data[index, ] new_folds <- origami::id_folds_to_folds(task$folds, index) - repeated_task <- task$next_in_chain( - column_names = column_names, - data = repeated_data, id = "id", - folds = new_folds - ) + nodes <- task$nodes + nodes$id <- "id" + repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type$type) + + # The below errors when used in CV due to the stored row index not being reset in next_in_chain. + #repeated_task <- task$next_in_chain( + # column_names = column_names, + #data = repeated_data, id = "id", + #folds = new_folds + #) # make bin indicators bin_number <- rep(level_index, each = task$nrow)