Skip to content

Commit

Permalink
Fix view_reduction in case the result var comes from shared mem
Browse files Browse the repository at this point in the history
  • Loading branch information
bartgol committed Oct 12, 2022
1 parent 9deb8f8 commit 38176bd
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/ekat/kokkos/ekat_kokkos_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ void view_reduction (const TeamMember& team,
using PackType = typename std::remove_reference<decltype(input(0))>::type;
constexpr int vector_size = PackType::n;

// We need to use a temporary, since we don't know whether result refers to a thread-local
// variable (e.g., automatic variable) or to shared-memory (e.g., an entry of a view).
// Hence, perform calculations on a local var, then copy back into the output result.
ValueType temp = result;

// Note: this team barrier is needed in some extreme case. Without it, it *could* happen that,
// if result is a ref to shared mem (e.g., an entry of a view) rather than thread-local,
// one team member might reach the end of the fcn (hence, updating result) *before*
// another thread might have the chance to init temp.
team.team_barrier();

// Perform a packed reduction over scalar indices
const bool has_garbage_begin = begin%vector_size != 0;
const bool has_garbage_end = end%vector_size != 0;
Expand All @@ -130,7 +141,7 @@ void view_reduction (const TeamMember& team,
const int first_indx = begin%vector_size;
Kokkos::single(Kokkos::PerThread(team),[&] {
for (int j=first_indx; j<vector_size; ++j) {
result += temp_input[j];
temp += temp_input[j];
}
});
}
Expand All @@ -143,7 +154,7 @@ void view_reduction (const TeamMember& team,
[&](const int k, ValueType& local_sum) {
// Sum over pack entries and add to local_sum
ekat::reduce_sum<Serialize>(input(k),local_sum);
}, result);
}, temp);
} else {
PackType packed_result(0);
impl::parallel_reduce<Serialize>(team, pack_loop_begin, pack_loop_end,
Expand All @@ -152,7 +163,7 @@ void view_reduction (const TeamMember& team,
local_packed_sum += input(k);
}, packed_result);

result += ekat::reduce_sum<Serialize>(packed_result);
temp += ekat::reduce_sum<Serialize>(packed_result);
}
}

Expand All @@ -165,10 +176,11 @@ void view_reduction (const TeamMember& team,
ConstExceptGnu int last_indx = end%vector_size;
Kokkos::single(Kokkos::PerThread(team),[&] {
for (int j=0; j<last_indx; ++j) {
result += temp_input[j];
temp += temp_input[j];
}
});
}
result = temp;
}
} //namespace impl

Expand Down

0 comments on commit 38176bd

Please sign in to comment.