Skip to content

Commit

Permalink
add mesh_ent_to_int
Browse files Browse the repository at this point in the history
  • Loading branch information
Fuad-HH committed Sep 18, 2024
1 parent cda6b88 commit 2ddbfae
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
81 changes: 43 additions & 38 deletions src/pcms/omega_h_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "pcms/array_mask.h"
#include "pcms/point_search.h"
#include <redev_variant_tools.h>
#include <type_traits>
#include "pcms/transfer_field.h"
#include "pcms/memory_spaces.h"
#include "pcms/profile.h"
Expand All @@ -28,15 +29,22 @@ struct OmegaHMemorySpace
using type = typename Kokkos::DefaultExecutionSpace::memory_space;
};

namespace detail
{
enum class mesh_entity_type {
enum class mesh_entity_type : int{
VERTEX = 0,
EDGE = 1,
FACE = 2,
REGION = 3
};

inline int mesh_entity_to_int(mesh_entity_type entity_type)
{
static_assert(std::is_same<std::underlying_type_t<mesh_entity_type>, int>::value, "mesh_entity_type must be an int");
return static_cast<std::underlying_type_t<mesh_entity_type>>(entity_type);
}


namespace detail
{
template <typename T>
struct memory_space_selector<Omega_h::Read<T>, void>
{
Expand Down Expand Up @@ -115,11 +123,11 @@ class OmegaHField
OmegaHField(std::string name, Omega_h::Mesh& mesh,
std::string global_id_name = "", int search_nx = 10,
int search_ny = 10,
detail::mesh_entity_type entity_type = detail::mesh_entity_type::VERTEX)
mesh_entity_type entity_type = mesh_entity_type::VERTEX)
: name_(std::move(name)),
mesh_(mesh),
search_{mesh, search_nx, search_ny},
size_(mesh.nents(static_cast<int>(entity_type))),
size_(mesh.nents(mesh_entity_to_int(entity_type))),
global_id_name_(std::move(global_id_name)),
entity_type_(entity_type)
{
Expand All @@ -128,7 +136,7 @@ class OmegaHField
OmegaHField(std::string name, Omega_h::Mesh& mesh,
Omega_h::Read<Omega_h::I8> mask, std::string global_id_name = "",
int search_nx = 10, int search_ny = 10,
detail::mesh_entity_type entity_type = detail::mesh_entity_type::VERTEX)
mesh_entity_type entity_type = mesh_entity_type::VERTEX)
: name_(std::move(name)),
mesh_(mesh),
search_{mesh, search_nx, search_ny},
Expand All @@ -143,8 +151,7 @@ class OmegaHField
// we use a parallel scan to construct the mask mapping so that filtering
// can happen in parallel. This method gives us the index to fill into the
// filtered array
printf("name: %s, ent %d | %d, mask size %d, ent size %d\n", name_.c_str(), static_cast<int>(entity_type_), static_cast<int>(entity_type), mask.size(), mesh.nents(static_cast<int>(entity_type_)));
PCMS_ALWAYS_ASSERT(mesh.nents(static_cast<int>(entity_type_)) == mask.size());
PCMS_ALWAYS_ASSERT(mesh.nents(mesh_entity_to_int(entity_type_)) == mask.size());
Omega_h::Write<LO> index_mask(mask.size());
auto index_mask_view = make_array_view(index_mask);
auto mask_view = make_const_array_view(mask);
Expand All @@ -153,7 +160,7 @@ class OmegaHField
Kokkos::parallel_for(policy, detail::ScaleAV{index_mask_view, mask_view});
mask_ = index_mask;
} else {
size_ = mesh.nents(static_cast<int>(entity_type_));
size_ = mesh.nents(mesh_entity_to_int(entity_type_));
}
}

Expand All @@ -164,7 +171,7 @@ class OmegaHField
return mask_;
};
[[nodiscard]] bool HasMask() const noexcept { return mask_.exists(); };
[[nodiscard]] pcms::detail::mesh_entity_type GetEntityType() const noexcept
[[nodiscard]] mesh_entity_type GetEntityType() const noexcept
{
return entity_type_;
}
Expand All @@ -179,29 +186,29 @@ class OmegaHField
PCMS_FUNCTION_TIMER;
if (HasMask())
return detail::filter_array(
mesh_.get_array<Omega_h::ClassId>(static_cast<int>(entity_type_), "class_id"), GetMask(), Size());
return mesh_.get_array<Omega_h::ClassId>(static_cast<int>(entity_type_), "class_id");
mesh_.get_array<Omega_h::ClassId>(mesh_entity_to_int(entity_type_), "class_id"), GetMask(), Size());
return mesh_.get_array<Omega_h::ClassId>(mesh_entity_to_int(entity_type_), "class_id");
}
[[nodiscard]] Omega_h::Read<Omega_h::I8> GetClassDims() const
{
PCMS_FUNCTION_TIMER;
if (HasMask())
return detail::filter_array(mesh_.get_array<Omega_h::I8>(static_cast<int>(entity_type_), "class_dim"),
return detail::filter_array(mesh_.get_array<Omega_h::I8>(mesh_entity_to_int(entity_type_), "class_dim"),
GetMask(), Size());
return mesh_.get_array<Omega_h::I8>(static_cast<int>(entity_type_), "class_dim");
return mesh_.get_array<Omega_h::I8>(mesh_entity_to_int(entity_type_), "class_dim");
}
[[nodiscard]] Omega_h::Read<Omega_h::GO> GetGids() const
{
PCMS_FUNCTION_TIMER;
Omega_h::Read<Omega_h::GO> gid_array;
if (global_id_name_.empty()) {
gid_array = mesh_.globals(static_cast<int>(entity_type_));
gid_array = mesh_.globals(mesh_entity_to_int(entity_type_));
} else {
auto tag = mesh_.get_tagbase(static_cast<int>(entity_type_), global_id_name_);
auto tag = mesh_.get_tagbase(mesh_entity_to_int(entity_type_), global_id_name_);
if (Omega_h::is<GO>(tag)) {
gid_array = mesh_.get_array<Omega_h::GO>(static_cast<int>(entity_type_), global_id_name_);
gid_array = mesh_.get_array<Omega_h::GO>(mesh_entity_to_int(entity_type_), global_id_name_);
} else if (Omega_h::is<LO>(tag)) {
auto array = mesh_.get_array<Omega_h::LO>(static_cast<int>(entity_type_), global_id_name_);
auto array = mesh_.get_array<Omega_h::LO>(mesh_entity_to_int(entity_type_), global_id_name_);
Omega_h::Write<Omega_h::GO> globals(array.size());
Omega_h::parallel_for(
array.size(), OMEGA_H_LAMBDA(int i) { globals[i] = array[i]; });
Expand All @@ -224,7 +231,7 @@ class OmegaHField
Omega_h::Read<LO> mask_;
LO size_;
std::string global_id_name_;
detail::mesh_entity_type entity_type_;
mesh_entity_type entity_type_;
};

using InternalCoordinateElement = Real;
Expand All @@ -242,7 +249,7 @@ auto get_nodal_data(const OmegaHField<T, CoordinateElementType>& field)
-> Omega_h::Read<T>
{
PCMS_FUNCTION_TIMER;
auto full_field = field.GetMesh().template get_array<T>(static_cast<int>(field.GetEntityType()), field.GetName());
auto full_field = field.GetMesh().template get_array<T>(mesh_entity_to_int(field.GetEntityType()), field.GetName());
if (field.HasMask()) {
return detail::filter_array<T>(full_field, field.GetMask(), field.Size());
}
Expand All @@ -258,13 +265,13 @@ auto get_nodal_coordinates(const OmegaHField<T, CoordinateElementType>& field)
static constexpr auto coordinate_dimension = 2;
if constexpr (detail::HasCoordinateSystem<CoordinateElementType>::value) {
//const auto coords = field.GetMesh().coords();
const auto coords = get_ent_centroids(field.GetMesh(), static_cast<int>(field.GetEntityType()));
const auto coords = get_ent_centroids(field.GetMesh(), mesh_entity_to_int(field.GetEntityType()));
return MDArray<CoordinateElementType>{};
// FIXME implement copy to
throw;
} else {
//auto coords = Omega_h::Reals{field.GetMesh().coords()};
auto coords = get_ent_centroids(field.GetMesh(), static_cast<int>(field.GetEntityType()));
auto coords = get_ent_centroids(field.GetMesh(), mesh_entity_to_int(field.GetEntityType()));
if (field.HasMask()) {
return detail::filter_array<typename decltype(coords)::value_type,
coordinate_dimension>(coords, field.GetMask(),
Expand All @@ -289,38 +296,38 @@ auto set_nodal_data(const OmegaHField<T, CoordinateElementType>& field,
"must be able to convert nodal data into the field types data");
auto& mesh = field.GetMesh();
auto entity_type = field.GetEntityType();
const auto has_tag = mesh.has_tag(static_cast<int>(entity_type), field.GetName());
const auto has_tag = mesh.has_tag(mesh_entity_to_int(entity_type), field.GetName());
if (field.HasMask()) {
auto& mask = field.GetMask();
PCMS_ALWAYS_ASSERT(mask.size() == mesh.nents(static_cast<int>(entity_type)));
PCMS_ALWAYS_ASSERT(mask.size() == mesh.nents(mesh_entity_to_int(entity_type)));
Omega_h::Write<T> array(mask.size());
if (has_tag) {
auto original_data = mesh.template get_array<T>(static_cast<int>(entity_type), field.GetName());
auto original_data = mesh.template get_array<T>(mesh_entity_to_int(entity_type), field.GetName());
PCMS_ALWAYS_ASSERT(original_data.size() == mask.size());
Omega_h::parallel_for(
mask.size(), OMEGA_H_LAMBDA(size_t i) {
array[i] = mask[i] ? data(mask[i] - 1) : original_data[i];
});
mesh.set_tag(static_cast<int>(entity_type), field.GetName(), Omega_h::Read<T>(array));
mesh.set_tag(mesh_entity_to_int(entity_type), field.GetName(), Omega_h::Read<T>(array));
} else {
Omega_h::parallel_for(
mask.size(), OMEGA_H_LAMBDA(size_t i) {
array[i] = mask[i] ? data(mask[i] - 1) : 0;
});
mesh.add_tag(static_cast<int>(entity_type), field.GetName(), 1, Omega_h::Read<T>(array));
mesh.add_tag(mesh_entity_to_int(entity_type), field.GetName(), 1, Omega_h::Read<T>(array));
}
} else {
PCMS_ALWAYS_ASSERT(static_cast<LO>(data.size()) == mesh.nents(static_cast<int>(entity_type)));
PCMS_ALWAYS_ASSERT(static_cast<LO>(data.size()) == mesh.nents(mesh_entity_to_int(entity_type)));
Omega_h::Write<T> array(data.size());
Omega_h::parallel_for(
data.size(), OMEGA_H_LAMBDA(size_t i) { array[i] = data(i); });
if (has_tag) {
mesh.set_tag(static_cast<int>(entity_type), field.GetName(), Omega_h::Read<T>(array));
mesh.set_tag(mesh_entity_to_int(entity_type), field.GetName(), Omega_h::Read<T>(array));
} else {
mesh.add_tag(static_cast<int>(entity_type), field.GetName(), 1, Omega_h::Read<T>(array));
mesh.add_tag(mesh_entity_to_int(entity_type), field.GetName(), 1, Omega_h::Read<T>(array));
}
}
PCMS_ALWAYS_ASSERT(mesh.has_tag(static_cast<int>(entity_type), field.GetName()));
PCMS_ALWAYS_ASSERT(mesh.has_tag(mesh_entity_to_int(entity_type), field.GetName()));
}

// TODO abstract out repeat parts of lagrange/nearest neighbor evaluation
Expand Down Expand Up @@ -488,23 +495,21 @@ class OmegaHFieldAdapter
using coordinate_element_type = CoordinateElementType;
OmegaHFieldAdapter(std::string name, Omega_h::Mesh& mesh,
std::string global_id_name = "", int search_nx = 10,
int search_ny = 10, detail::mesh_entity_type entity_type = detail::mesh_entity_type::VERTEX)
int search_ny = 10, mesh_entity_type entity_type = mesh_entity_type::VERTEX)
: field_{std::move(name), mesh, std::move(global_id_name), search_nx,
search_ny, entity_type}, entity_type_{entity_type}
{
PCMS_FUNCTION_TIMER;
printf("name: %s, ent %d | %d\n", name.c_str(), static_cast<int>(entity_type_), static_cast<int>(entity_type));
}

OmegaHFieldAdapter(std::string name, Omega_h::Mesh& mesh,
Omega_h::Read<Omega_h::I8> mask,
std::string global_id_name = "", int search_nx = 10,
int search_ny = 10, detail::mesh_entity_type entity_type = detail::mesh_entity_type::VERTEX)
int search_ny = 10, mesh_entity_type entity_type = mesh_entity_type::VERTEX)
: field_{std::move(name), mesh, mask,
std::move(global_id_name), search_nx, search_ny, entity_type}, entity_type_{entity_type}
{
PCMS_FUNCTION_TIMER;
printf("name: %s, ent %d | %d\n", name.c_str(), static_cast<int>(entity_type_), static_cast<int>(entity_type));
}
[[nodiscard]] const std::string& GetName() const noexcept
{
Expand Down Expand Up @@ -558,7 +563,7 @@ class OmegaHFieldAdapter
auto classIds_h = Omega_h::HostRead<Omega_h::ClassId>(field_.GetClassIDs());
auto classDims_h = Omega_h::HostRead<Omega_h::I8>(field_.GetClassDims());
//const auto coords = Omega_h::HostRead(field_.GetMesh().coords());
const auto coords = Omega_h::HostRead(get_ent_centroids(field_.GetMesh(), static_cast<int>(entity_type_)));
const auto coords = Omega_h::HostRead(get_ent_centroids(field_.GetMesh(), mesh_entity_to_int(entity_type_)));
auto dim = field_.GetMesh().dim();

// local_index number of vertices going to each destination process by
Expand Down Expand Up @@ -588,14 +593,14 @@ class OmegaHFieldAdapter
return field_;
}

[[nodiscard]] pcms::detail::mesh_entity_type GetEntityType() const noexcept
[[nodiscard]] mesh_entity_type GetEntityType() const noexcept
{
return entity_type_;
}

private:
OmegaHField<T, CoordinateElementType> field_;
pcms::detail::mesh_entity_type entity_type_;
mesh_entity_type entity_type_;
};
template <typename FieldAdapter>
void ConvertFieldAdapterToOmegaH(const FieldAdapter& adapter,
Expand Down
4 changes: 2 additions & 2 deletions src/pcms/xgc_field_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ class XGCFieldAdapter
return (plane_rank_ == plane_root_);
}

[[nodiscard]] pcms::detail::mesh_entity_type GetEntityType() const noexcept
[[nodiscard]] pcms::mesh_entity_type GetEntityType() const noexcept
{
return pcms::detail::mesh_entity_type::VERTEX;
return pcms::mesh_entity_type::VERTEX;
}

private:
Expand Down

0 comments on commit 2ddbfae

Please sign in to comment.