diff --git a/lib/JitCache.hpp b/lib/JitCache.hpp index 26361b38..ddb46194 100644 --- a/lib/JitCache.hpp +++ b/lib/JitCache.hpp @@ -34,10 +34,11 @@ inline hash_code hash_value(const proteus::RuntimeConstant &RC) { template class JitCache { public: + template uint64_t hash(StringRef ModuleUniqueId, StringRef FnName, - const RuntimeConstant *RC, int NumRuntimeConstants) const { + const RuntimeConstant *RC, int NumRuntimeConstants, Ts... args) const { ArrayRef Data(RC, NumRuntimeConstants); - auto HashValue = hash_combine(ExePath, ModuleUniqueId, FnName, Data); + auto HashValue = hash_combine(ExePath, ModuleUniqueId, FnName, Data, args...); return HashValue; } @@ -122,4 +123,4 @@ template class JitCache { } // namespace proteus -#endif \ No newline at end of file +#endif diff --git a/lib/JitEngineDevice.hpp b/lib/JitEngineDevice.hpp index aa253171..8ed93d67 100644 --- a/lib/JitEngineDevice.hpp +++ b/lib/JitEngineDevice.hpp @@ -439,8 +439,10 @@ JitEngineDevice::compileAndRun( getRuntimeConstantValues(KernelArgs, RCIndices, RCTypes, RCsVec); - uint64_t HashValue = CodeCache.hash(ModuleUniqueId, KernelName, RCsVec.data(), - NumRuntimeConstants); + uint64_t HashValue = CodeCache.hash(ModuleUniqueId, + KernelName, RCsVec.data(), NumRuntimeConstants, GridDim.x, GridDim.y, + GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z); + typename DeviceTraits::KernelFunction_t KernelFunc = CodeCache.lookup(HashValue); if (KernelFunc)