diff --git a/src/api.jl b/src/api.jl index 402b67b712e..0a36f0e910a 100644 --- a/src/api.jl +++ b/src/api.jl @@ -260,7 +260,7 @@ EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyze EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) -EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) diff --git a/src/gradientutils.jl b/src/gradientutils.jl index 67618e3a45f..655f794c8cd 100644 --- a/src/gradientutils.jl +++ b/src/gradientutils.jl @@ -24,7 +24,9 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.Type) end function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) - API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) + if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 + uncacheable .= 1 + end return uncacheable end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 7c9894b2892..77c5b170eb4 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1,5 +1,5 @@ -function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) +function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) ops = collect(operands(orig)) called = ops[end] ops = ops[1:end-1] @@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) return args, activity, (overwritten...,), actives, kwtup end -function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) +function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) width = get_width(gutils) mode = get_mode(gutils) @@ -216,7 +216,21 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) - activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + # Conditionally use the get return. This is done because EnzymeGradientUtilsGetReturnDiffeType + # calls differential use analysis to determine needsprimal/shadow. However, since now this function + # is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use + # the version without differential use if actual unreachable results are not available anyways. + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) + activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0 + API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) + else + actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) + needsPrimalP[] = 1 + if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED + needsShadowP[] = 1 + end + actv + end needsPrimal = needsPrimalP[] != 0 origNeedsPrimal = needsPrimal _, sret, _ = get_return_info(RealRt) @@ -479,7 +493,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) return false end -@inline function aug_fwd_mi(orig, gutils) +@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils) width = get_width(gutils) # 1) extract out the MI from attributes diff --git a/test/rrules.jl b/test/rrules.jl index 13228959241..28429831bc6 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, end end -function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) +function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) v = x.val[1] x.val[1] *= v return AugmentedReturn(nothing, nothing, v) end -function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) +function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) x.dval[1] = 100 + x.dval[1] * tape return (nothing,) end