Skip to content

Commit

Permalink
[jvm-packages] Supports external memory
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jan 7, 2025
1 parent bd92b1c commit 20a36f8
Show file tree
Hide file tree
Showing 15 changed files with 833 additions and 222 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,6 +86,16 @@ private List<CudfColumn> initializeCudfColumns(Table table) {
.collect(Collectors.toList());
}

// visible for testing
public Table getFeatureTable() {
return featureTable;
}

// visible for testing
public Table getLabelTable() {
return labelTable;
}

public List<CudfColumn> getFeatures() {
return features;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,14 +28,16 @@ public class QuantileDMatrix extends DMatrix {
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
this(iter, null, missing, maxBin, nthread);
int nthread,
boolean useExternalMemory) throws XGBoostError {
this(iter, null, missing, maxBin, nthread, useExternalMemory);
}

/**
Expand All @@ -50,17 +52,19 @@ public QuantileDMatrix(
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
QuantileDMatrix refDMatrix,
float missing,
int maxBin,
int nthread) throws XGBoostError {
int nthread,
boolean useExternalMemory) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
String conf = getConfig(missing, maxBin, nthread, useExternalMemory);
long[] ref = null;
if (refDMatrix != null) {
ref = new long[1];
Expand Down Expand Up @@ -111,9 +115,9 @@ public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}

private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
private String getConfig(float missing, int maxBin, int nthread, boolean useExternalMemory) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d," +
"\"use_ext_mem\":%b}", missing, maxBin, nthread, useExternalMemory);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2024 by Contributors
Copyright (c) 2021-2025 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,31 +30,39 @@ class QuantileDMatrix private[scala](
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
def this(iter: Iterator[ColumnBatch],
missing: Float,
maxBin: Int,
nthread: Int,
useExternalMemory: Boolean) {
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread, useExternalMemory))
}

/**
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* @param ref The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param useExternalMemory whether to use external memory or not
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch],
ref: QuantileDMatrix,
missing: Float,
maxBin: Int,
nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
nthread: Int,
useExternalMemory: Boolean) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread,
useExternalMemory))
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
Copyright (c) 2025 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import java.io.File
import java.nio.file.{Files, Paths}

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf._

import ml.dmlc.xgboost4j.java.{ColumnBatch, CudfColumnBatch}
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource

private[spark] trait ExternalMemory[T] extends Iterator[Table] with AutoCloseable {

protected val buffers = ArrayBuffer.empty[T]
private lazy val buffersIterator = buffers.toIterator

/**
* Convert the table to T which will be cached
*
* @param table to be converted
* @return the content
*/
def convertTable(table: Table): T

/**
* Load the content to the Table
*
* @param content to be loaded
* @return Table
*/
def loadTable(content: T): Table

// Cache the table
def cacheTable(table: Table): Unit = {
val content = convertTable(table)
buffers.append(content)
}

override def hasNext: Boolean = buffersIterator.hasNext

override def next(): Table = loadTable(buffersIterator.next())

override def close(): Unit = {}
}

// The data will be cached into disk.
private[spark] class DiskExternalMemoryIterator(val path: String) extends ExternalMemory[String] {

private lazy val root = {
val tmp = path + "/xgboost"
createDirectory(tmp)
tmp
}

private var counter = 0

private def createDirectory(dirPath: String): Unit = {
val path = Paths.get(dirPath)
if (!Files.exists(path)) {
Files.createDirectories(path)
} else {
}
}

/**
* Convert the table to file path which will be cached
*
* @param table to be converted
* @return the content
*/
override def convertTable(table: Table): String = {
val names = (1 to table.getNumberOfColumns).map(_.toString)
val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build()
val path = root + "/table_" + counter + "_" + System.nanoTime();
counter += 1
withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer =>
writer.write(table)
}
path
}

private def closeOnExcept[T <: AutoCloseable, V](r: ArrayBuffer[T])
(block: ArrayBuffer[T] => V): V = {
try {
block(r)
} catch {
case t: Throwable =>
r.foreach(_.close())
throw t
}
}

/**
* Load the path from disk to the Table
*
* @param name to be loaded
* @return Table
*/
override def loadTable(name: String): Table = {
val file = new File(name)
if (!file.exists()) {
throw new RuntimeException(s"The cached file ${name} not exist" )
}
try {
withResource(Table.readArrowIPCChunked(file)) { reader =>
val tables = ArrayBuffer.empty[Table]
closeOnExcept(tables) { tables =>
var table = Option(reader.getNextIfAvailable())
while (table.isDefined) {
tables.append(table.get)
table = Option(reader.getNextIfAvailable())
}
}
if (tables.size > 1) {
closeOnExcept(tables) { tables =>
Table.concatenate(tables.toArray: _*)
}
} else {
tables(0)
}
}
} catch {
case e: Throwable =>
close()
throw e
} finally {
if (file.exists()) {
file.delete()
}
}
}

override def close(): Unit = {
buffers.foreach { path =>
val file = new File(path)
if (file.exists()) {
file.delete()
}
}
buffers.clear()
}
}

private[spark] object ExternalMemory {
def apply(path: Option[String] = None): ExternalMemory[_] = {
path.map(new DiskExternalMemoryIterator(_))
.getOrElse(throw new RuntimeException("No disk path provided"))
}
}

/**
* ExternalMemoryIterator supports iterating the data twice if the `swap` is called.
*
* The first round iteration gets the input batch that will be
* 1. cached in the external memory
* 2. fed in QuantilDmatrix
* The second round iteration returns the cached batch got from external memory.
*
* @param input the spark input iterator
* @param indices column index
*/
private[scala] class ExternalMemoryIterator(val input: Iterator[Table],
val indices: ColumnIndices,
val path: Option[String] = None)
extends Iterator[ColumnBatch] {

private var iter = input

// Flag to indicate the input has been consumed.
private var inputIsConsumed = false
// Flag to indicate the input.next has been called which is valid
private var inputNextIsCalled = false

// visible for testing
private[spark] val externalMemory = ExternalMemory(path)

override def hasNext: Boolean = {
val value = iter.hasNext
if (!value && inputIsConsumed && inputNextIsCalled) {
externalMemory.close()
}
if (!inputIsConsumed && !value && inputNextIsCalled) {
inputIsConsumed = true
iter = externalMemory
}
value
}

override def next(): ColumnBatch = {
inputNextIsCalled = true
withResource(new GpuColumnBatch(iter.next())) { batch =>
if (iter == input) {
externalMemory.cacheTable(batch.table)
}
new CudfColumnBatch(
batch.select(indices.featureIds.get),
batch.select(indices.labelId),
batch.select(indices.weightId.getOrElse(-1)),
batch.select(indices.marginId.getOrElse(-1)),
batch.select(indices.groupId.getOrElse(-1)));
}
}

}
Loading

0 comments on commit 20a36f8

Please sign in to comment.