Skip to content

Commit

Permalink
!defer_within_autodiff -> within_autodiff_rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Jan 8, 2025
1 parent 7524d26 commit 3e37885
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inactive_rules::Bool
broadcast_rewrite::Bool

# When true, leave the check for within_autodiff to the handler.
defer_within_autodiff::Bool
# When false, leave the check for within_autodiff to the handler.
within_autodiff_rewrite::Bool

handler::T
end
Expand Down Expand Up @@ -173,7 +173,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
within_autodiff_rewrite::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()
Expand Down Expand Up @@ -234,7 +234,7 @@ function EnzymeInterpreter(
reverse_rules::Bool,
inactive_rules::Bool,
broadcast_rewrite::Bool,
defer_within_autodiff::Bool,
within_autodiff_rewrite::Bool,
handler
)
end
Expand All @@ -246,9 +246,9 @@ EnzymeInterpreter(
mode::API.CDerivativeMode,
inactive_rules::Bool,
broadcast_rewrite::Bool = true,
defer_within_autodiff::Bool = false,
within_autodiff_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, defer_within_autodiff, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, inactive_rules, broadcast_rewrite, within_autodiff_rewrite, handler)

function EnzymeInterpreter(interp::EnzymeInterpreter;
cache_or_token = (@static if HAS_INTEGRATED_CACHE
Expand All @@ -265,7 +265,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
reverse_rules = interp.reverse_rules,
inactive_rules = interp.inactive_rules,
broadcast_rewrite = interp.broadcast_rewrite,
defer_within_autodiff = interp.defer_within_autodiff,
within_autodiff_rewrite = interp.within_autodiff_rewrite,
handler = interp.handler)
return EnzymeInterpreter(
cache_or_token,
Expand All @@ -278,7 +278,7 @@ function EnzymeInterpreter(interp::EnzymeInterpreter;
reverse_rules,
inactive_rules,
broadcast_rewrite,
defer_within_autodiff,
within_autodiff_rewrite,
handler
)
end
Expand Down Expand Up @@ -949,7 +949,7 @@ function abstract_call_known(

(; fargs, argtypes) = arginfo

if !(interp.defer_within_autodiff) && f === Enzyme.within_autodiff
if interp.within_autodiff_rewrite && f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
Expand Down

0 comments on commit 3e37885

Please sign in to comment.