diff --git a/Project.toml b/Project.toml index eb689480c5..7bcf3a8438 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.8" +version = "0.12.9" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -29,8 +29,8 @@ EnzymeStaticArraysExt = "StaticArrays" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7" -Enzyme_jll = "0.0.113" +EnzymeCore = "0.7.3" +Enzyme_jll = "0.0.117" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 5249f78945..670e1f3014 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.2" +version = "0.7.3" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 727ee1b178..398c790087 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -219,6 +219,21 @@ function is_inactive_noinl_from_sig(@nospecialize(TT); return isapplicable(inactive_noinl, TT; world, method_table, caller) end +""" + noalias(func::typeof(f), args...) + +Mark a particular function as always being a fresh allocation which does not alias any other +accessible memory. +""" +function noalias end + +function noalias_from_sig(@nospecialize(TT); + world::UInt=Base.get_world_counter(), + method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, + caller::Union{Nothing,Core.MethodInstance}=nothing) + return isapplicable(noalias, TT; world, method_table, caller) +end + """ inactive_type(::Type{Ty}) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c75508cd77..911d1801ad 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -230,7 +230,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end if A <: Active - if !allocatedinline(rt) || rt isa Union + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) res = forward(f, args...) tape = res[1] @@ -244,7 +244,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end - if A <: Active && rt <: Complex + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() seen2 = IdDict() diff --git a/src/absint.jl b/src/absint.jl index 36c1689832..ae9c35a09b 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -113,12 +113,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) end ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) if ptr == C_NULL - # XXX: Is this correct? - bt = GPUCompiler.backtrace(arg) - btstr = sprint() do io - Base.show_backtrace(io, bt) - end - @error "Found null pointer at\n $btstr" arg + # bt = GPUCompiler.backtrace(arg) + # btstr = sprint() do io + # Base.show_backtrace(io, bt) + # end + # @error "Found null pointer at\n $btstr" arg return (false, nothing) end typ = Base.unsafe_pointer_to_objref(ptr) @@ -144,6 +143,7 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ ("jl_box_uint64", UInt64), ("ijl_box_uint64", UInt64), ("jl_box_int32", Int32), ("ijl_box_int32", Int32), ("jl_box_uint32", UInt32), ("ijl_box_uint32", UInt32), + ("jl_box_float32", Float32), ("ijl_box_float32", Float32), ) if nm == fname return (true, ty) @@ -221,7 +221,11 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end if nm == "jl_array_copy" || nm == "ijl_array_copy" - return abs_typeof(operands(arg)[1], partial) + legal, RT = abs_typeof(operands(arg)[1], partial) + if legal + @assert RT <: Array + end + return (legal, RT) end _, RT = enzyme_custom_extract_mi(arg, false) @@ -284,6 +288,9 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ fieldoffset(typ, i+1) end - offset if fsize == llsz(value_type(larg)) + if Base.isconcretetype(subT) && is_concrete_tuple(subT) && length(subT.parameters) == 1 + subT = subT.parameters[1] + end return (true, subT) end end diff --git a/src/api.jl b/src/api.jl index c0a971266a..c0632d2600 100644 --- a/src/api.jl +++ b/src/api.jl @@ -469,6 +469,25 @@ function maxtypeoffset!(val) ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, Int64), ptr, val) end +""" + maxtypedepth!(val::Bool) + +Enzyme runs a type analysis to deduce the corresponding types of all values being +differentiated. This is necessary to compute correct derivatives of various values. +To ensure this analysis temrinates, it operates on a finite lattice of possible +states. This function sets the maximum depth into a type that Enzyme will consider. +A smaller value will cause type analysis to run faster, but may result in some +necessary types not being found and result in unknown type errors. A larger value +may result in unknown type errors being resolved by searching a larger space, but +may run longer. The default setting is 6. +""" +function maxtypedepth!(val) + ptr = cglobal((:EnzymeMaxTypeDepth, libEnzyme)) + ccall((:EnzymeSetCLInteger, libEnzyme), Cvoid, (Ptr{Cvoid}, Int64), ptr, val) +end + + + """ looseTypeAnalysis!(val::Bool) diff --git a/src/compiler.jl b/src/compiler.jl index 769451f29c..8cbac14f94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_array_grow_at", "jl_array_grow_at", "ijl_try_substrtod", "jl_try_substrtod", "jl_f__apply_iterate", "ijl_field_index", "jl_field_index", @@ -371,6 +372,13 @@ end end) end +@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret} + if ST isa Union + return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret)))) + end + return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) +end + @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} if T === Any @@ -436,13 +444,7 @@ end # if sret union, the data is stored in a stack memory location and is therefore # not unique'd preventing the boxing of the union in the default case if UnionSret && is_sret_union(T) - @inline function recur(::Type{ST}) where ST - if ST isa Union - return forcefold(Val(recur(ST.a)), Val(recur(ST.b))) - end - return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret)) - end - return recur(T) + return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret)) else if justActive return AnyState @@ -575,6 +577,10 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT adjoint::PT end +struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width} + adjoint::PT +end + @inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT @inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT @@ -1352,8 +1358,13 @@ function emit_error(B::LLVM.IRBuilder, orig, string) string*=sprint(io->Base.show_backtrace(io, bt)) end + ct = if occursin("ptx", LLVM.triple(mod)) + GPUCompiler.emit_exception!(B, string, orig) + else + call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) + end + # 2. Call error function and insert unreachable - ct = call!(B, funcT, func, LLVM.Value[globalstring_ptr!(B, string)]) LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), EnumAttribute("noreturn")) LLVM.API.LLVMAddCallSiteAttribute(ct, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), StringAttribute("enzyme_error")) return ct @@ -1477,7 +1488,7 @@ end function Base.showerror(io::IO, ece::NoDerivativeException) print(io, "Enzyme compilation failed.\n") - if ece.ir !== nothing + if ece.ir !== nothing && !occursin("No create nofree of empty function", ece.msg) print(io, "Current scope: \n") print(io, ece.ir) end @@ -2107,7 +2118,12 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} return Any, true else e = LLVM.API.LLVMGetElementType(Type) - return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + tkind2 = LLVM.API.LLVMGetTypeKind(e) + if tkind2 == LLVM.API.LLVMFunctionTypeKind + return Core.LLVMPtr{Cvoid, Int(addrspace)}, false + else + return Core.LLVMPtr{to_tape_type(e)[1], Int(addrspace)}, false + end end end if tkind == LLVM.API.LLVMArrayTypeKind @@ -2170,7 +2186,7 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool} if tkind == LLVM.API.LLVMFP128TypeKind return Float128, false end - error("Can't construct tape type for $Type") + error("Can't construct tape type for $Type $(string(Type)) $tkind") end function tape_type(LLVMType::LLVM.LLVMType) @@ -2236,7 +2252,10 @@ end function get_julia_inner_types(B, p, startvals...; added=[]) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - vals = LLVM.Value[p] + vals = LLVM.Value[] + if p != nothing + push!(vals, p) + end todo = LLVM.Value[startvals...] while length(todo) != 0 cur = popfirst!(todo) @@ -3192,70 +3211,15 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr retType = convert(API.CDIFFE_TYPE, rt) rules = Dict{String, API.CustomRuleType}( - "jl_apply_generic" => @cfunction(ptr_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_apply_generic" => @cfunction(ptr_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "julia.gc_alloc_obj" => @cfunction(alloc_obj_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_float32" => @cfunction(f32_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_float32" => @cfunction(f32_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_int64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_int64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_box_uint64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_box_uint64" => @cfunction(i64_box_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "jl_array_copy" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "ijl_array_copy" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_1d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_1d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_2d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_2d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_alloc_array_3d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "ijl_alloc_array_3d" => @cfunction(alloc_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "julia.pointer_from_objref" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_wait" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - "jl_enq_work" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), - - "enz_noop" => @cfunction(noop_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), "jl_inactive_inout" => @cfunction(inout_rule, UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), @@ -4907,7 +4871,23 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if llvmfn == primalf actualRetType = k.ci.rettype end + + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) + push!(return_attributes(llvmfn), EnumAttribute("noalias")) + for u in LLVM.uses(llvmfn) + c = LLVM.user(u) + if !isa(c, LLVM.CallInst) + continue + end + cf = LLVM.called_operand(c) + if cf == llvmfn + LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) + end + end + end + func = mi.specTypes.parameters[1] + meth = mi.def name = meth.name jlmod = meth.module @@ -4935,7 +4915,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end - func = mi.specTypes.parameters[1] sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) @@ -4956,6 +4935,17 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; ]) continue end + if func == typeof(Base.mightalias) + handleCustom(llvmfn, "jl_mightalias", + [EnumAttribute("readonly", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + continue + end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" handleCustom(llvmfn, name, @@ -4974,15 +4964,15 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # fn, but it doesn't presently so for now we will ensure this by hand if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) continue end if EnzymeRules.is_inactive_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")]) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")]) continue end if EnzymeRules.is_inactive_noinl_from_sig(mi.specTypes; world, method_table, caller) - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation")], false, false) + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("nofree"), StringAttribute("enzyme_no_escaping_allocation"), StringAttribute("enzyme_ta_norecur")], false, false) for bb in blocks(llvmfn) for inst in instructions(bb) if isa(inst, LLVM.CallInst) @@ -5008,12 +4998,12 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; continue end if func == typeof(Base.enq_work) && length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_enq_work") + handleCustom(llvmfn, "jl_enq_work", [StringAttribute("enzyme_ta_norecur")]) continue end if func == typeof(Base.wait) || func == typeof(Base._wait) if length(sparam_vals) == 1 && first(sparam_vals) <: Task - handleCustom(llvmfn, "jl_wait") + handleCustom(llvmfn, "jl_wait", [StringAttribute("enzyme_ta_norecur")]) end continue end @@ -5186,11 +5176,50 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; GPUCompiler.optimize_module!(parent_job, mod) end + seen = TypeTreeTable() + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + dl = string(LLVM.datalayout(mod)) + ctx = LLVM.context(mod) for f in functions(mod), bb in blocks(f), inst in instructions(bb) if !isa(inst, LLVM.CallInst) continue end + fn = LLVM.called_operand(inst) + + if !API.HasFromStack(inst) && (!isa(fn, LLVM.Function) || isempty(blocks(fn))) + legal, source_typ = abs_typeof(inst) + codegen_typ = value_type(inst) + if legal + typ = if codegen_typ isa LLVM.PointerType + llvm_source_typ = convert(LLVMType, source_typ; allow_boxed=true) + # pointers are used for multiple kinds of arguments + # - literal pointer values + if source_typ <: Ptr || source_typ <: Core.LLVMPtr + source_typ + elseif llvm_source_typ isa LLVM.PointerType + #if llvm_source_typ != codegen_typ + # throw(AssertionError("llvmtype ($llvm_source_typ) is not codegen_typ ($codegen_typ), source_typ = $source_typ within $(string(inst))")) + #end + # push!(args, (cc=MUT_REF, typ=source_typ, name=source_name, idx=codegen_i)) + Ptr{source_typ} + # - references to aggregates + else + @assert llvm_source_typ != codegen_typ + # push!(args, (cc=BITS_REF, typ=source_typ, name=source_name, idx=codegen_i)) + Ptr{source_typ} + end + else + codegen_typ + end + + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", string(typetree(typ, ctx, dl, seen)))) + elseif codegen_typ == T_prjlvalue + LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_type", "{[-1]:Pointer}")) + end + end + if !isa(fn, LLVM.Function) continue end @@ -5252,7 +5281,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; cf = LLVM.called_operand(tmp) if isa(cf, LLVM.Function) nm = LLVM.name(cf) - if nm == "gpu_signal_exception" || nm == "gpu_report_exception" + if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw" shouldemit = false break end @@ -5408,6 +5437,9 @@ struct CompileResult{AT, PT} TapeType::Type end +@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = +enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) + @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @@ -5511,7 +5543,9 @@ end end @inline function default_adjoint(T) - if T <: AbstractFloat + if T == Union{} + return nothing + elseif T <: AbstractFloat return one(T) elseif T <: Complex error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") @@ -5534,7 +5568,7 @@ end JuliaContext() do ctx F = eltype(FA) - is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk + is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk needs_tape = CC <: AdjointThunk @@ -5544,23 +5578,33 @@ end argtypes = DataType[argtt.parameters...] argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N] - if !RawCall + if false && CC <: PrimalErrorThunk + primargs = [quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes)] + return quote + fn.val($(primargs...)) + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end + end + + if !RawCall && !(CC <: PrimalErrorThunk) if rettype <: Active if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end elseif rettype <: Const if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end else if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end end @@ -5568,8 +5612,10 @@ end types = DataType[] - if eltype(rettype) === Union{} - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + if eltype(rettype) === Union{} && false + return quote + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end end if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) rrt = eltype(rettype) @@ -5640,7 +5686,9 @@ end end continue end - + if CC <: PrimalErrorThunk + continue + end if T <: Active if is_adjoint if width == 1 @@ -5727,8 +5775,10 @@ end end push!(sret_types, NT) end - - @assert i == length(argexprs)+1 + + if !(CC <: PrimalErrorThunk) + @assert i == length(argexprs)+1 + end # Tape if CC <: AugmentedForwardThunk @@ -5760,7 +5810,7 @@ end T_void = convert(LLVMType, Nothing) - combinedReturn = Tuple{sret_types...} + combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) combinedReturn = AnonymousStruct(combinedReturn) end @@ -5978,29 +6028,30 @@ end params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} - interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early # rrt = Core.Compiler.return_type(f, primal.tt) # nothing rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + + run_enzyme = true if rrt == Union{} - estr = "Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up" - return quote - error($estr) - end + run_enzyme = false + A = Const end - if !(A <: Const) && guaranteed_const_nongen(rrt, World) + if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World) estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return quote error($estr) end end - rt2 = if A isa UnionAll + rt2 = if !run_enzyme + Const{rrt} + elseif A isa UnionAll A{rrt} else @assert A isa DataType @@ -6009,7 +6060,7 @@ end A end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) # We need to use primal as the key, to lookup the right method @@ -6020,7 +6071,13 @@ end compile_result = cached_compilation(job) - if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + if !run_enzyme + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} + return quote + Base.@_inline_meta + $ErrT($(compile_result.adjoint)) + end + elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} @@ -6061,7 +6118,6 @@ import GPUCompiler: deferred_codegen_jobs params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} interp = GPUCompiler.get_interpreter(tmp_job) rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 03e7a4458e..f8fa3a4cd2 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -161,6 +161,78 @@ function check_ir(job, mod::LLVM.Module) end end +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction, Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + changed = false + newbundles = OperandBundleDef[] + B = IRBuilder() + position!(B, inst) + for bunduse in operand_bundles(inst) + bunduse = LLVM.OperandBundleDef(bunduse) + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + push!(newbundles, OperandBundleDef(LLVM.tag_name(bunduse), uservals)) + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!(B, called_type(inst), called_operand(inst), collect(arguments(inst)), newbundles, prevname) + for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); + Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j in 1:count + LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function check_ir!(job, errors, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -174,6 +246,7 @@ function check_ir!(job, errors, mod::LLVM.Module) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) unsafe_delete!(mod, f) end + rewrite_ccalls!(mod) for f in collect(functions(mod)) check_ir!(job, errors, imported, f) end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index f22070b18a..5accfb24ca 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -269,6 +269,10 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, return RT, needsPrimal, needsShadowP[] != 0, origNeedsPrimal end +function custom_rule_method_error(world, fn, args...) + throw(MethodError(fn, (args...,), world)) +end + function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) return true @@ -331,20 +335,24 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) @safe_debug "Applying custom forward rule (kwcall)" TT llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) fwd_RT = Core.Compiler.return_type(kwfunc, TT, world) + else + TT = Tuple{typeof(world), typeof(kwfunc), TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} end else if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) @safe_debug "Applying custom forward rule" TT llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) fwd_RT = Core.Compiler.return_type(EnzymeRules.forward, TT, world) + else + TT = Tuple{typeof(world), typeof(EnzymeRules.forward), TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} end end - - if llvmf === nothing - @safe_debug "No custom forward rule is applicable for" TT - emit_error(B, orig, "Enzyme: No custom rule was applicable for " * string(TT)) - return false - end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -366,7 +374,6 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) sret = nothing end - if length(args) != length(parameters(llvmf)) GPUCompiler.@safe_error "Calling convention mismatch", args, llvmf, orig, isKWCall, kwtup, TT, sret, returnRoots return false @@ -536,6 +543,11 @@ end ami = GPUCompiler.methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) @safe_debug "Applying custom augmented_primal rule (kwcall)" TT=augprimal_TT catch e + augprimal_TT = Tuple{typeof(world), typeof(kwfunc), augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) + end end else @assert kwtup === nothing @@ -547,6 +559,11 @@ end ami = GPUCompiler.methodinstance(Core.Typeof(EnzymeRules.augmented_primal), augprimal_TT, world) @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT catch e + augprimal_TT = Tuple{typeof(world), typeof(EnzymeRules.augmented_primal), augprimal_TT.parameters...} + ami = GPUCompiler.methodinstance(typeof(custom_rule_method_error), augprimal_TT, world) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) + end end end return ami, augprimal_TT @@ -654,20 +671,24 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, @safe_debug "Applying custom reverse rule (kwcall)" TT=rev_TT llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + else + rev_TT = Tuple{typeof(world), typeof(rkwfunc), rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) @safe_debug "Applying custom reverse rule" TT=rev_TT llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + else + rev_TT = Tuple{typeof(world), typeof(EnzymeRules.reverse), rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} end end - - if llvmf == nothing - @safe_debug "No custom reverse rule is applicable for" rev_TT - emit_error(B, orig, "Enzyme: No custom reverse rule was applicable for " * string(rev_TT)) - return C_NULL - end end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 2f818df4a0..76b72466c1 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -306,11 +306,270 @@ end return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) end +@inline concat() = () +@inline concat(a) = a +@inline concat(a, b) = (a..., b...) +@inline concat(a, b, c...) = concat(concat(a, b), c...) + +@inline iterate_unwrap_inner_fwd(x::Const) = (map(Const, x.val)...,) +@inline iterate_unwrap_inner_fwd(x::Duplicated) = (map(Duplicated, x.val, x.dval)...,) +@inline batch_dup_tuple(x, vals...) = BatchDuplicated(x, (vals...,)) +@inline iterate_unwrap_inner_fwd(x::BatchDuplicated) = (map(batch_dup_tuple, x.val, x.dval...)...,) + +@inline function iterate_unwrap_fwd(args...) + ntuple(Val(length(args))) do i + Base.@_inline_meta + iterate_unwrap_inner_fwd(args[i]) + end +end + +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function fwddiff_with_return(::Val{width}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {width, Nargs} + tt′ = Enzyme.vaTypeof(args...) + ReturnPrimal = Val(true) + RT = A + ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) + + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + world = codegen_world_age(Core.Typeof(f.val), tt) + + thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) +end + +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + nnothing = ntuple(i->nothing, Val(Width+1)) + nres = ntuple(i->:(res[1]), Val(Width+1)) + ModifiedBetween = ntuple(i->false, Val(N+1)) + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + return quote + args0 = ($(wrapped...),) + args = concat(iterate_unwrap_fwd(args0...)...) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ForwardMode) + + annotation = @static if $Width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end + + res = fwddiff_with_return(Val($Width), dupClosure ? Duplicated(f, df) : Const(f), annotation, args...) + return if annotation <: Const + ReturnType(($(nres...),)) + else + if $Width == 1 + ReturnType((res[1], res[2])) + else + ReturnType((res[1], res[2]...)) + end + end + end +end + +function func_runtime_iterate_fwd(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + + quote + function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, ReturnType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) +end + +function body_runtime_iterate_augfwd(N, Width, wrapped, primttypes) + nnothing = ntuple(i->nothing, Val(Width+1)) + nres = ntuple(i->:(origRet), Val(Width+1)) + nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) + nres3 = ntuple(i->:(res[3]), Val(Width)) + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + + return quote + args = ($(wrapped...),) + throw(AssertionError("Runtime iterate augmented forward pass unhandled, f=$f df=$df args=$args")) + + # TODO: Annotation of return value + # tt0 = Tuple{$(primtypes...)} + tt′ = Tuple{$(Types...)} + rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) + annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + + world = codegen_world_age(FT, Tuple{$(ElTypes...)}) + + forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, + annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) + resT = typeof(origRet) + if annotation <: Const + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + return ReturnType(($(nres...), tape)) + elseif annotation <: Active + if $Width == 1 + shadow_return = Ref(make_zero(origRet)) + else + shadow_return = ($(nzeros...),) + end + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if $Width == 1 + return ReturnType((origRet, shadow_return, tape)) + else + return ReturnType((origRet, shadow_return..., tape)) + end + end + + @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if $Width == 1 + return ReturnType((origRet, initShadow, tape)) + else + return ReturnType((origRet, initShadow..., tape)) + end + end +end + +function func_runtime_iterate_augfwd(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) + body = body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + + quote + function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) +end + +function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) + outs = [] + for i in 1:N + for w in 1:Width + expr = if Width == 1 + :(tup[$i]) + else + :(tup[$i][$w]) + end + shad = shadowargs[i][w] + out = :(if tup[$i] === nothing + elseif $shad isa Base.RefValue + $shad[] = recursive_add($shad[], $expr) + else + error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + end + ) + push!(outs, out) + end + end + shadow_ret = nothing + if Width == 1 + shadowret = :(tape.shadow_return[]) + else + shadowret = [] + for w in 1:Width + push!(shadowret, :(tape.shadow_return[$w][])) + end + shadowret = :(($(shadowret...),)) + end + + ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) + Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + + quote + args = ($(wrapped...),) + throw(AssertionError("Runtime iterate reverse pass unhandled, f=$f df=$df args=$args")) + + # TODO: Annotation of return value + # tt0 = Tuple{$(primtypes...)} + tt = Tuple{$(ElTypes...)} + tt′ = Tuple{$(Types...)} + rt = Core.Compiler.return_type(f, tt) + annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + + dupClosure = ActivityTup[1] + FT = Core.Typeof(f) + if dupClosure && guaranteed_const(FT) + dupClosure = false + end + world = codegen_world_age(FT, tt) + + forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing + args = (args..., $shadowret) + end + + tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + + $(outs...) + return nothing + end +end + +function func_runtime_iterate_rev(N, Width) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) + body = body_runtime_iterate_rev(N, Width, wrapped, primtypes, batchshadowargs) + + quote + function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} + $body + end + end +end + +@generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} + N = div(length(allargs)+2, Width+1)-1 + _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) + return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) +end + # Create specializations for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_generic_fwd(N, Width)) eval(func_runtime_generic_augfwd(N, Width)) eval(func_runtime_generic_rev(N, Width)) + eval(func_runtime_iterate_fwd(N, Width)) + eval(func_runtime_iterate_augfwd(N, Width)) + eval(func_runtime_iterate_rev(N, Width)) end function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false) @@ -455,7 +714,11 @@ function common_generic_fwd(offset, B, orig, gutils, normalR, shadowR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -504,7 +767,11 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -552,12 +819,17 @@ function generic_augfwd(B, orig, gutils, normalR, shadowR, tapeR) end function common_generic_rev(offset, B, orig, gutils, tape)::Cvoid - if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) - @assert tape !== C_NULL - width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing end + + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset, B, true; tape) return nothing end @@ -572,7 +844,11 @@ function generic_rev(B, orig, gutils, tape)::Cvoid end function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -613,7 +889,11 @@ function common_apply_latest_fwd(offset, B, orig, gutils, normalR, shadowR) end function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -656,6 +936,13 @@ function common_apply_latest_augfwd(offset, B, orig, gutils, normalR, shadowR, t end function common_apply_latest_rev(offset, B, orig, gutils, tape)::Cvoid + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing + end if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) width = get_width(gutils) generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) @@ -690,10 +977,14 @@ function apply_latest_rev(B, orig, gutils, tape) end function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end - + v, isiter = absint(operands(orig)[offset+1]) v2, istup = absint(operands(orig)[offset+2]) @@ -744,24 +1035,42 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) unsafe_store!(shadowR, shadowres.ref) return false end - emit_error(B, orig, "Enzyme: Not yet implemented, forward for jl_f__apply_iterate") - if unsafe_load(shadowR) != C_NULL - cal = new_from_original(gutils, orig) - width = get_width(gutils) - if width == 1 - shadow = cal - else - ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - shadow = LLVM.UndefValue(ST) - for i in 1:width - shadow = insert_value!(B, shadow, cal, i-1) - if i == 1 - API.moveBefore(cal, shadow, B) + + if v && isiter == Base.iterate + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + sret = generic_setup(orig, runtime_iterate_fwd, AnyArray(1+Int(width)), gutils, #=start=#offset+2, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 1+Int(width)) + if unsafe_load(shadowR) != C_NULL + if width == 1 + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + shadow = LLVM.load!(B, T_prjlvalue, gep) + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) end end + unsafe_store!(shadowR, shadow.ref) end - unsafe_store!(shadowR, shadow.ref) + + if unsafe_load(normalR) != C_NULL + normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) + end + return false end + + emit_error(B, orig, "Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate "*string((v, v2, isiter, istup, length(operands(orig)), offset+4))) + return false end @@ -778,7 +1087,11 @@ function error_if_active_iter(arg) end function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -859,7 +1172,11 @@ function apply_iterate_rev(B, orig, gutils, tape) end function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end @@ -899,7 +1216,11 @@ function common_invoke_fwd(offset, B, orig, gutils, normalR, shadowR) end function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) - if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) return true end normal = (unsafe_load(normalR) != C_NULL) ? LLVM.Instruction(unsafe_load(normalR)) : nothing @@ -946,10 +1267,16 @@ function common_invoke_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) end function common_invoke_rev(offset, B, orig, gutils, tape) - if !is_constant_value(gutils, orig) || !is_constant_inst(gutils, orig) - width = get_width(gutils) - generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing end + + width = get_width(gutils) + generic_setup(orig, runtime_generic_rev, Nothing, gutils, #=start=#offset+1, B, true; tape) return nothing end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 4730db8654..569ef87323 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -1,28 +1,4 @@ -function noop_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - return UInt8(false) -end - -function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - if API.HasFromStack(inst) - return UInt8(false) - end - legal, typ = abs_typeof(inst) - if !legal - return UInt8(false) - throw(AssertionError("Cannot deduce type of alloc obj, $(string(inst)) of $(string(LLVM.parent(LLVM.parent(inst))))")) - end - - ctx = LLVM.context(LLVM.Value(val)) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` - only!(rest, -1) - API.EnzymeMergeTypeTree(ret, rest) - return UInt8(false) -end - function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 TT = TypeTree(API.DT_Integer, LLVM.context(LLVM.Value(val))) only!(TT, -1) @@ -30,41 +6,6 @@ function int_return_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C return UInt8(false) end -function i64_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - val = LLVM.Instruction(val) - TT = TypeTree(API.DT_Pointer, LLVM.context(val)) - if (direction & API.DOWN) != 0 - sub = TypeTree(unsafe_load(args)) - ctx = LLVM.context(val) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(val))))) - maxSize = div(width(value_type(operands(val)[1]))+7, 8) - shift!(sub, dl, 0, maxSize, 0) - API.EnzymeMergeTypeTree(TT, sub) - end - only!(TT, -1) - API.EnzymeMergeTypeTree(ret, TT) - return UInt8(false) -end - - -function f32_box_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - TT = TypeTree(API.DT_Float, LLVM.context(LLVM.Value(val))) - only!(TT, -1) - API.EnzymeMergeTypeTree(unsafe_load(args), TT) - - API.EnzymeMergeTypeTree(TT, TypeTree(API.DT_Pointer,LLVM.context(LLVM.Value(val)))) - only!(TT, -1) - API.EnzymeMergeTypeTree(ret, TT) - return UInt8(false) -end - -function ptr_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - TT = TypeTree(API.DT_Pointer, LLVM.context(LLVM.Value(val))) - only!(TT, -1) - API.EnzymeSetTypeTree(ret, TT) - return UInt8(false) -end - function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 if numArgs != 1 return UInt8(false) @@ -97,22 +38,3 @@ function inout_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT end return UInt8(false) end - -function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - @assert legal - - ctx = LLVM.context(LLVM.Value(val)) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` - only!(rest, -1) - API.EnzymeMergeTypeTree(ret, rest) - - for i = 1:numArgs - API.EnzymeMergeTypeTree(unsafe_load(args, i), TypeTree(API.DT_Integer, -1, ctx)) - end - return UInt8(false) -end diff --git a/test/kwrrules.jl b/test/kwrrules.jl index 72708d993b..a62ba94608 100644 --- a/test/kwrrules.jl +++ b/test/kwrrules.jl @@ -61,7 +61,7 @@ end # Test that this errors due to missing kwargs in rule definition g2(x, y) = f_kw2(x; val=y) -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, g2, Active(2.0), Const(42.0))[1][1] +@test_throws MethodError autodiff(Reverse, g2, Active(2.0), Const(42.0))[1][1] function f_kw3(x; val=nothing) diff --git a/test/kwrules.jl b/test/kwrules.jl index 13f916c65d..91d3dc859d 100644 --- a/test/kwrules.jl +++ b/test/kwrules.jl @@ -31,7 +31,7 @@ end # Test that this errors due to missing kwargs in rule definition g2(x, y) = f_kw2(x; val=y) -@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Forward, g2, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0 +@test_throws MethodError autodiff(Forward, g2, Duplicated(2.0, 1.0), Const(42.0))[1] ≈ 14.0 function f_kw3(x; val=nothing) x^2 diff --git a/test/rules.jl b/test/rules.jl index 4c2db62bf1..b6644d8c55 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -87,11 +87,11 @@ function forward(func::Const{typeof(g)}, ::Type{<:Const}, x::Const) end @testset "Registry" begin - @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, g, Duplicated(1.0, 1.0)) + @test_throws MethodError Enzyme.autodiff(Forward, g, Duplicated(1.0, 1.0)) rh(cond, x) = cond ? g(x) : x @test Enzyme.autodiff(Forward, rh, Const(false), Duplicated(1.0, 1.0)) == (1.0,) - @test_throws Enzyme.Compiler.EnzymeRuntimeException Enzyme.autodiff(Forward, rh, Const(true), Duplicated(1.0, 1.0)) + @test_throws MethodError Enzyme.autodiff(Forward, rh, Const(true), Duplicated(1.0, 1.0)) end function alloc_sq(x) diff --git a/test/runtests.jl b/test/runtests.jl index 47d9debbff..3db032fefd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1647,6 +1647,233 @@ end end + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +@testset "Forward Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(13.7, 15.2), (100.02, 304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + + a = [("a", "b"), ("c", "d")] + da = [("e", "f"), ("g", "h")] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + @test length(res) == 4 + @test res[1] == "a" + @test res[2] == "b" + @test res[3] == "c" + @test res[4] == "d" + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + + Enzyme.autodiff(Forward, metaconcat, Const(a)) + +@static if VERSION ≥ v"1.7-" + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" +end + + y = [(-92.0, -93.0), (-97.9, -911.2)] + dy = [(-913.7, -915.2), (-9100.02, -9304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + @test length(res) == 8 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test res[5] ≈ -92.0 + @test res[6] ≈ -93.0 + @test res[7] ≈ -97.9 + @test res[8] ≈ -911.2 + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(res) == 12 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + + @test res[5] == "a" + @test res[6] == "b" + @test res[7] == "c" + @test res[8] == "d" + + @test res[9] ≈ -92.0 + @test res[10] ≈ -93.0 + @test res[11] ≈ -97.9 + @test res[12] ≈ -911.2 + + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 +end + @testset "Dynamic Val Construction" begin dyn_f(::Val{D}) where D = prod(D) @@ -2062,6 +2289,63 @@ end end end +function bc0_test_function(ps) + z = view(ps, 26:30) + C = Matrix{Float64}(undef, 5, 1) + C .= z + return C[1] +end + +@noinline function bc1_bcs2(x, y) + x != y && error(2) + return x +end + +@noinline function bc1_affine_normalize(x::AbstractArray) + _axes = bc1_bcs2(axes(x), axes(x)) + dest = similar(Array{Float32}, _axes) + bc = convert(Broadcast.Broadcasted{Nothing}, Broadcast.instantiate(Base.broadcasted(+, x, x))) + copyto!(dest, bc) + return x +end + +function bc1_loss_function(x) + return bc1_affine_normalize(x)[1] +end + +function bc2_affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar, + scale::AbstractArray, bias::AbstractArray, epsilon::Real) + _scale = @. scale / sqrt(xvar + epsilon) + _bias = @. bias - xmean * _scale + return @. x * _scale + _bias +end + +function bc2_loss_function(x, scale, bias) + x_ = reshape(x, 6, 6, 3, 2, 2) + scale_ = reshape(scale, 1, 1, 3, 2, 1) + bias_ = reshape(bias, 1, 1, 3, 2, 1) + + xmean = mean(x_, dims=(1, 2, 5)) + xvar = var(x_, corrected=false, mean=xmean, dims=(1, 2, 5)) + + return sum(abs2, bc2_affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5)) +end + +@testset "Broadcast noalias" begin + + x = ones(30) + autodiff(Reverse, bc0_test_function, Active, Const(x)) + + x = rand(Float32, 2, 3) + Enzyme.autodiff(Reverse, bc1_loss_function, Duplicated(x, zero(x))) + + x = rand(Float32, 6, 6, 6, 2) + sc = rand(Float32, 6) + bi = rand(Float32, 6) + Enzyme.autodiff(Reverse, bc2_loss_function, Active, Duplicated(x, Enzyme.make_zero(x)), + Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi))) +end + @testset "GetField" begin mutable struct MyType x::Float64 @@ -2375,6 +2659,15 @@ end @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1] end + +function assured_err(x) + throw(AssertionError("foo")) +end + +@testset "UnionAll" begin + @test_throws AssertionError Enzyme.autodiff(Reverse, assured_err, Active, Active(2.0)) +end + struct MyFlux end @@ -3051,6 +3344,79 @@ end end end +const CUmemoryPool2 = Ptr{Float64} + +struct CUmemPoolProps2 + reserved::NTuple{31,Char} +end + +mutable struct CuMemoryPool2 + handle::CUmemoryPool2 +end + +function ccall_macro_lower(func, rettype, types, args, nreq) + # instead of re-using ccall or Expr(:foreigncall) to perform argument conversion, + # we need to do so ourselves in order to insert a jl_gc_safe_enter|leave + # just around the inner ccall + + cconvert_exprs = [] + cconvert_args = [] + for (typ, arg) in zip(types, args) + var = gensym("$(func)_cconvert") + push!(cconvert_args, var) + push!(cconvert_exprs, quote + $var = Base.cconvert($(esc(typ)), $(esc(arg))) + end) + end + + unsafe_convert_exprs = [] + unsafe_convert_args = [] + for (typ, arg) in zip(types, cconvert_args) + var = gensym("$(func)_unsafe_convert") + push!(unsafe_convert_args, var) + push!(unsafe_convert_exprs, quote + $var = Base.unsafe_convert($(esc(typ)), $arg) + end) + end + + quote + $(cconvert_exprs...) + + $(unsafe_convert_exprs...) + + ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)), + $(unsafe_convert_args...)) + end +end + +macro gcsafe_ccall(expr) + ccall_macro_lower(Base.ccall_macro_parse(expr)...) +end + +function cuMemPoolCreate2(pool, poolProps) + # CUDA.initialize_context() + #CUDA. + gc_state = @ccall(jl_gc_safe_enter()::Int8) + @gcsafe_ccall cuMemPoolCreate(pool::Ptr{CUmemoryPool2}, + poolProps::Ptr{CUmemPoolProps2})::Cvoid + @ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid) +end + +function cual() + props = Ref(CUmemPoolProps2( + ntuple(i->Char(0), 31) + )) + handle_ref = Ref{CUmemoryPool2}() + cuMemPoolCreate2(handle_ref, props) + + CuMemoryPool2(handle_ref[]) +end + +@testset "Unused shadow phi rev" begin + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(cual)}, Duplicated) +end + + const SEED = 42 const N_SAMPLES = 500 const N_COMPONENTS = 4