Skip to content

Commit

Permalink
Interp: optionally disable inactive noinline (#2247)
Browse files Browse the repository at this point in the history
* Interp: optionally disable inactive noinline

* Update interpreter.jl

* fix

* fix
wsmoses authored Jan 5, 2025
1 parent 230f171 commit 1389de1
Showing 3 changed files with 31 additions and 28 deletions.
11 changes: 7 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1313,14 +1313,14 @@ if VERSION >= v"1.11.0-DEV.1552"
param_type::Type
last_fwd_rule_world::Union{Nothing, Tuple}
last_rev_rule_world::Union{Nothing, Tuple}
last_ina_rule_world::Tuple
last_ina_rule_world::Union{Nothing, Tuple}
end

@inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool) =
@inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool, inactive_rule::Bool) =
EnzymeCacheToken(target_type, always_inline, method_table, param_type,
is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
(Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,)
inactive_rule ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,) : nothing
)

GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
@@ -1331,7 +1331,8 @@ if VERSION >= v"1.11.0-DEV.1552"
typeof(job.config.params),
job.world,
job.config.params.mode == API.DEM_ForwardMode,
job.config.params.mode != API.DEM_ForwardMode
job.config.params.mode != API.DEM_ForwardMode,
true
)

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
@@ -1340,6 +1341,7 @@ if VERSION >= v"1.11.0-DEV.1552"
GPUCompiler.method_table(job),
job.world,
job.config.params.mode,
true
)
else

@@ -1365,6 +1367,7 @@ else
GPUCompiler.method_table(job),
job.world,
job.config.params.mode,
true
)
end

40 changes: 19 additions & 21 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
@@ -129,6 +129,7 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter

forward_rules::Bool
reverse_rules::Bool
inactive_rules::Bool
broadcast_rewrite::Bool
handler::T
end
@@ -166,6 +167,7 @@ function EnzymeInterpreter(
world::UInt,
forward_rules::Bool,
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
handler = nothing
)
@@ -197,10 +199,12 @@ function EnzymeInterpreter(
end
end

inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)
if !rule_sigs_equal(inarules, LastInaWorld[])
LastInaWorld[] = inarules
invalid = true
if inactive_rules
inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)
if !rule_sigs_equal(inarules, LastInaWorld[])
LastInaWorld[] = inarules
invalid = true
end
end

if invalid
@@ -221,9 +225,10 @@ function EnzymeInterpreter(
# parameters for inference and optimization
parms,
OptimizationParams(),
forward_rules,
reverse_rules,
broadcast_rewrite,
forward_rules::Bool,
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
handler
)
end
@@ -233,9 +238,10 @@ EnzymeInterpreter(
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, handler)

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
@@ -364,20 +370,12 @@ function Core.Compiler.abstract_call_gf_by_type(
callinfo = AlwaysInlineCallInfo(callinfo, atype)
else
method_table = Core.Compiler.method_table(interp)
if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
if interp.inactive_rules && EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end
elseif interp.forward_rules && EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
elseif interp.reverse_rules && EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end

8 changes: 5 additions & 3 deletions src/typeutils/inference.jl
Original file line number Diff line number Diff line change
@@ -27,13 +27,14 @@ function primal_interp_world(
EnzymeCompilerParams,
world,
false,
true,
true
)
else
Enzyme.Compiler.GLOBAL_REV_CACHE
end

Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode)
Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
end

function primal_interp_world(
@@ -50,13 +51,14 @@ function primal_interp_world(
EnzymeCompilerParams,
world,
true,
false
false,
true
)
else
Enzyme.Compiler.GLOBAL_FWD_CACHE
end

Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode)
Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true)
end

@inline primal_interp_world(

0 comments on commit 1389de1

Please sign in to comment.