Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add differential use handler #1377

Merged
merged 3 commits into from
Apr 6, 2024
Merged

Add differential use handler #1377

merged 3 commits into from
Apr 6, 2024

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Apr 1, 2024

No description provided.

@wsmoses wsmoses requested a review from vchuravy April 1, 2024 02:13
@wsmoses wsmoses force-pushed the diffusehandler branch 2 times, most recently from 4a343b2 to e30f388 Compare April 1, 2024 03:18
Comment on lines 230 to 231
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle)
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle)
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle)
function EnzymeRegisterCallHandler(name, fwdhandle, revhandle)
return ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid,
(Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle,
revhandle)
end
function EnzymeRegisterFwdCallHandler(name, fwdhandle)
return ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid,
(Cstring, CustomForwardPass), name, fwdhandle)
end

@@ -568,7 +568,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
end

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}
C = EnzymeRules.Config{Bool(needsPrimal),Bool(needsShadowJL),Int(width),overwritten}

@@ -949,3 +949,12 @@ function enzyme_custom_rev(B, orig, gutils, tape)
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape)
enzyme_custom_common_rev(false, B, orig, gutils, C_NULL, C_NULL, tape) #=tape=#


function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode)
# use default
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) &&
!has_aug_fwd_rule(orig, gutils)

end
# don't use default and always require the arg
return (true, false)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

@@ -1155,7 +1155,18 @@ macro fwdfunc(f)
))
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines 1160 to 1165
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, shadowR)::Tuple{Bool, Bool}
unsafe_store(useDefault, UInt8(res[2]))
res[1]
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, shadowR)::Tuple{Bool, Bool}
unsafe_store(useDefault, UInt8(res[2]))
res[1]
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils),
LLVM.Value(val), shadow != 0,
shadowR)::Tuple{Bool,Bool}
unsafe_store(useDefault, UInt8(res[2]))
res[1]
end, UInt8,
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef,
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})))

@wsmoses wsmoses force-pushed the diffusehandler branch 2 times, most recently from f42863e to 1c93017 Compare April 1, 2024 03:29
Comment on lines 1160 to 1165
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
res[1]
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
res[1]
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils),
LLVM.Value(val), shadow != 0, mode)::Tuple{Bool,Bool}
unsafe_store!(useDefault, UInt8(res[2]))
res[1]
end, UInt8,
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef,
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})))

Comment on lines +1160 to +1165
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
UInt8(res[1])
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
UInt8(res[1])
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
return :(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils),
LLVM.Value(val), shadow != 0, mode)::Tuple{Bool,Bool}
unsafe_store!(useDefault, UInt8(res[2]))
UInt8(res[1])
end, UInt8,
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef,
LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})))

Comment on lines 25 to +27
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable))
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable))
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1)
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable,
length(uncacheable)) != 1

@@ -1,5 +1,5 @@

function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool)
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi,
@nospecialize(RT), reverse::Bool, isKWCall::Bool)

@@ -1,5 +1,5 @@

function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool)
ops = collect(operands(orig))
called = ops[end]
ops = ops[1:end-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ops = ops[1:end-1]
ops = ops[1:(end - 1)]

@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
return args, activity, (overwritten...,), actives, kwtup
end

function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi,
@nospecialize(RealRt))

@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
return args, activity, (overwritten...,), actives, kwtup
end

function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
width = get_width(gutils)
mode = get_mode(gutils)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

Comment on lines 223 to 225
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 0
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable,
length(uncacheable)) == 0
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP,
mode)

test/rrules.jl Outdated
@@ -31,13 +31,13 @@ function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape,
end
end

function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated)
function augmented_primal(::ConfigWidth{1}, func::Const{typeof(f_ip)}, ::Type{<:Const},
x::Duplicated)

test/rrules.jl Outdated
v = x.val[1]
x.val[1] *= v
return AugmentedReturn(nothing, nothing, v)
end

function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape, x::Duplicated)
function reverse(::ConfigWidth{1}, ::Const{typeof(f_ip)}, ::Type{<:Const}, tape,
x::Duplicated)

Comment on lines 223 to 225
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1)
activep = if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable,
length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP,
mode)

Comment on lines +223 to +225
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig))) - 1)
activep = if mode == API.DEM_ForwardMode ||
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable,
length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP,
mode)

@wsmoses wsmoses merged commit 981be77 into customconst Apr 6, 2024
6 of 47 checks passed
@wsmoses wsmoses deleted the diffusehandler branch April 6, 2024 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant