Skip to content

Commit

Permalink
feat: MpiWrapper::allReduce overload for arrays. (#3446)
Browse files Browse the repository at this point in the history
* feat: MpiWrapper::allReduce overload for arrays.
  • Loading branch information
CusiniM authored Jan 29, 2025
1 parent 5bfdb01 commit ac45ee2
Show file tree
Hide file tree
Showing 20 changed files with 270 additions and 108 deletions.
2 changes: 1 addition & 1 deletion host-configs/apple/macOS_base.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ set(ENABLE_CALIPER "OFF" CACHE PATH "" FORCE )
set( BLAS_LIBRARIES ${HOMEBREW_DIR}/opt/lapack/lib/libblas.dylib CACHE PATH "" FORCE )
set( LAPACK_LIBRARIES ${HOMEBREW_DIR}/opt/lapack/lib/liblapack.dylib CACHE PATH "" FORCE )

set(ENABLE_DOXYGEN ON CACHE BOOL "" FORCE)
set(ENABLE_DOXYGEN OFF CACHE BOOL "" FORCE)
set(ENABLE_SPHINX ON CACHE BOOL "" FORCE)
set(ENABLE_MATHPRESSO ON CACHE BOOL "" FORCE )

Expand Down
1 change: 1 addition & 0 deletions src/coreComponents/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ set( common_headers
Tensor.hpp
TimingMacros.hpp
TypeDispatch.hpp
TypesHelpers.hpp
initializeEnvironment.hpp
LifoStorage.hpp
LifoStorageCommon.hpp
Expand Down
88 changes: 76 additions & 12 deletions src/coreComponents/common/MpiWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "common/DataTypes.hpp"
#include "common/Span.hpp"
#include "common/TypesHelpers.hpp"

#if defined(GEOS_USE_MPI)
#include <mpi.h>
Expand Down Expand Up @@ -128,6 +129,8 @@ struct MpiWrapper
Min, //!< Min
Sum, //!< Sum
Prod, //!< Prod
LogicalAnd, //!< Logical and
LogicalOr, //!< Logical or
};

MpiWrapper() = delete;
Expand Down Expand Up @@ -351,18 +354,6 @@ struct MpiWrapper
array1d< T > & recvbuf,
MPI_Comm comm = MPI_COMM_GEOS );

/**
* @brief Strongly typed wrapper around MPI_Allreduce.
* @param[in] sendbuf The pointer to the sending buffer.
* @param[out] recvbuf The pointer to the receive buffer.
* @param[in] count The number of values to send/receive.
* @param[in] op The MPI_Op to perform.
* @param[in] comm The MPI_Comm over which the gather operates.
* @return The return value of the underlying call to MPI_Allreduce().
*/
template< typename T >
static int allReduce( T const * sendbuf, T * recvbuf, int count, MPI_Op op, MPI_Comm comm = MPI_COMM_GEOS );

/**
* @brief Convenience wrapper for the MPI_Allreduce function.
* @tparam T type of data to reduce. Must correspond to a valid MPI_Datatype.
Expand All @@ -385,6 +376,29 @@ struct MpiWrapper
template< typename T >
static void allReduce( Span< T const > src, Span< T > dst, Reduction const op, MPI_Comm comm = MPI_COMM_GEOS );

/**
* @brief Convenience wrapper for the MPI_Allreduce function. Version for arrays.
* @tparam T type of data to reduce. Must correspond to a valid MPI_Datatype.
* @param src[in] The values to send to the reduction.
* @param dst[out] The resulting values.
* @param op The Reduction enum to perform.
* @param comm The communicator.
*/
template< typename SRC_CONTAINER_TYPE, typename DST_CONTAINER_TYPE >
static void allReduce( SRC_CONTAINER_TYPE const & src, DST_CONTAINER_TYPE & dst, Reduction const op, MPI_Comm const comm = MPI_COMM_GEOS );

/**
* @brief Convenience wrapper for the MPI_Allreduce function. Version for arrays.
* @tparam T type of data to reduce. Must correspond to a valid MPI_Datatype.
* @param src[in] The values to send to the reduction.
* @param dst[out] The resulting values.
* @param count The number of contiguos elements of the arrays to perform the reduction on (must be leq than the size).
* @param op The Reduction enum to perform.
* @param comm The communicator.
*/
template< typename SRC_CONTAINER_TYPE, typename DST_CONTAINER_TYPE >
static void allReduce( SRC_CONTAINER_TYPE const & src, DST_CONTAINER_TYPE & dst, int const count, Reduction const op, MPI_Comm const comm );


