From 37600d41c8cb45392c249d35f18f395dbd4d6ed8 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 29 May 2024 14:04:25 +0200 Subject: [PATCH 01/22] Reverse mode apply iterate --- deps/build_local.jl | 50 +++- src/Enzyme.jl | 7 + src/rules/jitrules.jl | 549 +++++++++++++++++++++++++++--------------- 3 files changed, 404 insertions(+), 202 deletions(-) diff --git a/deps/build_local.jl b/deps/build_local.jl index 5c67ac0477..21be66745b 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -6,7 +6,8 @@ Enzyme_jll = Base.UUID("7cc45869-7501-5eee-bdea-0790c847d4ef") using Pkg, Scratch, Preferences, Libdl -BUILD_TYPE = "RelWithDebInfo" +BUILD_TYPE = "RelWithDebInfo" +BCLoad = true # 1. Get a scratch directory scratch_dir = get_scratch!(Enzyme_jll, "build") @@ -14,12 +15,28 @@ isdir(scratch_dir) && rm(scratch_dir; recursive=true) source_dir = nothing branch = nothing -if length(ARGS) == 2 - @assert ARGS[1] == "--branch" - branch = ARGS[2] - source_dir = nothing -elseif length(ARGS) == 1 - source_dir = ARGS[1] + +args = (ARGS...,) +while length(args) > 0 + if length(args) >= 2 && args[1] == "--branch" + branch = args[2] + source_dir = nothing + args = (args[3:end]...,) + continue + end + if length(args) >= 1 && args[1] == "--debug" + BUILD_TYPE = "Debug" + args = (args[2:end]...,) + continue + end + if length(args) >= 1 && args[1] == "--nobcload" + BCLoad = false + args = (args[2:end]...,) + continue + end + @assert length(args) == 1 + source_dir = args[1] + break end if branch === nothing @@ -62,7 +79,12 @@ LLVM_VER_MAJOR = Base.libllvm_version.major # Build! @info "Building" source_dir scratch_dir LLVM_DIR run(`cmake -DLLVM_DIR=$(LLVM_DIR) -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DENZYME_EXTERNAL_SHARED_LIB=ON -B$(scratch_dir) -S$(source_dir)`) -run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR) EnzymeBCLoad-$(LLVM_VER_MAJOR)`) + +if BCLoad + run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR) EnzymeBCLoad-$(LLVM_VER_MAJOR)`) +else + run(`cmake --build $(scratch_dir) --parallel $(Sys.CPU_THREADS) -t Enzyme-$(LLVM_VER_MAJOR)`) +end # Discover built libraries built_libs = filter(readdir(joinpath(scratch_dir, "Enzyme"))) do file @@ -72,18 +94,26 @@ end lib_path = joinpath(scratch_dir, "Enzyme", only(built_libs)) isfile(lib_path) || error("Could not find library $lib_path in build directory") +# Tell Enzyme_jll to load our library instead of the default artifact one +set_preferences!( + joinpath(dirname(@__DIR__), "LocalPreferences.toml"), + "Enzyme_jll", + "libEnzyme_path" => lib_path, + force=true, +) + +if BCLoad built_libs = filter(readdir(joinpath(scratch_dir, "BCLoad"))) do file endswith(file, ".$(Libdl.dlext)") && startswith(file, "lib") end libBC_path = joinpath(scratch_dir, "BCLoad", only(built_libs)) isfile(libBC_path) || error("Could not find library $libBC_path in build directory") - # Tell Enzyme_jll to load our library instead of the default artifact one set_preferences!( joinpath(dirname(@__DIR__), "LocalPreferences.toml"), "Enzyme_jll", - "libEnzyme_path" => lib_path, "libEnzymeBCLoad_path" => libBC_path; force=true, ) +end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 911d1801ad..9bccd5b959 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -74,6 +74,13 @@ end end)...} end +@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple} + return Tuple{(ntuple(Val(length(Ty.parameters))) do i + Base.@_inline_meta + eltype(Ty.parameters[i]) + end)...} +end + @inline function same_or_one_helper(current, next) if current == -1 return next diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 76b72466c1..8e59fca5ec 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,5 +1,5 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -59,8 +59,36 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) @assert length(primargs) == N @assert length(primtypes) == N wrapped = Expr[] + modbetween = Expr[:(MB[1])] for i in 1:N - expr = :( + if iterate + push!(modbetween, quote + ntuple(Val(length($(primargs[i])))) do _ + Base.@_inline_meta + MB[$i] + end... + end) + end + expr = if iterate + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + if !$forwardMode && active_reg($(primtypes[i])) + iterate_unwrap_augfwd_act($(primargs[i])...) + else + $((Width == 1) ? quote + iterate_unwrap_augfwd_dup($(primargs[i]), $(shadowargs[i])) + end : quote + iterate_unwrap_augfwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + end + ) + end + else + map(Const, $(primargs[i])...) + end + ) + else + :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @assert $(primtypes[i]) !== DataType if !$forwardMode && active_reg($(primtypes[i])) @@ -73,9 +101,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) end ) + end push!(wrapped, expr) end - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs + return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -131,7 +160,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function func_runtime_generic_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote @@ -143,7 +172,7 @@ end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end @@ -209,7 +238,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) quote @@ -221,7 +250,7 @@ end @generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) end @@ -290,7 +319,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) quote @@ -302,7 +331,7 @@ end @generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) end @@ -323,69 +352,118 @@ end end end +@inline function iterate_unwrap_augfwd_act(args...) + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + if guaranteed_const(Core.Typeof(arg)) + Const(arg) + else + Active(arg) + end + end +end + +@inline function iterate_unwrap_augfwd_dup(::Val{forwardMode}, args, dargs) where forwardMode + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchdup(::Val{forwardMode}, ::Val{Width}, args, dargs) where {forwardMode, Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) + end + end +end + +@inline function allFirst(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + res[1] + end +end + +@inline function allZero(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + make_zero(res) + end +end + # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(::Val{width}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {width, Nargs} - tt′ = Enzyme.vaTypeof(args...) +function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) - RT = A ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) - thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) -end + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Const{FT} : Duplicated{FT} -function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(res[1]), Val(Width+1)) - ModifiedBetween = ntuple(i->false, Val(N+1)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - return quote - args0 = ($(wrapped...),) - args = concat(iterate_unwrap_fwd(args0...)...) + tt = Enzyme.vaEltypes(tt′) - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ForwardMode) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ForwardMode) - annotation = @static if $Width != 1 - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - Const{rt} - end + annotation = @static if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} else - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - Duplicated{rt} - else - Const{rt} - end + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} end + end + + world = codegen_world_age(FRT, tt) - res = fwddiff_with_return(Val($Width), dupClosure ? Duplicated(f, df) : Const(f), annotation, args...) - return if annotation <: Const - ReturnType(($(nres...),)) + res = thunk(Val(world), FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) + return if annotation <: Const + ReturnType(allFirst(Val(width+1), res)) + else + if width == 1 + ReturnType((res[1], res[2])) else - if $Width == 1 - ReturnType((res[1], res[2])) - else - ReturnType((res[1], res[2]...)) - end + ReturnType((res[1], res[2]...)) end end end +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + end +end + function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -397,75 +475,89 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end -function body_runtime_iterate_augfwd(N, Width, wrapped, primttypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) - nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - return quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate augmented forward pass unhandled, f=$f df=$df args=$args")) - - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + RT = A + ModifiedBetween = Val(ModifiedBetween0) - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Const{FT} : Duplicated{FT} - world = codegen_world_age(FT, Tuple{$(ElTypes...)}) + tt = Enzyme.vaEltypes(tt′) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, - annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt) - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) - resT = typeof(origRet) - if annotation <: Const - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType(($(nres...), tape)) - elseif annotation <: Active - if $Width == 1 - shadow_return = Ref(make_zero(origRet)) - else - shadow_return = ($(nzeros...),) - end - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, shadow_return, tape)) - else - return ReturnType((origRet, shadow_return..., tape)) - end + annotation = @static if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + else + Const{rt} end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + world = codegen_world_age(FRT, tt) + forward, adjoint = thunk(Val(world), FA, + RT, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + resT = typeof(origRet) + if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, initShadow, tape)) + return ReturnType((allFirst(Val(width+1), origRet)..., tape)) + elseif annotation <: Active + if width == 1 + shadow_return = Ref(make_zero(origRet)) else - return ReturnType((origRet, initShadow..., tape)) + shadow_return = allZero(Val(width), origRet) + end + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, shadow_return, tape)) + else + return ReturnType((origRet, shadow_return..., tape)) end end + + @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, initShadow, tape)) + else + return ReturnType((origRet, initShadow..., tape)) + end + +end + +function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + end end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) + body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) quote function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -476,11 +568,109 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, _, _, wrapped, _ , modbetween, = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) end -function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) + + +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + RT = A + ModifiedBetween = Val(ModifiedBetween0) + + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Const{FT} : Duplicated{FT} + + tt = Enzyme.vaEltypes(tt′) + + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt) + + annotation = @static if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end + + world = codegen_world_age(FRT, tt) + + forward, adjoint = thunk(Val(world), FA, + RT, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + args2 = if tape.shadow_return !== nothing + if width == 1 + (args..., tape.shadow_return[]) + else + (args..., ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][] + end) + end + else + args + end + + tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args2..., tape.internal_tape)[1] + + ntuple(Val(Nargs)) do i + Base.@_inline_meta + + ntuple(Val(width)) do w + Base.@_inline_meta + + if tup[i] == nothing + else + expr = @static if width == 1 + tup[i] + else + tup[i][w] + end + idx_of_vec, idx_in_vec = lengths[i] + vec = @inbounds shadowargs[idx_of_vec] + if vec isa Base.RefValue + vecld = vec[] + T = Core.Typeof(vecld) + vec[] = splatnew(T, ntuple(Val(fieldcount(T))) do i + Base.@_inline_meta + prev = getfield(vecld, i) + if i == idx_in_vec + recursive_add(prev, expr) + else + prev + end + end) + else + val = @inbounds vec[idx_in_vec] + if val isa Base.RefValue + val[] = recursive_add(val[], expr) + elseif ismutable(vec) + vec[idx_in_vec] = recursive_add(val[], expr) + else + error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + end + end + end + + nothing + end + + nothing + end + nothing +end + +function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs) outs = [] for i in 1:N for w in 1:Width @@ -494,7 +684,7 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) elseif $shad isa Base.RefValue $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + error("Enzyme Mutability Error: Cannot add in place to immutable value "*string($shad)) end ) push!(outs, out) @@ -514,40 +704,26 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate reverse pass unhandled, f=$f df=$df args=$args")) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + lengths = ntuple(i->quote + (ntuple(Val(length($(primargs[i])))) do j + Base.@_inline_meta + ($i, j) + end) + end, Val(N)) - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt = Tuple{$(ElTypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, tt) - - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - args = (args..., $shadowret) - end - - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] - - $(outs...) + quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowargs...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width) + body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) quote function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -558,8 +734,8 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end # Create specializations @@ -1074,18 +1250,6 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function error_if_active_iter(arg) - # check if it could contain an active - for v in arg - seen = () - T = Core.Typeof(v) - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg == ActiveState - throw(AssertionError("Found unhandled active variable in tuple splat, jl_apply_iterate $T")) - end - end -end - function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) @@ -1100,51 +1264,41 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 - origops = collect(operands(orig)[1:end-1]) - shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] - shadowres = if width == 1 - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = shadowins[i-offset-3+1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end - end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - cal - else - ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end + if v && isiter == Base.iterate + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + sret = generic_setup(orig, runtime_iterate_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+2, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + + if unsafe_load(shadowR) != C_NULL + if width == 1 + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + shadow = LLVM.load!(B, T_prjlvalue, gep) + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) end - shadow + unsafe_store!(shadowR, shadow.ref) end - unsafe_store!(shadowR, shadowres.ref) + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + + if normalR != C_NULL + normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) + end + return false return false end @@ -1155,6 +1309,17 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, end function common_apply_iterate_rev(offset, B, orig, gutils, tape) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing + end + + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+3, B, true; tape) return nothing end From 4b2bdeea6866d47ac56947cf5c204039fec76c2d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 29 May 2024 15:44:30 +0200 Subject: [PATCH 02/22] fixed --- src/rules/jitrules.jl | 70 ++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 8e59fca5ec..6ad2e1502b 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -77,14 +77,14 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate_unwrap_augfwd_act($(primargs[i])...) else $((Width == 1) ? quote - iterate_unwrap_augfwd_dup($(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_dup(Val($forwardMode), $(primargs[i]), $(shadowargs[i])) end : quote - iterate_unwrap_augfwd_batchdup(Val($Width), $(primargs[i]), $(shadowargs[i])) + iterate_unwrap_augfwd_batchdup(Val($forwardMode), Val($Width), $(primargs[i]), $(shadowargs[i])) end ) end else - map(Const, $(primargs[i])...) + map(Const, $(primargs[i])) end ) else @@ -139,7 +139,6 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end world = codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -417,14 +416,14 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) dupClosure = dupClosure0 && !guaranteed_const(FT) - FA = dupClosure ? Const{FT} : Duplicated{FT} + FA = dupClosure ? Duplicated{FT} : Const{FT} tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ForwardMode) - annotation = @static if width != 1 + annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} else @@ -438,10 +437,18 @@ function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType end end - world = codegen_world_age(FRT, tt) - + world = codegen_world_age(FT, tt) + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end + else + Const(f) + end res = thunk(Val(world), FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(fa, args...) return if annotation <: Const ReturnType(allFirst(Val(width+1), res)) else @@ -458,12 +465,13 @@ function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) return quote args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType end end function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -475,7 +483,7 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end @@ -487,14 +495,14 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} ModifiedBetween = Val(ModifiedBetween0) dupClosure = dupClosure0 && !guaranteed_const(FT) - FA = dupClosure ? Const{FT} : Duplicated{FT} + FA = dupClosure ? Duplicated{FT} : Const{FT} tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt) - annotation = @static if width != 1 + annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} else @@ -508,13 +516,22 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} end end - world = codegen_world_age(FRT, tt) + world = codegen_world_age(FT, tt) + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end + else + Const(f) + end forward, adjoint = thunk(Val(world), FA, RT, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + internal_tape, origRet, initShadow = forward(fa, args...) resT = typeof(origRet) if annotation <: Const shadow_return = nothing @@ -551,6 +568,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) return quote args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType end end @@ -581,14 +599,14 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween ModifiedBetween = Val(ModifiedBetween0) dupClosure = dupClosure0 && !guaranteed_const(FT) - FA = dupClosure ? Const{FT} : Duplicated{FT} + FA = dupClosure ? Duplicated{FT} : Const{FT} tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt) - annotation = @static if width != 1 + annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} else @@ -602,13 +620,22 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween end end - world = codegen_world_age(FRT, tt) + world = codegen_world_age(FT, tt) + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end + else + Const(f) + end forward, adjoint = thunk(Val(world), FA, RT, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - args2 = if tape.shadow_return !== nothing + args = if tape.shadow_return !== nothing if width == 1 (args..., tape.shadow_return[]) else @@ -621,7 +648,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween args end - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args2..., tape.internal_tape)[1] + tup = adjoint(fa, args2..., tape.internal_tape)[1] ntuple(Val(Nargs)) do i Base.@_inline_meta @@ -631,7 +658,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween if tup[i] == nothing else - expr = @static if width == 1 + expr = if width == 1 tup[i] else tup[i][w] @@ -716,6 +743,7 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado quote args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowargs...),), args...) return nothing end From fb719f16aef68a0e11a6007ce599099395e2f1cb Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 29 May 2024 19:25:16 +0200 Subject: [PATCH 03/22] fixup --- src/compiler/validation.jl | 9 +- src/rules/jitrules.jl | 168 ++++++++---- src/rules/typeunstablerules.jl | 56 +++- test/applyiter.jl | 487 +++++++++++++++++++++++++++++++++ test/runtests.jl | 262 +----------------- 5 files changed, 660 insertions(+), 322 deletions(-) create mode 100644 test/applyiter.jl diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index f8fa3a4cd2..b23d6be040 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -701,7 +701,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - seen = Dict{LLVM.Value,Tuple}() + seen = Set{Tuple{LLVM.Value,Tuple}}() while length(todo) != 0 cur, off = pop!(todo) @@ -709,11 +709,10 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width cur = operands(cur)[1] end - if cur in keys(seen) - @assert seen[cur] == off + if cur in seen continue end - seen[cur] = off + push!(seen, (cur, off)) if isa(cur, LLVM.PHIInst) for (v, _) in LLVM.incoming(cur) @@ -739,7 +738,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width # if inserting at the current desired offset, we have found the value we need if ind == off[1] - push!(todo, (operands(cur)[2], -1)) + push!(todo, (operands(cur)[2], off[2:end])) # otherwise it must be inserted at a different point else push!(todo, (operands(cur)[1], off)) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 6ad2e1502b..ee70ff903b 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -66,7 +66,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, ntuple(Val(length($(primargs[i])))) do _ Base.@_inline_meta MB[$i] - end... + end end) end expr = if iterate @@ -178,7 +178,7 @@ end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) + nzeros = ntuple(i->:(Ref(make_zero(resT))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -190,7 +190,13 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) # tt0 = Tuple{$(primtypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -206,6 +212,10 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) resT = typeof(origRet) + @show "aug generic", f, args + @show "generic", origRet, resT, annotation + @show width + @show initShadow if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) @@ -216,6 +226,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) else shadow_return = ($(nzeros...),) end + @show "generic", shadow_return tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if $Width == 1 return ReturnType((origRet, shadow_return, tape)) @@ -284,6 +295,12 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) shadowret = :(($(shadowret...),)) end + shadowsplat = Expr[] + for s in shadowargs + for v in s + push!(shadowsplat, :(@show $v)) + end + end ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -295,7 +312,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) tt = Tuple{$(ElTypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -306,11 +329,16 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing args = (args..., $shadowret) end + @show "gen rev", f, args + @show tape.shadow_return + $(shadowsplat...) tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + @show tup $(outs...) return nothing @@ -406,7 +434,7 @@ end @inline function allZero(::Val{Width}, res) where Width ntuple(Val(Width)) do i Base.@_inline_meta - make_zero(res) + Ref(make_zero(res)) end end @@ -487,51 +515,62 @@ end return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end +function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + args[i].val + end +end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) - RT = A ModifiedBetween = Val(ModifiedBetween0) - dupClosure = dupClosure0 && !guaranteed_const(FT) - FA = dupClosure ? Duplicated{FT} : Const{FT} - tt = Enzyme.vaEltypes(tt′) - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} else Const{rt} end else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: Active + Active{rt} else Const{rt} end end - world = codegen_world_age(FT, tt) + internal_tape, origRet, initShadow = if f != Base.tuple + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} - fa = if dupClosure - if width == 1 - Duplicated(f, df) + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end else - BatchDuplicated(f, df) + Const(f) end + world = codegen_world_age(FT, tt) + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward(fa, args...) else - Const(f) + nothing, primal_tuple(args...), nothing end - forward, adjoint = thunk(Val(world), FA, - RT, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - internal_tape, origRet, initShadow = forward(fa, args...) resT = typeof(origRet) if annotation <: Const shadow_return = nothing @@ -560,7 +599,6 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} else return ReturnType((origRet, initShadow..., tape)) end - end function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) @@ -569,7 +607,7 @@ function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType end end @@ -595,7 +633,6 @@ end # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) - RT = A ModifiedBetween = Val(ModifiedBetween0) dupClosure = dupClosure0 && !guaranteed_const(FT) @@ -604,51 +641,75 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween tt = Enzyme.vaEltypes(tt′) rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) annotation = if width != 1 if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} else Const{rt} end else if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated Duplicated{rt} + elseif annotation0 <: Active + Active{rt} else Const{rt} end end - world = codegen_world_age(FT, tt) + tup = if f != Base.tuple + world = codegen_world_age(FT, tt) - fa = if dupClosure - if width == 1 - Duplicated(f, df) + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end else - BatchDuplicated(f, df) + Const(f) end - else - Const(f) - end - forward, adjoint = thunk(Val(world), FA, - RT, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - args = if tape.shadow_return !== nothing - if width == 1 - (args..., tape.shadow_return[]) + args2 = if tape.shadow_return !== nothing + if width == 1 + (args..., tape.shadow_return[]) + else + (args..., ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][] + end) + end else - (args..., ntuple(Val(width)) do w - Base.@_inline_meta - tape.shadow_return[w][] - end) + args end + + adjoint(fa, args2..., tape.internal_tape)[1] else - args + ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + if width == 1 + tape.shadow_return[][i] + else + ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][i] + end + end + else + nothing + end + end end - tup = adjoint(fa, args2..., tape.internal_tape)[1] + @show "idx rev", tup ntuple(Val(Nargs)) do i Base.@_inline_meta @@ -664,7 +725,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween tup[i][w] end idx_of_vec, idx_in_vec = lengths[i] - vec = @inbounds shadowargs[idx_of_vec] + vec = @inbounds shadowargs[idx_of_vec][w] if vec isa Base.RefValue vecld = vec[] T = Core.Typeof(vecld) @@ -682,7 +743,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween if val isa Base.RefValue val[] = recursive_add(val[], expr) elseif ismutable(vec) - vec[idx_in_vec] = recursive_add(val[], expr) + @inbounds vec[idx_in_vec] = recursive_add(val, expr) else error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") end @@ -739,18 +800,21 @@ function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shado end) end, Val(N)) - + shadowsplat = Expr[] + for s in shadowargs + push!(shadowsplat, :(($(s...),))) + end quote args = ($(wrappedexexpand...),) tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowargs...),), args...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowsplat...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) quote @@ -762,7 +826,7 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end @@ -1347,7 +1411,7 @@ function common_apply_iterate_rev(offset, B, orig, gutils, tape) @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+3, B, true; tape) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+2, B, true; tape) return nothing end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index f70e15c82c..05f0ab422f 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -288,7 +288,7 @@ function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptr if length(dptrs) == 0 return res else - return (res, (getfield(dv, symname) for dv in dptrs)...) + return (res, (getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) for dv in dptrs)...) end end end @@ -323,11 +323,59 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + @show dptr, cur, dret + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if i == symname+1 + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname+1, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname+1) + else + Base.getfield(dptrs[i], symname+1) + end + setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][])) + end end end end diff --git a/test/applyiter.jl b/test/applyiter.jl new file mode 100644 index 0000000000..78f5af7b89 --- /dev/null +++ b/test/applyiter.jl @@ -0,0 +1,487 @@ +using Enzyme, Test + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +function metasumsq(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v::Float64 + res += v*v + end + return res +end + +function metasumsq2(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2::Float64 + res += v*v + end + end + return res +end + + +function metasumsq3(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v + res += v*v + end + return res +end + +function metasumsq4(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2 + res += v*v + end + end + return res +end + +function make_byref(out, fn, args...) + out[] = fn(args...) + nothing +end + +function tupapprox(a, b) + if a isa Tuple && b isa Tuple + if length(a) != length(b) + return false + end + for (aa, bb) in zip(a, b) + if !tupapprox(aa, bb) + return false + end + end + return true + end + if a isa Array && b isa Array + if size(a) != size(b) + return false + end + for i in length(a) + if !tupapprox(a[i], b[i]) + return false + end + end + return true + end + return a ≈ b +end + +@testset "Reverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) +end + + +@testset "BatchReverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(ReverseWithPrimal, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + @test dx2 ≈ [[4.0, 6.0], [15.8, 22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + d2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(ReverseWithPrimal, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq3), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + dy2 = [(0, 0), (0, 0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq3), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + dy2 = [[0, 0], [0, 0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[4.0, 6.0], [15.8, 22.4]]) +end + +@testset "Forward Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(13.7, 15.2), (100.02, 304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + + a = [("a", "b"), ("c", "d")] + da = [("e", "f"), ("g", "h")] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + @test length(res) == 4 + @test res[1] == "a" + @test res[2] == "b" + @test res[3] == "c" + @test res[4] == "d" + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + + Enzyme.autodiff(Forward, metaconcat, Const(a)) + +@static if VERSION ≥ v"1.7-" + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" +end + + y = [(-92.0, -93.0), (-97.9, -911.2)] + dy = [(-913.7, -915.2), (-9100.02, -9304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + @test length(res) == 8 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test res[5] ≈ -92.0 + @test res[6] ≈ -93.0 + @test res[7] ≈ -97.9 + @test res[8] ≈ -911.2 + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(res) == 12 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + + @test res[5] == "a" + @test res[6] == "b" + @test res[7] == "c" + @test res[8] == "d" + + @test res[9] ≈ -92.0 + @test res[10] ≈ -93.0 + @test res[11] ≈ -97.9 + @test res[12] ≈ -911.2 + + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 +end + +@testset "legacy reverse apply iterate" begin + function mktup(v) + tup = tuple(v...) + return tup[1][1] * tup[3][1] + end + + data = [[3.0], nothing, [2.0]] + ddata = [[0.0], nothing, [0.0]] + + Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 2.0 + @test ddata[3][1] ≈ 3.0 + + function mktup2(v) + tup = tuple(v...) + return (tup[1][1] * tup[3])::Float64 + end + + data = [[3.0], nothing, 2.0] + ddata = [[0.0], nothing, 0.0] + + @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + + function mktup3(v) + tup = tuple(v..., v...) + return tup[1][1] * tup[1][1] + end + + data = [[3.0]] + ddata = [[0.0]] + + Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 6.0 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 94809a7f3e..640b94e4f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1647,232 +1647,7 @@ end end - -concat() = () -concat(a) = a -concat(a, b) = (a..., b...) -concat(a, b, c...) = concat(concat(a, b), c...) - -metaconcat(x) = concat(x...) - -metaconcat2(x, y) = concat(x..., y...) - -midconcat(x, y) = (x, concat(y...)...) - -metaconcat3(x, y, z) = concat(x..., y..., z...) - -@testset "Forward Apply iterate" begin - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(13.7, 15.2), (100.02, 304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - - a = [("a", "b"), ("c", "d")] - da = [("e", "f"), ("g", "h")] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) - @test length(res) == 4 - @test res[1] == "a" - @test res[2] == "b" - @test res[3] == "c" - @test res[4] == "d" - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - - Enzyme.autodiff(Forward, metaconcat, Const(a)) - -@static if VERSION ≥ v"1.7-" - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" -end - - y = [(-92.0, -93.0), (-97.9, -911.2)] - dy = [(-913.7, -915.2), (-9100.02, -9304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) - @test length(res) == 8 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test res[5] ≈ -92.0 - @test res[6] ≈ -93.0 - @test res[7] ≈ -97.9 - @test res[8] ≈ -911.2 - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(res) == 12 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - - @test res[5] == "a" - @test res[6] == "b" - @test res[7] == "c" - @test res[8] == "d" - - @test res[9] ≈ -92.0 - @test res[10] ≈ -93.0 - @test res[11] ≈ -97.9 - @test res[12] ≈ -911.2 - - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 -end +include("applyiter.jl") @testset "Dynamic Val Construction" begin @@ -2544,41 +2319,6 @@ end Enzyme.API.runtimeActivity!(false) end -@testset "apply iterate" begin - function mktup(v) - tup = tuple(v...) - return tup[1][1] * tup[3][1] - end - - data = [[3.0], nothing, [2.0]] - ddata = [[0.0], nothing, [0.0]] - - Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 2.0 - @test ddata[3][1] ≈ 3.0 - - function mktup2(v) - tup = tuple(v...) - return (tup[1][1] * tup[3])::Float64 - end - - data = [[3.0], nothing, 2.0] - ddata = [[0.0], nothing, 0.0] - - @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) - - function mktup3(v) - tup = tuple(v..., v...) - return tup[1][1] * tup[1][1] - end - - data = [[3.0]] - ddata = [[0.0]] - - Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 6.0 -end - @testset "BLAS" begin x = [2.0, 3.0] dx = [0.2,0.3] From 0e81ec560c4a3765f44f7e2286d11142dcbfb0d4 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 30 May 2024 00:45:00 +0200 Subject: [PATCH 04/22] cleanup --- src/rules/jitrules.jl | 21 ++------------------- test/applyiter.jl | 4 +++- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index ee70ff903b..72bd612ce8 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -192,7 +192,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - annotation = if $Width != 1 && annotation <: Duplicated + annotation = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} else annotation0 @@ -212,10 +212,6 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) resT = typeof(origRet) - @show "aug generic", f, args - @show "generic", origRet, resT, annotation - @show width - @show initShadow if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) @@ -226,7 +222,6 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) else shadow_return = ($(nzeros...),) end - @show "generic", shadow_return tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) if $Width == 1 return ReturnType((origRet, shadow_return, tape)) @@ -295,12 +290,6 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) shadowret = :(($(shadowret...),)) end - shadowsplat = Expr[] - for s in shadowargs - for v in s - push!(shadowsplat, :(@show $v)) - end - end ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -314,7 +303,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - annotation = if $Width != 1 && annotation <: Duplicated + annotation = if $Width != 1 && annotation0 <: Duplicated BatchDuplicated{rt, $Width} else annotation0 @@ -333,12 +322,8 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) if tape.shadow_return !== nothing args = (args..., $shadowret) end - @show "gen rev", f, args - @show tape.shadow_return - $(shadowsplat...) tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] - @show tup $(outs...) return nothing @@ -709,8 +694,6 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween end end - @show "idx rev", tup - ntuple(Val(Nargs)) do i Base.@_inline_meta diff --git a/test/applyiter.jl b/test/applyiter.jl index 78f5af7b89..d43a0dd5b8 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -146,7 +146,8 @@ end @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) end - +# GC Segfault .. needs investigation. Disabling now +@static if false @testset "BatchReverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] dx = [(0.0, 0.0), (0.0, 0.0)] @@ -237,6 +238,7 @@ end @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) @test tupapprox(dx2, [[4.0, 6.0], [15.8, 22.4]]) end +end @testset "Forward Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] From ec3412b957e882c399ee9cbecd9e9c807a13b796 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 3 Jun 2024 21:24:51 -0400 Subject: [PATCH 05/22] debugging fixes --- deps/build_local.jl | 3 +++ src/compiler.jl | 1 + src/rules/typeunstablerules.jl | 30 +++++++++++++++++++++++------- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/deps/build_local.jl b/deps/build_local.jl index 21be66745b..f7a1a5b9f8 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -18,6 +18,9 @@ branch = nothing args = (ARGS...,) while length(args) > 0 + global args + global branch + global source_dir if length(args) >= 2 && args[1] == "--branch" branch = args[2] source_dir = nothing diff --git a/src/compiler.jl b/src/compiler.jl index 8cbac14f94..95d49da86d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5985,6 +5985,7 @@ function _thunk(job, postopt::Bool=true) if postopt if job.config.params.ABI <: FFIABI post_optimze!(mod, JIT.get_tm()) + println(string(mod)) else propagate_returned!(mod) end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 05f0ab422f..cad8ae3ecf 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -271,24 +271,36 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs end end -function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {NT, T, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else Base.getfield(dptr, symname+1) end RT = Core.Typeof(res) - if active_reg(RT) + actreg = active_reg(RT) + @show actreg, RT + if actreg if length(dptrs) == 0 - return Ref{RT}(make_zero(res)) + return Ref{RT}(make_zero(res))::Any else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + fval0 = NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) + @show fval0 + return fval0 end else if length(dptrs) == 0 - return res + return res::Any else - return (res, (getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) + end)...)) + return fval end end end @@ -587,7 +599,7 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) inps = [new_from_original(gutils, ops[1])] end - vals = LLVM.Value[] + vals = LLVM.Value[unsafe_to_llvm(Val(AnyArray(Int(width))))] push!(vals, inps[1]) sym = new_from_original(gutils, ops[2]) @@ -607,6 +619,9 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) debug_from_orig!(gutils, cal, orig) + emit_jl!(B, unsafe_to_llvm("Result of call to idx_jl_getfield_aug")) + emit_jl!(B, cal) + if width == 1 shadowres = cal else @@ -623,6 +638,7 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if !is_constant_value(gutils, ops[1]) gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) ld = LLVM.load!(B, T_prjlvalue, gep) + emit_jl!(B, ld) else ld = forgep end From 3e621a18964783f480fab15b75f4654a4091536e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 5 Jun 2024 04:34:07 -0400 Subject: [PATCH 06/22] fixup --- src/rules/jitrules.jl | 10 +++++----- src/rules/typeunstablerules.jl | 13 +++---------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 72bd612ce8..e5e729b198 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -178,7 +178,7 @@ end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(make_zero(resT))), Val(Width)) + nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -247,13 +247,13 @@ function func_runtime_generic_augfwd(N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) @@ -948,7 +948,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end debug_from_orig!(gutils, cal, orig) - + if tape === nothing llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) @@ -1029,7 +1029,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index cad8ae3ecf..d3dd8af65c 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -279,17 +279,14 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end RT = Core.Typeof(res) actreg = active_reg(RT) - @show actreg, RT if actreg if length(dptrs) == 0 return Ref{RT}(make_zero(res))::Any else - fval0 = NT(ntuple(Val(1+length(dptrs))) do i + return NT(ntuple(Val(1+length(dptrs))) do i Base.@_inline_meta Ref{RT}(make_zero(res)) end) - @show fval0 - return fval0 end else if length(dptrs) == 0 @@ -335,7 +332,6 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - @show dptr, cur, dret if dptr isa Base.RefValue vload = dptr[] dRT = Core.Typeof(vload) @@ -599,7 +595,8 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) inps = [new_from_original(gutils, ops[1])] end - vals = LLVM.Value[unsafe_to_llvm(Val(AnyArray(Int(width))))] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[2]) @@ -619,9 +616,6 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) debug_from_orig!(gutils, cal, orig) - emit_jl!(B, unsafe_to_llvm("Result of call to idx_jl_getfield_aug")) - emit_jl!(B, cal) - if width == 1 shadowres = cal else @@ -638,7 +632,6 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) if !is_constant_value(gutils, ops[1]) gep = LLVM.inbounds_gep!(B, AT, forgep, [LLVM.ConstantInt(0), LLVM.ConstantInt(i-1)]) ld = LLVM.load!(B, T_prjlvalue, gep) - emit_jl!(B, ld) else ld = forgep end From 5a14ec8c5d515a36f99536359887a09a732802d0 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 5 Jun 2024 04:37:53 -0400 Subject: [PATCH 07/22] cleanup --- src/compiler.jl | 1 - test/applyiter.jl | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 95d49da86d..8cbac14f94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5985,7 +5985,6 @@ function _thunk(job, postopt::Bool=true) if postopt if job.config.params.ABI <: FFIABI post_optimze!(mod, JIT.get_tm()) - println(string(mod)) else propagate_returned!(mod) end diff --git a/test/applyiter.jl b/test/applyiter.jl index d43a0dd5b8..56b9454b33 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -146,8 +146,6 @@ end @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) end -# GC Segfault .. needs investigation. Disabling now -@static if false @testset "BatchReverse Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] dx = [(0.0, 0.0), (0.0, 0.0)] @@ -238,7 +236,6 @@ end @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) @test tupapprox(dx2, [[4.0, 6.0], [15.8, 22.4]]) end -end @testset "Forward Apply iterate" begin x = [(2.0, 3.0), (7.9, 11.2)] @@ -486,4 +483,4 @@ end Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) @test ddata[1][1] ≈ 6.0 -end \ No newline at end of file +end From e91997f1c02294a916acf9bb8e8e10e7d46a001d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 5 Jun 2024 07:28:28 -0400 Subject: [PATCH 08/22] fix tests --- src/rules/jitrules.jl | 2 +- test/applyiter.jl | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index e5e729b198..7e25f36a3f 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -685,7 +685,7 @@ function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween else ntuple(Val(width)) do w Base.@_inline_meta - tape.shadow_return[w][i] + tape.shadow_return[w][][i] end end else diff --git a/test/applyiter.jl b/test/applyiter.jl index 56b9454b33..37df9a2880 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -153,7 +153,8 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @show dx, dx2, dout, dout2 @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) @@ -161,8 +162,8 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(ReverseWithPrimal, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test res[2] ≈ 200.84999999999997 + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 200.84999999999997 @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) @@ -173,7 +174,7 @@ end dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] @test dx2 ≈ [[4.0, 6.0], [15.8, 22.4]] @@ -182,9 +183,9 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(ReverseWithPrimal, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Active, Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @test res[2] ≈ 200.84999999999997 + @test out[] ≈ 200.84999999999997 @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) @@ -196,7 +197,7 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq3), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) @@ -209,7 +210,7 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq3), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) @@ -220,7 +221,7 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) @@ -232,7 +233,7 @@ end out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) - res = Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Active, Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) @test tupapprox(dx2, [[4.0, 6.0], [15.8, 22.4]]) end From e5bea84d7a5b8e1d162525cec33df6cae26c4093 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 16:00:24 -0400 Subject: [PATCH 09/22] fix batch getfield rev --- src/rules/typeunstablerules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index d3dd8af65c..7ea2ca0d9d 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -365,9 +365,9 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} end for i in 1:length(dptrs) if dptrs[i] isa Base.RefValue - vload = dptr[] + vload = dptrs[i][] dRT = Core.Typeof(vload) - dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j Base.@_inline_meta prev = getfield(vload, j) if j == symname+1 From a09f6cb1ef5b13dda0550ef3a47c02f21613b2a7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 16:23:06 -0400 Subject: [PATCH 10/22] fix tests --- test/applyiter.jl | 17 ++++++++++------- test/runtests.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/test/applyiter.jl b/test/applyiter.jl index 37df9a2880..1227b3cae7 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -154,18 +154,18 @@ end dout = Ref(1.0) dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) - @show dx, dx2, dout, dout2 @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) @test out[] ≈ 200.84999999999997 @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) x = [[2.0, 3.0], [7.9, 11.2]] dx = [[0.0, 0.0], [0.0, 0.0]] @@ -176,10 +176,10 @@ end Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] - @test dx2 ≈ [[4.0, 6.0], [15.8, 22.4]] + @test dx2 ≈ [[3*4.0, 3*6.0], [3*15.8, 3*22.4]] dx = [[0.0, 0.0], [0.0, 0.0]] - d2 = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) @@ -187,6 +187,7 @@ end @test out[] ≈ 200.84999999999997 @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) x = [(2.0, 3.0), (7.9, 11.2)] @@ -199,6 +200,7 @@ end dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) x = [(2.0, 3.0), (7.9, 11.2)] @@ -212,18 +214,19 @@ end dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) - @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) x = [[2.0, 3.0], [7.9, 11.2]] dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] y = [[13, 17], [25, 31]] out = Ref(0.0) dout = Ref(1.0) dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - @test tupapprox(dx2, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) x = [[2.0, 3.0], [7.9, 11.2]] dx = [[0.0, 0.0], [0.0, 0.0]] @@ -235,7 +238,7 @@ end dout2 = Ref(3.0) Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) - @test tupapprox(dx2, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) end @testset "Forward Apply iterate" begin diff --git a/test/runtests.jl b/test/runtests.jl index 640b94e4f9..4783bb254c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1647,6 +1647,34 @@ end end + +function batchgf(out, args) + res = 0.0 + x = Base.inferencebarrier((args[1][1],)) + for v in x + v = v::Float64 + res += v + break + end + out[] = res + nothing +end + +@testset "Batch Getfield" begin + x = [(2.0, 3.0)] + dx = [(0.0, 0.0)] + dx2 = [(0.0, 0.0)] + dx3 = [(0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + dout3 = Ref(5.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) + @test dx[1] ≈ (4.0, 6.0) + @test dx2[1] ≈ (3*4.0, 3*6.0) + @test dx3[1] ≈ (5*4.0, 5*6.0) +end + include("applyiter.jl") @testset "Dynamic Val Construction" begin From d03da279f524515e6b9307ed5d0a0ed5bc1c31b1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 16:25:55 -0400 Subject: [PATCH 11/22] more test fix --- test/applyiter.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/applyiter.jl b/test/applyiter.jl index 1227b3cae7..2518e2d829 100644 --- a/test/applyiter.jl +++ b/test/applyiter.jl @@ -230,6 +230,7 @@ end x = [[2.0, 3.0], [7.9, 11.2]] dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] y = [[13, 17], [25, 31]] dy = [[0, 0], [0, 0]] dy2 = [[0, 0], [0, 0]] From 19186c8b38dbbd21717d2db9ca57172343176871 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 16:42:11 -0400 Subject: [PATCH 12/22] fix tuple fast path --- src/rules/jitrules.jl | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 7e25f36a3f..af12d2bfbc 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -507,6 +507,32 @@ function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs end end +function shadow_tuple(::Val{1}, args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval + end + end +end + +function shadow_tuple(::Val{width}, args::Vararg{Annotation, Nargs}) where {width, Nargs} + ntuple(Val(width)) do w + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval[w] + end + end + end +end + # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) @@ -553,7 +579,7 @@ function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType} ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) forward(fa, args...) else - nothing, primal_tuple(args...), nothing + nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(Val(width), args...) end resT = typeof(origRet) From 6995a9d4a4753273e92a755d0ca8d824e5c5288a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 19:43:40 -0400 Subject: [PATCH 13/22] fix --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index c1cde3aac3..6932a2170b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1669,7 +1669,7 @@ end dout = Ref(1.0) dout2 = Ref(3.0) dout3 = Ref(5.0) - Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) + Enzyme.autodiff(Reverse, batchgf, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) @test dx[1] ≈ (4.0, 6.0) @test dx2[1] ≈ (3*4.0, 3*6.0) @test dx3[1] ≈ (5*4.0, 5*6.0) From 93bf415e47d1d5a272d8d5ab70e4811b84151b32 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 7 Jun 2024 10:03:34 +0200 Subject: [PATCH 14/22] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0b19c3cead..243b3e5888 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.10" +version = "0.12.11" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.3" -Enzyme_jll = "0.0.119" +Enzyme_jll = "0.0.120" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" From 8bee8b1ae0e2191faf77e0d55798bcf43ce1458a Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 7 Jun 2024 10:30:00 +0200 Subject: [PATCH 15/22] fix sym index rev --- src/rules/typeunstablerules.jl | 62 ++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 52fd4aad3d..1c7e4d5156 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -249,7 +249,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname) else @@ -271,7 +271,7 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs end end -function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {NT, T, symname, isconst} +function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else @@ -302,7 +302,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc end end -function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) else @@ -312,17 +312,65 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname, recursive_add(cur, dret[])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if fieldname(dRT, i) == symname + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptrs[i][] + dRT = Core.Typeof(vload) + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname) + else + Base.getfield(dptrs[i], symname) + end + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][])) + end end end end return nothing end -function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} + +function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else From 5f0ca77d75119db3a4267ecbdb151aa8027119cb Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 7 Jun 2024 10:33:00 +0200 Subject: [PATCH 16/22] fix test --- test/runtests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6932a2170b..fa3f8ea546 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1670,9 +1670,12 @@ end dout2 = Ref(3.0) dout3 = Ref(5.0) Enzyme.autodiff(Reverse, batchgf, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) - @test dx[1] ≈ (4.0, 6.0) - @test dx2[1] ≈ (3*4.0, 3*6.0) - @test dx3[1] ≈ (5*4.0, 5*6.0) + @test dx[1][1] ≈ 1.0 + @test dx[1][2] ≈ 0.0 + @test dx2[1][1] ≈ 3.0 + @test dx2[1][2] ≈ 0.0 + @test dx3[1][1] ≈ 5.0 + @test dx2[1][2] ≈ 0.0 end include("applyiter.jl") From 8278e9b34955e6764471a4c95baae93e65b072a1 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 7 Jun 2024 12:15:42 +0200 Subject: [PATCH 17/22] fixup --- src/rules/typeunstablerules.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 1c7e4d5156..b9ef1e4fa4 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -249,7 +249,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} +function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname) else @@ -260,13 +260,21 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + return NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) end else if length(dptrs) == 0 return res else - return (res, (getfield(dv, symname) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname) + end)...)) + return fval end end end @@ -466,7 +474,8 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta inps = [new_from_original(gutils, ops[2])] end - vals = LLVM.Value[] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[3]) From 016d7faa93f33bc77184ec5cec9209bf56486472 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Jun 2024 17:31:11 +0200 Subject: [PATCH 18/22] Fix unionall --- src/compiler.jl | 23 +++++++++++++++++------ src/utils.jl | 2 +- test/runtests.jl | 1 + 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 30bf6f0d9c..d14067a8a2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -379,8 +379,15 @@ end return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) end -@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} +@inline function unionall_body(::Type{T}) where T + if T isa UnionAll + unionall_body(T.body) + else + T + end +end +@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} if T === Any return DupState end @@ -422,7 +429,9 @@ end else inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world) args = Any[EnzymeCore.EnzymeRules.inactive_type, T]; - ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + GC.@preserve T begin + ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + end end if inactivety @@ -480,11 +489,13 @@ end @static if VERSION < v"1.7.0" nT = T else - nT = if is_concrete_tuple(T) - Tuple{(ntuple(length(T.parameters)) do i + nT = if T <: Tuple && T != Tuple + Tuple{(ntuple(length(unionall_body(T).parameters)) do i Base.@_inline_meta - sT = T.parameters[i] - if sT isa Core.TypeofVararg + sT = unionall_body(T.parameters[i]) + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg Any else sT diff --git a/src/utils.jl b/src/utils.jl index a3268c6c94..916818181e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ @inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) export unsafe_to_pointer -@inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) +@inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) export is_concrete_tuple const Tracked = 10 diff --git a/test/runtests.jl b/test/runtests.jl index fa3f8ea546..ca1feaa8de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -137,6 +137,7 @@ end @assert Enzyme.Compiler.active_reg_inner(Symbol, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(String, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Tuple{Any,Int64}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState world = codegen_world_age(typeof(f0), Tuple{Float64}) From 7689e499bbdc3f3b45687db9750cb574db40b51d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 8 Jun 2024 10:40:19 -0400 Subject: [PATCH 19/22] fix --- src/compiler.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d14067a8a2..9f72c10399 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -379,14 +379,6 @@ end return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) end -@inline function unionall_body(::Type{T}) where T - if T isa UnionAll - unionall_body(T.body) - else - T - end -end - @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} if T === Any return DupState @@ -489,10 +481,10 @@ end @static if VERSION < v"1.7.0" nT = T else - nT = if T <: Tuple && T != Tuple - Tuple{(ntuple(length(unionall_body(T).parameters)) do i + nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) + Tuple{(ntuple(length(T.parameters)) do i Base.@_inline_meta - sT = unionall_body(T.parameters[i]) + sT = T.parameters[i] if sT isa TypeVar Any elseif sT isa Core.TypeofVararg From 9499ee4deb0908ec18eec4538bf41f3dd359efd6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 8 Jun 2024 11:35:45 -0400 Subject: [PATCH 20/22] fix sym offset --- src/rules/typeunstablerules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index b9ef1e4fa4..1ee4f0d961 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -333,7 +333,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + setfield!(dptr, symname, recursive_add(cur, dret[])) end else if dptr isa Base.RefValue @@ -349,7 +349,7 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, end end) else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + setfield!(dptr, symname, recursive_add(cur, dret[1][])) end for i in 1:length(dptrs) if dptrs[i] isa Base.RefValue From a2a09ec345e844b118ae854b3dcc84e0aef08655 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 8 Jun 2024 12:41:20 -0400 Subject: [PATCH 21/22] ix constantarray --- src/compiler/validation.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 3b36cb2ccf..caf86cbc03 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -879,10 +879,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end + if isa(cur, LLVM.ConstantArray) + push!(todo, (cur[off[1]], off[2:end])) + continue + end + msg = sprint() do io::IO println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") println(io, string(enzymefn)) - println(io, "cur=", cur) + println(io, "cur=", string(cur)) println(io, "off=", off) end throw(AssertionError(msg)) From 18c6fa18c064e6a8d395192fe645daadc92a2dd8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 8 Jun 2024 16:11:13 -0400 Subject: [PATCH 22/22] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 90de295a9e..848c47e7ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.11" +version = "0.12.12" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.4" -Enzyme_jll = "0.0.120" +Enzyme_jll = "0.0.121" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4"