diff --git a/lib/JitEngineDeviceCUDA.cpp b/lib/JitEngineDeviceCUDA.cpp index 8fa335d6..c69dab99 100644 --- a/lib/JitEngineDeviceCUDA.cpp +++ b/lib/JitEngineDeviceCUDA.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/Metadata.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include @@ -118,10 +119,26 @@ void JitEngineDeviceCUDA::setLaunchBoundsForKernel(Module &M, Function &F, // properties. // TODO: set min GridSize. int MaxThreads = std::min(1024, BlockSize); - Metadata *MDVals[] = {ConstantAsMetadata::get(&F), - MDString::get(M.getContext(), "maxntidx"), - ConstantAsMetadata::get(ConstantInt::get( - Type::getInt32Ty(M.getContext()), MaxThreads))}; + auto *FuncMetadata = ConstantAsMetadata::get(&F); + auto *MaxntidxMetadata = MDString::get(M.getContext(), "maxntidx"); + auto *MaxThreadsMetadata = ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(M.getContext()), MaxThreads)); + + // Replace if the metadata exists. + for (auto *MetadataNode : NvvmAnnotations->operands()) { + // Expecting 3 operands ptr, desc, i32 value. + assert(MetadataNode->getNumOperands() == 3); + + auto *PtrMetadata = MetadataNode->getOperand(0).get(); + auto *DescMetadata = MetadataNode->getOperand(1).get(); + if (PtrMetadata == FuncMetadata && MaxntidxMetadata == DescMetadata) { + MetadataNode->replaceOperandWith(2, MaxThreadsMetadata); + return; + } + } + + // Otherwise create the metadata and insert. + Metadata *MDVals[] = {FuncMetadata, MaxntidxMetadata, MaxThreadsMetadata}; NvvmAnnotations->addOperand(MDNode::get(M.getContext(), MDVals)); } diff --git a/lib/JitEngineDeviceHIP.cpp b/lib/JitEngineDeviceHIP.cpp index 9564a7ce..c618eddb 100644 --- a/lib/JitEngineDeviceHIP.cpp +++ b/lib/JitEngineDeviceHIP.cpp @@ -167,6 +167,7 @@ void JitEngineDeviceHIP::setLaunchBoundsForKernel(Module &M, Function &F, // TODO: find maximum (hardcoded 1024) from device info. // TODO: Setting as 1, BlockSize to replicate launch bounds settings // Does setting it as BlockSize, BlockSize help? + // Setting the attribute override any previous setting. F.addFnAttr("amdgpu-flat-work-group-size", "1," + std::to_string(std::min(1024, BlockSize))); // TODO: find warp size (hardcoded 64) from device info. diff --git a/tests/gpu/CMakeLists.txt b/tests/gpu/CMakeLists.txt index ae5ba84a..5d830c8b 100644 --- a/tests/gpu/CMakeLists.txt +++ b/tests/gpu/CMakeLists.txt @@ -174,6 +174,7 @@ CREATE_GPU_TEST(types types.cpp) CREATE_GPU_TEST(kernel_unused_gvar kernel_unused_gvar.cpp kernel_unused_gvar_def.cpp) CREATE_GPU_TEST(kernel_repeat kernel_repeat.cpp) CREATE_GPU_TEST(kernel_launch_exception kernel_launch_exception.cpp) +CREATE_GPU_TEST(kernel_preset_bounds kernel_preset_bounds.cpp) CREATE_GPU_TEST_RDC(kernel kernel.cpp) CREATE_GPU_TEST_RDC(kernel_cache kernel_cache.cpp) @@ -194,6 +195,7 @@ CREATE_GPU_TEST_RDC(types types.cpp) CREATE_GPU_TEST_RDC(kernel_calls_func kernel_calls_func.cpp device_func.cpp) CREATE_GPU_TEST_RDC(kernel_repeat kernel_repeat.cpp) CREATE_GPU_TEST_RDC(kernel_launch_exception kernel_launch_exception.cpp) +CREATE_GPU_TEST_RDC(kernel_preset_bounds kernel_preset_bounds.cpp) CREATE_GPU_LIBRARY(device_func_lib device_func.cpp) CREATE_GPU_TEST_RDC_LIBS(kernel_calls_func_lib device_func_lib kernel_calls_func_lib.cpp) diff --git a/tests/gpu/kernel_preset_bounds.cpp b/tests/gpu/kernel_preset_bounds.cpp new file mode 100644 index 00000000..afdeca5d --- /dev/null +++ b/tests/gpu/kernel_preset_bounds.cpp @@ -0,0 +1,26 @@ +// clang-format off +// RUN: ./kernel_preset_bounds.%ext | FileCheck %s --check-prefixes=CHECK,CHECK-FIRST +// Second run uses the object cache. +// RUN: ./kernel_preset_bounds.%ext | FileCheck %s --check-prefixes=CHECK,CHECK-SECOND +// clang-format on +#include +#include + +#include "gpu_common.h" + +__global__ __attribute__((annotate("jit"))) +__launch_bounds__(128, 4) void kernel() { + printf("Kernel\n"); +} + +int main() { + kernel<<<1, 1>>>(); + gpuErrCheck(gpuDeviceSynchronize()); + return 0; +} + +// CHECK: Kernel +// CHECK: JitCache hits 0 total 1 +// CHECK: HashValue {{[0-9]+}} NumExecs 1 NumHits 0 +// CHECK-FIRST: JitStorageCache hits 0 total 1 +// CHECK-SECOND: JitStorageCache hits 1 total 1