/**
* @brief Strongly typed wrapper around MPI_Reduce.
Expand Down Expand Up @@ -639,6 +653,19 @@ struct MpiWrapper
*/
template< typename T > static T maxValLoc( T localValueLocation, MPI_Comm comm = MPI_COMM_GEOS );

private:

/**
* @brief Strongly typed wrapper around MPI_Allreduce.
* @param[in] sendbuf The pointer to the sending buffer.
* @param[out] recvbuf The pointer to the receive buffer.
* @param[in] count The number of values to send/receive.
* @param[in] op The MPI_Op to perform.
* @param[in] comm The MPI_Comm over which the gather operates.
* @return The return value of the underlying call to MPI_Allreduce().
*/
template< typename T >
static int allReduce( T const * sendbuf, T * recvbuf, int count, MPI_Op op, MPI_Comm comm = MPI_COMM_GEOS );
};

namespace internal
Expand Down Expand Up @@ -701,6 +728,14 @@ inline MPI_Op MpiWrapper::getMpiOp( Reduction const op )
{
return MPI_PROD;
}
case Reduction::LogicalAnd:
{
return MPI_LAND;
}
case Reduction::LogicalOr:
{
return MPI_LOR;
}
default:
GEOS_ERROR( "Unsupported reduction operation" );
return MPI_NO_OP;
Expand Down Expand Up @@ -1113,6 +1148,35 @@ void MpiWrapper::allReduce( Span< T const > const src, Span< T > const dst, Redu
allReduce( src.data(), dst.data(), LvArray::integerConversion< int >( src.size() ), getMpiOp( op ), comm );
}

template< typename SRC_CONTAINER_TYPE, typename DST_CONTAINER_TYPE >
void MpiWrapper::allReduce( SRC_CONTAINER_TYPE const & src, DST_CONTAINER_TYPE & dst, int const count, Reduction const op, MPI_Comm const comm )
{
static_assert( std::is_trivially_copyable< typename get_value_type< SRC_CONTAINER_TYPE >::type >::value,
"The type in the source container must be trivially copyable." );
static_assert( std::is_trivially_copyable< typename get_value_type< DST_CONTAINER_TYPE >::type >::value,
"The type in the destination container must be trivially copyable." );
static_assert( std::is_same< typename get_value_type< SRC_CONTAINER_TYPE >::type,
typename get_value_type< DST_CONTAINER_TYPE >::type >::value,
"Source and destination containers must have the same value type." );
GEOS_ASSERT_GE( src.size(), count );
GEOS_ASSERT_GE( dst.size(), count );
allReduce( src.data(), dst.data(), count, getMpiOp( op ), comm );
}

template< typename SRC_CONTAINER_TYPE, typename DST_CONTAINER_TYPE >
void MpiWrapper::allReduce( SRC_CONTAINER_TYPE const & src, DST_CONTAINER_TYPE & dst, Reduction const op, MPI_Comm const comm )
{
static_assert( std::is_trivially_copyable< typename get_value_type< SRC_CONTAINER_TYPE >::type >::value,
"The type in the source container must be trivially copyable." );
static_assert( std::is_trivially_copyable< typename get_value_type< DST_CONTAINER_TYPE >::type >::value,
"The type in the destination container must be trivially copyable." );
static_assert( std::is_same< typename get_value_type< SRC_CONTAINER_TYPE >::type,
typename get_value_type< DST_CONTAINER_TYPE >::type >::value,
"Source and destination containers must have the same value type." );
GEOS_ASSERT_EQ( src.size(), dst.size() );
allReduce( src.data(), dst.data(), LvArray::integerConversion< int >( src.size() ), getMpiOp( op ), comm );
}

