Skip to content

Commit

Permalink
Fixing sync collectives
Browse files Browse the repository at this point in the history
- adding example
  • Loading branch information
hkaiser committed Jan 7, 2025
1 parent a40e83e commit 141c921
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 21 deletions.
5 changes: 4 additions & 1 deletion libs/full/collectives/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ else()
return()
endif()

set(example_programs channel_communicator)
set(example_programs channel_communicator distributed_pi)

set(channel_communicator_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(channel_communicator_FLAGS DEPENDENCIES iostreams_component)

set(distributed_pi_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(distributed_pi_FLAGS COMPILE_FLAGS -DHPX_HAVE_RUN_MAIN_EVERYWHERE)

foreach(example_program ${example_programs})

set(sources ${example_program}.cpp)
Expand Down
47 changes: 47 additions & 0 deletions libs/full/collectives/examples/distributed_pi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <hpx/hpx.hpp>
#include <hpx/hpx_main.hpp>

#include <cstddef>
#include <cstdint>
#include <iostream>
#include <string>

inline double sqr(double val)
{
return val * val;
}

int main(int argc, char* argv[])
{
std::size_t N = 1'000'000;
std::uint32_t num_localities = hpx::get_num_localities(hpx::launch::sync);
std::uint32_t locality_id = hpx::get_locality_id();

if (locality_id == 0 && argc > 1)
N = std::stol(argv[1]);

hpx::collectives::broadcast(hpx::collectives::get_world_communicator(), N);

std::size_t const blocksize = N / num_localities;
std::size_t const begin = blocksize * locality_id;
std::size_t const end = blocksize * (locality_id + 1);
double h = 1.0 / N;

double pi = 0.0;
for (std::size_t i = begin; i != end; ++i)
pi += h * 4.0 / (1 + sqr(i * h));

hpx::collectives::reduce(
hpx::collectives::get_world_communicator(), pi, std::plus{});

if (locality_id == 0)
std::cout << "pi: " << pi << std::endl;

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ namespace hpx::collectives {

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
if (this_site == std::get<2>(fid.get_info_ex()))
{
broadcast_to(
hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2023 Hartmut Kaiser
// Copyright (c) 2020-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -113,6 +113,7 @@ namespace hpx { namespace collectives {
#include <hpx/components/client_base.hpp>
#include <hpx/type_support/extra_data.hpp>

#include <tuple>
#include <utility>

///////////////////////////////////////////////////////////////////////////////
Expand All @@ -123,6 +124,7 @@ namespace hpx::collectives::detail {
{
num_sites_arg num_sites_;
this_site_arg this_site_;
root_site_arg root_site_;
};
} // namespace hpx::collectives::detail

Expand Down Expand Up @@ -173,8 +175,13 @@ namespace hpx::collectives {
{
}

HPX_EXPORT void set_info(
num_sites_arg num_sites, this_site_arg this_site) noexcept;
HPX_EXPORT void set_info(num_sites_arg num_sites,
this_site_arg this_site,
root_site_arg root_site = root_site_arg()) noexcept;

[[nodiscard]] HPX_EXPORT
std::tuple<num_sites_arg, this_site_arg, root_site_arg>
get_info_ex() const noexcept;

[[nodiscard]] HPX_EXPORT std::pair<num_sites_arg, this_site_arg>
get_info() const noexcept;
Expand All @@ -186,9 +193,14 @@ namespace hpx::collectives {
};

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
// Predefined global communicator (refers to all localities)
HPX_EXPORT communicator get_world_communicator();

///////////////////////////////////////////////////////////////////////////
// Predefined local communicator (refers to all threads on the calling
// locality)
HPX_EXPORT communicator get_local_communicator();

///////////////////////////////////////////////////////////////////////////
HPX_EXPORT communicator create_communicator(char const* basename,
num_sites_arg num_sites = num_sites_arg(),
Expand Down
2 changes: 1 addition & 1 deletion libs/full/collectives/include/hpx/collectives/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ namespace hpx::collectives {

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
if (this_site == std::get<2>(fid.get_info_ex()))
{
local_result = reduce_here(hpx::launch::sync, HPX_MOVE(fid),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site,
Expand Down
83 changes: 70 additions & 13 deletions libs/full/collectives/src/create_communicator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2023 Hartmut Kaiser
// Copyright (c) 2020-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -68,14 +68,15 @@ namespace hpx::collectives {
} // namespace detail

///////////////////////////////////////////////////////////////////////////
void communicator::set_info(
num_sites_arg num_sites, this_site_arg this_site) noexcept
void communicator::set_info(num_sites_arg num_sites,
this_site_arg this_site, root_site_arg root_site) noexcept
{
auto& [num_sites_, this_site_] =
auto& [num_sites_, this_site_, root_site_] =
get_extra_data<detail::communicator_data>();

num_sites_ = num_sites;
this_site_ = this_site;
root_site_ = root_site;
}

std::pair<num_sites_arg, this_site_arg> communicator::get_info()
Expand All @@ -86,11 +87,27 @@ namespace hpx::collectives {

if (client_data != nullptr)
{
return std::make_pair(
return std::make_tuple(
client_data->num_sites_, client_data->this_site_);
}

return std::make_pair(num_sites_arg{}, this_site_arg{});
return std::make_tuple(num_sites_arg{}, this_site_arg{});
}

std::tuple<num_sites_arg, this_site_arg, root_site_arg>
communicator::get_info_ex() const noexcept
{
auto const* client_data =
try_get_extra_data<detail::communicator_data>();

if (client_data != nullptr)
{
return std::make_tuple(client_data->num_sites_,
client_data->this_site_, client_data->root_site_);
}

return std::make_tuple(
num_sites_arg{}, this_site_arg{}, root_site_arg());
}

///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -141,13 +158,17 @@ namespace hpx::collectives {
"operation was already registered: {}",
target.registered_name());
}
target.set_info(num_sites, this_site);
target.set_info(num_sites, this_site, root_site);
return target;
});
}

// find existing communicator
return hpx::find_from_basename<communicator>(HPX_MOVE(name), root_site);
return hpx::find_from_basename<communicator>(HPX_MOVE(name), root_site)
.then(hpx::launch::sync, [=](communicator&& c) {
c.set_info(num_sites, this_site, root_site);
return c;
});
}

///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -193,31 +214,67 @@ namespace hpx::collectives {
c.registered_name());
}

c.set_info(num_sites, this_site);
c.set_info(num_sites, this_site, root_site);
return c;
}

// find existing communicator
return hpx::find_from_basename<communicator>(HPX_MOVE(name), root_site);
return hpx::find_from_basename<communicator>(HPX_MOVE(name), root_site)
.then(hpx::launch::sync, [=](communicator&& c) {
c.set_info(num_sites, this_site, root_site);
return c;
});
}

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
namespace {

communicator world_communicator;
hpx::mutex world_communicator_mtx;
communicator local_communicator;
hpx::mutex communicator_mtx;
} // namespace

communicator get_world_communicator()
{
{
std::lock_guard<hpx::mutex> l(world_communicator_mtx);
std::lock_guard<hpx::mutex> l(communicator_mtx);
if (!world_communicator)
{
auto const num_sites =
num_sites_arg(agas::get_num_localities(hpx::launch::sync));
auto const this_site = this_site_arg(agas::get_locality_id());

world_communicator =
create_communicator("hpx::collectives::world_communicator");
create_communicator("/0/world_communicator", num_sites,
this_site, generation_arg(), root_site_arg(0));
world_communicator.set_info(
num_sites, this_site, root_site_arg(0));
}
}
return world_communicator;
}

communicator get_local_communicator()
{
{
std::lock_guard<hpx::mutex> l(communicator_mtx);
if (!local_communicator)
{
auto const num_sites =
num_sites_arg(hpx::get_num_worker_threads());
auto const this_site =
this_site_arg(hpx::get_worker_thread_num());

local_communicator =
create_local_communicator("/local_communicator", num_sites,
this_site, generation_arg(), root_site_arg(0));
local_communicator.set_info(
num_sites, this_site, root_site_arg(0));
}
}
return local_communicator;
}
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
3 changes: 2 additions & 1 deletion libs/full/include/include/hpx/hpx.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2023 Hartmut Kaiser
// Copyright (c) 2007-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand All @@ -9,6 +9,7 @@
#include <hpx/algorithm.hpp>
#include <hpx/any.hpp>
#include <hpx/chrono.hpp>
#include <hpx/collectives.hpp>
#include <hpx/execution.hpp>
#include <hpx/functional.hpp>
#include <hpx/future.hpp>
Expand Down

0 comments on commit 141c921

Please sign in to comment.