diff --git a/Project.toml b/Project.toml index d92a4a5975..fac186963a 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Adapt = "3.3" CEnum = "0.4" -Enzyme_jll = "~0.0.23" +Enzyme_jll = "~0.0.25" GPUCompiler = "0.13" LLVM = "4.1" ObjectFile = "0.3" diff --git a/src/api.jl b/src/api.jl index 869ebd033c..c91e8b96a9 100644 --- a/src/api.jl +++ b/src/api.jl @@ -224,6 +224,11 @@ function EnzymeGetCLBool(name) end # void EnzymeSetCLInteger(void *, int64_t); +function zcache!(val) + ptr = cglobal((:EnzymeZeroCache, libEnzyme)) + ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) +end + function printperf!(val) ptr = cglobal((:EnzymePrintPerf, libEnzyme)) ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) @@ -276,6 +281,7 @@ function __init__() ptr = cglobal((:EnzymeJuliaAddrLoad, libEnzyme)) val = true ccall((:EnzymeSetCLBool, libEnzyme), Cvoid, (Ptr{Cvoid}, UInt8), ptr, val) + zcache!(true) end function moveBefore(i1, i2) diff --git a/src/compiler.jl b/src/compiler.jl index 92a6742bfd..a62b655615 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -662,7 +662,9 @@ function genericSetup(orig, gutils, start, ctx::LLVM.Context, B::LLVM.Builder, f pushfirst!(vals, ret) end - to_preserve = LLVM.Value[primal, shadow] + # to_preserve = LLVM.Value[primal, shadow] + to_preserve = LLVM.Value[] + for (i, op) in enumerate(ops) idx = LLVM.Value[LLVM.ConstantInt(0; ctx), LLVM.ConstantInt(i-1; ctx)] @@ -1953,6 +1955,9 @@ function create_abi_wrapper(enzymefn::LLVM.Function, F, argtypes, rettype, actua ret!(builder) end + # make sure that arguments are rooted if necessary + reinsert_gcmarker!(llvm_f) + return llvm_f end @@ -2223,7 +2228,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; llvmfn = primalf FT = eltype(llvmtype(llvmfn)::LLVM.PointerType)::LLVM.FunctionType - wrapper_f = LLVM.Function(mod, LLVM.name(llvmfn)*"wrap", FT) + wrapper_f = LLVM.Function(mod, LLVM.name(llvmfn)*"mustwrap", FT) let builder = Builder(ctx) entry = BasicBlock(wrapper_f, "entry"; ctx) diff --git a/test/runtests.jl b/test/runtests.jl index e7ca9b59c0..4a0426a584 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -534,3 +534,56 @@ end c = 5.0 @test 5.0 ≈ autodiff((A,)->c * A, Active, Active(2.0))[1] end + +@testset "Type-instable capture" begin + L = Array{Float64, 1}(undef, 2) + + F = [1.0, 0.0] + + function main() + t = 0.0 + + function cap(m) + t = m + end + + @noinline function inner(F, cond) + if cond + genericcall(F) + end + end + + function tobedifferentiated(F, cond) + inner(F, cond) + # Force an apply generic + -t + nothing + end + autodiff(tobedifferentiated, Duplicated(F, L), false) + end + + main() +end + +@testset "Arrays are double pointers" begin + @noinline function func_scalar(X) + return X + end + + function timsteploop_scalar(FH1) + G = Float64[FH1] + k1 = @inbounds func_scalar(G[1]) + return k1 + end + @test Enzyme.autodiff(timsteploop_scalar, Active(2.0))[1] ≈ 1.0 + + @noinline function func(X) + return @inbounds X[1] + end + function timsteploop(FH1) + G = Float64[FH1] + k1 = func(G) + return k1 + end + @test Enzyme.autodiff(timsteploop, Active(2.0))[1] ≈ 1.0 +end