Skip to content

Commit

Permalink
Handle mixed return of unstable (#2104)
Browse files Browse the repository at this point in the history
* Handle mixed return of unstable

* with test

* Update mixed.jl

* Update mixed.jl
  • Loading branch information
wsmoses authored Nov 18, 2024
1 parent ba4c22a commit 9c6899c
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 68 deletions.
19 changes: 9 additions & 10 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,16 +682,15 @@ code, as well as high-order differentiation.

if A isa UnionAll
rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt)
rt = Core.Compiler.return_type(f.val, tt)
A2 = A{rt}
if rt == Union{}
throw(ErrorException("Return type inferred to be Union{}. Giving up."))
end
else
@assert A isa DataType
rt = A
if rt == Union{}
throw(ErrorException("Return type inferred to be Union{}. Giving up."))
A2 = A{rt}
if rt == Union{}
throw(ErrorException("Return type inferred to be Union{}. Giving up."))
end
else
@assert A isa DataType
rt = A
if rt == Union{}
throw(ErrorException("Return type inferred to be Union{}. Giving up."))
end
end

Expand Down
65 changes: 47 additions & 18 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4491,12 +4491,17 @@ function create_abi_wrapper(
push!(sret_types, AnonymousStruct(NTuple{width,literal_rt}))
end
elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
rty = if Base.isconcretetype(literal_rt)
Base.RefValue{literal_rt}
else
(Base.RefValue{T} where T <: literal_rt)
end
if width == 1
push!(sret_types, Base.RefValue{literal_rt})
push!(sret_types, rty)
else
push!(
sret_types,
AnonymousStruct(NTuple{width,Base.RefValue{literal_rt}}),
AnonymousStruct(NTuple{width,rty}),
)
end
end
Expand Down Expand Up @@ -4633,6 +4638,7 @@ function create_abi_wrapper(
convty = convert(LLVMType, T′; allow_boxed = true)

if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
@assert Base.isconcretetype(T′)
al0 = al = emit_allocobj!(builder, Base.RefValue{T′}, "mixedparameter")
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, params[i], al)
Expand Down Expand Up @@ -4692,6 +4698,7 @@ function create_abi_wrapper(
parmsi = params[i]

if T <: BatchMixedDuplicated
@assert Base.isconcretetype(T′)
if GPUCompiler.deserves_argbox(NTuple{width,Base.RefValue{T′}})
njlvalue = LLVM.ArrayType(Int(width), T_prjlvalue)
parmsi = bitcast!(
Expand Down Expand Up @@ -4812,26 +4819,37 @@ function create_abi_wrapper(
for idx = 1:width
pv =
(width == 1) ? eval : extract_value!(builder, eval, idx - 1)
al0 =
irt = eltype(rettype)
ires = if Base.isconcretetype(irt)
al = emit_allocobj!(
builder,
Base.RefValue{eltype(rettype)},
"batchmixedret",
)
llty = value_type(pv)
al = bitcast!(
builder,
al,
LLVM.PointerType(llty, addrspace(value_type(al))),
)
store!(builder, pv, al)
emit_writebarrier!(
builder,
get_julia_inner_types(builder, al0, pv),
)
al0 = al
llty = value_type(pv)
al = bitcast!(
builder,
al,
LLVM.PointerType(llty, addrspace(value_type(al))),
)
store!(builder, pv, al)
emit_writebarrier!(
builder,
get_julia_inner_types(builder, al0, pv),
)
al0
else
# emit_allocobj!(
# builder,
# emit_apply_type!(builder, Base.RefValue, [emit_jltypeof!(builder, pv)]),
# "batchmixedret",
# )
pv
end
ival =
(width == 1) ? al0 :
insert_value!(builder, ival, al0, idx - 1)
(width == 1) ? ires :
insert_value!(builder, ival, ires, idx - 1)
end
eval = ival
end
Expand Down Expand Up @@ -8223,11 +8241,21 @@ end
if rettype <: Duplicated || rettype <: DuplicatedNoNeed
push!(sret_types, jlRT)
elseif rettype <: MixedDuplicated
push!(sret_types, Base.RefValue{jlRT})
rty = if Base.isconcretetype(jlRT)
Base.RefValue{jlRT}
else
(Base.RefValue{T} where T <: jlRT)
end
push!(sret_types, rty)
elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed
push!(sret_types, AnonymousStruct(NTuple{width,jlRT}))
elseif rettype <: BatchMixedDuplicated
push!(sret_types, AnonymousStruct(NTuple{width,Base.RefValue{jlRT}}))
rty = if Base.isconcretetype(jlRT)
Base.RefValue{jlRT}
else
(Base.RefValue{T} where T <: jlRT)
end
push!(sret_types, AnonymousStruct(NTuple{width,rty}))
elseif CC <: AugmentedForwardThunk
push!(sret_types, Nothing)
elseif rettype <: Const
Expand Down Expand Up @@ -8363,6 +8391,7 @@ end

@assert length(types) == length(ccexprs)


if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT))
return quote
Base.@_inline_meta
Expand Down
86 changes: 47 additions & 39 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,32 +438,33 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
args = ($(wrapped...),)
$(MakeTypes...)

FT = Core.Typeof(f)
dupClosure0 = if ActivityTup[1]
!guaranteed_const(FT)
else
false
end

tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

annotationA = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt,$Width}
elseif $Width != 1 && annotation0 <: MixedDuplicated
BatchMixedDuplicated{rt,$Width}
else
annotation0
end

internal_tape, origRet, initShadow, annotation = if f isa typeof(Core.getglobal)
gv = Core.getglobal(args[1].val, args[2].val)
@assert sizeof(gv) == 0
(nothing, gv, nothing, Const)
else
FT = Core.Typeof(f)
tt = Tuple{$(ElTypes...)}
world = codegen_world_age(FT, tt)

dupClosure0 = if ActivityTup[1]
!guaranteed_const(FT)
else
false
end

rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt)

annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

annotationA = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt,$Width}
elseif $Width != 1 && annotation0 <: MixedDuplicated
BatchMixedDuplicated{rt,$Width}
else
annotation0
end

opt_mi = Val(world)
forward, adjoint = thunk(
opt_mi,
Expand Down Expand Up @@ -492,7 +493,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
)
return ReturnType(($(nres...), tape))
elseif annotation <: Active
shadow_return = $shadowretinit
shadow_return = if Base.isconcretetype(rt)
$shadowretinit
else
initShadow
end
tape = Tape{typeof(internal_tape),typeof(shadow_return),resT}(
internal_tape,
shadow_return,
Expand Down Expand Up @@ -634,31 +639,33 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
end

quote
$(active_refs...)
args = ($(wrapped...),)
$(MakeTypes...)

FT = Core.Typeof(f)
dupClosure0 = if ActivityTup[1]
!guaranteed_const(FT)
if f isa typeof(Core.getglobal)
else
false
end
$(active_refs...)
args = ($(wrapped...),)
$(MakeTypes...)

tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)
FT = Core.Typeof(f)
dupClosure0 = if ActivityTup[1]
!guaranteed_const(FT)
else
false
end

annotation = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt,$Width}
else
annotation0
end
tt = Tuple{$(ElTypes...)}

if f isa typeof(Core.getglobal)
else
world = codegen_world_age(FT, tt)

rt = Compiler.primal_return_type(Reverse, Val(world), FT, tt)

annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

annotation = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt,$Width}
else
annotation0
end

opt_mi = Val(world)
_, adjoint = thunk(
opt_mi,
Expand Down Expand Up @@ -1488,6 +1495,7 @@ end
if vec isa Base.RefValue
vecld = vec[]
T = Core.Typeof(vecld)
@assert !(vecld isa Base.RefValue)
vec[] = recursive_index_add(T, vecld, Val(idx_in_vec), expr)
else
val = @inbounds vec[idx_in_vec]
Expand Down
41 changes: 40 additions & 1 deletion test/mixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,43 @@ end

@testset "Mixed PrimalError" begin
@test_throws AssertionError autodiff(Reverse, bad_abi, MixedDuplicated(Foobar(2, 3, 4, 5, 6.0), Ref(Foobar(2, 3, 4, 5, 6.0))))
end
end



function flattened_unique_values(tupled)
flattened = flatten_tuple(tupled)

return nothing
end

@inline flatten_tuple(a::Tuple) = tuple(inner_flatten_tuple(a[1])..., inner_flatten_tuple(a[2:end])...)
@inline flatten_tuple(a::Tuple{<:Any}) = tuple(inner_flatten_tuple(a[1])...)

@inline inner_flatten_tuple(a) = tuple(a)
@inline inner_flatten_tuple(a::Tuple) = flatten_tuple(a)
@inline inner_flatten_tuple(a::Tuple{}) = ()


struct Center end

struct Field{LX}
grid :: Float64
data :: Float64
end

@testset "Mixed Unstable Return" begin
grid = 1.0
data = 2.0
f1 = Field{Center}(grid, data)
f2 = Field{Center}(grid, data)
f3 = Field{Center}(grid, data)
f4 = Field{Center}(grid, data)
f5 = Field{Nothing}(grid, data)
thing = (f1, f2, f3, f4, f5)
dthing = Enzyme.make_zero(thing)

dedC = autodiff(Enzyme.Reverse,
flattened_unique_values,
Duplicated(thing, dthing))
end

0 comments on commit 9c6899c

Please sign in to comment.