Skip to content

Commit

Permalink
Adding synchronous collective operations
Browse files Browse the repository at this point in the history
- adding predefined world_comunicator
  • Loading branch information
hkaiser committed Jan 6, 2025
1 parent 64b1c0d commit 1fea02f
Show file tree
Hide file tree
Showing 22 changed files with 2,553 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,9 @@ if(WIN32)
# Silence C++20 deprecation warnings
hpx_add_config_cond_define(_SILENCE_ALL_CXX20_DEPRECATION_WARNINGS)

# Silence C++23 deprecation warnings
hpx_add_config_cond_define(_SILENCE_ALL_CXX23_DEPRECATION_WARNINGS)

# ASan is available in Visual Studion starting V16.8
if((MSVC_VERSION GREATER_EQUAL 1928) AND HPX_WITH_SANITIZERS)
hpx_add_target_compile_option(
Expand Down
36 changes: 36 additions & 0 deletions libs/full/collectives/include/hpx/collectives/all_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,42 @@ namespace hpx::collectives {
generation, root_site),
HPX_FORWARD(T, local_result), this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
communicator fid, T&& local_result,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return all_gather(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
communicator fid, T&& local_result, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return all_gather(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
char const* basename, T&& local_result,
num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
return all_gather(create_communicator(basename, num_sites, this_site,
generation, root_site),
HPX_FORWARD(T, local_result), this_site)
.get();
}
} // namespace hpx::collectives

////////////////////////////////////////////////////////////////////////////////
Expand Down
36 changes: 35 additions & 1 deletion libs/full/collectives/include/hpx/collectives/all_reduce.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2024 Hartmut Kaiser
// Copyright (c) 2019-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -272,6 +272,40 @@ namespace hpx::collectives {
generation, root_site),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site);
}

