Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache inference #2141

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
handler::T
end

const inference_cache = Dict{UInt,Vector{InferenceResult}}()
const inference_lock = ReentrantLock()
function get_or_create_inference_cache(world::UInt, forward_rules::Bool, reverse_rules::Bool, deferred_lower::Bool, broadcast_rewrite::Bool, handler)
key = hash((world, forward_rules, reverse_rules, deferred_lower, broadcast_rewrite, handler))

# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
lock(inference_lock)
try
obj = get(inference_cache, key, nothing)
if obj === nothing
obj = Vector{InferenceResult}()
inference_cache[key] = obj
end
obj
finally
unlock(inference_lock)
end
end

function EnzymeInterpreter(
cache_or_token,
mt::Union{Nothing,Core.MethodTable},
Expand All @@ -55,7 +74,8 @@ function EnzymeInterpreter(
reverse_rules::Bool,
deferred_lower::Bool = true,
broadcast_rewrite::Bool = true,
handler = nothing
handler = nothing,
local_cache = get_or_create_inference_cache(world, forward_rules, reverse_rules, deferred_lower, broadcast_rewrite, handler)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the local_cache ought to be ephermal and can't be shared across multiple abstract interpreter instances.

cc: @aviatesk

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but in this case we actually are effectively caching the absint itself (since all data about it is in the key)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All caches that ought to survive the absint go into the global cache?

)
@assert world <= Base.get_world_counter()

Expand All @@ -70,7 +90,7 @@ function EnzymeInterpreter(
mt,

# Initially empty cache
Vector{InferenceResult}(),
local_cache,

# world age counter
world,
Expand All @@ -93,8 +113,9 @@ EnzymeInterpreter(
mode::API.CDerivativeMode,
deferred_lower::Bool = true,
broadcast_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler)
handler = nothing,
local_cache = get_or_create_inference_cache(world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler, local_cache)

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Expand Down Expand Up @@ -202,15 +223,7 @@ function Core.Compiler.abstract_call_gf_by_type(
sv::AbsIntState,
max_methods::Int,
)
ret = @invoke Core.Compiler.abstract_call_gf_by_type(
interp::AbstractInterpreter,
f::Any,
arginfo::ArgInfo,
si::StmtInfo,
atype::Any,
sv::AbsIntState,
max_methods::Int,
)
ret = Core.invoke(Core.Compiler.abstract_call_gf_by_type, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, Any, AbsIntState, Int}, interp, f, arginfo, si, atype, sv, max_methods)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were spending a huge amount of runtime on that function (which appears to compile as jl_invoke). Tried playing around with to fix to no avail

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this change doesn't do anything.

callinfo = ret.info
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(atype)
Expand Down Expand Up @@ -887,14 +900,8 @@ function abstract_call_known(
if interp.handler != nothing
return interp.handler(interp, f, arginfo, si, sv, max_methods)
end
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
f::Any,
arginfo::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int,
)
return Core.invoke(abstract_call_known, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, AbsIntState, Int}, interp, f, arginfo, si, sv, max_methods)
#return @inline Core.invoke(abstract_call_known, Tuple{AbstractInterpreter, Any, ArgInfo, StmtInfo, AbsIntState, Int}, interp, f, arginfo, si, sv, max_methods)
end

end
Loading