Skip to content

Commit

Permalink
initial support for fit size reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyrcoyle committed Oct 1, 2021
1 parent 20834ae commit be37523
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 5 deletions.
51 changes: 51 additions & 0 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
}
new_object <- self$clone() # copy parameters, and whatever else
new_object$set_train(fit_object, task)
if (getOption("sl3.reduce_fit")) {
new_object$reduce_fit(check_preds = FALSE)
}
return(new_object)
},
set_train = function(fit_object, training_task) {
Expand Down Expand Up @@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
} else {
return(task)
}
},
reduce_fit = function(fit_object = NULL, check_preds = TRUE, set_train = TRUE) {
if (is.null(fit_object)) {
fit_object <- self$fit_object
}
if (check_preds) {
preds_full <- self$predict(task)
}


# try reducing the size
size_full <- true_obj_size(fit_object)

# see what's taking up the space
# element_sizes <- sapply(fo, true_obj_size)
# ranked <- sort(element_sizes/size_full, decreasing = TRUE)

# by default, drop out call
# within(fit_object, rm(private$.fit_can_remove))
keep <- setdiff(names(fit_object), private$.fit_can_remove)

# gotta preserve the attributes (not sure why they're getting dropped)
attrs <- attributes(fit_object)
attrs$names <- keep
reduced <- fit_object[keep]
attributes(reduced) <- attrs
fit_object <- reduced
size_reduced <- true_obj_size(fit_object)
reduction_percent <- 1 - size_reduced / size_full

if (getOption("sl3.verbose")) {
message(sprintf("Fit object size reduced %0.0f%%", 100 * reduction_percent))
}


if (set_train) {
self$set_train(fit_object, self$training_task)
}


# verify prediction still works
if (check_preds) {
preds_reduced <- self$predict(task)
assert_that(all.equal(preds_full, preds_reduced))
}

return(fit_object)
}
),
active = list(
Expand Down Expand Up @@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
.required_packages = NULL,
.properties = list(),
.custom_chain = NULL,
.fit_can_remove = c("call"),
.train_sublearners = function(task) {
# train sublearners here
return(NULL)
Expand Down
1 change: 1 addition & 0 deletions R/Lrnr_glm_fast.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Lrnr_glm_fast <- R6Class(
}
return(predictions)
},
.fit_can_remove = c("XTX"),
.required_packages = c("speedglm")
)
)
3 changes: 1 addition & 2 deletions R/Lrnr_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class(

return(fit_object)
},

.predict = function(task = NULL) {
predictions <- stats::predict(
self$fit_object,
Expand All @@ -111,7 +110,7 @@ Lrnr_hal9001 <- R6Class(
}
return(predictions)
},

.fit_can_remove = c("lasso_fit", "x_basis"),
.required_packages = c("hal9001", "glmnet")
)
)
1 change: 1 addition & 0 deletions R/Lrnr_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Lrnr_xgboost <- R6Class(

return(predictions)
},
.fit_can_remove = c("raw", "call"),
.required_packages = c("xgboost")
)
)
10 changes: 10 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ true_obj_size <- function(obj) {
length(serialize(obj, NULL))
}

#' @keywords internal
check_fit_sizes <- function(fit) {
fo <- fit$fit_object
# see what's taking up the space
element_sizes <- sapply(fo, true_obj_size)
ranked <- sort(element_sizes / sum(element_sizes), decreasing = TRUE)

return(ranked)
}

################################################################################

#' Drop components from learner fits
Expand Down
6 changes: 3 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ sl3Options <- function(o, value) {
}
if (is.null(value)) {
res[o] <- list(NULL)
}
else {
} else {
res[[o]] <- value
}
options(res[o])
Expand All @@ -62,7 +61,8 @@ sl3Options <- function(o, value) {
"sl3.pcontinuous" = 0.05,
"sl3.max_p_missing" = 0.5,
"sl3.transform.offset" = TRUE,
"sl3.enable.future" = TRUE
"sl3.enable.future" = TRUE,
"sl3.reduce_fit" = FALSE
)
# for (i in setdiff(names(opts),names(options()))) {
# browser()
Expand Down
44 changes: 44 additions & 0 deletions tests/testthat/test-reduce_fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

set.seed(1234)

# TODO: maybe check storage at different n to get rate
n <- 1e3
p <- 100
# these two define the DGP (randomly)
p_X <- runif(p, 0.2, 0.8)
beta <- rnorm(p)

# simulate from the DGP
X <- sapply(p_X, function(p_Xi) rbinom(n, 1, p_Xi))
p_Yx <- plogis(X %*% beta)
Y <- rbinom(n, 1, p_Yx)
data <- data.table(X, Y)

# generate the sl3 task and learner
outcome <- "Y"
covariates <- setdiff(names(data), outcome)
task <- make_sl3_Task(data, covariates, outcome)

options(sl3.verbose = TRUE)
options(sl3.reduce_fit = TRUE)
test_reduce_fit <- function(learner) {
fit <- learner$train(task)
print(sl3:::check_fit_sizes(fit))
if (!getOption("sl3.reduce_fit")) {
# if we aren't automatically reducing, do it manually
fit_object <- fit$reduce_fit()
}

still_present <- intersect(
names(fit$fit_object),
fit$.__enclos_env__$private$.fit_can_remove
)

expect_equal(length(still_present), 0)
}

test_reduce_fit(make_learner(Lrnr_glmnet))
test_reduce_fit(make_learner(Lrnr_ranger))
test_reduce_fit(make_learner(Lrnr_glm_fast))
test_reduce_fit(make_learner(Lrnr_xgboost))
test_reduce_fit(make_learner(Lrnr_hal9001))

0 comments on commit be37523

Please sign in to comment.