Skip to content

Commit

Permalink
Simplify workdiv creation
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetyusufoglu authored and bernhardmgruber committed Feb 27, 2024
1 parent 9116b75 commit e01c0d7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
21 changes: 21 additions & 0 deletions include/alpaka/workdiv/WorkDivMembers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace alpaka
public:
ALPAKA_FN_HOST_ACC WorkDivMembers() = delete;

//! Accepts different alpaka vector types and takes the last TDim number of items.
ALPAKA_NO_HOST_ACC_WARNING
template<typename TGridBlockExtent, typename TBlockThreadExtent, typename TThreadElemExtent>
ALPAKA_FN_HOST_ACC explicit WorkDivMembers(
Expand All @@ -33,6 +34,18 @@ namespace alpaka
{
}

//! \brief Accepts single specific type and is called without explicit template parameters.
ALPAKA_NO_HOST_ACC_WARNING
ALPAKA_FN_HOST_ACC WorkDivMembers(
alpaka::Vec<TDim, TIdx> const& gridBlockExtent,
alpaka::Vec<TDim, TIdx> const& blockThreadExtent,
alpaka::Vec<TDim, TIdx> const& elemExtent)
: m_gridBlockExtent(gridBlockExtent)
, m_blockThreadExtent(blockThreadExtent)
, m_threadElemExtent(elemExtent)
{
}

ALPAKA_NO_HOST_ACC_WARNING
ALPAKA_FN_HOST_ACC WorkDivMembers(WorkDivMembers const& other)
: m_gridBlockExtent(other.m_gridBlockExtent)
Expand Down Expand Up @@ -83,6 +96,14 @@ namespace alpaka
Vec<TDim, TIdx> m_threadElemExtent;
};

//! Deduction guide for the constructor which can be called without explicit template type parameters
ALPAKA_NO_HOST_ACC_WARNING
template<typename TDim, typename TIdx>
ALPAKA_FN_HOST_ACC WorkDivMembers(
alpaka::Vec<TDim, TIdx> const& gridBlockExtent,
alpaka::Vec<TDim, TIdx> const& blockThreadExtent,
alpaka::Vec<TDim, TIdx> const& elemExtent) -> WorkDivMembers<TDim, TIdx>;

namespace trait
{
//! The WorkDivMembers dimension get trait specialization.
Expand Down
43 changes: 43 additions & 0 deletions test/unit/workDiv/src/WorkDivHelpersTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,46 @@ TEMPLATE_LIST_TEST_CASE("isValidWorkDiv", "[workDiv]", alpaka::test::TestAccs)
REQUIRE(alpaka::isValidWorkDiv(alpaka::getAccDevProps<Acc>(dev), workDiv));
REQUIRE(alpaka::isValidWorkDiv<Acc>(dev, workDiv));
}

//! Test the constructors of WorkDivMembers using 3D extent, 3D extent with zero elements and 2D extents
TEST_CASE("WorkDivMembers", "[workDiv]")
{
using Idx = std::size_t;
using Dim3D = alpaka::DimInt<3>;
using Vec3D = alpaka::Vec<Dim3D, Idx>;

auto const elementsPerThread3D = Vec3D::all(static_cast<Idx>(1u));
auto const threadsPerBlock3D = Vec3D{2u, 2u, 2u};
auto blocksPerGrid3D = Vec3D{1u, 1u, 1u};

auto ref3D = alpaka::WorkDivMembers<Dim3D, Idx>{blocksPerGrid3D, threadsPerBlock3D, elementsPerThread3D};
// call WorkDivMembers without explicit class template types
auto workDiv3D = alpaka::WorkDivMembers(blocksPerGrid3D, threadsPerBlock3D, elementsPerThread3D);
CHECK(workDiv3D == ref3D);

// change blocks per grid, assign zero to an element
blocksPerGrid3D = Vec3D{3u, 3u, 0u};
ref3D = alpaka::WorkDivMembers<Dim3D, Idx>{blocksPerGrid3D, threadsPerBlock3D, elementsPerThread3D};
// call without explicit template parameter types
workDiv3D = alpaka::WorkDivMembers(blocksPerGrid3D, threadsPerBlock3D, elementsPerThread3D);
CHECK(workDiv3D == ref3D);

// test using 2D vectors
using Dim2D = alpaka::DimInt<2>;
using Vec2D = alpaka::Vec<Dim2D, Idx>;

auto const threadsPerBlock2D = Vec2D{2u, 2u};
auto const blocksPerGrid2D = Vec2D{1u, 1u};
auto const elementsPerThread2D = Vec2D::all(static_cast<Idx>(1u));
auto const ref2D = alpaka::WorkDivMembers<Dim2D, Idx>{blocksPerGrid2D, threadsPerBlock2D, elementsPerThread2D};
auto const workDiv2D = alpaka::WorkDivMembers(blocksPerGrid2D, threadsPerBlock2D, elementsPerThread2D);
CHECK(workDiv2D == ref2D);

// Test using different input types, reduced to given explicit class template types
auto ref2DimUsingMixed
= alpaka::WorkDivMembers<Dim2D, Idx>{blocksPerGrid2D, threadsPerBlock3D, elementsPerThread3D};
CHECK(workDiv2D == ref2DimUsingMixed);

ref2DimUsingMixed = alpaka::WorkDivMembers<Dim2D, Idx>{blocksPerGrid2D, threadsPerBlock3D, elementsPerThread2D};
CHECK(workDiv2D == ref2DimUsingMixed);
}

0 comments on commit e01c0d7

Please sign in to comment.