From ffcc20c81b4fb1a1f1785f75f29c61df63f9f677 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 11:02:26 -0500 Subject: [PATCH] Fix const-only apply iterate (#1526) * Fix const-only apply iterate * fix ct * Fix mixed activity for type unstable * Update jitrules.jl * Update jitrules.jl * wip tuple * fix batch tuple generation * Ensure runtime store error * fix * cleanup * ignore 1.8 * newstructv * ignore test --- src/compiler.jl | 43 +++- src/rules/jitrules.jl | 390 +++++++++++++++++++++-------- src/rules/typeunstablerules.jl | 445 +++++++++++++++++++++++++++++---- test/applyiter.jl | 14 ++ test/mixed.jl | 71 ++++++ test/runtests.jl | 1 + test/threads.jl | 3 +- 7 files changed, 808 insertions(+), 159 deletions(-) create mode 100644 test/mixed.jl diff --git a/src/compiler.jl b/src/compiler.jl index cca67bc874..ed44563ec4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -474,7 +474,7 @@ end end @assert !Base.isabstracttype(T) - if !(Base.isconcretetype(T) || is_concrete_tuple(T) || T isa UnionAll) + if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll) throw(AssertionError("Type $T is not concrete type or concrete tuple")) end @@ -515,7 +515,7 @@ end return active_reg_inner(T, (), world) end -@inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} +Base.@pure @inline function active_reg(::Type{T}, world::Union{Nothing, UInt}=nothing)::Bool where {T} seen = () # check if it could contain an active @@ -3342,6 +3342,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt + @assert eltype(rt) != Union{} + shadow_init = job.config.params.shadowInit ctx = context(mod) dl = string(LLVM.datalayout(mod)) @@ -3546,6 +3548,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, pactualRetType = actualRetType sret_union = is_sret_union(actualRetType) literal_rt = eltype(rettype) + @assert literal_rt != Union{} sret_union_rt = is_sret_union(literal_rt) @assert sret_union == sret_union_rt if sret_union @@ -3684,9 +3687,10 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - combinedReturn = Tuple{sret_types...} - if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) - combinedReturn = AnonymousStruct(combinedReturn) + combinedReturn = if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) + AnonymousStruct(Tuple{sret_types...}) + else + Tuple{sret_types...} end uses_sret = is_sret(combinedReturn) @@ -4794,6 +4798,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; libraries::Bool=true, deferred_codegen::Bool=true, optimize::Bool=true, toplevel::Bool=true, strip::Bool=false, validate::Bool=true, only_entry::Bool=false, parent_job::Union{Nothing, CompilerJob} = nothing) params = job.config.params + if params.run_enzyme + @assert eltype(params.rt) != Union{} + end expectedTapeType = params.expectedTapeType mode = params.mode TT = params.TT @@ -4801,7 +4808,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; abiwrap = params.abiwrap primal = job.source modifiedBetween = params.modifiedBetween - @assert length(modifiedBetween) == length(TT.parameters) + if length(modifiedBetween) != length(TT.parameters) + throw(AssertionError("length(modifiedBetween) [aka $(length(modifiedBetween))] != length(TT.parameters) [aka $(length(TT.parameters))] at TT=$TT")) + end returnPrimal = params.returnPrimal if !(params.rt <: Const) @@ -5297,6 +5306,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end @assert actualRetType !== nothing + if params.run_enzyme + @assert actualRetType != Union{} + end if must_wrap llvmfn = primalf @@ -5838,7 +5850,11 @@ end end push!(ccexprs, argexpr) - if !(FA <: Const) + if (FA <: Active) + return quote + error("Cannot have function with Active annotation, $FA") + end + elseif !(FA <: Const) argexpr = :(fn.dval) if isboxed push!(types, Any) @@ -6274,9 +6290,16 @@ end compile_result = cached_compilation(job) if !run_enzyme ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} - return quote - Base.@_inline_meta - $ErrT($(compile_result.adjoint)) + if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + return quote + Base.@_inline_meta + ($ErrT($(compile_result.adjoint)), $ErrT($(compile_result.adjoint))) + end + else + return quote + Base.@_inline_meta + $ErrT($(compile_result.adjoint)) + end end elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index af12d2bfbc..e5ce78aa02 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,5 +1,87 @@ +function func_mixed_call(N) + allargs = Expr[] + typeargs = Union{Symbol,Expr}[] + exprs2 = Union{Symbol,Expr}[] + for i in 1:N + arg = Symbol("arg_$i") + targ = Symbol("T$i") + e = :($arg::$targ) + push!(allargs, e) + push!(typeargs, targ) + + inarg = quote + if RefTypes[1+$i] + $arg[] + else + $arg + end + end + push!(exprs2, inarg) + end + + quote + @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)} + fexpr = :f + if RefTypes[1] + fexpr = :(($fexpr)[]) + end + exprs2 = Union{Symbol,Expr}[] + for i in 1:$N + arg = Symbol("arg_$i") + inarg = if RefTypes[1+i] + :($arg[]) + else + :($arg) + end + push!(exprs2, inarg) + end + @static if VERSION ≥ v"1.8-" + return quote + Base.@_inline_meta + @inline $fexpr($(exprs2...)) + end + else + return quote + Base.@_inline_meta + $fexpr($(exprs2...)) + end + end + end + end +end -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false) +@generated function runtime_mixed_call(::Val{RefTypes}, f::F, allargs::Vararg{Any, N}) where {RefTypes, F, N} + fexpr = :f + if RefTypes[1] + fexpr = :(($fexpr)[]) + end + exprs2 = Union{Symbol,Expr}[] + for i in 1:N + inarg = if RefTypes[1+i] + :(allargs[$i][]) + else + :(allargs[$i]) + end + push!(exprs2, inarg) + end + @static if VERSION ≥ v"1.8-" + return quote + Base.@_inline_meta + @inline $fexpr($(exprs2...)) + end + else + return quote + Base.@_inline_meta + $fexpr($(exprs2...)) + end + end +end + +for N in 0:10 + eval(func_mixed_call(N)) +end + +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -8,18 +90,20 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, typeargs = Symbol[] dfns = Union{Symbol,Expr}[:df] base_idx = 1 - for w in 2:Width - if base === nothing - shad = Symbol("df_$w") - t = Symbol("DF__$w*") - e = :($shad::$t) - push!(allargs, e) - push!(typeargs, t) - else - shad = :($base[$base_idx]) - base_idx += 1 + if func + for w in 2:Width + if base === nothing + shad = Symbol("df_$w") + t = Symbol("DF__$w*") + e = :($shad::$t) + push!(allargs, e) + push!(typeargs, t) + else + shad = :($base[$base_idx]) + base_idx += 1 + end + push!(dfns, shad) end - push!(dfns, shad) end for i in 1:N if base === nothing @@ -60,6 +144,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, @assert length(primtypes) == N wrapped = Expr[] modbetween = Expr[:(MB[1])] + active_refs = Expr[] for i in 1:N if iterate push!(modbetween, quote @@ -69,6 +154,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end end) end + aref = Symbol("active_ref_$i") + push!(active_refs, quote + $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); + end) expr = if iterate :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @@ -88,23 +177,57 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, end ) else - :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - if !$forwardMode && active_reg($(primtypes[i])) - Active($(primargs[i])) - else - $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) - end - else - Const($(primargs[i])) - end - - ) + if forwardMode + quote + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + else + Const($(primargs[i])) + end + end + else + quote + if ActivityTup[$i+1] && $aref != AnyState + @assert $(primtypes[i]) !== DataType + if $aref == ActiveState + Active($(primargs[i])) + elseif $aref == MixedState + $((Width == 1) ? :Duplicated : :BatchDuplicated)(Ref($(primargs[i])), $(shadowargs[i])) + else + $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) + end + else + Const($(primargs[i])) + end + end + end end push!(wrapped, expr) end - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween + + any_mixed = quote false end + for i in 1:N + aref = Symbol("active_ref_$i") + if mixed_or_active + any_mixed = :($any_mixed || $aref == MixedState || $aref == ActiveState) + else + any_mixed = :($any_mixed || $aref == MixedState) + end + end + + if mixed_or_active + push!(active_refs, quote + active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState) for i in 1:N)...)) + end) + else + push!(active_refs, quote + active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState) for i in 1:N)...)) + end) + end + push!(active_refs, quote + any_mixed = $any_mixed + end) + return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -159,7 +282,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function func_runtime_generic_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote @@ -171,46 +294,75 @@ end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end -function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) +function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) + + MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) + + Types = ntuple(i->Symbol("type_$i"), Val(N)) + + MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) return quote + $(active_refs...) args = ($(wrapped...),) + $(MakeTypes...) - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) else - annotation0 + false end - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, Tuple{$(ElTypes...)}) + internal_tape, origRet, initShadow, annotation = if any_mixed + ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} + rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) + annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, - annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + annotationM = if $Width != 1 && annotation0M <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0M + end + worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) + ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) + + forward, adjoint = thunk(Val(worldM), + Const{typeof(runtime_mixed_call)}, + annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + forward(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationM + + else + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotationA = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end + world = codegen_world_age(FT, tt) + + forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationA + end - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) resT = typeof(origRet) if annotation <: Const shadow_return = nothing @@ -243,8 +395,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) quote function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -255,11 +407,11 @@ end @generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs) end -function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) +function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs) outs = [] for i in 1:N for w in 1:Width @@ -273,7 +425,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) elseif $shad isa Base.RefValue $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)*" tup[i]="*string(tup[$i])*" i="*string($i)*" w="*string($w)*" tup="*string(tup)) end ) push!(outs, out) @@ -290,49 +442,81 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) shadowret = :(($(shadowret...),)) end - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + ElTypes = ntuple(i->:(eltype($(Symbol("type_$i")))), Val(N)) + + MakeTypes = ntuple(i->:($(Symbol("type_$i")) = Core.Typeof(args[$i])), Val(N)) + + Types = ntuple(i->Symbol("type_$i"), Val(N)) + + MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N)) quote + $(active_refs...) args = ($(wrapped...),) + $(MakeTypes...) + + FT = Core.Typeof(f) + dupClosure0 = if ActivityTup[1] + !guaranteed_const(FT) + else + false + end - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt = Tuple{$(ElTypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + if any_mixed + ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} + rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) + annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} + annotationM = if $Width != 1 && annotation0M <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0M + end + worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) + ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) + + _, adjoint = thunk(Val(worldM), + Const{typeof(runtime_mixed_call)}, + annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) + else + adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) + end + nothing else - annotation0 - end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, tt) + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + world = codegen_world_age(FT, tt) - if tape.shadow_return !== nothing - args = (args..., $shadowret) - end + _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + tup = if tape.shadow_return !== nothing + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + end - $(outs...) + $(outs...) + end return nothing end end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) quote function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -343,8 +527,8 @@ end @generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + _, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs, active_refs) end @inline concat() = () @@ -416,6 +600,13 @@ end end end +@inline function allSame(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + res + end +end + @inline function allZero(::Val{Width}, res) where Width ntuple(Val(Width)) do i Base.@_inline_meta @@ -484,7 +675,7 @@ function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -496,7 +687,7 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) + _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end @@ -586,7 +777,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((allFirst(Val(width+1), origRet)..., tape)) + return ReturnType((allSame(Val(width+1), origRet)..., tape)) elseif annotation <: Active if width == 1 shadow_return = Ref(make_zero(origRet)) @@ -623,7 +814,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, modbetween = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) quote @@ -635,7 +826,7 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ , modbetween, = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) end @@ -835,7 +1026,7 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end @@ -849,7 +1040,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -862,8 +1053,6 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, ActivityList = LLVM.Value[] - to_preserve = LLVM.Value[] - @assert length(ops) != 0 fill_val = unsafe_to_llvm(nothing) @@ -918,9 +1107,6 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else ev = extract_value!(B, inverted, w-1) end - if tape !== nothing - push!(to_preserve, ev) - end end push!(vals, ev) @@ -929,7 +1115,13 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, @assert length(ActivityList) == length(ops) if tape !== nothing - pushfirst!(vals, tape) + if tape isa Vector + for t in reverse(tape) + pushfirst!(vals, t) + end + else + pushfirst!(vals, tape) + end else pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) end @@ -975,7 +1167,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, debug_from_orig!(gutils, cal, orig) - if tape === nothing + if tape === nothing && endcast llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 1ee4f0d961..101796401f 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -1,8 +1,229 @@ +function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) + shadow_rets = Vector{Expr}[] + results = quote + $(active_refs...) + end + @assert length(primtypes) == N + @assert length(primargs) == N + @assert length(batchshadowargs) == N + for i in 1:N + @assert length(batchshadowargs[i]) == Width + shadow_rets_i = Expr[] + aref = Symbol("active_ref_$i") + for w in 1:Width + sref = Symbol("shadow_"*string(i)*"_"*string(w)) + push!(shadow_rets_i, quote + $sref = if $aref == AnyState + $(primargs[i]); + else + if !ActivityTup[$i] + if $aref == DupState || $aref == MixedState + prim = $(primargs[i]) + throw("Error cannot store inactive but differentiable variable $prim into active tuple") + end + end + if $aref == DupState + $(batchshadowargs[i][w]) + else + $(batchshadowargs[i][w])[] + end + end + end) + end + push!(shadow_rets, shadow_rets_i) + end -function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL - return true + refs = Expr[] + ref_syms = Symbol[] + res_syms = Symbol[] + for w in 1:Width + sres = Symbol("result_$w") + ref_res = Symbol("ref_result_$w") + combined = Expr[] + for i in 1:N + push!(combined, shadow_rets[i][w]) + end + if tuple + results = quote + $results + $sres = ($(combined...),) + end + else + results = quote + $results + $sres = $(Expr(:new, :NewType, combined...)) + end + end + push!(refs, quote + $ref_res = Ref($sres) + end) + push!(ref_syms, ref_res) + push!(res_syms, sres) end + + if Width == 1 + return quote + $results + if any_mixed + $(refs...) + $(ref_syms[1]) + else + $(res_syms[1]) + end + end + else + return quote + $results + if any_mixed + $(refs...) + ReturnType(($(ref_syms...),)) + else + ReturnType(($(res_syms...),)) + end + end + end +end + + +function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, tuple) + outs = [] + for i in 1:N + for w in 1:Width + tsym = Symbol("tval_$w") + expr = if tuple + :($tsym[$i]) + else + :(getfield($tsym, $i)) + end + shad = batchshadowargs[i][w] + out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState + if $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) + else + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + end + end + ) + push!(outs, out) + end + end + + tapes = Expr[:(tval_1 = tape[])] + for w in 2:Width + sym = Symbol("tval_$w") + df = Symbol("df_$w") + push!(tapes, :($sym = $df[])) + end + + quote + $(active_refs...) + + if any_mixed + $(tapes...) + $(outs...) + end + return nothing + end +end + + +function body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) +end + +function body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +end + + +function body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, true) +end + +function func_runtime_tuple_augfwd(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; func=false, mixed_or_active=true) + body = body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_tuple_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType} + N = div(length(allargs), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; func=false, mixed_or_active=true) + return body_runtime_tuple_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + + +function func_runtime_tuple_rev(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) + body = body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, TapeType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_tuple_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, TapeType} + N = div(length(allargs)-(Width-1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) + return body_runtime_tuple_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + + +function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs, false) +end + +function func_runtime_newstruct_augfwd(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} + N = div(length(allargs)+2, Width+1)-1 + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + +function func_runtime_newstruct_rev(N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) + body = body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) + + quote + function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, $(allargs...)) where {ActivityTup, MB, NewStruct, TapeType, $(typeargs...)} + $body + end + end +end + +@generated function runtime_newstruct_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewStruct}, tape::TapeType, allargs...) where {ActivityTup, MB, Width, NewStruct, TapeType} + N = div(length(allargs)-(Width-1), Width) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) + return body_runtime_newstruct_rev(N, Width, primtypes, active_refs, primargs, batchshadowargs) +end + +for (N, Width) in Iterators.product(0:30, 1:10) + eval(func_runtime_newstruct_augfwd(N, Width)) + eval(func_runtime_newstruct_rev(N, Width)) + eval(func_runtime_tuple_augfwd(N, Width)) + eval(func_runtime_tuple_rev(N, Width)) +end + + +# returns if legal and completed +function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) @@ -10,34 +231,35 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) @assert is_constant_value(gutils, origops[offset]) icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + abs = [abs_typeof(v) for v in origops[offset+1:end-1]] - legal = true - for (icv, (found, typ)) in zip(icvs, abs) + @assert length(icvs) == length(abs) + for (icv, (found_partial, typ_partial), (found, typ)) in zip(icvs, abs_partial, abs) + # Constants not handled unless known inactive from type if icv - if found - if guaranteed_const_nongen(typ, world) - continue - end + if !found_partial + return false + end + if !guaranteed_const_nongen(typ_partial, world) + return false + end + end + # if any active [e.g. ActiveState / MixedState] data could exist + # err + if !fwd + if !found + return false + end + act = active_reg_inner(typ, (), world) + if act == MixedState || act == ActiveState + return false end - legal = false end end - # if all(icvs) - # shadowres = new_from_original(gutils, orig) - # if width != 1 - # shadowres2 = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(shadowres)))) - # for idx in 1:width - # shadowres2 = insert_value!(B, shadowres2, shadowres, idx-1) - # end - # shadowres = shadowres2 - # end - # unsafe_store!(shadowR, shadowres.ref) - # return false - # end - if !legal - emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) + if !run + return true end shadowsin = LLVM.Value[invert_pointer(gutils, o, B) for o in origops[offset:end-1] ] @@ -62,19 +284,72 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end end unsafe_store!(shadowR, shadowres.ref) + return true +end + + +function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return true + end + + if !newstruct_common(#=fwd=#true, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + abs_partial = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + origops = collect(operands(orig)) + emit_error(B, orig, "Enzyme: Not yet implemented, mixed activity for jl_new_struct constants="*string(icvs)*" "*string(orig)*" "*string(abs)*" "*string([v for v in origops[offset+1:end-1]])) + end + return false end + function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) -end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) -function error_if_active_newstruct(::Type{T}, ::Type{Y}) where {T, Y} - seen = () - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg == ActiveState - throw(AssertionError("Found unhandled active variable ($T) in reverse mode of jl_newstruct constructor for $Y")) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return true end - nothing + + if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + + width = get_width(gutils) + + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false) + + if width == 1 + shadow = sret + else + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = sret + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) + end + end + unsafe_store!(shadowR, shadow.ref) + + unsafe_store!(tapeR, sret.ref) + return false + end + + return false end function common_newstructv_rev(offset, B, orig, gutils, tape) @@ -90,20 +365,11 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !needsShadow return end - - origops = collect(operands(orig)) - width = get_width(gutils) - - world = enzyme_extract_world(LLVM.parent(position(B))) - @assert is_constant_value(gutils, origops[offset]) - icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] - abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] - - - ty = lookup_value(gutils, new_from_original(gutils, origops[offset]), B) - for v in origops[offset+1:end-1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, lookup_value(gutils, new_from_original(gutils, v), B)), ty]) + if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape) end return nothing @@ -112,13 +378,94 @@ end function common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR) common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end + function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if is_constant_value(gutils, orig) || needsShadowP[] == 0 + return true + end + + if !newstruct_common(#=fwd=#false, #=run=#true, offset, B, orig, gutils, normalR, shadowR) + normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing + shadow = (unsafe_load(shadowR) != C_NULL) ? LLVM.Instruction(unsafe_load(shadowR)) : nothing + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + + width = get_width(gutils) + + sret = generic_setup(orig, runtime_tuple_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset+1, B, false; endcast = false) + + if width == 1 + shadow = sret + else + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = sret + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) + end + end + unsafe_store!(shadowR, shadow.ref) + + unsafe_store!(tapeR, sret.ref) + + return false + end end function common_f_tuple_rev(offset, B, orig, gutils, tape) - # This function allocates a new return which returns a pointer, thus this instruction itself cannot transfer - # derivative info, only create a shadow pointer, which is handled by the forward pass. + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, API.DEM_ReverseModePrimal) + needsPrimal = needsPrimalP[] != 0 + needsShadow = needsShadowP[] != 0 + + if !needsShadow + return + end + + if is_constant_value(gutils, orig) + return true + end + + if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) + @assert tape !== C_NULL + width = get_width(gutils) + tape2 = if width != 1 + res = LLVM.Value[] + + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + AT = LLVM.ArrayType(T_prjlvalue, Int(width)) + llty = convert(LLVMType, AnyArray(Int(width))) + cal = tape + cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) + cal = LLVM.pointercast!(B, cal, LLVM.PointerType(llty, Derived)) + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, cal, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + push!(res, ld) + end + res + else + tape + end + generic_setup(orig, runtime_tuple_rev, Nothing, gutils, #=start=#offset+1, B, true; tape=tape2) + end return nothing end diff --git a/test/applyiter.jl b/test/applyiter.jl index 2518e2d829..b1a26e5f54 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -89,6 +89,20 @@ function tupapprox(a, b) return a ≈ b end +@testset "Const Apply iterate" begin + function extiter() + vals = Any[3,] + extracted = Tuple(vals) + return extracted + end + + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(extiter)}, Duplicated) + + tape, res, dres = fwd(Const(extiter)) + @test res == (3,) + @test dres == (3,) +end + @testset "Reverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] dx = [(0.0, 0.0), (0.0, 0.0)] diff --git a/test/mixed.jl b/test/mixed.jl new file mode 100644 index 0000000000..dae0623073 --- /dev/null +++ b/test/mixed.jl @@ -0,0 +1,71 @@ +using Enzyme, Test + +@noinline function mixedmul(tup::T) where T + return tup[1] * tup[2][1] +end + +function outmixedmul(x::Float64) + vec = [x] + tup = (x, vec) + Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +function outmixedmul2(res, x::Float64) + vec = [x] + tup = (x, vec) + res[] = Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Basic Mixed Activity" begin + @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul, Active, Active(3.1))[1][1] +end + +@testset "Byref Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + @test 6.2 ≈ Enzyme.autodiff(Reverse, outmixedmul2, Const, Duplicated(res, dres), Active(3.1))[1][2] +end + +@static if VERSION >= v"1.8-" +@testset "Batched Byref Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + dres2 = Ref(3.0) + sig = Enzyme.autodiff(Reverse, outmixedmul2, Const, BatchDuplicated(res, (dres, dres2)), Active(3.1)) + @test 6.2 ≈ sig[1][2][1] + @test 3*6.2 ≈ sig[1][2][2] +end +end + +function tupmixedmul(x::Float64) + vec = [x] + tup = (x, Base.inferencebarrier(vec)) + Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Tuple Mixed Activity" begin + @test 6.2 ≈ Enzyme.autodiff(Reverse, tupmixedmul, Active, Active(3.1))[1][1] +end + +function outtupmixedmul(res, x::Float64) + vec = [x] + tup = (x, Base.inferencebarrier(vec)) + res[] = Base.inferencebarrier(mixedmul)(tup)::Float64 +end + +@testset "Byref Tuple Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + @test 6.2 ≈ Enzyme.autodiff(Reverse, outtupmixedmul, Const, Duplicated(res, dres), Active(3.1))[1][2] +end + +@static if VERSION >= v"1.8-" +@testset "Batched Byref Tuple Mixed Activity" begin + res = Ref(4.7) + dres = Ref(1.0) + dres2 = Ref(3.0) + sig = Enzyme.autodiff(Reverse, outtupmixedmul, Const, BatchDuplicated(res, (dres, dres2)), Active(3.1)) + @test 6.2 ≈ sig[1][2][1] + @test 3*6.2 ≈ sig[1][2][2] +end +end diff --git a/test/runtests.jl b/test/runtests.jl index e931666f90..ca05883c13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1701,6 +1701,7 @@ end @test dx2[1][2] ≈ 0.0 end +include("mixed.jl") include("applyiter.jl") @testset "Dynamic Val Construction" begin diff --git a/test/threads.jl b/test/threads.jl index 5fe80916d3..6899d8d2d6 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -74,7 +74,8 @@ end out = [1.0, 2.0] dout = [1.0, 1.0] @static if VERSION < v"1.8" - @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) + # GPUCompiler causes a stack overflow due to https://github.com/JuliaGPU/GPUCompiler.jl/issues/587 + # @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) else res = autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0)) @test res[1][2] ≈ 2.0