template< typename T >
T MpiWrapper::sum( T const & value, MPI_Comm comm )
{
Expand Down
124 changes: 124 additions & 0 deletions src/coreComponents/common/TypesHelpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* ------------------------------------------------------------------------------------------------------------
* SPDX-License-Identifier: LGPL-2.1-only
*
* Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
* Copyright (c) TotalEnergies
* Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
* Copyright (c) 2023-2024 Chevron
* Copyright (c) 2019- GEOS/GEOSX Contributors
* All rights reserved
*
* See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
* ------------------------------------------------------------------------------------------------------------
*/

/**
* @file TypesHelpers.hpp
*
*/

#ifndef TYPES_HELPERS_HPP
#define TYPES_HELPERS_HPP

#include <type_traits>

namespace geos
{

namespace internal
{
/**
* @brief Trait to determine if a type defines a `value_type` member.
*
* This primary template defaults to `std::false_type`, indicating that
* the type `T` does not define a `value_type` member.
*
* @tparam T The type to check.
* @tparam void A SFINAE parameter used to specialize the trait.
*/
template< typename T, typename = void >
struct has_value_type : std::false_type {};

/**
* @brief Specialization of `has_value_type` for types with a `value_type` member.
*
* If the type `T` defines a `value_type` member, this specialization
* is used, which inherits from `std::true_type`.
*
* @tparam T The type to check.
*/
template< typename T >
struct has_value_type< T, std::void_t< typename T::value_type > > : std::true_type {};

/**
* @brief Trait to determine if a type defines a `ValueType` member.
*
* This primary template defaults to `std::false_type`, indicating that
* the type `T` does not define a `ValueType` member.
*
* @tparam T The type to check.
* @tparam void A SFINAE parameter used to specialize the trait.
*/
template< typename T, typename = void >
struct has_ValueType : std::false_type {};

/**
* @brief Specialization of `has_ValueType` for types with a `ValueType` member.
*
* If the type `T` defines a `ValueType` member, this specialization
* is used, which inherits from `std::true_type`.
*
* @tparam T The type to check.
*/
template< typename T >
struct has_ValueType< T, std::void_t< typename T::ValueType > > : std::true_type {};

} // namespace internal

/**
* @brief Trait to retrieve the `value_type` or `ValueType` of a type `T`.
*
* This primary template provides a static assertion error if `T` does not
* define either `value_type` or `ValueType`.
*
* @tparam T The type from which to extract the type alias.
* @tparam Enable A SFINAE parameter used for specialization.
*/
template< typename T, typename Enable = void >
struct get_value_type
{
static_assert( sizeof(T) == 0, "T must define either value_type or ValueType." );
};

/**
* @brief Specialization of `get_value_type` for types with a `value_type` member.
*
* If the type `T` defines a `value_type` member, this specialization
* retrieves it as the alias `type`.
*
* @tparam T The type from which to extract `value_type`.
*/
template< typename T >
struct get_value_type< T, std::enable_if_t< internal::has_value_type< T >::value > >
{
using type = typename T::value_type;
};

/**
* @brief Specialization of `get_value_type` for types with a `ValueType` member.
*
* If the type `T` does not define a `value_type` but defines a `ValueType`,
* this specialization retrieves it as the alias `type`.
*
* @tparam T The type from which to extract `ValueType`.
*/
template< typename T >
struct get_value_type< T, std::enable_if_t< !internal::has_value_type< T >::value && internal::has_ValueType< T >::value > >
{
using type = typename T::ValueType;
};

} // namespace geos

#endif /* TYPES_HELPERS_HPP */
2 changes: 1 addition & 1 deletion src/coreComponents/common/initializeEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ static void addUmpireHighWaterMarks()
string allocatorNameMinChars = string( MAX_NAME_LENGTH, '\0' );

// Make sure that each rank is looking at the same allocator.
MpiWrapper::allReduce( allocatorNameFixedSize.c_str(), &allocatorNameMinChars.front(), MAX_NAME_LENGTH, MPI_MIN, MPI_COMM_GEOS );
MpiWrapper::allReduce( allocatorNameFixedSize, allocatorNameMinChars, MpiWrapper::Reduction::Min, MPI_COMM_GEOS );
if( allocatorNameFixedSize != allocatorNameMinChars )
{
GEOS_WARNING( "Not all ranks have an allocator named " << allocatorNameFixedSize << ", cannot compute high water mark." );
Expand Down
5 changes: 3 additions & 2 deletions src/coreComponents/fileIO/timeHistory/HDFHistoryIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ void HDFHistoryIO::init( bool existsOkay )
void HDFHistoryIO::write()
{
// check if the size has changed on any process in the primary comm
int anyChanged = false;
MpiWrapper::allReduce( &m_sizeChanged, &anyChanged, 1, MPI_LOR, m_comm );
int const anyChanged = MpiWrapper::allReduce( m_sizeChanged,
MpiWrapper::Reduction::LogicalOr,
m_comm );
m_sizeChanged = anyChanged;

// this will set the first dim large enough to hold all the rows we're about to write
Expand Down
7 changes: 3 additions & 4 deletions src/coreComponents/finiteVolume/TwoPointFluxApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,10 +951,9 @@ void TwoPointFluxApproximation::computeAquiferStencil( DomainPartition & domain,
localSumFaceAreasView[aquiferIndex] += targetSetSumFaceAreas.get();
} );

MpiWrapper::allReduce( localSumFaceAreas.data(),
globalSumFaceAreas.data(),
localSumFaceAreas.size(),
MpiWrapper::getMpiOp( MpiWrapper::Reduction::Sum ),
MpiWrapper::allReduce( localSumFaceAreas,
globalSumFaceAreas,
MpiWrapper::Reduction::Sum,
MPI_COMM_GEOS );

// Step 3: compute the face area fraction for each connection, and insert into boundary stencil
Expand Down
8 changes: 3 additions & 5 deletions src/coreComponents/mesh/ParticleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ void ParticleManager::setMaxGlobalIndex()
m_localMaxGlobalIndex = std::max( m_localMaxGlobalIndex, subRegion.maxGlobalIndex() );
} );

MpiWrapper::allReduce( &m_localMaxGlobalIndex,
&m_maxGlobalIndex,
1,
MPI_MAX,
MPI_COMM_GEOS );
m_maxGlobalIndex = MpiWrapper::allReduce( m_localMaxGlobalIndex,
MpiWrapper::Reduction::Max,
MPI_COMM_GEOS );
}

