Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed activity for getfield #1535

Merged
merged 11 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.12.13"
version = "0.12.14"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7.5"
Enzyme_jll = "0.0.121"
Enzyme_jll = "0.0.122"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7"
ObjectFile = "0.4"
Expand Down
14 changes: 14 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ end
arg = @inbounds args[i]
if arg isa Active
return true
elseif arg isa MixedDuplicated
return true
elseif arg isa BatchMixedDuplicated
return true
else
return false
end
Expand Down Expand Up @@ -95,6 +99,10 @@ end
end

@inline same_or_one_rec(current) = current
@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} =
same_or_one_rec(same_or_one_helper(current, N), args...)
@inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} =
Expand Down Expand Up @@ -844,6 +852,12 @@ result, ∂v, ∂A
else
BatchDuplicatedNoNeed{eltype(A2), width}
end
elseif A2 <: MixedDuplicated && width != 1
if A2 isa UnionAll
BatchMixedDuplicated{T, width} where T
else
BatchMixedDuplicated{eltype(A2), width}
end
else
A2
end
Expand Down
137 changes: 102 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,13 @@ end
return res
end

# check if a value is guaranteed to be not contain active[register] data
# (aka not either mixed or active)
@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing))
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))

@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
Expand All @@ -555,6 +562,8 @@ end
else
if ActReg == ActiveState
return Active{T}
elseif ActReg == MixedState
return MixedDuplicated{T}
else
return Duplicated{T}
end
Expand Down Expand Up @@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p)
return
end

function get_julia_inner_types(B, p, startvals...; added=[])
function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[])
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
Expand Down Expand Up @@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[])
end
continue
end
GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v
@assert false
if isa(ty, LLVM.IntegerType)
continue
end
if isa(ty, LLVM.FloatingPointType)
continue
end
msg = sprint() do io
println(io, "Enzyme illegal subtype")
println(io, "ty=", ty)
println(io, "cur=", cur)
println(io, "p=", p)
println(io, "startvals=", startvals)
end
throw(AssertionError(msg))
end
return vals
end
Expand Down Expand Up @@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument
retType = convert(API.CDIFFE_TYPE, rt)
retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated
API.DFT_OUT_DIFF
else
convert(API.CDIFFE_TYPE, rt)
end

rules = Dict{String, API.CustomRuleType}(
"jl_array_copy" => @cfunction(inout_rule,
Expand Down Expand Up @@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr

if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient
returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType))
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED)
shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated)
returnUsed &= returnPrimal
augmented = API.EnzymeCreateAugmentedPrimal(
logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed,
Expand Down Expand Up @@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end

# API.DFT_OUT_DIFF
if is_adjoint && rettype <: Active
@assert !sret_union
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
end
if !allocatedinline(actualRetType)
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
if is_adjoint
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
@assert !sret_union
if allocatedinline(actualRetType) != allocatedinline(literal_rt)
throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)"))
end
if rettype <: Active
if !allocatedinline(actualRetType)
throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)"))
end
end
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active))))
push!(T_wrapperargs, dretTy)
end
dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType)))
push!(T_wrapperargs, dretTy)
end

data = Array{Int64}(undef, 3)
Expand Down Expand Up @@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
else
push!(sret_types, AnonymousStruct(NTuple{width, literal_rt}))
end
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
if width == 1
push!(sret_types, Base.RefValue{literal_rt})
else
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}}))
end
end
else
@assert rettype <: Const || rettype <: Active
Expand Down Expand Up @@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end
end

if is_adjoint && rettype <: Active
if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated)
push!(realparms, params[i])
i += 1
end
Expand Down Expand Up @@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if data[i] != -1
eval = extract_value!(builder, val, data[i])
end
if i == 3
if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
for idx in 1:width
pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1)
al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)})
llty = value_type(pv)
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, pv, al)
emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv))
ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1)
end
eval = ival
end
end
eval = fixup_abi(i, eval)
ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)])
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
si = store!(builder, eval, ptr)
returnNum+=1

if i == 3 && shadow_init
shadows = LLVM.Value[]
if width == 1
Expand Down Expand Up @@ -5943,34 +5992,35 @@ end
end

if !RawCall && !(CC <: PrimalErrorThunk)
if rettype <: Active
if rettype <: Active
if length(argtypes) + is_adjoint + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), (fn, args...)))
end
end
elseif rettype <: Const
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
else
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC(fptr), $args))
throw(MethodError($CC(fptr), (fn, args...)))
end
end
end
end

types = DataType[]

if eltype(rettype) === Union{} && false
return quote
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end
if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType)
rrt = eltype(rettype)
error("Return type `$rrt` not marked Const, but is ghost or const type.")
Expand Down Expand Up @@ -6133,17 +6183,28 @@ end
end

# API.DFT_OUT_DIFF
if is_adjoint && rettype <: Active
# TODO handle batch width
@assert allocatedinline(jlRT)
j_drT = if width == 1
jlRT
else
NTuple{width, jlRT}
if is_adjoint
if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
# TODO handle batch width
if rettype <: Active
@assert allocatedinline(jlRT)
end
j_drT = if width == 1
jlRT
else
NTuple{width, jlRT}
end
push!(types, j_drT)
if width == 1 || rettype <: Active
push!(ccexprs, argexprs[i])
i+=1
else
push!(ccexprs, quote
($(argexprs[i:i+width-1]...),)
end)
i+=width
end
end
push!(types, j_drT)
push!(ccexprs, argexprs[i])
i+=1
end

if needs_tape
Expand Down Expand Up @@ -6181,8 +6242,12 @@ end
end
if rettype <: Duplicated || rettype <: DuplicatedNoNeed
push!(sret_types, jlRT)
elseif rettype <: MixedDuplicated
push!(sret_types, Base.RefValue{jlRT})
elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed
push!(sret_types, AnonymousStruct(NTuple{width, jlRT}))
elseif rettype <: BatchMixedDuplicated
push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}}))
elseif CC <: AugmentedForwardThunk
push!(sret_types, Nothing)
elseif rettype <: Const
Expand Down Expand Up @@ -6406,6 +6471,8 @@ end
@inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed
@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated
@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated

@inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI}
JuliaContext() do ctx
Expand Down
Loading
Loading