Skip to content

Commit

Permalink
Handle non-zero mixed return (#1529)
Browse files Browse the repository at this point in the history
* Handle non-zero mixed return

* improve mixed activity rule errors
  • Loading branch information
wsmoses authored Jun 10, 2024
1 parent b8f9beb commit df7dd87
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 15 deletions.
12 changes: 8 additions & 4 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
return args, activity, (overwritten...,), actives, kwtup
end

function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B)
width = get_width(gutils)
mode = get_mode(gutils)

Expand Down Expand Up @@ -246,10 +246,14 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi,
activep = API.DFT_DUP_NONEED
end


if activep == API.DFT_CONSTANT
RT = Const{RealRt}

elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(RealRt, world) )
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState)
if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing
emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")
end
RT = Active{RealRt}

elseif activep == API.DFT_DUP_ARG
Expand Down Expand Up @@ -298,7 +302,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)

# 2) Create activity, and annotate function spec
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))
Expand Down Expand Up @@ -511,7 +515,7 @@ end

# 2) Create activity, and annotate function spec
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)

needsShadowJL = if RT <: Active
false
Expand Down
93 changes: 82 additions & 11 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,44 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)

MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N))

ending = if Width == 1
quote
if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue)
shadow_return = Ref(initShadow)
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
return ReturnType((origRet, shadow_return, tape))
else
shadow_return = nothing
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
return ReturnType((origRet, initShadow, tape))
end
end
else
expr = :()
shads = Expr[]
for i in 1:Width
if i == 1
expr = quote !(initShadow[$i] isa Base.RefValue) end
else
expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end
end
push!(shads, quote
Ref(initShadow[$i])
end)
end
quote
if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr)
shadow_return = ($(shads...),)
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
return ReturnType((origRet, shadow_return..., tape))
else
shadow_return = nothing
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
return ReturnType((origRet, initShadow..., tape))
end
end
end

return quote
$(active_refs...)
args = ($(wrapped...),)
Expand Down Expand Up @@ -384,13 +422,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)

@assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed

shadow_return = nothing
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
if $Width == 1
return ReturnType((origRet, initShadow, tape))
else
return ReturnType((origRet, initShadow..., tape))
end
$ending
end
end

Expand All @@ -411,6 +443,31 @@ end
return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs)
end

function nonzero_active_data(x::T) where T<: AbstractFloat
return x != zero(T)
end

nonzero_active_data(::T) where T<: Base.RefValue = false
nonzero_active_data(::T) where T<: Array = false
nonzero_active_data(::T) where T<: Ptr = false

function nonzero_active_data(x::T) where T
if guaranteed_const(T)
return false
end
if ismutable(x)
return false
end

for f in fieldnames(T)
xi = getfield(x, f)
if nonzero_active_data(xi)
return true
end
end
return false
end

function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs)
outs = []
for i in 1:N
Expand Down Expand Up @@ -462,6 +519,10 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
false
end

tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

if any_mixed
ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)}
rtM = Core.Compiler.return_type(runtime_mixed_call, ttM)
Expand All @@ -479,16 +540,20 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
Const{typeof(runtime_mixed_call)},
annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

if tape.shadow_return !== nothing
if !(annotation0M <: Active) && nonzero_active_data(($shadowret,))
ET = ($(ElTypes...),)
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
end
end
if annotation0M <: Active
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)
else
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)
end
nothing
else
tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

annotation = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt, $Width}
Expand All @@ -502,7 +567,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

tup = if tape.shadow_return !== nothing
if tape.shadow_return !== nothing
if !(annotation0 <: Active) && nonzero_active_data(($shadowret,))
ET = ($(ElTypes...),)
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
end
end
tup = if annotation0 <: Active
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1]
else
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1]
Expand Down

0 comments on commit df7dd87

Please sign in to comment.