Skip to content

Commit

Permalink
Allow custom rule for constant arg/ret in rev mode (#1371)
Browse files Browse the repository at this point in the history
* Allow custom rule for constant arg/ret in rev mode

* cse

* Add differential use handler

* fixup

* fix

* fix

* fixup

* fixup

* fixup
  • Loading branch information
wsmoses authored May 28, 2024
1 parent cf1851b commit 1e45f26
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 76 deletions.
4 changes: 3 additions & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan

EnzymeInsertValue(B::LLVM.IRBuilder, v::LLVM.Value, v2::LLVM.Value, insts::Vector{Cuint}, name="") = LLVM.Value(ccall((:EnzymeInsertValue, libEnzyme), LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVMValueRef, LLVMValueRef, Ptr{Cuint}, Int64, Cstring), B, v, v2, insts, length(insts), name))

const CustomDiffUse = Ptr{Cvoid}
EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle)
EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove))
EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args))
EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)
Expand Down Expand Up @@ -260,7 +262,7 @@ EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyze

EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val)

EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)
EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), UInt8, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)

EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign)

Expand Down
4 changes: 3 additions & 1 deletion src/gradientutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.LLVMType)
end
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst)
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable))
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1
uncacheable .= 1
end
return uncacheable
end

Expand Down
211 changes: 137 additions & 74 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool)
ops = collect(operands(orig))
called = ops[end]
ops = ops[1:end-1]
Expand Down Expand Up @@ -46,6 +46,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
if !(isKWCall && arg.arg_i == 1)
push!(overwritten, false)
end
if B !== nothing
if Core.Compiler.isconstType(arg.typ) && !Core.Compiler.isconstType(Const{arg.typ})
llty = convert(LLVMType, Const{arg.typ})
al0 = al = emit_allocobj!(B, Const{arg.typ})
Expand All @@ -63,6 +64,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
else
@assert isghostty(Const{arg.typ}) || Core.Compiler.isconstType(Const{arg.typ})
end
end
continue
end
@assert !(isghostty(arg.typ) || Core.Compiler.isconstType(arg.typ))
Expand All @@ -74,7 +76,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
end

val = new_from_original(gutils, op)
if reverse
if reverse && B !== nothing
val = lookup_value(gutils, val, B)
end

Expand All @@ -100,51 +102,57 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
Ty = Const{arg.typ}
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
if B !== nothing
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))

ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
end
store!(B, val, ptr)
ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
end
store!(B, val, ptr)

if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
end
if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
end

push!(args, al)
push!(args, al)
end

push!(activity, Ty)

elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(arg.typ, world) )
Ty = Active{arg.typ}
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
if B !== nothing
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))

ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
@assert !overwritten[end]
val = load!(B, arty, val)
end
store!(B, val, ptr)
ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
@assert !overwritten[end]
val = load!(B, arty, val)
end
store!(B, val, ptr)

if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
end
if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
end

push!(args, al)
push!(args, al)
end

push!(activity, Ty)
push!(actives, op)
else
ival = invert_pointer(gutils, op, B)
if reverse
ival = lookup_value(gutils, ival, B)
if B !== nothing
ival = invert_pointer(gutils, op, B)
if reverse
ival = lookup_value(gutils, ival, B)
end
end
if width == 1
if activep == API.DFT_DUP_ARG
Expand All @@ -165,39 +173,41 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty))
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
if B !== nothing
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))

ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
ptr_val = ival
ival = UndefValue(sarty)
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
ld = load!(B, arty, ev)
ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1)
ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
ptr_val = ival
ival = UndefValue(sarty)
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
ld = load!(B, arty, ev)
ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1)
end
end
end
store!(B, val, ptr)
store!(B, val, ptr)

iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)])
store!(B, ival, iptr)
iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)])
store!(B, ival, iptr)

if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val, ival))
end
if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val, ival))
end

push!(args, al)
push!(args, al)
end
push!(activity, Ty)
end

end
return args, activity, (overwritten...,), actives, kwtup
end

function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
width = get_width(gutils)
mode = get_mode(gutils)

Expand All @@ -206,7 +216,23 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
needsShadowP = Ref{UInt8}(0)
needsPrimalP = Ref{UInt8}(0)

activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
# Conditionally use the get return. This is done because EnzymeGradientUtilsGetReturnDiffeType
# calls differential use analysis to determine needsprimal/shadow. However, since now this function
# is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use
# the version without differential use if actual unreachable results are not available anyways.
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1)
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode)
else
actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false)
if !isghostty(RealRt)
needsPrimalP[] = 1
if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED
needsShadowP[] = 1
end
end
actv
end
needsPrimal = needsPrimalP[] != 0
origNeedsPrimal = needsPrimal
_, sret, _ = get_return_info(RealRt)
Expand Down Expand Up @@ -349,7 +375,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)
end

if length(args) != length(parameters(llvmf))
GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, TT, sret, returnRoots
GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, string(value_type(llvmf)), orig, isKWCall, kwtup, TT, sret, returnRoots
return false
end

Expand Down Expand Up @@ -476,19 +502,9 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)
return false
end

function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef

ctx = LLVM.context(orig)

@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils, forward=false, B=nothing)
width = get_width(gutils)

shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))
if shadowR != C_NULL
unsafe_store!(shadowR,UndefValue(shadowType).ref)
end

# TODO: don't inject the code multiple times for multiple calls

# 1) extract out the MI from attributes
mi, RealRt = enzyme_custom_extract_mi(orig)
isKWCall = isKWCallSignature(mi.specTypes)
Expand All @@ -503,11 +519,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
needsShadow
end

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))

curent_bb = position(B)
fn = LLVM.parent(curent_bb)
fn = LLVM.parent(LLVM.parent(orig))
world = enzyme_extract_world(fn)

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}
Expand Down Expand Up @@ -554,13 +566,55 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
end
end
end
return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal)
end

@inline function has_aug_fwd_rule(orig, gutils)
return aug_fwd_mi(orig, gutils)[1] !== nothing
end

function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, normalR, shadowR, tape)::LLVM.API.LLVMValueRef

ctx = LLVM.context(orig)

width = get_width(gutils)

shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))
if shadowR != C_NULL
unsafe_store!(shadowR,UndefValue(shadowType).ref)
end

# TODO: don't inject the code multiple times for multiple calls

# 1) extract out the MI from attributes
mi, RealRt = enzyme_custom_extract_mi(orig)
isKWCall = isKWCallSignature(mi.specTypes)

# 2) Create activity, and annotate function spec
ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B)
args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal = setup

needsShadowJL = if RT <: Active
false
else
needsShadow
end

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))

curent_bb = position(B)
fn = LLVM.parent(curent_bb)
world = enzyme_extract_world(fn)

mode = get_mode(gutils)

@assert ami !== nothing
target = DefaultCompilerTarget()
params = PrimalCompilerParams(mode)
aug_RT = something(Core.Compiler.typeinf_type(GPUCompiler.get_interpreter(CompilerJob(ami, CompilerConfig(target, params; kernel=false), world)), ami.def, ami.specTypes, ami.sparam_vals), Any)

@assert ami !== nothing

if kwtup !== nothing && kwtup <: Duplicated
@safe_debug "Non-constant keyword argument found for " augprimal_TT
emit_error(B, orig, "Enzyme: Non-constant keyword argument found for " * string(augprimal_TT))
Expand Down Expand Up @@ -904,7 +958,7 @@ end


function enzyme_custom_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
return true
end
tape = enzyme_custom_common_rev(#=forward=#true, B, orig, gutils, normalR, shadowR, #=tape=#nothing)
Expand All @@ -916,9 +970,18 @@ end


function enzyme_custom_rev(B, orig, gutils, tape)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
return
end
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape)
return nothing
end

function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode)
# use default
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
return (false, true)
end
# don't use default and always require the arg
return (true, false)
end
11 changes: 11 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,18 @@ macro fwdfunc(f)
))
end


macro diffusefunc(f)
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
UInt8(res[1])
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
end

@noinline function register_llvm_rules()
API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse))
register_handler!(
("julia.call",),
@augfunc(jlcall_augfwd),
Expand Down

2 comments on commit 1e45f26

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107816

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.9 -m "<description of version>" 1e45f264dbd2dacd79686c891f2d8c42ead33fce
git push origin v0.12.9

Please sign in to comment.