Skip to content

Commit

Permalink
[SYCL] No nullptr program_impl (#13155)
Browse files Browse the repository at this point in the history
The `kernel_impl` has been storing a `program_impl` pointer. But none of
the users of the `kernel_impl` need that, they just need the `PIProgram`
. Conversely, some of the callers constructing `kernel_impl` don't have
a `program_impl` ptr and so we were using nullptr, which leads to
crashes in some simple use cases.

Replacing the `program_impl` ptr with a `PIProgram` member var instead
and ensuring that all callers provide one, with the least amount of
change to the `kernel_impl` interface. Simple test added as well - it
crashes without this PR.
  • Loading branch information
cperkinsintel authored Mar 26, 2024
1 parent 838198d commit 9f296d8
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 18 deletions.
1 change: 0 additions & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,6 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else if (Kernel != nullptr) {
PiKernel = Kernel->getHandleRef();
auto SyclProg = Kernel->getProgramImpl();
EliminatedArgMask = Kernel->getKernelArgMask();
} else {
std::tie(PiKernel, std::ignore, EliminatedArgMask, std::ignore) =
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,9 @@ class kernel_bundle_impl {
MContext, KernelID.get_name(), /*PropList=*/{},
SelectedImage->get_program_ref());

std::shared_ptr<kernel_impl> KernelImpl =
std::make_shared<kernel_impl>(Kernel, detail::getSyclObjImpl(MContext),
SelectedImage, Self, ArgMask, CacheMutex);
std::shared_ptr<kernel_impl> KernelImpl = std::make_shared<kernel_impl>(
Kernel, detail::getSyclObjImpl(MContext), SelectedImage, Self, ArgMask,
SelectedImage->get_program_ref(), CacheMutex);

return detail::createSyclObjFromImpl<kernel>(KernelImpl);
}
Expand Down
11 changes: 6 additions & 5 deletions sycl/source/detail/kernel_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ kernel_impl::kernel_impl(sycl::detail::pi::PiKernel Kernel,
KernelBundleImplPtr KernelBundleImpl,
const KernelArgMask *ArgMask)
: MKernel(Kernel), MContext(ContextImpl),
MProgramImpl(std::move(ProgramImpl)),
MProgram(ProgramImpl->getHandleRef()),
MCreatedFromSource(IsCreatedFromSource),
MKernelBundleImpl(std::move(KernelBundleImpl)),
MKernelArgMaskPtr{ArgMask} {
Expand All @@ -55,23 +55,24 @@ kernel_impl::kernel_impl(sycl::detail::pi::PiKernel Kernel,
"Input context must be the same as the context of cl_kernel",
PI_ERROR_INVALID_CONTEXT);

MIsInterop = MProgramImpl->isInterop();
MIsInterop = ProgramImpl->isInterop();
}

kernel_impl::kernel_impl(sycl::detail::pi::PiKernel Kernel,
ContextImplPtr ContextImpl,
DeviceImageImplPtr DeviceImageImpl,
KernelBundleImplPtr KernelBundleImpl,
const KernelArgMask *ArgMask, std::mutex *CacheMutex)
: MKernel(Kernel), MContext(std::move(ContextImpl)), MProgramImpl(nullptr),
const KernelArgMask *ArgMask, PiProgram ProgramPI,
std::mutex *CacheMutex)
: MKernel(Kernel), MContext(std::move(ContextImpl)), MProgram(ProgramPI),
MCreatedFromSource(false), MDeviceImageImpl(std::move(DeviceImageImpl)),
MKernelBundleImpl(std::move(KernelBundleImpl)),
MKernelArgMaskPtr{ArgMask}, MCacheMutex{CacheMutex} {
MIsInterop = MKernelBundleImpl->isInterop();
}

kernel_impl::kernel_impl(ContextImplPtr Context, ProgramImplPtr ProgramImpl)
: MContext(Context), MProgramImpl(std::move(ProgramImpl)) {}
: MContext(Context), MProgram(ProgramImpl->getHandleRef()) {}

