Skip to content

Commit

Permalink
MixedDuplicated for custom rules
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 11, 2024
1 parent bd60907 commit 985986d
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 17 deletions.
26 changes: 26 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ end
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N


"""
MixedDuplicated(x, ∂f_∂x)
Like [`Duplicated`](@ref), except x may contain both active [immutable] and duplicated [mutable]
data which is differentiable. Only used within custom rules.
"""
struct MixedDuplicated{T} <: Annotation{T}
val::T
dval::Base.RefValue{T}
@inline MixedDuplicated(x::T1, dx::Base.RefValue{T1}, check::Bool=true) where {T1} = new{T1}(x, dx)
end

"""
BatchMixedDuplicated(x, ∂f_∂xs)
Like [`MixedDuplicated`](@ref), except contains several shadows to compute derivatives
for all at once. Only used within custom rules.
"""
struct BatchMixedDuplicated{T,N} <: Annotation{T}
val::T
dval::NTuple{N,Base.RefValue{T}}
@inline BatchMixedDuplicated(x::T1, dx::NTuple{N,Base.RefValue{T1}}, check::Bool=true) where {T1, N} = new{T1, N}(x, dx)
end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
abstract type ABI
Expand Down
51 changes: 34 additions & 17 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,11 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,

push!(activity, Ty)

elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState)
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world) == ActiveState)
Ty = Active{arg.typ}
llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
if B !== nothing
if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState
emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")
end
if B !== nothings
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
Expand Down Expand Up @@ -157,25 +154,41 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
ival = lookup_value(gutils, ival, B)
end
end
shadowty = arg.typ
if width == 1
if activep == API.DFT_DUP_ARG
Ty = Duplicated{arg.typ}

if active_reg_inner(arg.typ, (), world) == MixedState
# TODO mixedupnoneed
shadowty = Base.RefValue{shadowty}
Ty = MixedDuplicated{arg.typ}
else
@assert activep == API.DFT_DUP_NONEED
Ty = DuplicatedNoNeed{arg.typ}
if activep == API.DFT_DUP_ARG
Ty = Duplicated{arg.typ}
else
@assert activep == API.DFT_DUP_NONEED
Ty = DuplicatedNoNeed{arg.typ}
end
end
else
if activep == API.DFT_DUP_ARG
Ty = BatchDuplicated{arg.typ, Int(width)}
if active_reg_inner(arg.typ, (), world) == MixedState
# TODO batchmixedupnoneed
shadowty = Base.RefValue{shadowty}
Ty = BatchMixedDuplicated{arg.typ, Int(width)}
else
@assert activep == API.DFT_DUP_NONEED
Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)}
if activep == API.DFT_DUP_ARG
Ty = BatchDuplicated{arg.typ, Int(width)}
else
@assert activep == API.DFT_DUP_NONEED
Ty = BatchDuplicatedNoNeed{arg.typ, Int(width)}
end
end
end

llty = convert(LLVMType, Ty)
arty = convert(LLVMType, arg.typ; allow_boxed=true)
iarty = convert(LLVMType, shadowty; allow_boxed=true)
sarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, arty))
siarty = LLVM.LLVMType(API.EnzymeGetShadowType(width, iarty))
if B !== nothing
al0 = al = emit_allocobj!(B, Ty)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
Expand All @@ -184,17 +197,21 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
ptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 0)])
if value_type(val) != eltype(value_type(ptr))
val = load!(B, arty, val)
end
store!(B, val, ptr)

iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)])

if value_type(ival) != eltype(value_type(iptr))
ptr_val = ival
ival = UndefValue(sarty)
ival = UndefValue(siarty)
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
ld = load!(B, arty, ev)
ld = load!(B, iarty, ev)
ival = (width == 1 ) ? ld : insert_value!(B, ival, ld, idx-1)
end
end
store!(B, val, ptr)

iptr = inbounds_gep!(B, llty, al, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), 1)])
store!(B, ival, iptr)

if any_jltypes(llty)
Expand Down
44 changes: 44 additions & 0 deletions test/mixedrrule.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module ReverseMixedRules

using Enzyme
using Enzyme: EnzymeRules
using Test

function mixfnc(tup)
return tup[1] * tup[2][1]
end

function mixouter(x, y)
res = mixfnc((x, y))
fill!(y, 0.0)
return res
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)},
::Type{<:Active}, tup)
@show tup
pval = func.val(tup.val)
vec = copy(tup.val[2])
primal = if EnzymeRules.needs_primal(config)
pval
else
nothing
end
return AugmentedReturn(primal, nothing, vec)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure},
dret::Active, tape, tup)
dargs = 7 * tup[1].val * dret.val + tape[1] * 1000
return (dres,)
end

@testset "Mixed activity rule" begin
x = [3.14]
dx = [0.0]
res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1]
@test res 7 * 2.7 + 3.14 * 1000
@test cl.v[1] 0.0
end

end # ReverseMixedRules
1 change: 1 addition & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,5 @@ end
@test cl.v[1] 0.0
end

include("mixedrrule.jl")
end # ReverseRules

0 comments on commit 985986d

Please sign in to comment.