Skip to content

Commit

Permalink
🎨 proposal #1
Browse files Browse the repository at this point in the history
  • Loading branch information
MelReyCG committed Jan 31, 2025
1 parent 2ccefa7 commit cfb41e3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 29 deletions.
51 changes: 51 additions & 0 deletions src/coreComponents/common/MpiWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,57 @@ int MpiWrapper::nodeCommSize()
return nodeCommSize;
}

namespace internal
{

template< typename FIRST, typename SECOND >
MPI_Datatype getMpiCustomPairType()
{
static auto const createTypeHolder = [] () {
using PAIR_T = MpiWrapper::PairType< FIRST, SECOND >;
static_assert( std::is_standard_layout_v< PAIR_T > );
static_assert( std::is_trivially_copyable_v< PAIR_T > );
MPI_Datatype types[2] = { getMpiType< FIRST >(), getMpiType< SECOND >() };
MPI_Aint offsets[2] = { offsetof( PAIR_T, first ), offsetof( PAIR_T, second ) };
int blocksCount[2] = { 1, 1 };
MPI_Datatype mpiType;
GEOS_ERROR_IF_NE( MPI_Type_create_struct( 2, blocksCount, offsets, types, &mpiType ), MPI_SUCCESS );
GEOS_ERROR_IF_NE( MPI_Type_commit( &mpiType ), MPI_SUCCESS );
return mpiType;
};
static MPI_Datatype mpiType{ createTypeHolder() };
return mpiType;
}

template<> MPI_Datatype getMpiPairType< int, int >()
{ return MPI_2INT; }

template<> MPI_Datatype getMpiPairType< long int, int >()
{ return MPI_LONG_INT; }

template<> MPI_Datatype getMpiPairType< long int, long int >()
{ return getMpiCustomPairType< long int, long int >(); }

template<> MPI_Datatype getMpiPairType< long long int, long long int >()
{ return getMpiCustomPairType< long long int, long long int >(); }

template<> MPI_Datatype getMpiPairType< float, int >()
{ return MPI_FLOAT_INT; }

template<> MPI_Datatype getMpiPairType< double, int >()
{ return MPI_DOUBLE_INT; }

template<> MPI_Datatype getMpiPairType< double, long int >()
{ return getMpiCustomPairType< double, long int >(); }

template<> MPI_Datatype getMpiPairType< double, long long int >()
{ return getMpiCustomPairType< double, long long int >(); }

template<> MPI_Datatype getMpiPairType< double, double >()
{ return getMpiCustomPairType< double, double >(); }

} /* namespace internal */

} /* namespace geos */

#if defined(__clang__)
Expand Down
42 changes: 13 additions & 29 deletions src/coreComponents/common/MpiWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,36 +784,21 @@ MPI_Datatype getMpiType()
}

template< typename FIRST, typename SECOND >
MPI_Datatype getMpiCustomPairType()
MPI_Datatype getMpiPairType()
{
static auto const createTypeHolder = [] () {
using PAIR_T = MpiWrapper::PairType< FIRST, SECOND >;
static_assert( std::is_standard_layout< PAIR_T >::value );
MPI_Datatype types[2] = { getMpiType< FIRST >(), getMpiType< SECOND >() };
MPI_Aint offsets[2] = { offsetof( PAIR_T, first ), offsetof( PAIR_T, second ) };
int blocksCount[2] = { 1, 1 };
MPI_Datatype mpiType;
GEOS_ERROR_IF_NE( MPI_Type_create_struct( 2, blocksCount, offsets, types, &mpiType ), MPI_SUCCESS );
GEOS_ERROR_IF_NE( MPI_Type_commit( &mpiType ), MPI_SUCCESS );
return mpiType;
};
static MPI_Datatype mpiType{ createTypeHolder() };
return mpiType;
static_assert("no default implementation, please add a template specialization and add it in the \"testMpiWrapper\" unit test.");
return {};
}

/* no default get() implementation, please add a template specialization and add it in the "testMpiWrapper" unit test. */
template< typename FIRST, typename SECOND >
MPI_Datatype const mpiPairType;

template<> MPI_Datatype const mpiPairType< float, int > = MPI_FLOAT_INT;
template<> MPI_Datatype const mpiPairType< double, int > = MPI_DOUBLE_INT;
template<> MPI_Datatype const mpiPairType< int, int > = MPI_2INT;
template<> MPI_Datatype const mpiPairType< long int, int > = MPI_LONG_INT;
template<> MPI_Datatype const mpiPairType< long int, long int > = getMpiCustomPairType< long int, long int >();
template<> MPI_Datatype const mpiPairType< long long int, long long int > = getMpiCustomPairType< long long int, long long int >();
template<> MPI_Datatype const mpiPairType< double, long int > = getMpiCustomPairType< double, long int >();
template<> MPI_Datatype const mpiPairType< double, long long int > = getMpiCustomPairType< double, long long int >();
template<> MPI_Datatype const mpiPairType< double, double > = getMpiCustomPairType< double, double >();
template<> MPI_Datatype getMpiPairType< int, int >();
template<> MPI_Datatype getMpiPairType< long int, int >();
template<> MPI_Datatype getMpiPairType< long int, long int >();
template<> MPI_Datatype getMpiPairType< long long int, long long int >();
template<> MPI_Datatype getMpiPairType< float, int >();
template<> MPI_Datatype getMpiPairType< double, int >();
template<> MPI_Datatype getMpiPairType< double, long int >();
template<> MPI_Datatype getMpiPairType< double, long long int >();
template<> MPI_Datatype getMpiPairType< double, double >();

// It is advised to always use this custom operator for pairs as MPI_MAXLOC is not a true lexicographical comparator.
template< typename FIRST, typename SECOND, MpiWrapper::PairReduction OP >
Expand Down Expand Up @@ -1373,13 +1358,12 @@ void MpiWrapper::reduce( Span< T const > const src, Span< T > const dst, Reducti
reduce( src.data(), dst.data(), LvArray::integerConversion< int >( src.size() ), getMpiOp( op ), root, comm );
}


template< typename FIRST, typename SECOND, MpiWrapper::PairReduction const OP >
MpiWrapper::PairType< FIRST, SECOND >
MpiWrapper::allReduce( PairType< FIRST, SECOND > const & localPair, MPI_Comm comm )
{
#ifdef GEOS_USE_MPI
auto const type = internal::mpiPairType< FIRST, SECOND >;
auto const type = internal::getMpiPairType< FIRST, SECOND >();
auto const mpiOp = internal::getMpiPairReductionOp< FIRST, SECOND, OP >();
PairType< FIRST, SECOND > pair{ localPair.first, localPair.second };
MPI_Allreduce( MPI_IN_PLACE, &pair, 1, type, mpiOp, comm );
Expand Down

0 comments on commit cfb41e3

Please sign in to comment.