Skip to content

Commit

Permalink
[R] Move gc data protection to R side (#11104)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Jan 2, 2025
1 parent 57ce062 commit 92d1bfe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
33 changes: 28 additions & 5 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ xgb.QuantileDMatrix <- function(
)
data_iterator <- .single.data.iterator(iterator_env)

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

# Note: the ProxyDMatrix has its finalizer assigned in the R externalptr
# object, but that finalizer will only be called once the object is
# garbage-collected, which doesn't happen immediately after it goes out
Expand All @@ -363,9 +366,10 @@ xgb.QuantileDMatrix <- function(
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(iterator_env))
}
calling_env <- environment()
Expand Down Expand Up @@ -553,7 +557,8 @@ xgb.DataBatch <- function(
}

# This is only for internal usage, class is not exposed to the user.
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) {
env_keep_alive$keepalive <- NULL
lst <- data_iterator$f_next(data_iterator$env)
if (is.null(lst)) {
return(0L)
Expand All @@ -566,13 +571,19 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
stop("Either one of 'group' or 'qid' should be NULL")
}
if (is.data.frame(lst$data)) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
data <- lst$data
lst$data <- NULL
tmp <- .process.df.for.dmatrix(data, lst$feature_types)
lst$feature_types <- tmp$feature_types
data <- NULL
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
} else if (is.matrix(lst$data)) {
env_keep_alive$keepalive <- lst
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data))
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp)
} else {
stop("'data' has unsupported type.")
Expand Down Expand Up @@ -712,14 +723,23 @@ xgb.ExtMemDMatrix <- function(
cache_prefix <- path.expand(cache_prefix)
nthread <- as.integer(NVL(nthread, -1L))

# The purpose of this environment is to keep data alive (protected from the
# garbage collector) after setting the data in the proxy dmatrix. The data
# held here (under name 'keepalive') should be unset (leaving it unprotected
# for garbage collection) before the start of each data iteration batch and
# during each iterator reset.
env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down Expand Up @@ -779,14 +799,17 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint

nthread <- as.integer(NVL(nthread, -1L))

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL
proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down
12 changes: 3 additions & 9 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
{
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
R_SetExternalPtrProtected(handle, R_mat);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -708,7 +707,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
array_str_indices.c_str(),
array_str_data.c_str(),
ncol);
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -722,7 +720,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -736,20 +733,17 @@ struct _RDataIterator {
SEXP f_reset;
SEXP calling_env;
SEXP continuation_token;
SEXP proxy_dmat;

_RDataIterator(
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}
continuation_token(continuation_token) {}

void reset() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
}

int next() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SEXP R_res = Rf_protect(
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
int res = Rf_asInteger(R_res);
Expand Down Expand Up @@ -777,7 +771,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(

int res_code;
try {
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);

std::string str_cache_prefix;
xgboost::Json jconfig{xgboost::Object{}};
Expand Down

0 comments on commit 92d1bfe

Please sign in to comment.