Skip to content

Commit

Permalink
Add common reference semantics to sycl graphs
Browse files Browse the repository at this point in the history
Adds missing common reference semantic functionality such
as operator==, operator!= and hash functions to all
sycl graph related classes.
  • Loading branch information
fabiomestre committed Jan 27, 2025
1 parent 6965142 commit e786630
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 6 deletions.
84 changes: 83 additions & 1 deletion sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
#include <sycl/detail/string_view.hpp>
#endif
#include <sycl/device.hpp> // for device
#include <sycl/device.hpp> // for device
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.hpp> // for graph properties classes
#include <sycl/nd_range.hpp> // for range, nd_range
#include <sycl/properties/property_traits.hpp> // for is_property, is_property_of
Expand Down Expand Up @@ -142,6 +142,14 @@ class __SYCL_EXPORT node {
/// Update the Range of this node if it is a kernel execution node
template <int Dimensions> void update_range(range<Dimensions> executionRange);

/// Common Reference Semantics
friend bool operator==(const node &LHS, const node &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const node &LHS, const node &RHS) {
return LHS.impl != RHS.impl;
}

private:
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}

Expand Down Expand Up @@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group {
size_t get_active_index() const;
void set_active_index(size_t Index);

/// Common Reference Semantics
friend bool operator==(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_command_group &LHS,
const dynamic_command_group &RHS) {
return LHS.impl != RHS.impl;
}

private:
template <class Obj>
friend const decltype(Obj::impl) &
Expand Down Expand Up @@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph
/// Get a list of all root nodes (nodes without dependencies) in this graph.
std::vector<node> get_root_nodes() const;

/// Common Reference Semantics
friend bool operator==(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
return LHS.impl != RHS.impl;
}

protected:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
Expand Down Expand Up @@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph
/// @param Nodes The nodes to use for updating the graph.
void update(const std::vector<node> &Nodes);

/// Common Reference Semantics
friend bool operator==(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const executable_command_graph &LHS,
const executable_command_graph &RHS) {
return LHS.impl != RHS.impl;
}

protected:
/// Constructor used by internal runtime.
/// @param Graph Detail implementation class to construct with.
Expand Down Expand Up @@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base {
Graph,
size_t ParamSize, const void *Data);

/// Common Reference Semantics
friend bool operator==(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return LHS.impl == RHS.impl;
}
friend bool operator!=(const dynamic_parameter_base &LHS,
const dynamic_parameter_base &RHS) {
return LHS.impl != RHS.impl;
}

protected:
void updateValue(const void *NewValue, size_t Size);

Expand Down Expand Up @@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice,

} // namespace _V1
} // namespace sycl

namespace std {
template <> struct __SYCL_EXPORT hash<sycl::ext::oneapi::experimental::node> {
size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const;
};

template <>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_command_group> {
size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group
&DynamicCGH) const;
};

template <sycl::ext::oneapi::experimental::graph_state State>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::command_graph<State>> {
size_t operator()(const sycl::ext::oneapi::experimental::command_graph<State>
&Graph) const {
auto ID = sycl::detail::getSyclObjImpl(Graph)->getID();
return std::hash<decltype(ID)>()(ID);
}
};

