Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Supports external memory #11127

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion jvm-packages/create_jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,17 @@ def native_build(cli_args: argparse.Namespace) -> None:
CONFIG["USE_DLOPEN_NCCL"] = "OFF"

args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
if sys.platform != "win32":
try:
subprocess.check_call(["ninja", "--version"])
args.append("-GNinja")
except FileNotFoundError:
pass

# if enviorment set GPU_ARCH_FLAG
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
if gpu_arch_flag is not None:
args.append("%s" % gpu_arch_flag)
args.append("-DCMAKE_CUDA_ARCHITECTURES=%s" % gpu_arch_flag)

with cd(build_dir):
lib_dir = os.path.join(os.pardir, "lib")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,6 +86,16 @@ private List<CudfColumn> initializeCudfColumns(Table table) {
.collect(Collectors.toList());
}

// visible for testing
public Table getFeatureTable() {
return featureTable;
}

// visible for testing
public Table getLabelTable() {
return labelTable;
}

public List<CudfColumn> getFeatures() {
return features;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,40 @@
*/
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.util.Iterator;
import java.util.Map;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;

class F64NaNSerializer extends JsonSerializer<Double> {
@Override
public void serialize(Double value, JsonGenerator gen,
SerializerProvider serializers) throws IOException {
if (value.isNaN()) {
gen.writeRawValue("NaN"); // Write NaN without quotes
} else {
gen.writeNumber(value);
}
}
}

class F32NaNSerializer extends JsonSerializer<Float> {
@Override
public void serialize(Float value, JsonGenerator gen,
SerializerProvider serializers) throws IOException {
if (value.isNaN()) {
gen.writeRawValue("NaN"); // Write NaN without quotes
} else {
gen.writeNumber(value);
}
}
}

/**
* QuantileDMatrix will only be used to train
Expand All @@ -28,14 +61,16 @@ public class QuantileDMatrix extends DMatrix {
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
this(iter, null, missing, maxBin, nthread);
int nthread,
boolean useExternalMemory) throws XGBoostError {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a new structure for external memory instead of reusing the class?

this(iter, null, missing, maxBin, nthread, useExternalMemory);
}

/**
Expand All @@ -50,17 +85,19 @@ public QuantileDMatrix(
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
QuantileDMatrix refDMatrix,
float missing,
int maxBin,
int nthread) throws XGBoostError {
int nthread,
boolean useExternalMemory) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
String conf = getConfig(missing, maxBin, nthread, useExternalMemory);
long[] ref = null;
if (refDMatrix != null) {
ref = new long[1];
Expand Down Expand Up @@ -111,9 +148,25 @@ public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}

private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
private String getConfig(float missing, int maxBin, int nthread, boolean useExternalMemory) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);
conf.put("use_ext_mem", useExternalMemory);
ObjectMapper mapper = new ObjectMapper();

// Handle NaN values. Jackson by default serializes NaN values into strings.
SimpleModule module = new SimpleModule();
module.addSerializer(Double.class, new F64NaNSerializer());
module.addSerializer(Float.class, new F32NaNSerializer());
mapper.registerModule(module);

try {
String config = mapper.writeValueAsString(conf);
return config;
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,31 +30,39 @@ class QuantileDMatrix private[scala](
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
def this(iter: Iterator[ColumnBatch],
missing: Float,
maxBin: Int,
nthread: Int,
useExternalMemory: Boolean) {
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread, useExternalMemory))
}

/**
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* @param ref The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch],
ref: QuantileDMatrix,
missing: Float,
maxBin: Int,
nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
nthread: Int,
useExternalMemory: Boolean) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread,
useExternalMemory))
}

/**
Expand Down
Loading
Loading