Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 1, 2024
1 parent 239a88c commit fc25b4f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyze

EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val)

EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)
EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)

EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign)

Expand Down
4 changes: 3 additions & 1 deletion src/gradientutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.Type)
end
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable))
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1
uncacheable .= 1
end
return uncacheable
end

Expand Down
22 changes: 18 additions & 4 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool)
ops = collect(operands(orig))
called = ops[end]
ops = ops[1:end-1]
Expand Down Expand Up @@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
return args, activity, (overwritten...,), actives, kwtup
end

function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
width = get_width(gutils)
mode = get_mode(gutils)

Expand All @@ -216,7 +216,21 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)

activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
# Conditionally use the get return. This is done because EnzymeGradientUtilsGetReturnDiffeType
# calls differential use analysis to determine needsprimal/shadow. However, since now this function
# is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use
# the version without differential use if actual unreachable results are not available anyways.
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
else
actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false)
needsPrimalP[] = 1
if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED
needsShadowP[] = 1
end
actv
end
needsPrimal = needsPrimalP[] != 0
origNeedsPrimal = needsPrimal
_, sret, _ = get_return_info(RealRt)
Expand Down Expand Up @@ -479,7 +493,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)
return false
end

@inline function aug_fwd_mi(orig, gutils)
@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils)
width = get_width(gutils)

# 1) extract out the MI from attributes
Expand Down
4 changes: 2 additions & 2 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape,
end
end

function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
v = x.val[1]
x.val[1] *= v
return AugmentedReturn(nothing, nothing, v)
end

function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
x.dval[1] = 100 + x.dval[1] * tape
return (nothing,)
end
Expand Down

0 comments on commit fc25b4f

Please sign in to comment.