Skip to content

Commit

Permalink
Add pass to replace BlockDim/GridDim/threadIDx accesses with constant…
Browse files Browse the repository at this point in the history
…s on AMD
  • Loading branch information
koparasy committed Dec 3, 2024
1 parent b98b9f4 commit d30f979
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
1 change: 0 additions & 1 deletion lib/JitEngineDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ template <typename ImplT> class JitEngineDevice : protected JitEngine {
}

void setKernelDims(Module &M, Function &F, dim3 &GridDim, dim3 &BlockDim) {
std::cout << "I am here \n";
static_cast<ImplT &>(*this).setKernelDims(M, F, GridDim, BlockDim);
}

Expand Down
56 changes: 55 additions & 1 deletion lib/JitEngineDeviceHIP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,61 @@ void JitEngineDeviceHIP::setLaunchBoundsForKernel(Module &M, Function &F,

void JitEngineDeviceHIP::setKernelDims(Module &M, Function &F, dim3 &GridDim,
dim3 &BlockDim) {
std::abort();
auto ReplaceDim = [&](StringRef IntrinsicName, uint32_t value) {
Function *IntrinsicFunction = M.getFunction(IntrinsicName);
if (!IntrinsicFunction)
return;

for (auto U = IntrinsicFunction->use_begin(),
UE = IntrinsicFunction->use_end();
U != UE;) {
Use &Use = *U++;
if (auto *Call = dyn_cast<CallInst>(Use.getUser())) {
Value *ConstantValue =
ConstantInt::get(Type::getInt32Ty(M.getContext()), value);
Call->replaceAllUsesWith(ConstantValue);
Call->eraseFromParent();
}
}
};

ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__XcvjEv", BlockDim.x);
ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__YcvjEv", BlockDim.y);
ReplaceDim("_ZNK17__HIP_CoordinatesI13__HIP_GridDimE3__ZcvjEv", BlockDim.z);

ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__XcvjEv", BlockDim.x);
ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__YcvjEv", BlockDim.y);
ReplaceDim("_ZNK17__HIP_CoordinatesI14__HIP_BlockDimE3__ZcvjEv", BlockDim.z);

auto InsertAssume = [&](StringRef IntrinsicName, int BlockDim) {
Function *IntrinsicFunction = M.getFunction(IntrinsicName);
if (!IntrinsicFunction || IntrinsicFunction->use_empty())
return; // No modifications made if the intrinsic is not used

// Iterate over all uses of the intrinsic
for (auto U = IntrinsicFunction->use_begin(),
UE = IntrinsicFunction->use_end();
U != UE;) {
Use &Use = *U++;
if (auto *Call = dyn_cast<CallInst>(Use.getUser())) {
// Insert the llvm.assume intrinsic
IRBuilder<> Builder(Call->getNextNode());
Value *Bound = ConstantInt::get(Call->getType(), BlockDim);
Value *Cmp = Builder.CreateICmpULT(Call, Bound);

Function *AssumeIntrinsic =
Intrinsic::getDeclaration(&M, Intrinsic::assume);
Builder.CreateCall(AssumeIntrinsic, Cmp);
}
}
};

InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__XcvjEv",
BlockDim.x);
InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__YcvjEv",
BlockDim.y);
InsertAssume("_ZNK17__HIP_CoordinatesI15__HIP_ThreadIdxE3__ZcvjEv",
BlockDim.z);
}

std::unique_ptr<MemoryBuffer>
Expand Down

0 comments on commit d30f979

Please sign in to comment.