////////////////////////////////////////////////////////////////////////////
template <typename T, typename F>
decltype(auto) all_reduce(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return all_reduce(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) all_reduce(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return all_reduce(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) all_reduce(hpx::launch::sync_policy, char const* basename,
T&& local_result, F&& op, num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
return all_reduce(create_communicator(basename, num_sites, this_site,
generation, root_site),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site)
.get();
}
} // namespace hpx::collectives

////////////////////////////////////////////////////////////////////////////////
Expand Down
40 changes: 37 additions & 3 deletions libs/full/collectives/include/hpx/collectives/all_to_all.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2024 Hartmut Kaiser
// Copyright (c) 2019-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -222,7 +222,6 @@ namespace hpx::collectives {
return fid.then(hpx::launch::sync, HPX_MOVE(all_to_all_data));
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<std::vector<T>> all_to_all(communicator fid,
std::vector<T>&& local_result, generation_arg generation,
Expand All @@ -232,7 +231,6 @@ namespace hpx::collectives {
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<std::vector<T>> all_to_all(char const* basename,
std::vector<T>&& local_result,
Expand All @@ -245,6 +243,42 @@ namespace hpx::collectives {
generation, root_site),
HPX_MOVE(local_result), this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
std::vector<T> all_to_all(hpx::launch::sync_policy, communicator fid,
std::vector<T>&& local_result,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return all_to_all(
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation)
.get();
}

template <typename T>
std::vector<T> all_to_all(hpx::launch::sync_policy, communicator fid,
std::vector<T>&& local_result, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return all_to_all(
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation)
.get();
}

template <typename T>
std::vector<T> all_to_all(hpx::launch::sync_policy, char const* basename,
std::vector<T>&& local_result,
num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
return all_to_all(create_communicator(basename, num_sites, this_site,
generation, root_site),
HPX_MOVE(local_result), this_site)
.get();
}
} // namespace hpx::collectives

////////////////////////////////////////////////////////////////////////////////
Expand Down
90 changes: 88 additions & 2 deletions libs/full/collectives/include/hpx/collectives/broadcast.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2024 Hartmut Kaiser
// Copyright (c) 2020-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -200,7 +200,6 @@ namespace hpx { namespace collectives {
#include <hpx/assert.hpp>
#include <hpx/async_base/launch_policy.hpp>
#include <hpx/async_distributed/async.hpp>
#include <hpx/async_local/dataflow.hpp>
#include <hpx/collectives/argument_types.hpp>
#include <hpx/collectives/create_communicator.hpp>
#include <hpx/components_base/agas_interface.hpp>
Expand Down Expand Up @@ -334,6 +333,39 @@ namespace hpx::collectives {
HPX_FORWARD(T, local_result), this_site);
}

////////////////////////////////////////////////////////////////////////////
template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, communicator fid,
T&& local_result, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_to(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, communicator fid,
T&& local_result, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return broadcast_to(
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
.get();
}

template <typename T>
decltype(auto) broadcast_to(hpx::launch::sync_policy, char const* basename,
T&& local_result, num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_to(hpx::launch::sync,
create_communicator(basename, num_sites, this_site, generation,
root_site_arg(this_site.argument_)),
HPX_FORWARD(T, local_result), this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<T> broadcast_from(communicator fid,
Expand Down Expand Up @@ -392,6 +424,60 @@ namespace hpx::collectives {
this_site, generation, root_site),
this_site);
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
T broadcast_from(hpx::launch::sync_policy, communicator fid,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return broadcast_from<T>(HPX_MOVE(fid), this_site, generation).get();
}

template <typename T>
T broadcast_from(hpx::launch::sync_policy, communicator fid,
generation_arg generation, this_site_arg this_site = this_site_arg())
{
return broadcast_from<T>(HPX_MOVE(fid), this_site, generation).get();
}

template <typename T>
T broadcast_from(hpx::launch::sync_policy, char const* basename,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
HPX_ASSERT(this_site != root_site);
return broadcast_from<T>(create_communicator(basename, num_sites_arg(),
this_site, generation, root_site),
this_site)
.get();
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
void broadcast(communicator fid, T& value,
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
if (this_site == static_cast<std::size_t>(-1))
{
this_site = static_cast<std::size_t>(agas::get_locality_id());
}

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
{
broadcast_to(
hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation);
}
else
{
value = broadcast_from<T>(
hpx::launch::sync, HPX_MOVE(fid), this_site, generation);
}
}
} // namespace hpx::collectives

////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ namespace hpx::collectives {
}
};

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
HPX_EXPORT communicator get_world_communicator();

///////////////////////////////////////////////////////////////////////////
HPX_EXPORT communicator create_communicator(char const* basename,
num_sites_arg num_sites = num_sites_arg(),
Expand Down
37 changes: 36 additions & 1 deletion libs/full/collectives/include/hpx/collectives/exclusive_scan.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2024 Hartmut Kaiser
// Copyright (c) 2019-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -283,6 +283,41 @@ namespace hpx::collectives {
this_site, generation, root_site),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site);
}

////////////////////////////////////////////////////////////////////////////
template <typename T, typename F>
decltype(auto) exclusive_scan(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg())
{
return exclusive_scan(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) exclusive_scan(hpx::launch::sync_policy, communicator fid,
T&& local_result, F&& op, generation_arg generation,
this_site_arg this_site = this_site_arg())
{
return exclusive_scan(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
HPX_FORWARD(F, op), this_site, generation)
.get();
}

template <typename T, typename F>
decltype(auto) exclusive_scan(hpx::launch::sync_policy,
char const* basename, T&& local_result, F&& op,
num_sites_arg num_sites = num_sites_arg(),
this_site_arg this_site = this_site_arg(),
generation_arg generation = generation_arg(),
root_site_arg root_site = root_site_arg())
{
return exclusive_scan(create_communicator(basename, num_sites,
this_site, generation, root_site),
HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site)
.get();
}
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
Expand Down
Loading

0 comments on commit 1fea02f

Please sign in to comment.