template <typename ValueT>
struct __SYCL_EXPORT
hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
size_t
operator()(const sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>
&DynamicParam) const {
auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID();
return std::hash<decltype(ID)>()(ID);
}
};
} // namespace std
23 changes: 20 additions & 3 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
const sycl::device &SyclDevice,
const sycl::property_list &PropList)
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap(), MInorderQueueMap() {
MEventsMap(), MInorderQueueMap(),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
if (PropList.has_property<property::graph::no_cycle_check>()) {
MSkipCycleChecks = true;
Expand Down Expand Up @@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()),
MEnableProfiling(
PropList.has_property<property::graph::enable_profiling>()) {
PropList.has_property<property::graph::enable_profiling>()),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
checkGraphPropertiesAndThrow(PropList);
// If the graph has been marked as updatable then check if the backend
// actually supports that. Devices supporting aspect::ext_oneapi_graph must
Expand Down Expand Up @@ -2035,7 +2037,8 @@ void dynamic_parameter_impl::updateCGAccessor(

dynamic_command_group_impl::dynamic_command_group_impl(
const command_graph<graph_state::modifiable> &Graph)
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {}

void dynamic_command_group_impl::finalizeCGFList(
const std::vector<std::function<void(handler &)>> &CGFList) {
Expand Down Expand Up @@ -2159,3 +2162,17 @@ void dynamic_command_group::set_active_index(size_t Index) {
} // namespace ext
} // namespace _V1
} // namespace sycl

size_t std::hash<sycl::ext::oneapi::experimental::node>::operator()(
const sycl::ext::oneapi::experimental::node &Node) const {
auto ID = sycl::detail::getSyclObjImpl(Node)->getID();
return std::hash<decltype(ID)>()(ID);
}

size_t
std::hash<sycl::ext::oneapi::experimental::dynamic_command_group>::operator()(
const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCGH)
const {
auto ID = sycl::detail::getSyclObjImpl(DynamicCGH)->getID();
return std::hash<decltype(ID)>()(ID);
}
28 changes: 26 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
return MBarrierDependencyMap[Queue];
}

unsigned long long getID() { return MID; }

private:
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
/// @param NodeFunc A function which receives as input a node in the graph to
Expand Down Expand Up @@ -1198,6 +1200,9 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MBarrierDependencyMap;

unsigned long long MID;
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

/// Class representing the implementation of command_graph<executable>.
Expand Down Expand Up @@ -1297,6 +1302,8 @@ class exec_graph_impl {

void updateImpl(std::shared_ptr<node_impl> NodeImpl);

unsigned long long getID() { return MID; }

private:
/// Create a command-group for the node and add it to command-buffer by going
/// through the scheduler.
Expand Down Expand Up @@ -1408,21 +1415,26 @@ class exec_graph_impl {
// Stores a cache of node ids from modifiable graph nodes to the companion
// node(s) in this graph. Used for quick access when updating this graph.
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;

unsigned long long MID;
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_parameter_impl {
public:
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
size_t ParamSize, const void *Data)
: MGraph(GraphImpl), MValueStorage(ParamSize) {
: MGraph(GraphImpl), MValueStorage(ParamSize),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
std::memcpy(MValueStorage.data(), Data, ParamSize);
}

/// sycl_ext_oneapi_raw_kernel_arg constructor
/// Parameter size is taken from member of raw_kernel_arg object.
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl, size_t,
raw_kernel_arg *Data)
: MGraph(GraphImpl) {
: MGraph(GraphImpl),
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
size_t RawArgSize = Data->MArgSize;
const void *RawArgData = Data->MArgData;
MValueStorage.reserve(RawArgSize);
Expand Down Expand Up @@ -1493,13 +1505,19 @@ class dynamic_parameter_impl {
int ArgIndex,
const sycl::detail::AccessorBaseHost *Acc);

unsigned long long getID() { return MID; }

// Weak ptrs to node_impls which will be updated
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
// Dynamic command-groups which will be updated
std::vector<DynamicCGInfo> MDynCGs;

std::shared_ptr<graph_impl> MGraph;
std::vector<std::byte> MValueStorage;

private:
unsigned long long MID;
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};

class dynamic_command_group_impl
Expand Down Expand Up @@ -1540,6 +1558,12 @@ class dynamic_command_group_impl

/// List of nodes using this dynamic command-group.
std::vector<std::weak_ptr<node_impl>> MNodes;

unsigned long long getID() { return MID; }

private:
unsigned long long MID;
inline static std::atomic<unsigned long long> NextAvailableID = 0;
};
} // namespace detail
} // namespace experimental
Expand Down
1 change: 1 addition & 0 deletions sycl/unittests/Extensions/CommandGraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)
add_sycl_unittest(CommandGraphExtensionTests OBJECT
Barrier.cpp
CommandGraph.cpp
CommonReferenceSemantics.cpp
Exceptions.cpp
InOrderQueue.cpp
MultiThreaded.cpp
Expand Down
Loading

0 comments on commit e786630

Please sign in to comment.