Skip to content

Commit

Permalink
Allow variadic args in hash function (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbeckingsale authored Dec 23, 2024
1 parent 25133d7 commit 84fa758
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
9 changes: 6 additions & 3 deletions lib/JitCache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) {

template <typename Function_t> class JitCache {
public:
template <typename... Ts>
uint64_t hash(StringRef ModuleUniqueId, StringRef FnName,
const RuntimeConstant *RC, int NumRuntimeConstants) const {
const RuntimeConstant *RC, int NumRuntimeConstants,
Ts... args) const {
ArrayRef<RuntimeConstant> Data(RC, NumRuntimeConstants);
auto HashValue = hash_combine(ExePath, ModuleUniqueId, FnName, Data);
auto HashValue =
hash_combine(ExePath, ModuleUniqueId, FnName, Data, args...);
return HashValue;
}

Expand Down Expand Up @@ -122,4 +125,4 @@ template <typename Function_t> class JitCache {

} // namespace proteus

#endif
#endif
6 changes: 4 additions & 2 deletions lib/JitEngineDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,10 @@ JitEngineDevice<ImplT>::compileAndRun(

getRuntimeConstantValues(KernelArgs, RCIndices, RCTypes, RCsVec);

uint64_t HashValue = CodeCache.hash(ModuleUniqueId, KernelName, RCsVec.data(),
NumRuntimeConstants);
uint64_t HashValue = CodeCache.hash(
ModuleUniqueId, KernelName, RCsVec.data(), NumRuntimeConstants, GridDim.x,
GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);

typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc =
CodeCache.lookup(HashValue);
if (KernelFunc)
Expand Down

0 comments on commit 84fa758

Please sign in to comment.