Group * ParticleManager::createChild( string const & childKey, string const & childName )
Expand Down
7 changes: 3 additions & 4 deletions src/coreComponents/mesh/generators/InternalMeshGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,10 +603,9 @@ void InternalMeshGenerator::fillCellBlockManager( CellBlockManager & cellBlockMa
{
elemCenterCoordsLocal[k] = m_min[dim] + ( m_max[dim] - m_min[dim] ) * ( k + 0.5 ) / m_numElemsTotal[dim];
}
MpiWrapper::allReduce( elemCenterCoordsLocal.data(),
elemCenterCoords[dim].data(),
m_numElemsTotal[dim],
MPI_MAX,
MpiWrapper::allReduce( elemCenterCoordsLocal,
elemCenterCoords[dim],
MpiWrapper::Reduction::Max,
MPI_COMM_GEOS );
}

Expand Down
3 changes: 1 addition & 2 deletions src/coreComponents/mesh/generators/VTKUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,7 @@ vtkSmartPointer< vtkDataSet > manageGlobalIds( vtkSmartPointer< vtkDataSet > mes
{
// Add global ids on the fly if needed
int const me = hasGlobalIds( mesh );
int everyone;
MpiWrapper::allReduce( &me, &everyone, 1, MPI_MAX, MPI_COMM_GEOS );
int const everyone = MpiWrapper::allReduce( me, MpiWrapper::Reduction::Max, MPI_COMM_GEOS );

if( everyone and not me )
{
Expand Down
24 changes: 12 additions & 12 deletions src/coreComponents/physicsSolvers/PhysicsSolverBaseKernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,9 @@ class LinfResidualNormHelper
static void computeGlobalNorm( array1d< real64 > const & localResidualNorm,
array1d< real64 > & globalResidualNorm )
{
MpiWrapper::allReduce( localResidualNorm.data(),
globalResidualNorm.data(),
localResidualNorm.size(),
MpiWrapper::getMpiOp( MpiWrapper::Reduction::Max ),
MpiWrapper::allReduce( localResidualNorm,
globalResidualNorm,
MpiWrapper::Reduction::Max,
MPI_COMM_GEOS );
}
};
Expand Down Expand Up @@ -309,16 +308,17 @@ class L2ResidualNormHelper
{
array1d< real64 > sumLocalResidualNorm( localResidualNorm.size() );
array1d< real64 > sumLocalResidualNormalizer( localResidualNormalizer.size() );
MpiWrapper::allReduce( localResidualNorm.data(),
sumLocalResidualNorm.data(),
localResidualNorm.size(),
MpiWrapper::getMpiOp( MpiWrapper::Reduction::Sum ),

MpiWrapper::allReduce( localResidualNorm,
sumLocalResidualNorm,
MpiWrapper::Reduction::Sum,
MPI_COMM_GEOS );
MpiWrapper::allReduce( localResidualNormalizer.data(),
sumLocalResidualNormalizer.data(),
localResidualNormalizer.size(),
MpiWrapper::getMpiOp( MpiWrapper::Reduction::Sum ),

MpiWrapper::allReduce( localResidualNormalizer,
sumLocalResidualNormalizer,
MpiWrapper::Reduction::Sum,
MPI_COMM_GEOS );

for( integer i = 0; i < localResidualNorm.size(); ++i )
{
globalResidualNorm[i] = sqrt( sumLocalResidualNorm[i] ) / sqrt( sumLocalResidualNormalizer[i] );
Expand Down
Loading

0 comments on commit ac45ee2

Please sign in to comment.