kernel_impl::~kernel_impl() {
// TODO catch an exception and put it to list of asynchronous exceptions
Expand Down
8 changes: 5 additions & 3 deletions sycl/source/detail/kernel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class kernel_bundle_impl;
using ContextImplPtr = std::shared_ptr<context_impl>;
using ProgramImplPtr = std::shared_ptr<program_impl>;
using KernelBundleImplPtr = std::shared_ptr<kernel_bundle_impl>;
using sycl::detail::pi::PiProgram;
class kernel_impl {
public:
/// Constructs a SYCL kernel instance from a PiKernel
Expand Down Expand Up @@ -74,7 +75,8 @@ class kernel_impl {
kernel_impl(sycl::detail::pi::PiKernel Kernel, ContextImplPtr ContextImpl,
DeviceImageImplPtr DeviceImageImpl,
KernelBundleImplPtr KernelBundleImpl,
const KernelArgMask *ArgMask, std::mutex *CacheMutex);
const KernelArgMask *ArgMask, PiProgram ProgramPI,
std::mutex *CacheMutex);

/// Constructs a SYCL kernel for host device
///
Expand Down Expand Up @@ -179,7 +181,7 @@ class kernel_impl {

bool isInterop() const { return MIsInterop; }

ProgramImplPtr getProgramImpl() const { return MProgramImpl; }
PiProgram getProgramRef() const { return MProgram; }
ContextImplPtr getContextImplPtr() const { return MContext; }

std::mutex &getNoncacheableEnqueueMutex() {
Expand All @@ -192,7 +194,7 @@ class kernel_impl {
private:
sycl::detail::pi::PiKernel MKernel;
const ContextImplPtr MContext;
const ProgramImplPtr MProgramImpl;
const PiProgram MProgram = nullptr;
bool MCreatedFromSource = true;
const DeviceImageImplPtr MDeviceImageImpl;
const KernelBundleImplPtr MKernelBundleImpl;
Expand Down
10 changes: 4 additions & 6 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1993,8 +1993,7 @@ void instrumentationAddExtraKernelMetadata(
EliminatedArgMask = KernelImpl->getKernelArgMask();
Program = KernelImpl->getDeviceImage()->get_program_ref();
} else if (nullptr != SyclKernel) {
auto SyclProg = SyclKernel->getProgramImpl();
Program = SyclProg->getHandleRef();
Program = SyclKernel->getProgramRef();
if (!SyclKernel->isCreatedFromSource())
EliminatedArgMask = SyclKernel->getKernelArgMask();
} else {
Expand Down Expand Up @@ -2489,8 +2488,7 @@ pi_int32 enqueueImpCommandBufferKernel(
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else if (Kernel != nullptr) {
PiKernel = Kernel->getHandleRef();
auto SyclProg = Kernel->getProgramImpl();
PiProgram = SyclProg->getHandleRef();
PiProgram = Kernel->getProgramRef();
EliminatedArgMask = Kernel->getKernelArgMask();
} else {
std::tie(PiKernel, std::ignore, EliminatedArgMask, PiProgram) =
Expand Down Expand Up @@ -2603,8 +2601,8 @@ pi_int32 enqueueImpKernel(
assert(MSyclKernel->get_info<info::kernel::context>() ==
Queue->get_context());
Kernel = MSyclKernel->getHandleRef();
auto SyclProg = MSyclKernel->getProgramImpl();
Program = SyclProg->getHandleRef();
Program = MSyclKernel->getProgramRef();

// Non-cacheable kernels use mutexes from kernel_impls.
// TODO this can still result in a race condition if multiple SYCL
// kernels are created with the same native handle. To address this,
Expand Down
74 changes: 74 additions & 0 deletions sycl/test-e2e/KernelAndProgram/kernel-bundle-find-run.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// This test finds a known kernel and runs it.

#include <sycl/sycl.hpp>

using namespace sycl;

// Kernel finder
class KernelFinder {
queue &Queue;
std::vector<sycl::kernel_id> AllKernelIDs;

public:
KernelFinder(queue &Q) : Queue(Q) {
// Obtain kernel bundle
kernel_bundle Bundle =
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
std::cout << "Bundle obtained\n";
AllKernelIDs = sycl::get_kernel_ids();
std::cout << "Number of kernels = " << AllKernelIDs.size() << std::endl;
for (auto K : AllKernelIDs) {
std::cout << "Kernel obtained: " << K.get_name() << std::endl;
}
}

kernel get_kernel(const char *name) {
kernel_bundle Bundle =
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
for (auto K : AllKernelIDs) {
auto Kname = K.get_name();
if (strcmp(name, Kname) == 0) {
kernel Kernel = Bundle.get_kernel(K);
std::cout << "Found kernel\n";
return Kernel;
}
}
std::cout << "No kernel found\n";
exit(1);
}
};

void sycl_kernel(queue Queue) {
range<1> R1{1};
Queue.submit([&](handler &CGH) {
CGH.parallel_for<class KernelB>(R1, [=](id<1> WIid) {});
});
Queue.wait();
}

int test_sycl_kernel(queue Queue) {
KernelFinder KF(Queue);

kernel Kernel = KF.get_kernel("_ZTSZZ11sycl_kernelN4sycl3_V15queueEENKUlRNS0_"
"7handlerEE_clES3_E7KernelB");

range<1> R1{1};
Queue.submit([&](handler &Handler) { Handler.parallel_for(R1, Kernel); });
Queue.wait();

return 0;
}

int main() {
queue Queue;

sycl_kernel(Queue);
std::cout << "sycl_kernel done\n";
test_sycl_kernel(Queue);
std::cout << "test_sycl_kernel done\n";

return 0;
}

0 comments on commit 9f296d8

Please sign in to comment.