diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index e8c573a176..989a733c01 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, return args, activity, (overwritten...,), actives, kwtup end -function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) +function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B) width = get_width(gutils) mode = get_mode(gutils) @@ -246,10 +246,14 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, activep = API.DFT_DUP_NONEED end + if activep == API.DFT_CONSTANT RT = Const{RealRt} - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(RealRt, world) ) + elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState) + if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing + emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + end RT = Active{RealRt} elseif activep == API.DFT_DUP_ARG @@ -298,7 +302,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) # 2) Create activity, and annotate function spec args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -511,7 +515,7 @@ end # 2) Create activity, and annotate function spec args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) - RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B) needsShadowJL = if RT <: Active false diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index e5ce78aa02..4aaa1c813d 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -311,6 +311,44 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) + ending = if Width == 1 + quote + if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue) + shadow_return = Ref(initShadow) + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, shadow_return, tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + end + end + else + expr = :() + shads = Expr[] + for i in 1:Width + if i == 1 + expr = quote !(initShadow[$i] isa Base.RefValue) end + else + expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end + end + push!(shads, quote + Ref(initShadow[$i]) + end) + end + quote + if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr) + shadow_return = ($(shads...),) + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, shadow_return..., tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + end + end + end + return quote $(active_refs...) args = ($(wrapped...),) @@ -384,13 +422,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, initShadow, tape)) - else - return ReturnType((origRet, initShadow..., tape)) - end + $ending end end @@ -411,6 +443,31 @@ end return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end +function nonzero_active_data(x::T) where T<: AbstractFloat + return x != zero(T) +end + +nonzero_active_data(::T) where T<: Base.RefValue = false +nonzero_active_data(::T) where T<: Array = false +nonzero_active_data(::T) where T<: Ptr = false + +function nonzero_active_data(x::T) where T + if guaranteed_const(T) + return false + end + if ismutable(x) + return false + end + + for f in fieldnames(T) + xi = getfield(x, f) + if nonzero_active_data(xi) + return true + end + end + return false +end + function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] for i in 1:N @@ -462,6 +519,10 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act false end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + if any_mixed ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) @@ -479,16 +540,20 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act Const{typeof(runtime_mixed_call)}, annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + if !(annotation0M <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + end + end + if annotation0M <: Active adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) else adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) end nothing else - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} @@ -502,7 +567,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - tup = if tape.shadow_return !== nothing + if tape.shadow_return !== nothing + if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) + end + end + tup = if annotation0 <: Active adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1]