Skip to content

Commit

Permalink
created partition wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Angelyr committed Jan 29, 2025
1 parent 369834f commit 87ce843
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 50 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(PCMS_HEADERS
pcms/inclusive_scan.h
pcms/profile.h
pcms/print.h
pcms/partition.h
)

set(PCMS_SOURCES
Expand Down
28 changes: 3 additions & 25 deletions src/pcms/omega_h_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "pcms/transfer_field.h"
#include "pcms/memory_spaces.h"
#include "pcms/profile.h"
#include "pcms/partition.h"

// FIXME add executtion spaces (don't use kokkos exe spaces directly)

Expand Down Expand Up @@ -85,29 +86,6 @@ Omega_h::Read<T> filter_array(Omega_h::Read<T> array,
});
return filtered_field;
}
struct GetRankOmegaH
{
GetRankOmegaH(int i, Omega_h::I8 dim, Omega_h::ClassId id, std::array<pcms::Real,3> & coord)
: i_(i), id_(id), dim_(dim), coord_(coord)
{
PCMS_FUNCTION_TIMER;
}
auto operator()(const redev::ClassPtn& ptn) const
{
PCMS_FUNCTION_TIMER;
const auto ent = redev::ClassPtn::ModelEnt({dim_, id_});
return ptn.GetRank(ent);
}
auto operator()(const redev::RCBPtn& ptn)
{
PCMS_FUNCTION_TIMER;
return ptn.GetRank(coord_);
}
int i_;
Omega_h::ClassId id_;
Omega_h::I8 dim_;
std::array<pcms::Real,3> coord_;
};
} // namespace detail

template <typename T,
Expand Down Expand Up @@ -571,12 +549,12 @@ class OmegaHFieldAdapter
std::array<pcms::Real, 3> coord;
pcms::ReversePartitionMap reverse_partition;
pcms::LO local_index = 0;
Partition part{partition};
for (auto i = 0; i < classIds_h.size(); i++) {
coord[0] = coords[i * dim];
coord[1] = coords[i * dim + 1];
coord[2] = (dim == 3) ? coords[i * dim + 2] : 0.0;
auto dr = std::visit(detail::GetRankOmegaH{i, classDims_h[i], classIds_h[i], coord},
partition);
auto dr = part.GetDr(classIds_h[i], classDims_h[i], coord);
reverse_partition[dr].emplace_back(local_index++);
}
return reverse_partition;
Expand Down
47 changes: 47 additions & 0 deletions src/pcms/partition.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifndef PCMS_PARTITION_H
#define PCMS_PARTITION_H
#include "pcms/common.h"


namespace pcms
{
struct GetRank
{
GetRank(LO id, LO dim, std::array<Real,3>& coord)
: id_(id), dim_(dim), coord_(coord)
{
PCMS_FUNCTION_TIMER;
}
auto operator()(const redev::ClassPtn& ptn) const
{
PCMS_FUNCTION_TIMER;
const auto ent = redev::ClassPtn::ModelEnt({dim_, id_});
return ptn.GetRank(ent);
}
auto operator()(const redev::RCBPtn& ptn)
{
PCMS_FUNCTION_TIMER;
return ptn.GetRank(coord_);
}
LO id_;
LO dim_;
std::array<Real,3> coord_;
};

struct Partition
{
Partition(const redev::Partition& partition) : partition_(partition)
{
PCMS_FUNCTION_TIMER;
}

auto GetDr(LO id, LO dim, std::array<Real,3> coord = {})
{
return std::visit(GetRank{id, dim, coord}, partition_);
}

redev::Partition partition_;
};
} // namespace pcms

#endif //PCMS_PARTITION_H
27 changes: 2 additions & 25 deletions src/pcms/xgc_field_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,6 @@

namespace pcms
{
namespace detail
{
// Needed since NVHPC doesn't work with overloaded
struct GetRank
{
using GeomType = DimID;
GetRank(const GeomType& geom) : geom_(geom) {}
auto operator()(const redev::ClassPtn& ptn) const
{
PCMS_FUNCTION_TIMER;
const auto ent = redev::ClassPtn::ModelEnt({geom_.dim, geom_.id});
return ptn.GetRank(ent);
}
auto operator()(const redev::RCBPtn& /*unused*/) const
{
PCMS_FUNCTION_TIMER;
std::cerr << "RCB partition not handled yet\n";
std::terminate();
return 0;
}
const GeomType& geom_;
};
} // namespace detail

template <typename T, typename CoordinateElementType = Real>
class XGCFieldAdapter
{
Expand Down Expand Up @@ -144,11 +120,12 @@ class XGCFieldAdapter
pcms::ReversePartitionMap reverse_partition;
// in_overlap_ must contain a function!
PCMS_ALWAYS_ASSERT(static_cast<bool>(in_overlap_));
Partition part{partition};
for (const auto& geom : reverse_classification_) {
// if the geometry is in specified overlap region
if (in_overlap_(geom.first.dim, geom.first.id)) {

auto dr = std::visit(detail::GetRank{geom.first}, partition);
auto dr = part.GetDr(geom.first.id, geom.first.dim);
auto [it, inserted] = reverse_partition.try_emplace(dr);
// the map gives the local iteration order of the global ids
auto map = mask_.GetMap();
Expand Down

0 comments on commit 87ce843

Please sign in to comment.