Skip to content

Commit

Permalink
cuda.parallel: Add documentation for the current iterators along with…
Browse files Browse the repository at this point in the history
… examples and tests (NVIDIA#3311)

* Add tests demonstrating usage of different iterators

* Update documentation of reduce_into by merging import code snippet with the rest of the example

* Add documentation for current iterators

* Run pre-commit checks and update accordingly

* Fix comments to refer to the proper lines in the code snippets in the docs
  • Loading branch information
NaderAlAwar authored and davebayer committed Jan 18, 2025
1 parent 6b82e05 commit 3a70e49
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,7 @@ def reduce_into(
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.
Example:
The code snippet below illustrates a user-defined min-reduction of a
device vector of ``int`` data elements.
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin imports
:end-before: example-end imports
Below is the code snippet that demonstrates the usage of the ``reduce_into`` API:
The code snippet below demonstrates the usage of the ``reduce_into`` API:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ def CacheModifiedInputIterator(device_array, modifier):
Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html
Currently the only supported modifier is "stream" (LOAD_CS).
Example:
The code snippet below demonstrates the usage of a ``CacheModifiedInputIterator``:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin cache-iterator
:end-before: example-end cache-iterator
Args:
device_array: CUDA device array storing the input sequence of data items
modifier: The PTX cache load modifier
Returns:
A ``CacheModifiedInputIterator`` object initialized with ``device_array``
"""
if modifier != "stream":
raise NotImplementedError("Only stream modifier is supported")
Expand All @@ -19,15 +35,74 @@ def CacheModifiedInputIterator(device_array, modifier):


def ConstantIterator(value):
"""Returns an Iterator representing a sequence of constant values."""
"""Returns an Iterator representing a sequence of constant values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1constant__iterator.html
Example:
The code snippet below demonstrates the usage of a ``ConstantIterator``
representing the sequence ``[10, 10, 10]``:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin constant-iterator
:end-before: example-end constant-iterator
Args:
value: The value of every item in the sequence
Returns:
A ``ConstantIterator`` object initialized to ``value``
"""
return _iterators.ConstantIterator(value)


def CountingIterator(offset):
"""Returns an Iterator representing a sequence of incrementing values."""
"""Returns an Iterator representing a sequence of incrementing values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1counting__iterator.html
Example:
The code snippet below demonstrates the usage of a ``CountingIterator``
representing the sequence ``[10, 11, 12]``:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin counting-iterator
:end-before: example-end counting-iterator
Args:
offset: The initial value of the sequence
Returns:
A ``CountingIterator`` object initialized to ``offset``
"""
return _iterators.CountingIterator(offset)


def TransformIterator(it, op):
"""Returns an Iterator representing a transformed sequence of values."""
"""Returns an Iterator representing a transformed sequence of values.
Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1TransformInputIterator.html
Example:
The code snippet below demonstrates the usage of a ``TransformIterator``
composed with a ``CountingIterator``, transforming the sequence ``[10, 11, 12]``
by squaring each item before reducing the output:
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
:dedent:
:start-after: example-begin transform-iterator
:end-before: example-end transform-iterator
Args:
it: The iterator object to be transformed
op: The transform operation
Returns:
A ``TransformIterator`` object to transform the items in ``it`` using ``op``
"""
return _iterators.make_transform_iterator(it, op)
157 changes: 149 additions & 8 deletions python/cuda_parallel/tests/test_reduce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# example-begin imports
import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

# example-end imports


def test_device_reduce():
# example-begin reduce-min
import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

def min_op(a, b):
return a if a < b else b

Expand All @@ -37,3 +34,147 @@ def min_op(a, b):
expected_output = 0
assert (d_output == expected_output).all()
# example-end reduce-min


def test_cache_modified_input_iterator():
# example-begin cache-iterator
import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
return a + b

values = [8, 6, 7, 5, 3, 0, 9]
d_input = cp.array(values, dtype=np.int32)
d_output = cp.empty(1, dtype=np.int32)

iterator = iterators.CacheModifiedInputIterator(
d_input, modifier="stream"
) # Input sequence
h_init = np.array([0], dtype=np.int32) # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32) # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(iterator, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, iterator, d_output, len(values), h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, iterator, d_output, len(values), h_init)

expected_output = functools.reduce(lambda a, b: a + b, values)
assert (d_output == expected_output).all()
# example-end cache-iterator


def test_constant_iterator():
# example-begin constant-iterator
import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
return a + b

value = 10
num_items = 3

constant_it = iterators.ConstantIterator(np.int32(value)) # Input sequence
h_init = np.array([0], dtype=np.int32) # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32) # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(constant_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, constant_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, constant_it, d_output, num_items, h_init)

expected_output = functools.reduce(lambda a, b: a + b, [value] * num_items)
assert (d_output == expected_output).all()
# example-end constant-iterator


def test_counting_iterator():
# example-begin counting-iterator
import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
return a + b

first_item = 10
num_items = 3

first_it = iterators.CountingIterator(np.int32(first_item)) # Input sequence
h_init = np.array([0], dtype=np.int32) # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32) # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(first_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, first_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, first_it, d_output, num_items, h_init)

expected_output = functools.reduce(
lambda a, b: a + b, range(first_item, first_item + num_items)
)
assert (d_output == expected_output).all()
# example-end counting-iterator


def test_transform_iterator():
# example-begin transform-iterator
import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
return a + b

def square_op(a):
return a**2

first_item = 10
num_items = 3

transform_it = iterators.TransformIterator(
iterators.CountingIterator(np.int32(first_item)), square_op
) # Input sequence
h_init = np.array([0], dtype=np.int32) # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32) # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(transform_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, transform_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, transform_it, d_output, num_items, h_init)

expected_output = functools.reduce(
lambda a, b: a + b, [a**2 for a in range(first_item, first_item + num_items)]
)
assert (d_output == expected_output).all()
# example-end transform-iterator

0 comments on commit 3a70e49

Please sign in to comment.