Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interpreter caches #1698

Merged
merged 7 commits into from
Aug 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3361,20 +3361,38 @@ struct EnzymeCacheToken
always_inline
method_table::Core.MethodTable
param_type::Type
mode::API.CDerivativeMode
is_fwd::API.CDerivativeMode
end

GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
EnzymeCacheToken(
typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job),
typeof(job.config.params), job.config.params.mode,
typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode,
)

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache_token(job), GPUCompiler.method_table(job), job.world, job.config.params.mode)
else

# the codeinstance cache to use -- should only be used for the constructor
# Note that the only way the interpreter modifies codegen is either not inlining a fwd mode
# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used.
const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache()
const GLOBAL_REV_CACHE = GPUCompiler.CodeCache()
function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams})
return if job.config.params.mode == API.DEM_ForwardMode
GLOBAL_FWD_CACHE
else
GLOBAL_REV_CACHE
end
end

@static if VERSION < v"1.8"
GPUCompiler.ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = enzyme_ci_cache(job)
end

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode)
Interpreter.EnzymeInterpreter(enzyme_ci_cache(job), GPUCompiler.method_table(job), job.world, job.config.params.mode)
end

include("compiler/passes.jl")
Expand Down Expand Up @@ -6952,7 +6970,7 @@ end
run_enzyme = false
Const
else
A
A
end

if run_enzyme && !(A2 <: Const) && guaranteed_const_nongen(rrt, World)
Expand Down
Loading