From b0cd3ed9985780177060be70f3e1c28311eff37a Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 08:21:19 -0700 Subject: [PATCH 01/11] Mixed activity for getfield --- src/rules/typeunstablerules.jl | 67 ++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 36f2798c0c..d22baed6f2 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -603,7 +603,9 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco Base.getfield(dptr, symname) end RT = Core.Typeof(res) - if active_reg(RT) + + actreg = active_reg_nothrow(RT, Val(nothing)) + if actreg == ActiveState if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else @@ -612,6 +614,17 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco Ref{RT}(make_zero(res)) end) end + elseif actreg == MixedState + if length(dptrs) == 0 + return Ref{RT}(res) + else + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname)) + end)...)) + return fval + end else if length(dptrs) == 0 return res @@ -633,8 +646,8 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc Base.getfield(dptr, symname+1) end RT = Core.Typeof(res) - actreg = active_reg(RT) - if actreg + actreg = active_reg_nothrow(RT, Val(nothing)) + if actreg == ActiveState if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else @@ -643,6 +656,17 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc Ref{RT}(make_zero(res)) end) end + elseif actreg == MixedState + if length(dptrs) == 0 + return Ref{RT}(res)::Any + else + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname+1)) + end)...)) + return fval + end else if length(dptrs) == 0 return res::Any @@ -657,6 +681,13 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end end +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) @@ -665,7 +696,9 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, end RT = Core.Typeof(cur) - if active_reg(RT) && !isconst + + actreg = active_reg_nothrow(RT, Val(nothing)) + if (actreg == ActiveState || actreg == MixedState) && !isconst if length(dptrs) == 0 if dptr isa Base.RefValue vload = dptr[] @@ -674,13 +707,13 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, i) if fieldname(dRT, i) == symname - recursive_add(prev, dret[]) + recursive_add(prev, dret[], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname, recursive_add(cur, dret[])) + setfield!(dptr, symname, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end else if dptr isa Base.RefValue @@ -690,7 +723,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, j) if fieldname(dRT, j) == symname - recursive_add(prev, dret[1][]) + recursive_add(prev, dret[1][], identity, guaranteed_nonactive) else prev end @@ -706,7 +739,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, Base.@_inline_meta prev = getfield(vload, j) if fieldname(dRT, j) == symname - recursive_add(prev, dret[1+i][]) + recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) else prev end @@ -717,7 +750,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, else Base.getfield(dptrs[i], symname) end - setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][])) + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][]), identity, guaranteed_nonactive) end end end @@ -733,7 +766,9 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} end RT = Core.Typeof(cur) - if active_reg(RT) && !isconst + + actreg = active_reg_nothrow(RT, Val(nothing)) + if (actreg == ActiveState || actreg == MixedState) && !isconst if length(dptrs) == 0 if dptr isa Base.RefValue vload = dptr[] @@ -742,13 +777,13 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, i) if i == symname+1 - recursive_add(prev, dret[]) + recursive_add(prev, dret[], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + setfield!(dptr, symname+1, recursive_add(cur, dret[]), identity, guaranteed_nonactive) end else if dptr isa Base.RefValue @@ -758,13 +793,13 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, j) if j == symname+1 - recursive_add(prev, dret[1][]) + recursive_add(prev, dret[1][], identity, guaranteed_nonactive) else prev end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + setfield!(dptr, symname+1, recursive_add(cur, dret[1][], identity, guaranteed_nonactive)) end for i in 1:length(dptrs) if dptrs[i] isa Base.RefValue @@ -774,7 +809,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} Base.@_inline_meta prev = getfield(vload, j) if j == symname+1 - recursive_add(prev, dret[1+i][]) + recursive_add(prev, dret[1+i][], identity, guaranteed_nonactive) else prev end @@ -785,7 +820,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} else Base.getfield(dptrs[i], symname+1) end - setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][])) + setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) end end end From 283e3ccf2475a763da147bda4ed8587831de6ec3 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 08:21:36 -0700 Subject: [PATCH 02/11] bump ver --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 66a06e1714..6cc31394ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.13" +version = "0.12.14" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 09fc316b46ffff9168acad7f5c74bfa5e1dbb529 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 09:30:57 -0700 Subject: [PATCH 03/11] fixup runtime iterate for mixed --- src/rules/jitrules.jl | 133 ++++++++++++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 37 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index af8f83b80e..74b5a7e12b 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -76,23 +76,51 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, $aref = active_reg_nothrow($(primtypes[i]), Val(nothing)); end) expr = if iterate - :( - if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) - @assert $(primtypes[i]) !== DataType - if !$forwardMode && active_reg($(primtypes[i])) - iterate_unwrap_augfwd_act($(primargs[i])...) - else - $((Width == 1) ? quote - iterate_unwrap_augfwd_dup(Val($forwardMode), $(primargs[i]), $(shadowargs[i])) - end : quote - iterate_unwrap_augfwd_batchdup(Val($forwardMode), Val($Width), $(primargs[i]), $(shadowargs[i])) + if forwardMode + dupexpr = if Width == 1 + quote + iterate_unwrap_fwd_dup($(primargs[i]), $(shadowargs[i])) + end + else + quote + iterate_unwrap_fwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + end + end + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + $dupexpr + else + map(Const, $(primargs[i])) + end + ) + else + dupexpr = if Width == 1 + quote + iterate_unwrap_augfwd_dup($(primargs[i]), $(shadowargs[i])) end - ) - end - else - map(Const, $(primargs[i])) - end - ) + else + quote + iterate_unwrap_augfwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + end + end + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + if $aref == ActiveState + Active($(primargs[i])) + iterate_unwrap_augfwd_act($(primargs[i])...) + elseif $aref == MixedState + T = $(primtypes[i]) + throw(AssertionError("Mixed State of type $T is unsupported in apply iterate")) + else + $dupexpr + end + else + map(Const, $(primargs[i])) + end + ) + end else if forwardMode quote @@ -131,16 +159,6 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, any_mixed = :($any_mixed || $aref == MixedState) end end - - if mixed_or_active - push!(active_refs, quote - active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState) for i in 1:N)...)) - end) - else - push!(active_refs, quote - active_refs = (false, $(collect(:($(Symbol("active_ref_$i")) == MixedState) for i in 1:N)...)) - end) - end push!(active_refs, quote any_mixed = $any_mixed end) @@ -493,30 +511,69 @@ end end end -@inline function iterate_unwrap_augfwd_dup(::Val{forwardMode}, args, dargs) where forwardMode +@inline function iterate_unwrap_fwd_dup(args, dargs) ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] ty = Core.Typeof(arg) if guaranteed_const(ty) Const(arg) - elseif !forwardMode && active_reg(ty) - Active(arg) else Duplicated(arg, dargs[i]) end end end -@inline function iterate_unwrap_augfwd_batchdup(::Val{forwardMode}, ::Val{Width}, args, dargs) where {forwardMode, Width} + +@inline function iterate_unwrap_fwd_batchdup(::Val{Width}, args, dargs) where {Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] ty = Core.Typeof(arg) if guaranteed_const(ty) Const(arg) - elseif !forwardMode && active_reg(ty) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) + end + end +end + +@inline function iterate_unwrap_augfwd_dup(args, dargs) + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState + Active(arg) + elseif actreg == MixedState + MixedDuplicated(arg, dargs[i]) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchdup(::Val{Width}, args, dargs) where {Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + actreg = active_reg_nothrow(ty, Val(nothing)) + if actreg == AnyState + Const(arg) + elseif actreg == ActiveState Active(arg) + elseif actreg == MixedState + BatchMixedDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) else BatchDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta @@ -597,9 +654,10 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end end -function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) return quote + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) @@ -609,7 +667,7 @@ end function func_runtime_iterate_fwd(N, Width) _, _, primtypes, allargs, typeargs, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) quote function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} @@ -621,7 +679,7 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _, active_refs = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + return body_runtime_iterate_fwd(N, Width, wrapped, primtypes, active_refs) end function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs @@ -736,9 +794,10 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end end -function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) +function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) return quote + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) @@ -748,7 +807,7 @@ end function func_runtime_iterate_augfwd(N, Width) _, _, primtypes, allargs, typeargs, wrapped, _, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) - body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) quote function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -760,7 +819,7 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _ , modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) end From 6924d59c3eede73e96947aead984f72d47ea033b Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 09:59:25 -0700 Subject: [PATCH 04/11] fix iter --- src/rules/jitrules.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 74b5a7e12b..5073ec8a52 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -950,7 +950,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween nothing end -function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs) +function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) outs = [] for i in 1:N for w in 1:Width @@ -997,6 +997,7 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado push!(shadowsplat, :(($(s...),))) end quote + $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) @@ -1006,8 +1007,8 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) - body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) + body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) quote function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -1019,7 +1020,7 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) - return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) + return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) end # Create specializations From c215cf309f716a9b0c2c24effa316a736244b451 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 20:29:15 -0400 Subject: [PATCH 05/11] mixedduplicated return --- src/Enzyme.jl | 14 ++++ src/compiler.jl | 137 ++++++++++++++++++++++++--------- src/rules/jitrules.jl | 91 ++++++++++++---------- src/rules/typeunstablerules.jl | 7 -- test/applyiter.jl | 2 + test/mixedapplyiter.jl | 95 +++++++++++++++++++++++ test/usermixed.jl | 116 ++++++++++++++++++++++++++++ 7 files changed, 380 insertions(+), 82 deletions(-) create mode 100644 test/mixedapplyiter.jl diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 87b8e249e9..de694b04f3 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -64,6 +64,10 @@ end arg = @inbounds args[i] if arg isa Active return true + elseif arg isa MixedDuplicated + return true + elseif arg isa BatchMixedDuplicated + return true else return false end @@ -95,6 +99,10 @@ end end @inline same_or_one_rec(current) = current +@inline same_or_one_rec(current, arg::BatchMixedDuplicated{T, N}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) +@inline same_or_one_rec(current, arg::Type{BatchMixedDuplicated{T, N}}, args...) where {T,N} = + same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg::BatchDuplicatedFunc{T, N}, args...) where {T,N} = same_or_one_rec(same_or_one_helper(current, N), args...) @inline same_or_one_rec(current, arg::Type{BatchDuplicatedFunc{T, N}}, args...) where {T,N} = @@ -844,6 +852,12 @@ result, ∂v, ∂A else BatchDuplicatedNoNeed{eltype(A2), width} end + elseif A2 <: MixedDuplicated && width != 1 + if A2 isa UnionAll + BatchMixedDuplicated{T, width} where T + else + BatchMixedDuplicated{eltype(A2), width} + end else A2 end diff --git a/src/compiler.jl b/src/compiler.jl index bdaacd05dd..7bbb1bbedd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -543,6 +543,13 @@ end return res end +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + @inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode)) @inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} @@ -555,6 +562,8 @@ end else if ActReg == ActiveState return Active{T} + elseif ActReg == MixedState + return MixedDuplicated{T} else return Duplicated{T} end @@ -2494,7 +2503,7 @@ function store_nonjl_types!(B, startval, p) return end -function get_julia_inner_types(B, p, startvals...; added=[]) +function get_julia_inner_types(B, p, startvals...; added=LLVM.API.LLVMValueRef[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2547,8 +2556,20 @@ function get_julia_inner_types(B, p, startvals...; added=[]) end continue end - GPUCompiler.@safe_warn "Enzyme illegal subtype", ty, cur, SI, p, v - @assert false + if isa(ty, LLVM.IntegerType) + continue + end + if isa(ty, LLVM.FloatingPointType) + continue + end + msg = sprint() do io + println(io, "Enzyme illegal subtype") + println(io, "ty=", ty) + println(io, "cur=", cur) + println(io, "p=", p) + println(io, "startvals=", startvals) + end + throw(AssertionError(msg)) end return vals end @@ -3474,7 +3495,11 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr # If requested, the shadow return value of the function # For each active (non duplicated) argument # The adjoint of that argument - retType = convert(API.CDIFFE_TYPE, rt) + retType = if rt <: MixedDuplicated || rt <: BatchMixedDuplicated + API.DFT_OUT_DIFF + else + convert(API.CDIFFE_TYPE, rt) + end rules = Dict{String, API.CustomRuleType}( "jl_array_copy" => @cfunction(inout_rule, @@ -3513,7 +3538,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient returnUsed = !(isghostty(actualRetType) || Core.Compiler.isconstType(actualRetType)) - shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED) + shadowReturnUsed = returnUsed && (retType == API.DFT_DUP_ARG || retType == API.DFT_DUP_NONEED || rt <: MixedDuplicated || rt <: BatchMixedDuplicated) returnUsed &= returnPrimal augmented = API.EnzymeCreateAugmentedPrimal( logic, primalf, retType, args_activity, TA, #=returnUsed=# returnUsed, @@ -3679,16 +3704,20 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end # API.DFT_OUT_DIFF - if is_adjoint && rettype <: Active - @assert !sret_union - if allocatedinline(actualRetType) != allocatedinline(literal_rt) - throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) - end - if !allocatedinline(actualRetType) - throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + if is_adjoint + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + @assert !sret_union + if allocatedinline(actualRetType) != allocatedinline(literal_rt) + throw(AssertionError("Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = $(actualRetType), literal_rt = $(literal_rt), rettype = $(rettype)")) + end + if rettype <: Active + if !allocatedinline(actualRetType) + throw(AssertionError("Base.allocatedinline(actualRetType) returns false: actualRetType = $(actualRetType), rettype = $(rettype)")) + end + end + dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType; allow_boxed=!(rettype <: Active)))) + push!(T_wrapperargs, dretTy) end - dretTy = LLVM.LLVMType(API.EnzymeGetShadowType(width, convert(LLVMType, actualRetType))) - push!(T_wrapperargs, dretTy) end data = Array{Int64}(undef, 3) @@ -3730,6 +3759,12 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, else push!(sret_types, AnonymousStruct(NTuple{width, literal_rt})) end + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if width == 1 + push!(sret_types, Base.RefValue{literal_rt}) + else + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{literal_rt}})) + end end else @assert rettype <: Const || rettype <: Active @@ -3953,7 +3988,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end end - if is_adjoint && rettype <: Active + if is_adjoint && (rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated) push!(realparms, params[i]) i += 1 end @@ -3999,12 +4034,26 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if data[i] != -1 eval = extract_value!(builder, val, data[i]) end + if i == 3 + if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) + for idx in 1:width + pv = (width == 1) ? eval : extract_value!(builder, eval, idx-1) + al0 = al = emit_allocobj!(builder, Base.RefValue{eltype(rettype)}) + llty = value_type(pv) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, pv, al) + emit_writebarrier!(builder, get_julia_inner_types(builder, al0, pv)) + ival = (width == 1 ) ? al0 : insert_value!(builder, ival, al0, idx-1) + end + eval = ival + end + end eval = fixup_abi(i, eval) ptr = inbounds_gep!(builder, jltype, sret, [LLVM.ConstantInt(LLVM.IntType(64), 0), LLVM.ConstantInt(LLVM.IntType(32), returnNum)]) ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval))) si = store!(builder, eval, ptr) returnNum+=1 - if i == 3 && shadow_init shadows = LLVM.Value[] if width == 1 @@ -5943,22 +5992,28 @@ end end if !RawCall && !(CC <: PrimalErrorThunk) - if rettype <: Active + if rettype <: Active if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) + end + end + elseif rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + if length(argtypes) + is_adjoint * width + needs_tape != length(argexprs) + return quote + throw(MethodError($CC(fptr), (fn, args...))) end end elseif rettype <: Const if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) end end else if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC(fptr), $args)) + throw(MethodError($CC(fptr), (fn, args...))) end end end @@ -5966,11 +6021,6 @@ end types = DataType[] - if eltype(rettype) === Union{} && false - return quote - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") - end - end if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) rrt = eltype(rettype) error("Return type `$rrt` not marked Const, but is ghost or const type.") @@ -6133,17 +6183,28 @@ end end # API.DFT_OUT_DIFF - if is_adjoint && rettype <: Active - # TODO handle batch width - @assert allocatedinline(jlRT) - j_drT = if width == 1 - jlRT - else - NTuple{width, jlRT} + if is_adjoint + if rettype <: Active || rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated + # TODO handle batch width + if rettype <: Active + @assert allocatedinline(jlRT) + end + j_drT = if width == 1 + jlRT + else + NTuple{width, jlRT} + end + push!(types, j_drT) + if width == 1 || rettype <: Active + push!(ccexprs, argexprs[i]) + i+=1 + else + push!(ccexprs, quote + ($(argexprs[i:i+width-1]...),) + end) + i+=width + end end - push!(types, j_drT) - push!(ccexprs, argexprs[i]) - i+=1 end if needs_tape @@ -6181,8 +6242,12 @@ end end if rettype <: Duplicated || rettype <: DuplicatedNoNeed push!(sret_types, jlRT) + elseif rettype <: MixedDuplicated + push!(sret_types, Base.RefValue{jlRT}) elseif rettype <: BatchDuplicated || rettype <: BatchDuplicatedNoNeed push!(sret_types, AnonymousStruct(NTuple{width, jlRT})) + elseif rettype <: BatchMixedDuplicated + push!(sret_types, AnonymousStruct(NTuple{width, Base.RefValue{jlRT}})) elseif CC <: AugmentedForwardThunk push!(sret_types, Nothing) elseif rettype <: Const @@ -6406,6 +6471,8 @@ end @inline remove_innerty(::Type{<:DuplicatedNoNeed}) = DuplicatedNoNeed @inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated @inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed +@inline remove_innerty(::Type{<:MixedDuplicated}) = MixedDuplicated +@inline remove_innerty(::Type{<:BatchMixedDuplicated}) = MixedDuplicated @inline @generated function thunk(::Val{World}, ::Type{FA}, ::Type{A}, tt::Type{TT},::Val{Mode}, ::Val{width}, ::Val{ModifiedBetween}, ::Val{ReturnPrimal}, ::Val{ShadowInit}, ::Type{ABI}) where {FA<:Annotation, A<:Annotation, TT, Mode, ModifiedBetween, width, ReturnPrimal, ShadowInit, World, ABI} JuliaContext() do ctx diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 5073ec8a52..f38f484ced 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -248,8 +248,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) ending = if Width == 1 quote - if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue) - shadow_return = Ref(initShadow) + if annotation <: MixedDuplicated + shadow_return = initShadow tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) return ReturnType((origRet, shadow_return, tape)) else @@ -259,23 +259,11 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end end else - expr = :() - shads = Expr[] - for i in 1:Width - if i == 1 - expr = quote !(initShadow[$i] isa Base.RefValue) end - else - expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end - end - push!(shads, quote - Ref(initShadow[$i]) - end) - end quote - if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr) - shadow_return = ($(shads...),) + if annotation <: BatchMixedDuplicated + shadow_return = (initShadow...,) tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType((origRet, shadow_return..., tape)) + return ReturnType((origRet, initShadow..., tape)) else shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) @@ -302,6 +290,8 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) annotationA = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} + elseif $Width != 1 && annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, $Width} else annotation0 end @@ -333,8 +323,6 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) end end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - $ending end end @@ -448,13 +436,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end - tup = if annotation0 <: Active + tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] @@ -552,7 +534,11 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - MixedDuplicated(arg, dargs[i]) + if dargs[i] isa Base.RefValue + MixedDuplicated(arg, dargs[i]) + else + MixedDuplicated(arg, Ref(dargs[i])) + end else Duplicated(arg, dargs[i]) end @@ -572,7 +558,11 @@ end elseif actreg == MixedState BatchMixedDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta - dargs[j][i] + if dargs[j][i] isa Base.RefValue + dargs[j][i] + else + Ref(dargs[j][i]) + end end) else BatchDuplicated(arg, ntuple(Val(Width)) do j @@ -727,6 +717,8 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, width} elseif annotation0 <: Active Active{rt} else @@ -735,6 +727,8 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} elseif annotation0 <: Active Active{rt} else @@ -765,15 +759,16 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end resT = typeof(origRet) + if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) return ReturnType((allSame(Val(width+1), origRet)..., tape)) elseif annotation <: Active - if width == 1 - shadow_return = Ref(make_zero(origRet)) + shadow_return = if width == 1 + Ref(make_zero(origRet)) else - shadow_return = allZero(Val(width), origRet) + allZero(Val(width), origRet) end tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if width == 1 @@ -783,14 +778,26 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if width == 1 - return ReturnType((origRet, initShadow, tape)) + if annotation <: MixedDuplicated + shadow_return = initShadow + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow, tape)) + end else - return ReturnType((origRet, initShadow..., tape)) + if annotation <: BatchMixedDuplicated + shadow_return = initShadow + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + else + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType((origRet, initShadow..., tape)) + end end end @@ -840,6 +847,8 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: MixedDuplicated + BatchMixedDuplicated{rt, width} elseif annotation0 <: Active Active{rt} else @@ -848,6 +857,8 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: MixedDuplicated + MixedDuplicated{rt} elseif annotation0 <: Active Active{rt} else @@ -870,7 +881,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween forward, adjoint = thunk(Val(world), FA, annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - + args2 = if tape.shadow_return !== nothing if width == 1 (args..., tape.shadow_return[]) @@ -925,7 +936,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween Base.@_inline_meta prev = getfield(vecld, i) if i == idx_in_vec - recursive_add(prev, expr) + recursive_add(prev, expr, identity, guaranteed_nonactive) else prev end @@ -935,7 +946,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween if val isa Base.RefValue val[] = recursive_add(val[], expr) elseif ismutable(vec) - @inbounds vec[idx_in_vec] = recursive_add(val, expr) + @inbounds vec[idx_in_vec] = recursive_add(val, expr, identity, guaranteed_nonactive) else error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index d22baed6f2..db9759ec51 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -681,13 +681,6 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end end -# check if a value is guaranteed to be not contain active[register] data -# (aka not either mixed or active) -@inline function guaranteed_nonactive(::Type{T}) where T - rt = Enzyme.Compiler.active_reg_nothrow(T, Val(nothing)) - return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState -end - function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) diff --git a/test/applyiter.jl b/test/applyiter.jl index b1a26e5f54..11e9ebf37c 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -503,3 +503,5 @@ end Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) @test ddata[1][1] ≈ 6.0 end + +include("mixedapplyiter.jl") \ No newline at end of file diff --git a/test/mixedapplyiter.jl b/test/mixedapplyiter.jl new file mode 100644 index 0000000000..722f5832ab --- /dev/null +++ b/test/mixedapplyiter.jl @@ -0,0 +1,95 @@ +using Enzyme, Test + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +function mixed_metasumsq(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v::Tuple{Float64, Vector{Float64}} + res += v[1]*v[1] + v[2][1] * v[2][1] + end + return res +end + +function mixed_metasumsq3(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v + res += v*v + end + return res +end + +function make_byref(out, fn, args...) + out[] = fn(args...) + nothing +end + +function tupapprox(a, b) + if a isa Tuple && b isa Tuple + if length(a) != length(b) + return false + end + for (aa, bb) in zip(a, b) + if !tupapprox(aa, bb) + return false + end + end + return true + end + if a isa Array && b isa Array + if size(a) != size(b) + return false + end + for i in length(a) + if !tupapprox(a[i], b[i]) + return false + end + end + return true + end + return a ≈ b +end + + +@testset "Mixed Reverse Apply iterate" begin + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @show dx + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @show dx + @show res + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @show dx + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) +end diff --git a/test/usermixed.jl b/test/usermixed.jl index b5cd0e158b..f97c5737ec 100644 --- a/test/usermixed.jl +++ b/test/usermixed.jl @@ -1,6 +1,122 @@ using Enzyme using Test +########## MixedDuplicated of Return + +function user_mixret(x, y) + return (x, y) +end + +@testset "MixedDuplicated struct return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_mixret)}, MixedDuplicated, Active{Float64}, Duplicated{Vector{Float64}}) + + tape, res, dres = fwd(Const(user_mixret), Active(x), Duplicated(y, dy)) + + @test res[1] ≈ x + @test res[2] === y + + @test dres[][1] ≈ 0.0 + @test dres[][2] === dy + + outs = rev(Const(user_mixret), Active(x), Duplicated(y, dy), (47.56, dy), tape) + + @test outs[1][1] ≈ 47.56 +end + +@testset "BatchMixedDuplicated struct return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + dy2 = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_mixret)}, BatchMixedDuplicated, Active{Float64}, BatchDuplicated{Vector{Float64}, 2}) + + tape, res, dres = fwd(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2))) + + @test res[1] ≈ x + @test res[2] === y + + @test dres[1][][1] ≈ 0.0 + @test dres[1][][2] === dy + @test dres[2][][1] ≈ 0.0 + @test dres[2][][2] === dy2 + + outs = rev(Const(user_mixret), Active(x), BatchDuplicated(y, (dy, dy2)), (47.0, dy), (56.0, dy), tape) + + @test outs[1][1][1] ≈ 47.0 + @test outs[1][1][2] ≈ 56.0 +end + + +function user_fltret(x, y) + return x +end + +@testset "MixedDuplicated float return" begin + x = 2.7 + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(identity)}, MixedDuplicated, Active{Float64}) + + tape, res, dres = fwd(Const(identity), Active(x)) + + @test res ≈ x + @test dres[] ≈ 0.0 + + outs = rev(Const(identity), Active(x), 47.56, tape) + + @test outs[1][1] ≈ 47.56 +end + +@testset "BatchMixedDuplicated float return" begin + x = 2.7 + y = [3.14] + dy = [0.0] + dy2 = [0.0] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(user_fltret)}, BatchMixedDuplicated, Active{Float64}, BatchDuplicated{Vector{Float64}, 2}) + + tape, res, dres = fwd(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2))) + + @test res ≈ x + + @test dres[1][] ≈ 0.0 + @test dres[2][] ≈ 0.0 + + outs = rev(Const(user_fltret), Active(x), BatchDuplicated(y, (dy, dy2)), 47.0, 56.0, tape) + + @test outs[1][1][1] ≈ 47.0 + @test outs[1][1][2] ≈ 56.0 +end + +function vecsq(x) + x[2] = x[1] * x[1] + return x +end + +@testset "MixedDuplicated vector return" begin + y = [3.14, 0.0] + dy = [0.0, 2.7] + + fwd, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(vecsq)}, MixedDuplicated, Duplicated{Vector{Float64}}) + + tape, res, dres = fwd(Const(vecsq), Duplicated(y, dy)) + + @test res === y + + @test dres[] === dy + + outs = rev(Const(vecsq), Duplicated(y, dy), dy, tape) + + @test dy ≈ [3.14 * 2.7 * 2, 0.0] +end + + +########## MixedDuplicated of Argument + function user_mixfnc(tup) return tup[1] * tup[2][1] end From 4f1ac50d2783d4dc18c99b999153b164423b52e1 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 22:26:46 -0400 Subject: [PATCH 06/11] fixup --- src/rules/jitrules.jl | 144 +++++++++++++++++++++------------ src/rules/typeunstablerules.jl | 4 +- test/mixedapplyiter.jl | 75 ++++++++++++++--- 3 files changed, 156 insertions(+), 67 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f38f484ced..f33c71f35e 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,4 +1,4 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false, reverse=false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -97,18 +97,17 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, else dupexpr = if Width == 1 quote - iterate_unwrap_augfwd_dup($(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i])) end else quote - iterate_unwrap_augfwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchdup(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i])) end end :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @assert $(primtypes[i]) !== DataType if $aref == ActiveState - Active($(primargs[i])) iterate_unwrap_augfwd_act($(primargs[i])...) elseif $aref == MixedState T = $(primtypes[i]) @@ -436,8 +435,14 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - tup = if annotation0 <: Active || annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated + tup = if annotation0 <: Active adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + elseif annotation0 <: MixedDuplicated || annotation0 <: BatchMixedDuplicated + if $Width == 1 + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret..., tape.internal_tape)[1] + end else adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] end @@ -523,7 +528,7 @@ end end end -@inline function iterate_unwrap_augfwd_dup(args, dargs) +@inline function iterate_unwrap_augfwd_dup(::Val{reverse}, vals, args, dargs) where reverse ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -537,7 +542,14 @@ end if dargs[i] isa Base.RefValue MixedDuplicated(arg, dargs[i]) else - MixedDuplicated(arg, Ref(dargs[i])) + rval = if reverse + popfirst!(vals) + else + tmp = Ref(dargs[i]) + push!(vals, tmp) + tmp + end + MixedDuplicated(arg, rval) end else Duplicated(arg, dargs[i]) @@ -545,7 +557,7 @@ end end end -@inline function iterate_unwrap_augfwd_batchdup(::Val{Width}, args, dargs) where {Width} +@inline function iterate_unwrap_augfwd_batchdup(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width} ntuple(Val(length(args))) do i Base.@_inline_meta arg = args[i] @@ -561,7 +573,13 @@ end if dargs[j][i] isa Base.RefValue dargs[j][i] else - Ref(dargs[j][i]) + if reverse + popfirst!(vals) + else + tmp = Ref(dargs[j][i]) + push!(vals, tmp) + tmp + end end end) else @@ -679,29 +697,43 @@ function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs end end -function shadow_tuple(::Val{1}, args::Vararg{Annotation, Nargs}) where Nargs - ntuple(Val(Nargs)) do i +function shadow_tuple(::Type{Ann}, ::Val{1}, args::Vararg{Annotation, Nargs}) where {Ann, Nargs} + res = ntuple(Val(Nargs)) do i Base.@_inline_meta @assert !(args[i] isa Active) if args[i] isa Const args[i].val + elseif args[i] isa MixedDuplicated + args[i].dval[] else args[i].dval end end + if Ann <: MixedDuplicated + Ref(res) + else + res + end end -function shadow_tuple(::Val{width}, args::Vararg{Annotation, Nargs}) where {width, Nargs} +function shadow_tuple(::Type{Ann}, ::Val{width}, args::Vararg{Annotation, Nargs}) where {Ann, width, Nargs} ntuple(Val(width)) do w - ntuple(Val(Nargs)) do i - Base.@_inline_meta - @assert !(args[i] isa Active) - if args[i] isa Const - args[i].val - else - args[i].dval[w] + res = ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + elseif args[i] isa BatchMixedDuplicated + args[i].dval[w][] + else + args[i].dval[w] + end + end + if Ann <: BatchMixedDuplicated + Ref(res) + else + res end - end end end @@ -755,7 +787,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) forward(fa, args...) else - nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(Val(width), args...) + nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(annotation, Val(width), args...) end resT = typeof(origRet) @@ -803,12 +835,18 @@ end function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes, active_refs) wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + results = Expr[] + for i in 1:(Width+1) + push!(results, :(tmpvals[$i])) + end return quote + refs = Base.RefValue[] $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + tmpvals = augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType + ReturnType(($(results...), (tmpvals[$(Width+2)], refs))) end end @@ -886,10 +924,15 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween if width == 1 (args..., tape.shadow_return[]) else - (args..., ntuple(Val(width)) do w + shads = ntuple(Val(width)) do w Base.@_inline_meta tape.shadow_return[w][] - end) + end + if annotation <: MixedDuplicated || annotation <: BatchMixedDuplicated + (args..., shads...,) + else + (args..., shads) + end end else args @@ -908,6 +951,15 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween tape.shadow_return[w][][i] end end + elseif args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated + if width == 1 + tape.shadow_return[][i] + else + ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][][i] + end + end else nothing end @@ -919,14 +971,20 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween ntuple(Val(width)) do w Base.@_inline_meta - - if tup[i] == nothing - else - expr = if width == 1 - tup[i] + if args[i] isa Active || args[i] isa MixedDuplicated || args[i] isa BatchMixedDuplicated + expr = if args[i] isa Active || f == Base.tuple + if width == 1 + tup[i] + else + tup[i][w] + end + elseif args[i] isa MixedDuplicated + args[i].dval[] else - tup[i][w] + # if args[i] isa BatchMixedDuplicated + args[i].dval[w][] end + idx_of_vec, idx_in_vec = lengths[i] vec = @inbounds shadowargs[idx_of_vec][w] if vec isa Base.RefValue @@ -962,25 +1020,6 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween end function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs, active_refs) - outs = [] - for i in 1:N - for w in 1:Width - expr = if Width == 1 - :(tup[$i]) - else - :(tup[$i][$w]) - end - shad = shadowargs[i][w] - out = :(if tup[$i] === nothing - elseif $shad isa Base.RefValue - $shad[] = recursive_add($shad[], $expr) - else - error("Enzyme Mutability Error: Cannot add in place to immutable value "*string($shad)) - end - ) - push!(outs, out) - end - end shadow_ret = nothing if Width == 1 shadowret = :(tape.shadow_return[]) @@ -1008,17 +1047,18 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado push!(shadowsplat, :(($(s...),))) end quote + (tape0, refs) = tape $(active_refs...) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowsplat...),), args...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape0, ($(shadowsplat...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true; reverse=true) body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) quote @@ -1030,7 +1070,7 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween, active_refs = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true; reverse=true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs, active_refs) end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index db9759ec51..dbd22275c8 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -618,7 +618,7 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco if length(dptrs) == 0 return Ref{RT}(res) else - fval = NT((res, (ntuple(Val(length(dptrs))) do i + fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i Base.@_inline_meta dv = dptrs[i] Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname)) @@ -660,7 +660,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc if length(dptrs) == 0 return Ref{RT}(res)::Any else - fval = NT((res, (ntuple(Val(length(dptrs))) do i + fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i Base.@_inline_meta dv = dptrs[i] Ref{RT}(getfield(dv isa Base.RefValue ? dv[] : dv, symname+1)) diff --git a/test/mixedapplyiter.jl b/test/mixedapplyiter.jl index 722f5832ab..bb7f18243c 100644 --- a/test/mixedapplyiter.jl +++ b/test/mixedapplyiter.jl @@ -64,32 +64,81 @@ function tupapprox(a, b) return a ≈ b end - -@testset "Mixed Reverse Apply iterate" begin +@testset "Mixed Reverse Apply iterate (tuple)" begin x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @show dx - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @show dx - @show res - @test res[2] ≈ 200.84999999999997 - @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test res[2] ≈ 5562.9996 + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) +end +@testset "BatchMixed Reverse Apply iterate (tuple)" begin + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) + + x = [((2.0, [2.7]), (3.0, [3.14])), ((7.9, [47.0]), (11.2, [56.0]))] + dx = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + dx2 = [((0.0, [0.0]), (0.0, [0.0])), ((0.0, [0.0]), (0.0, [0.0]))] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 5562.9996 + @test tupapprox(dx, [((4.0, [5.4]), (6.0, [6.28])), ((15.8, [94.0]), (22.4, [112.0]))]) + @test tupapprox(dx2, [((3*4.0, [3*5.4]), (3*6.0, [3*6.28])), ((3*15.8, [3*94.0]), (3*22.4, [3*112.0]))]) +end + + +@testset "Mixed Reverse Apply iterate (list)" begin x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] res = Enzyme.autodiff(Reverse, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - @show dx - @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] res = Enzyme.autodiff(ReverseWithPrimal, mixed_metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) - - @test res[2] ≈ 200.84999999999997 - @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test res[2] ≈ 5562.9996 + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) end + +@testset "BatchMixed Reverse Apply iterate (list)" begin + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) + + x = [[(2.0, [2.7]), (3.0, [3.14])], [(7.9, [47.0]), (11.2, [56.0])]] + dx = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + dx2 = [[(0.0, [0.0]), (0.0, [0.0])], [(0.0, [0.0]), (0.0, [0.0])]] + + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(mixed_metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 5562.9996 + @test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]]) + @test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]]) +end \ No newline at end of file From b329abe2b90bfbd14219fccb301273ad84f8db5d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 13 Jun 2024 22:38:37 -0400 Subject: [PATCH 07/11] fix --- src/rules/typeunstablerules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index dbd22275c8..b295403a99 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -776,7 +776,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[]), identity, guaranteed_nonactive) + setfield!(dptr, symname+1, recursive_add(cur, dret[], identity, guaranteed_nonactive)) end else if dptr isa Base.RefValue From 7f08afe53b8846dc7c77de2fb1722a7a1974423f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 13 Jun 2024 23:01:23 -0400 Subject: [PATCH 08/11] try inference fix re ref --- src/rules/jitrules.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f33c71f35e..6ae7106131 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -539,13 +539,14 @@ end elseif actreg == ActiveState Active(arg) elseif actreg == MixedState - if dargs[i] isa Base.RefValue - MixedDuplicated(arg, dargs[i]) + darg = Base.inferencebarrier(dargs[i]) + if darg isa Base.RefValue + MixedDuplicated(arg, darg) else rval = if reverse popfirst!(vals) else - tmp = Ref(dargs[i]) + tmp = Ref(darg) push!(vals, tmp) tmp end @@ -570,13 +571,14 @@ end elseif actreg == MixedState BatchMixedDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta - if dargs[j][i] isa Base.RefValue - dargs[j][i] + darg = Base.inferencebarrier(dargs[j][i]) + if darg isa Base.RefValue + darg else if reverse popfirst!(vals) else - tmp = Ref(dargs[j][i]) + tmp = Ref(darg) push!(vals, tmp) tmp end From e0cc94f69144348ff0c2ca6fba66b78559cb5315 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 13 Jun 2024 23:08:33 -0400 Subject: [PATCH 09/11] try more --- src/rules/jitrules.jl | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 6ae7106131..fb8b34b831 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -528,6 +528,20 @@ end end end +function push_if_not_ref(::Val{reverse}, vals, darg, ::Type{T2}) where {reverse, T2} + if reverse + return popfirst!(vals) + else + tmp = Base.RefValue{T2}(darg) + push!(vals, tmp) + return tmp + end +end + +function push_if_not_ref(::Val{reverse}, vals, darg::Base.RefValue{T2}, ::Type{T2}) where {reverse, T2} + return darg +end + @inline function iterate_unwrap_augfwd_dup(::Val{reverse}, vals, args, dargs) where reverse ntuple(Val(length(args))) do i Base.@_inline_meta @@ -540,18 +554,7 @@ end Active(arg) elseif actreg == MixedState darg = Base.inferencebarrier(dargs[i]) - if darg isa Base.RefValue - MixedDuplicated(arg, darg) - else - rval = if reverse - popfirst!(vals) - else - tmp = Ref(darg) - push!(vals, tmp) - tmp - end - MixedDuplicated(arg, rval) - end + MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) else Duplicated(arg, dargs[i]) end @@ -572,17 +575,7 @@ end BatchMixedDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta darg = Base.inferencebarrier(dargs[j][i]) - if darg isa Base.RefValue - darg - else - if reverse - popfirst!(vals) - else - tmp = Ref(darg) - push!(vals, tmp) - tmp - end - end + MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) end) else BatchDuplicated(arg, ntuple(Val(Width)) do j From 2020e20ffc460ac1b30aac84bdc04b7f404757ad Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 14 Jun 2024 07:41:49 -0400 Subject: [PATCH 10/11] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6cc31394ab..f72e22bd39 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.5" -Enzyme_jll = "0.0.121" +Enzyme_jll = "0.0.122" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From ab7cee5966c41df265d4727e59af909b20b517c1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 14 Jun 2024 08:45:30 -0400 Subject: [PATCH 11/11] Update jitrules.jl --- src/rules/jitrules.jl | 2 +- src/rules/typeunstablerules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index fb8b34b831..f04145f7bc 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -575,7 +575,7 @@ end BatchMixedDuplicated(arg, ntuple(Val(Width)) do j Base.@_inline_meta darg = Base.inferencebarrier(dargs[j][i]) - MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}) + push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty} end) else BatchDuplicated(arg, ntuple(Val(Width)) do j diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index b295403a99..0b20dc77d4 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -743,7 +743,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, else Base.getfield(dptrs[i], symname) end - setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][]), identity, guaranteed_nonactive) + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][], identity, guaranteed_nonactive)) end end end