Skip to content

Commit

Permalink
more mixed duplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 12, 2024
1 parent 985986d commit 971194f
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc

import EnzymeCore: MixedDuplicated, BatchMixedDuplicated
export MixedDuplicated, BatchMixedDuplicated

import EnzymeCore: batch_size, get_func
export batch_size, get_func

Expand Down
44 changes: 44 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,50 @@ else
end
end

function store_nonjl_types!(B, p, startval)
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
if p != nothing
push!(vals, p)
end
todo = Tuple{Tuple, LLVM.Value}[((), startval)]
while length(todo) != 0
path, cur = popfirst!(todo)
ty = value_type(cur)
if isa(ty, LLVM.PointerType)
if any_jltypes(ty)
continue
end
end
if isa(ty, LLVM.ArrayType)
if any_jltypes(ty)
for i=1:length(ty)
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
if isa(ty, LLVM.StructType)
for (i, t) in enumerate(LLVM.elements(ty))
if any_jltypes(t)
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
continue
end
end
parray = LLVM.Value[LLVM.ConstantInt(LLVM.IntType(64), 0)]
for v in path
push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v))
end
gptr = gep!(B, p, parray)
store!(B, cur, gptr)
end
return
end

function get_julia_inner_types(B, p, startvals...; added=[])
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
Expand Down
63 changes: 54 additions & 9 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,

actives = LLVM.Value[]

mixeds = Tuple{LLVM.Value, Type, LLVM.Value}[]
uncacheable = get_uncacheable(gutils, orig)
mode = get_mode(gutils)

Expand Down Expand Up @@ -126,7 +127,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
Ty = Active{arg.typ}
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
if B !== nothings
if B !== nothing
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
Expand Down Expand Up @@ -155,12 +156,14 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
end
end
shadowty = arg.typ
mixed = false
if width == 1

if active_reg_inner(arg.typ, (), world) == MixedState
# TODO mixedupnoneed
shadowty = Base.RefValue{shadowty}
Ty = MixedDuplicated{arg.typ}
mixed = true
else
if activep == API.DFT_DUP_ARG
Ty = Duplicated{arg.typ}
Expand All @@ -174,6 +177,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
# TODO batchmixedupnoneed
shadowty = Base.RefValue{shadowty}
Ty = BatchMixedDuplicated{arg.typ, Int(width)}
mixed = true
else
if activep == API.DFT_DUP_ARG
Ty = BatchDuplicated{arg.typ, Int(width)}
Expand All @@ -195,21 +199,46 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))

ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
needsload = false
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
if !mixed
ptr_val = ival
ival = UndefValue(siarty)
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
ld = load!(B, iarty, ev)
ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1)
end
end
needsload = true
end
store!(B, val, ptr)

iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)])

if value_type(ival) != eltype(value_type(iptr))
if mixed
RefTy = arg.typ
if width != 1
RefTy = NTuple{N, RefTy}
end
llrty = convert(LLVMType, RefTy)
RefTy = Base.RefValue{RefTy}
refal0 = refal = emit_allocobj!(B, RefTy)
refal = bitcast!(B, refal, LLVM.PointerType(llrty, addrspace(value_type(refal))))

@assert needsload
ptr_val = ival
ival = UndefValue(siarty)
ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llrty)))
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
ld = load!(B, iarty, ev)
ld = load!(B, llrty, ev)
ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1)
end
store!(B, ival, refal)
emit_writebarrier!(B, get_julia_inner_types(B, refal0, ival))
ival = refal0
push!(mixeds, (ptr_val, arg.typ, refal))
end

store!(B, ival, iptr)
Expand All @@ -224,7 +253,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
end

end
return args, activity, (overwritten...,), actives, kwtup
return args, activity, (overwritten...,), actives, kwtup, mixeds
end

function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B)
Expand Down Expand Up @@ -321,7 +350,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)
end

# 2) Create activity, and annotate function spec
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall)
args, activity, overwritten, actives, kwtup, _ = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)

alloctx = LLVM.IRBuilder()
Expand Down Expand Up @@ -534,7 +563,7 @@ end
isKWCall = isKWCallSignature(mi.specTypes)

# 2) Create activity, and annotate function spec
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
args, activity, overwritten, actives, kwtup, mixeds = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)

needsShadowJL = if RT <: Active
Expand Down Expand Up @@ -590,7 +619,7 @@ end
end
end
end
return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal)
return ami, augprimal_TT, (args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds)
end

@inline function has_aug_fwd_rule(orig, gutils)
Expand All @@ -616,7 +645,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

# 2) Create activity, and annotate function spec
ami, augprimal_TT, setup = aug_fwd_mi(orig, gutils, forward, B)
args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal = setup
args, activity, overwritten, actives, kwtup, RT, needsPrimal, needsShadow, origNeedsPrimal, mixeds = setup

needsShadowJL = if RT <: Active
false
Expand Down Expand Up @@ -987,6 +1016,22 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
end
idx+=1
end

# @show mixeds
for (ptr_val, argTyp, refal) in mixeds
RefTy = argTyp
if width != 1
RefTy = NTuple{N, RefTy}
end
curs = load!(B, convert(LLVMType, RefTy), refal)

for idx in 1:width
evp = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1)
store_nonjl_types!(B, evcur, evp)
end
@show curs, ptr_val, argTyp, refal
end
end

if forward
Expand Down
35 changes: 30 additions & 5 deletions test/mixedrrule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ using Enzyme
using Enzyme: EnzymeRules
using Test

Enzyme.API.printall!(true)

import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig
using .EnzymeRules

function mixfnc(tup)
return tup[1] * tup[2][1]
end
Expand All @@ -27,18 +32,38 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof
return AugmentedReturn(primal, nothing, vec)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure},
@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)},
dret::Active, tape, tup)
dargs = 7 * tup[1].val * dret.val + tape[1] * 1000
return (dres,)
prev = tup.dval[]
dRT = typeof(prev)
@show "rev", tup
@show dRT, fieldcount(dRT)
tup.dval[] = Enzyme.Compiler.splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i
Base.@_inline_meta
pv = getfield(prev, i)
if i == 1
next = 7 * tape[1] * dret.val
Enzyme.Compiler.recursive_add(pv, next, identity, guaranteed_nonactive)
else
pv
end
end)
prev[2][1] = 1000 * dret.val * prev[1]
return (nothing,)
end

@testset "Mixed activity rule" begin
x = [3.14]
dx = [0.0]
res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1]
@test res 7 * 2.7 + 3.14 * 1000
@test cl.v[1] 0.0
@test res 7 * 3.14
@test dx[1] 1000 * 2.7
@test x[1] 0.0
end

end # ReverseMixedRules

0 comments on commit 971194f

Please sign in to comment.