From d30f97991738407034f43c4d07b25e03bb0e34eb Mon Sep 17 00:00:00 2001 From: koparasy Date: Tue, 3 Dec 2024 07:52:07 -0800 Subject: [PATCH] Add pass to replace BlockDim/GridDim/threadIDx accesses with constants on AMD --- lib/JitEngineDevice.hpp | 1 - lib/JitEngineDeviceHIP.cpp | 56 +++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index 21675997..a2289c7e 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -120,7 +120,6 @@ template class JitEngineDevice : protected JitEngine { } void setKernelDims(Module &M, Function &F, dim3 &GridDim, dim3 &BlockDim) { - std::cout << "I am here \n"; static_cast(*this).setKernelDims(M, F, GridDim, BlockDim); } diff --git a/lib/JitEngineDeviceHIP.cpp b/lib/JitEngineDeviceHIP.cpp index 1692cf81..5b8a824e 100644 --- a/lib/JitEngineDeviceHIP.cpp +++ b/lib/JitEngineDeviceHIP.cpp @@ -135,7 +135,61 @@ void JitEngineDeviceHIP::setLaunchBoundsForKernel(Module &M, Function &F, void JitEngineDeviceHIP::setKernelDims(Module &M, Function &F, dim3 &GridDim, dim3 &BlockDim) { - std::abort(); + auto ReplaceDim = [&](StringRef IntrinsicName, uint32_t value) { + Function *IntrinsicFunction = M.getFunction(IntrinsicName); + if (!IntrinsicFunction) + return; + + for (auto U = IntrinsicFunction->use_begin(), + UE = IntrinsicFunction->use_end(); + U != UE;) { + Use &Use = *U++; + if (auto *Call = dyn_cast(Use.getUser())) { + Value *ConstantValue = + ConstantInt::get(Type::getInt32Ty(M.getContext()), value); + Call->replaceAllUsesWith(ConstantValue); + Call->eraseFromParent(); + } + } + }; + + ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__XcvjEv", BlockDim.x); + ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__YcvjEv", BlockDim.y); + ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__ZcvjEv", BlockDim.z); + + ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__XcvjEv", BlockDim.x); + ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__YcvjEv", BlockDim.y); + ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__ZcvjEv", BlockDim.z); + + auto InsertAssume = [&](StringRef IntrinsicName, int BlockDim) { + Function *IntrinsicFunction = M.getFunction(IntrinsicName); + if (!IntrinsicFunction || IntrinsicFunction->use_empty()) + return; // No modifications made if the intrinsic is not used + + // Iterate over all uses of the intrinsic + for (auto U = IntrinsicFunction->use_begin(), + UE = IntrinsicFunction->use_end(); + U != UE;) { + Use &Use = *U++; + if (auto *Call = dyn_cast(Use.getUser())) { + // Insert the llvm.assume intrinsic + IRBuilder<> Builder(Call->getNextNode()); + Value *Bound = ConstantInt::get(Call->getType(), BlockDim); + Value *Cmp = Builder.CreateICmpULT(Call, Bound); + + Function *AssumeIntrinsic = + Intrinsic::getDeclaration(&M, Intrinsic::assume); + Builder.CreateCall(AssumeIntrinsic, Cmp); + } + } + }; + + InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__XcvjEv", + BlockDim.x); + InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__YcvjEv", + BlockDim.y); + InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__ZcvjEv", + BlockDim.z); } std::unique_ptr