Skip to content

Commit

Permalink
Fixing sync collectives
Browse files Browse the repository at this point in the history
- adding example
- adding sync overloads for set/get
- adding predefined channel communicator
  • Loading branch information
hkaiser committed Jan 8, 2025
1 parent 09ffd2f commit 03470ae
Show file tree
Hide file tree
Showing 14 changed files with 284 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.examples.targets
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ tests.examples.quickstart.partitioned_vector_spmd_foreach
tests.examples.quickstart.sort_by_key_demo
tests.examples.transpose.transpose_block_numa
tests.examples.modules.collectives.distributed.tcp.channel_communicator
tests.examples.modules.collectives.distributed.tcp.distributed_pi
5 changes: 4 additions & 1 deletion libs/full/collectives/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ else()
return()
endif()

set(example_programs channel_communicator)
set(example_programs channel_communicator distributed_pi)

set(channel_communicator_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(channel_communicator_FLAGS DEPENDENCIES iostreams_component)

set(distributed_pi_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(distributed_pi_FLAGS COMPILE_FLAGS -DHPX_HAVE_RUN_MAIN_EVERYWHERE)

foreach(example_program ${example_programs})

set(sources ${example_program}.cpp)
Expand Down
47 changes: 47 additions & 0 deletions libs/full/collectives/examples/distributed_pi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <hpx/hpx.hpp>
#include <hpx/hpx_main.hpp>

#include <cstddef>
#include <cstdint>
#include <iostream>
#include <string>

inline double sqr(double val)
{
return val * val;
}

int main(int argc, char* argv[])
{
std::size_t N = 1'000'000;
std::uint32_t num_localities = hpx::get_num_localities(hpx::launch::sync);
std::uint32_t locality_id = hpx::get_locality_id();

if (locality_id == 0 && argc > 1)
N = std::stol(argv[1]);

hpx::collectives::broadcast(hpx::collectives::get_world_communicator(), N);

std::size_t const blocksize = N / num_localities;
std::size_t const begin = blocksize * locality_id;
std::size_t const end = blocksize * (locality_id + 1);
double h = 1.0 / N;

double pi = 0.0;
for (std::size_t i = begin; i != end; ++i)
pi += h * 4.0 / (1 + sqr(i * h));

hpx::collectives::reduce(
hpx::collectives::get_world_communicator(), pi, std::plus{});

if (locality_id == 0)
std::cout << "pi: " << pi << std::endl;

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ namespace hpx::collectives {

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
if (this_site == std::get<2>(fid.get_info_ex()))
{
broadcast_to(
hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ namespace hpx { namespace collectives {
#include <cstddef>
#include <memory>
#include <utility>
#include <vector>

namespace hpx { namespace collectives {
namespace hpx::collectives {

// forward declarations
class channel_communicator;
Expand All @@ -126,10 +125,18 @@ namespace hpx { namespace collectives {
hpx::future<T> get(
channel_communicator, that_site_arg, tag_arg = tag_arg());

template <typename T>
T get(hpx::launch::sync_policy, channel_communicator, that_site_arg,
tag_arg = tag_arg());

template <typename T>
hpx::future<void> set(
channel_communicator, that_site_arg, T&&, tag_arg = tag_arg());

template <typename T>
void set(hpx::launch::sync_policy, channel_communicator, that_site_arg, T&&,
tag_arg = tag_arg());

class channel_communicator
{
private:
Expand All @@ -140,10 +147,18 @@ namespace hpx { namespace collectives {
template <typename T>
friend hpx::future<T> get(channel_communicator, that_site_arg, tag_arg);

template <typename T>
friend T get(hpx::launch::sync_policy, channel_communicator,
that_site_arg, tag_arg);

template <typename T>
friend hpx::future<void> set(
channel_communicator, that_site_arg, T&&, tag_arg);

template <typename T>
friend void set(hpx::launch::sync_policy, channel_communicator,
that_site_arg, T&&, tag_arg);

private:
HPX_EXPORT channel_communicator(char const* basename,
num_sites_arg num_sites, this_site_arg this_site,
Expand All @@ -163,6 +178,11 @@ namespace hpx { namespace collectives {

HPX_EXPORT void free();

explicit operator bool() const noexcept
{
return comm_.get() != nullptr;
}

private:
std::shared_ptr<detail::channel_communicator> comm_;
};
Expand All @@ -185,14 +205,41 @@ namespace hpx { namespace collectives {
return comm.comm_->template get<T>(site.argument_, tag.argument_);
}

template <typename T>
T get(hpx::launch::sync_policy, channel_communicator comm,
that_site_arg site, tag_arg tag)
{
return comm.comm_->template get<T>(site.argument_, tag.argument_).get();
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<void> set(
channel_communicator comm, that_site_arg site, T&& value, tag_arg tag)
{
return comm.comm_->set(
site.argument_, HPX_FORWARD(T, value), tag.argument_);
}
}} // namespace hpx::collectives

template <typename T>
void set(hpx::launch::sync_policy, channel_communicator comm,
that_site_arg site, T&& value, tag_arg tag)
{
return comm.comm_
->set(site.argument_, HPX_FORWARD(T, value), tag.argument_)
.get();
}

///////////////////////////////////////////////////////////////////////////
// Predefined p2p communicator (refers to all localities)
HPX_EXPORT channel_communicator get_world_channel_communicator();

namespace detail {

HPX_EXPORT void create_world_channel_communicator();
HPX_EXPORT void reset_world_channel_communicator();
} // namespace detail
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
#endif // DOXYGEN
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2023 Hartmut Kaiser
// Copyright (c) 2020-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -113,6 +113,7 @@ namespace hpx { namespace collectives {
#include <hpx/components/client_base.hpp>
#include <hpx/type_support/extra_data.hpp>

#include <tuple>
#include <utility>

///////////////////////////////////////////////////////////////////////////////
Expand All @@ -123,6 +124,7 @@ namespace hpx::collectives::detail {
{
num_sites_arg num_sites_;
this_site_arg this_site_;
root_site_arg root_site_;
};
} // namespace hpx::collectives::detail

Expand Down Expand Up @@ -173,8 +175,13 @@ namespace hpx::collectives {
{
}

HPX_EXPORT void set_info(
num_sites_arg num_sites, this_site_arg this_site) noexcept;
HPX_EXPORT void set_info(num_sites_arg num_sites,
this_site_arg this_site,
root_site_arg root_site = root_site_arg()) noexcept;

[[nodiscard]] HPX_EXPORT
std::tuple<num_sites_arg, this_site_arg, root_site_arg>
get_info_ex() const noexcept;

[[nodiscard]] HPX_EXPORT std::pair<num_sites_arg, this_site_arg>
get_info() const noexcept;
Expand All @@ -186,9 +193,26 @@ namespace hpx::collectives {
};

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
// Predefined global communicator (refers to all localities)
HPX_EXPORT communicator get_world_communicator();

namespace detail {

HPX_EXPORT void create_global_communicator();
HPX_EXPORT void reset_global_communicator();
} // namespace detail

///////////////////////////////////////////////////////////////////////////
// Predefined local communicator (refers to all threads on the calling
// locality)
HPX_EXPORT communicator get_local_communicator();

namespace detail {

HPX_EXPORT void create_local_communicator();
HPX_EXPORT void reset_local_communicator();
} // namespace detail

///////////////////////////////////////////////////////////////////////////
HPX_EXPORT communicator create_communicator(char const* basename,
num_sites_arg num_sites = num_sites_arg(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
#include <hpx/lcos_local/channel.hpp>
#include <hpx/lock_registration/detail/register_locks.hpp>
#include <hpx/synchronization/spinlock.hpp>
#include <hpx/type_support/unused.hpp>

#include <cstddef>
#include <map>
#include <mutex>
#include <utility>
#include <vector>

namespace hpx { namespace collectives { namespace detail {
namespace hpx::collectives::detail {

///////////////////////////////////////////////////////////////////////////
class channel_communicator_server
Expand All @@ -39,7 +38,6 @@ namespace hpx { namespace collectives { namespace detail {

public:
channel_communicator_server() //-V730
: data_()
{
HPX_ASSERT(false); // shouldn't ever be called
}
Expand All @@ -57,8 +55,7 @@ namespace hpx { namespace collectives { namespace detail {

{
std::unique_lock l(data_[which].mtx_);
util::ignore_while_checking il(&l);
HPX_UNUSED(il);
[[maybe_unused]] util::ignore_while_checking il(&l);

channel_type& c = data_[which].channels_[tag];
f = c.get();
Expand All @@ -84,8 +81,7 @@ namespace hpx { namespace collectives { namespace detail {
void set(std::size_t which, T value, std::size_t tag)
{
std::unique_lock l(data_[which].mtx_);
util::ignore_while_checking il(&l);
HPX_UNUSED(il);
[[maybe_unused]] util::ignore_while_checking il(&l);

data_[which].channels_[tag].set(unique_any_nonser(HPX_MOVE(value)));
}
Expand Down Expand Up @@ -157,6 +153,6 @@ namespace hpx { namespace collectives { namespace detail {
std::size_t this_site_;
std::vector<client_type> clients_;
};
}}} // namespace hpx::collectives::detail
} // namespace hpx::collectives::detail

#endif // COMPUTE_HOST_CODE
2 changes: 1 addition & 1 deletion libs/full/collectives/include/hpx/collectives/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ namespace hpx::collectives {

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
if (this_site == std::get<2>(fid.get_info_ex()))
{
local_result = reduce_here(hpx::launch::sync, HPX_MOVE(fid),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site,
Expand Down
45 changes: 45 additions & 0 deletions libs/full/collectives/src/channel_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
#include <hpx/components_base/server/component.hpp>
#include <hpx/errors/exception.hpp>
#include <hpx/modules/futures.hpp>
#include <hpx/modules/lock_registration.hpp>
#include <hpx/runtime_components/new.hpp>
#include <hpx/synchronization/mutex.hpp>

#include <cstddef>
#include <memory>
#include <mutex>
#include <utility>

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -90,6 +93,48 @@ namespace hpx::collectives {
return create_channel_communicator(basename, num_sites, this_site)
.get();
}

///////////////////////////////////////////////////////////////////////////
// Predefined channel (p2p) communicator
namespace {

channel_communicator world_channel_communicator;
hpx::mutex world_channel_communicator_mtx;
} // namespace

channel_communicator get_world_channel_communicator()
{
detail::create_world_channel_communicator();
return world_channel_communicator;
}

namespace detail {

void create_world_channel_communicator()
{
std::unique_lock<hpx::mutex> l(world_channel_communicator_mtx);
[[maybe_unused]] util::ignore_while_checking il(&l);

if (!world_channel_communicator)
{
auto const num_sites =
num_sites_arg(agas::get_num_localities(hpx::launch::sync));
auto const this_site = this_site_arg(agas::get_locality_id());

world_channel_communicator =
collectives::create_channel_communicator(hpx::launch::sync,
"world_channel_communicator", num_sites, this_site);
}
}

void reset_world_channel_communicator()
{
if (world_channel_communicator)
{
world_channel_communicator.free();
}
}
} // namespace detail
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
Loading

0 comments on commit 03470ae

Please sign in to comment.