-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add differential use handler #1377
Conversation
4a343b2
to
e30f388
Compare
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle) | ||
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle) | |
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle) | |
function EnzymeRegisterCallHandler(name, fwdhandle, revhandle) | |
return ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, | |
(Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, | |
revhandle) | |
end | |
function EnzymeRegisterFwdCallHandler(name, fwdhandle) | |
return ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, | |
(Cstring, CustomForwardPass), name, fwdhandle) | |
end |
@@ -568,7 +568,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, | |||
end | |||
|
|||
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} | |
C = EnzymeRules.Config{Bool(needsPrimal),Bool(needsShadowJL),Int(width),overwritten} |
@@ -949,3 +949,12 @@ function enzyme_custom_rev(B, orig, gutils, tape) | |||
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) | |
enzyme_custom_common_rev(false, B, orig, gutils, C_NULL, C_NULL, tape) #=tape=# |
|
||
function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) | ||
# use default | ||
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) | |
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && | |
!has_aug_fwd_rule(orig, gutils) |
end | ||
# don't use default and always require the arg | ||
return (true, false) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
@@ -1155,7 +1155,18 @@ macro fwdfunc(f) | |||
)) | |||
end | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
src/rules/llvmrules.jl
Outdated
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | ||
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, shadowR)::Tuple{Bool, Bool} | ||
unsafe_store(useDefault, UInt8(res[2])) | ||
res[1] | ||
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, shadowR)::Tuple{Bool, Bool} | |
unsafe_store(useDefault, UInt8(res[2])) | |
res[1] | |
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | |
)) | |
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), | |
LLVM.Value(val), shadow != 0, | |
shadowR)::Tuple{Bool,Bool} | |
unsafe_store(useDefault, UInt8(res[2])) | |
res[1] | |
end, UInt8, | |
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, | |
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}))) |
f42863e
to
1c93017
Compare
src/rules/llvmrules.jl
Outdated
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | ||
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} | ||
unsafe_store!(useDefault, UInt8(res[2])) | ||
res[1] | ||
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} | |
unsafe_store!(useDefault, UInt8(res[2])) | |
res[1] | |
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | |
)) | |
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), | |
LLVM.Value(val), shadow != 0, mode)::Tuple{Bool,Bool} | |
unsafe_store!(useDefault, UInt8(res[2])) | |
res[1] | |
end, UInt8, | |
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, | |
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}))) |
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | ||
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} | ||
unsafe_store!(useDefault, UInt8(res[2])) | ||
UInt8(res[1]) | ||
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} | |
unsafe_store!(useDefault, UInt8(res[2])) | |
UInt8(res[1]) | |
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | |
)) | |
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | |
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), | |
LLVM.Value(val), shadow != 0, mode)::Tuple{Bool,Bool} | |
unsafe_store!(useDefault, UInt8(res[2])) | |
UInt8(res[1]) | |
end, UInt8, | |
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, | |
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}))) |
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) | ||
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) | ||
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) | |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | |
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) | |
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 | |
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) | |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) | |
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, | |
length(uncacheable)) != 1 |
@@ -1,5 +1,5 @@ | |||
|
|||
function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | |||
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) | |
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, | |
@nospecialize(RT), reverse::Bool, isKWCall::Bool) |
@@ -1,5 +1,5 @@ | |||
|
|||
function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | |||
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) | |||
ops = collect(operands(orig)) | |||
called = ops[end] | |||
ops = ops[1:end-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
ops = ops[1:end-1] | |
ops = ops[1:(end - 1)] |
@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | |||
return args, activity, (overwritten...,), actives, kwtup | |||
end | |||
|
|||
function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) | |||
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) | |
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, | |
@nospecialize(RealRt)) |
@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | |||
return args, activity, (overwritten...,), actives, kwtup | |||
end | |||
|
|||
function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) | |||
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) | |||
width = get_width(gutils) | |||
mode = get_mode(gutils) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
src/rules/customrules.jl
Outdated
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0 | ||
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | |
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) | |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) | |
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, | |
length(uncacheable)) == 0 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, | |
mode) |
test/rrules.jl
Outdated
@@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, | |||
end | |||
end | |||
|
|||
function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) | |||
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) | |
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, | |
x::Duplicated) |
test/rrules.jl
Outdated
v = x.val[1] | ||
x.val[1] *= v | ||
return AugmentedReturn(nothing, nothing, v) | ||
end | ||
|
||
function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) | ||
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated) | |
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, | |
x::Duplicated) |
src/rules/customrules.jl
Outdated
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 | ||
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | |
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) | |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) | |
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, | |
length(uncacheable)) == 1 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, | |
mode) |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 | ||
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | |
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) | |
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1) | |
activep = if mode == API.DEM_ForwardMode || | |
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, | |
length(uncacheable)) == 1 | |
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, | |
mode) |
No description provided.