Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Move gc data protection to R side #11104

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ xgb.QuantileDMatrix <- function(
)
data_iterator <- .single.data.iterator(iterator_env)

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

# Note: the ProxyDMatrix has its finalizer assigned in the R externalptr
# object, but that finalizer will only be called once the object is
# garbage-collected, which doesn't happen immediately after it goes out
Expand All @@ -363,9 +366,10 @@ xgb.QuantileDMatrix <- function(
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(iterator_env))
}
calling_env <- environment()
Expand Down Expand Up @@ -553,7 +557,8 @@ xgb.DataBatch <- function(
}

# This is only for internal usage, class is not exposed to the user.
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) {
env_keep_alive$keepalive <- NULL
lst <- data_iterator$f_next(data_iterator$env)
if (is.null(lst)) {
return(0L)
Expand All @@ -566,13 +571,19 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
stop("Either one of 'group' or 'qid' should be NULL")
}
if (is.data.frame(lst$data)) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
data <- lst$data
lst$data <- NULL
tmp <- .process.df.for.dmatrix(data, lst$feature_types)
lst$feature_types <- tmp$feature_types
data <- NULL
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
} else if (is.matrix(lst$data)) {
env_keep_alive$keepalive <- lst
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data))
env_keep_alive$keepalive <- tmp
.Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp)
} else {
stop("'data' has unsupported type.")
Expand Down Expand Up @@ -707,14 +718,23 @@ xgb.ExtMemDMatrix <- function(
cache_prefix <- path.expand(cache_prefix)
nthread <- as.integer(NVL(nthread, -1L))

# The purpose of this environment is to keep data alive (protected from the
# garbage collector) after setting the data in the proxy dmatrix. The data
# held here (under name 'keepalive') should be unset (leaving it unprotected
# for garbage collection) before the start of each data iteration batch and
# during each iterator reset.
env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL

proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down Expand Up @@ -774,14 +794,17 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint

nthread <- as.integer(NVL(nthread, -1L))

env_keep_alive <- new.env()
env_keep_alive$keepalive <- NULL
proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down
12 changes: 3 additions & 9 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
{
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
R_SetExternalPtrProtected(handle, R_mat);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -708,7 +707,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
array_str_indices.c_str(),
array_str_data.c_str(),
ncol);
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -722,7 +720,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -736,20 +733,17 @@ struct _RDataIterator {
SEXP f_reset;
SEXP calling_env;
SEXP continuation_token;
SEXP proxy_dmat;

_RDataIterator(
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}
continuation_token(continuation_token) {}

void reset() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
}

int next() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SEXP R_res = Rf_protect(
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
int res = Rf_asInteger(R_res);
Expand Down Expand Up @@ -777,7 +771,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(

int res_code;
try {
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);

std::string str_cache_prefix;
xgboost::Json jconfig{xgboost::Object{}};
Expand Down
Loading