Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Update the iterator demo. #11222

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
- rmm
- python-cuda

.. seealso::

:ref:`sphx_glr_python_examples_distributed_extmem_basic.py`

"""

import argparse
Expand Down
51 changes: 30 additions & 21 deletions demo/guide-python/quantile_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@
.. versionadded:: 1.2.0

The demo that defines a customized iterator for passing batches of data into
:py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for
training. The feature is used primarily designed to reduce the required GPU
memory for training on distributed environment.
:py:class:`xgboost.QuantileDMatrix` and use this ``QuantileDMatrix`` for training. The
feature is primarily designed to reduce the required GPU memory for training on
distributed environment.

Aftering going through the demo, one might ask why don't we use more native
Python iterator? That's because XGBoost requires a `reset` function, while
using `itertools.tee` might incur significant memory usage according to:
Aftering going through the demo, one might ask why don't we use more native Python
iterator? That's because XGBoost requires a `reset` function, while using
`itertools.tee` might incur significant memory usage according to:

https://docs.python.org/3/library/itertools.html#itertools.tee.

.. seealso::

:ref:`sphx_glr_python_examples_external_memory.py`

"""

from typing import Callable

import cupy
import numpy

Expand All @@ -35,7 +41,7 @@ class IterForDMatrixDemo(xgboost.core.DataIter):

"""

def __init__(self):
def __init__(self) -> None:
"""Generate some random data for demostration.

Actual data can be anything that is currently supported by XGBoost.
Expand All @@ -50,41 +56,44 @@ def __init__(self):
self.it = 0 # set iterator to 0
super().__init__()

def as_array(self):
def as_array(self) -> cupy.ndarray:
return cupy.concatenate(self._data)

def as_array_labels(self):
def as_array_labels(self) -> cupy.ndarray:
return cupy.concatenate(self._labels)

def as_array_weights(self):
def as_array_weights(self) -> cupy.ndarray:
return cupy.concatenate(self._weights)

def data(self):
def data(self) -> cupy.ndarray:
"""Utility function for obtaining current batch of data."""
return self._data[self.it]

def labels(self):
def labels(self) -> cupy.ndarray:
"""Utility function for obtaining current batch of label."""
return self._labels[self.it]

def weights(self):
def weights(self) -> cupy.ndarray:
return self._weights[self.it]

def reset(self):
def reset(self) -> None:
"""Reset the iterator"""
self.it = 0

def next(self, input_data):
"""Yield next batch of data."""
def next(self, input_data: Callable) -> bool:
"""Yield the next batch of data."""
if self.it == len(self._data):
# Return 0 when there's no more batch.
return 0
# Return False to let XGBoost know this is the end of iteration
return False

# input_data is a keyword-only function passed in by XGBoost and has the similar
# signature to the ``DMatrix`` constructor.
input_data(data=self.data(), label=self.labels(), weight=self.weights())
self.it += 1
return 1
return True


def main():
def main() -> None:
rounds = 100
it = IterForDMatrixDemo()

Expand All @@ -103,7 +112,7 @@ def main():

assert m_with_it.num_col() == m.num_col()
assert m_with_it.num_row() == m.num_row()
# Tree meethod must be `hist`.
# Tree method must be `hist`.
reg_with_it = xgboost.train(
{"tree_method": "hist", "device": "cuda"},
m_with_it,
Expand Down
1 change: 1 addition & 0 deletions ops/script/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class LintersPaths:
"demo/guide-python/model_parser.py",
"demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",
"demo/guide-python/quantile_data_iterator.py",
"demo/guide-python/multioutput_regression.py",
"demo/guide-python/learning_to_rank.py",
"demo/aft_survival/aft_survival_viz_demo.py",
Expand Down
Loading