-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache inference #2141
base: main
Are you sure you want to change the base?
Cache inference #2141
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,25 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter | |
handler::T | ||
end | ||
|
||
const inference_cache = Dict{UInt,Vector{InferenceResult}}() | ||
const inference_lock = ReentrantLock() | ||
function get_or_create_inference_cache(world::UInt, forward_rules::Bool, reverse_rules::Bool, deferred_lower::Bool, broadcast_rewrite::Bool, handler) | ||
key = hash((world, forward_rules, reverse_rules, deferred_lower, broadcast_rewrite, handler)) | ||
|
||
# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean | ||
lock(inference_lock) | ||
try | ||
obj = get(inference_cache, key, nothing) | ||
if obj === nothing | ||
obj = Vector{InferenceResult}() | ||
inference_cache[key] = obj | ||
end | ||
obj | ||
finally | ||
unlock(inference_lock) | ||
end | ||
end | ||
|
||
function EnzymeInterpreter( | ||
cache_or_token, | ||
mt::Union{Nothing,Core.MethodTable}, | ||
|
@@ -55,7 +74,8 @@ function EnzymeInterpreter( | |
reverse_rules::Bool, | ||
deferred_lower::Bool = true, | ||
broadcast_rewrite::Bool = true, | ||
handler = nothing | ||
handler = nothing, | ||
local_cache = get_or_create_inference_cache(world, forward_rules, reverse_rules, deferred_lower, broadcast_rewrite, handler) | ||
) | ||
@assert world <= Base.get_world_counter() | ||
|
||
|
@@ -70,7 +90,7 @@ function EnzymeInterpreter( | |
mt, | ||
|
||
# Initially empty cache | ||
Vector{InferenceResult}(), | ||
local_cache, | ||
|
||
# world age counter | ||
world, | ||
|
@@ -93,8 +113,9 @@ EnzymeInterpreter( | |
mode::API.CDerivativeMode, | ||
deferred_lower::Bool = true, | ||
broadcast_rewrite::Bool = true, | ||
handler = nothing | ||
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) | ||
handler = nothing, | ||
local_cache = get_or_create_inference_cache(world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) | ||
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler, local_cache) | ||
|
||
Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params | ||
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params | ||
|
@@ -202,15 +223,7 @@ function Core.Compiler.abstract_call_gf_by_type( | |
sv::AbsIntState, | ||
max_methods::Int, | ||
) | ||
ret = @invoke Core.Compiler.abstract_call_gf_by_type( | ||
interp::AbstractInterpreter, | ||
f::Any, | ||
arginfo::ArgInfo, | ||
si::StmtInfo, | ||
atype::Any, | ||
sv::AbsIntState, | ||
max_methods::Int, | ||
) | ||
ret = Core.invoke(Core.Compiler.abstract_call_gf_by_type, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, Any, AbsIntState, Int}, interp, f, arginfo, si, atype, sv, max_methods) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We were spending a huge amount of runtime on that function (which appears to compile as jl_invoke). Tried playing around with to fix to no avail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this change doesn't do anything. |
||
callinfo = ret.info | ||
method_table = Core.Compiler.method_table(interp) | ||
specTypes = simplify_kw(atype) | ||
|
@@ -887,14 +900,8 @@ function abstract_call_known( | |
if interp.handler != nothing | ||
return interp.handler(interp, f, arginfo, si, sv, max_methods) | ||
end | ||
return Base.@invoke abstract_call_known( | ||
interp::AbstractInterpreter, | ||
f::Any, | ||
arginfo::ArgInfo, | ||
si::StmtInfo, | ||
sv::AbsIntState, | ||
max_methods::Int, | ||
) | ||
return Core.invoke(abstract_call_known, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, AbsIntState, Int}, interp, f, arginfo, si, sv, max_methods) | ||
#return @inline Core.invoke(abstract_call_known, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, AbsIntState, Int}, interp, f, arginfo, si, sv, max_methods) | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the local_cache ought to be ephermal and can't be shared across multiple abstract interpreter instances.
cc: @aviatesk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but in this case we actually are effectively caching the absint itself (since all data about it is in the key)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All caches that ought to survive the absint go into the global cache?